Commit ca8692c7 authored by InCerry's avatar InCerry
Browse files

Merge remote-tracking branch 'upstream/main'

# Conflicts:
#	backend/internal/service/openai_gateway_messages.go
parents b6d46fd5 318aa5e0
//go:build unit
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// --- mock: Internal500CounterCache ---
type mockInternal500Cache struct {
incrementCount int64
incrementErr error
resetErr error
incrementCalls []int64 // 记录 IncrementInternal500Count 被调用时的 accountID
resetCalls []int64 // 记录 ResetInternal500Count 被调用时的 accountID
}
func (m *mockInternal500Cache) IncrementInternal500Count(_ context.Context, accountID int64) (int64, error) {
m.incrementCalls = append(m.incrementCalls, accountID)
return m.incrementCount, m.incrementErr
}
func (m *mockInternal500Cache) ResetInternal500Count(_ context.Context, accountID int64) error {
m.resetCalls = append(m.resetCalls, accountID)
return m.resetErr
}
// --- mock: 专用于 internal500 惩罚测试的 AccountRepository ---
type internal500AccountRepoStub struct {
AccountRepository // 嵌入接口,未实现的方法会 panic(不应被调用)
tempUnschedCalls []tempUnschedCall
setErrorCalls []setErrorCall
}
type tempUnschedCall struct {
accountID int64
until time.Time
reason string
}
type setErrorCall struct {
accountID int64
reason string
}
func (r *internal500AccountRepoStub) SetTempUnschedulable(_ context.Context, id int64, until time.Time, reason string) error {
r.tempUnschedCalls = append(r.tempUnschedCalls, tempUnschedCall{accountID: id, until: until, reason: reason})
return nil
}
func (r *internal500AccountRepoStub) SetError(_ context.Context, id int64, errorMsg string) error {
r.setErrorCalls = append(r.setErrorCalls, setErrorCall{accountID: id, reason: errorMsg})
return nil
}
// =============================================================================
// TestIsAntigravityInternalServerError
// =============================================================================
func TestIsAntigravityInternalServerError(t *testing.T) {
t.Run("匹配完整的 INTERNAL 500 body", func(t *testing.T) {
body := []byte(`{"error":{"code":500,"message":"Internal error encountered.","status":"INTERNAL"}}`)
require.True(t, isAntigravityInternalServerError(500, body))
})
t.Run("statusCode 不是 500", func(t *testing.T) {
body := []byte(`{"error":{"code":500,"message":"Internal error encountered.","status":"INTERNAL"}}`)
require.False(t, isAntigravityInternalServerError(429, body))
require.False(t, isAntigravityInternalServerError(503, body))
require.False(t, isAntigravityInternalServerError(200, body))
})
t.Run("body 中 message 不匹配", func(t *testing.T) {
body := []byte(`{"error":{"code":500,"message":"Some other error","status":"INTERNAL"}}`)
require.False(t, isAntigravityInternalServerError(500, body))
})
t.Run("body 中 status 不匹配", func(t *testing.T) {
body := []byte(`{"error":{"code":500,"message":"Internal error encountered.","status":"UNAVAILABLE"}}`)
require.False(t, isAntigravityInternalServerError(500, body))
})
t.Run("body 中 code 不匹配", func(t *testing.T) {
body := []byte(`{"error":{"code":503,"message":"Internal error encountered.","status":"INTERNAL"}}`)
require.False(t, isAntigravityInternalServerError(500, body))
})
t.Run("空 body", func(t *testing.T) {
require.False(t, isAntigravityInternalServerError(500, []byte{}))
require.False(t, isAntigravityInternalServerError(500, nil))
})
t.Run("其他 500 错误格式(纯文本)", func(t *testing.T) {
body := []byte(`Internal Server Error`)
require.False(t, isAntigravityInternalServerError(500, body))
})
t.Run("其他 500 错误格式(不同 JSON 结构)", func(t *testing.T) {
body := []byte(`{"message":"Internal Server Error","statusCode":500}`)
require.False(t, isAntigravityInternalServerError(500, body))
})
}
// =============================================================================
// TestApplyInternal500Penalty
// =============================================================================
func TestApplyInternal500Penalty(t *testing.T) {
t.Run("count=1 → SetTempUnschedulable 10 分钟", func(t *testing.T) {
repo := &internal500AccountRepoStub{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 1, Name: "acc-1"}
before := time.Now()
svc.applyInternal500Penalty(context.Background(), "[test]", account, 1)
after := time.Now()
require.Len(t, repo.tempUnschedCalls, 1)
require.Empty(t, repo.setErrorCalls)
call := repo.tempUnschedCalls[0]
require.Equal(t, int64(1), call.accountID)
require.Contains(t, call.reason, "INTERNAL 500")
// until 应在 [before+10m, after+10m] 范围内
require.True(t, call.until.After(before.Add(internal500PenaltyTier1Duration).Add(-time.Second)))
require.True(t, call.until.Before(after.Add(internal500PenaltyTier1Duration).Add(time.Second)))
})
t.Run("count=2 → SetTempUnschedulable 10 小时", func(t *testing.T) {
repo := &internal500AccountRepoStub{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 2, Name: "acc-2"}
before := time.Now()
svc.applyInternal500Penalty(context.Background(), "[test]", account, 2)
after := time.Now()
require.Len(t, repo.tempUnschedCalls, 1)
require.Empty(t, repo.setErrorCalls)
call := repo.tempUnschedCalls[0]
require.Equal(t, int64(2), call.accountID)
require.Contains(t, call.reason, "INTERNAL 500")
require.True(t, call.until.After(before.Add(internal500PenaltyTier2Duration).Add(-time.Second)))
require.True(t, call.until.Before(after.Add(internal500PenaltyTier2Duration).Add(time.Second)))
})
t.Run("count=3 → SetError 永久禁用", func(t *testing.T) {
repo := &internal500AccountRepoStub{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 3, Name: "acc-3"}
svc.applyInternal500Penalty(context.Background(), "[test]", account, 3)
require.Empty(t, repo.tempUnschedCalls)
require.Len(t, repo.setErrorCalls, 1)
call := repo.setErrorCalls[0]
require.Equal(t, int64(3), call.accountID)
require.Contains(t, call.reason, "INTERNAL 500 consecutive failures: 3")
})
t.Run("count=5 → SetError 永久禁用(>=3 都走永久禁用)", func(t *testing.T) {
repo := &internal500AccountRepoStub{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 5, Name: "acc-5"}
svc.applyInternal500Penalty(context.Background(), "[test]", account, 5)
require.Empty(t, repo.tempUnschedCalls)
require.Len(t, repo.setErrorCalls, 1)
call := repo.setErrorCalls[0]
require.Equal(t, int64(5), call.accountID)
require.Contains(t, call.reason, "INTERNAL 500 consecutive failures: 5")
})
t.Run("count=0 → 不调用任何方法", func(t *testing.T) {
repo := &internal500AccountRepoStub{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 10, Name: "acc-10"}
svc.applyInternal500Penalty(context.Background(), "[test]", account, 0)
require.Empty(t, repo.tempUnschedCalls)
require.Empty(t, repo.setErrorCalls)
})
}
// =============================================================================
// TestHandleInternal500RetryExhausted
// =============================================================================
func TestHandleInternal500RetryExhausted(t *testing.T) {
t.Run("internal500Cache 为 nil → 不 panic,不调用任何方法", func(t *testing.T) {
repo := &internal500AccountRepoStub{}
svc := &AntigravityGatewayService{
accountRepo: repo,
internal500Cache: nil,
}
account := &Account{ID: 1, Name: "acc-1"}
// 不应 panic
require.NotPanics(t, func() {
svc.handleInternal500RetryExhausted(context.Background(), "[test]", account)
})
require.Empty(t, repo.tempUnschedCalls)
require.Empty(t, repo.setErrorCalls)
})
t.Run("IncrementInternal500Count 返回 error → 不调用惩罚方法", func(t *testing.T) {
repo := &internal500AccountRepoStub{}
cache := &mockInternal500Cache{
incrementErr: errors.New("redis connection error"),
}
svc := &AntigravityGatewayService{
accountRepo: repo,
internal500Cache: cache,
}
account := &Account{ID: 2, Name: "acc-2"}
svc.handleInternal500RetryExhausted(context.Background(), "[test]", account)
require.Len(t, cache.incrementCalls, 1)
require.Equal(t, int64(2), cache.incrementCalls[0])
require.Empty(t, repo.tempUnschedCalls)
require.Empty(t, repo.setErrorCalls)
})
t.Run("IncrementInternal500Count 返回 count=1 → 触发 tier1 惩罚", func(t *testing.T) {
repo := &internal500AccountRepoStub{}
cache := &mockInternal500Cache{
incrementCount: 1,
}
svc := &AntigravityGatewayService{
accountRepo: repo,
internal500Cache: cache,
}
account := &Account{ID: 3, Name: "acc-3"}
svc.handleInternal500RetryExhausted(context.Background(), "[test]", account)
require.Len(t, cache.incrementCalls, 1)
require.Equal(t, int64(3), cache.incrementCalls[0])
// tier1: SetTempUnschedulable
require.Len(t, repo.tempUnschedCalls, 1)
require.Equal(t, int64(3), repo.tempUnschedCalls[0].accountID)
require.Empty(t, repo.setErrorCalls)
})
t.Run("IncrementInternal500Count 返回 count=3 → 触发 tier3 永久禁用", func(t *testing.T) {
repo := &internal500AccountRepoStub{}
cache := &mockInternal500Cache{
incrementCount: 3,
}
svc := &AntigravityGatewayService{
accountRepo: repo,
internal500Cache: cache,
}
account := &Account{ID: 4, Name: "acc-4"}
svc.handleInternal500RetryExhausted(context.Background(), "[test]", account)
require.Len(t, cache.incrementCalls, 1)
require.Empty(t, repo.tempUnschedCalls)
require.Len(t, repo.setErrorCalls, 1)
require.Equal(t, int64(4), repo.setErrorCalls[0].accountID)
})
}
// =============================================================================
// TestResetInternal500Counter
// =============================================================================
func TestResetInternal500Counter(t *testing.T) {
t.Run("internal500Cache 为 nil → 不 panic", func(t *testing.T) {
svc := &AntigravityGatewayService{
internal500Cache: nil,
}
require.NotPanics(t, func() {
svc.resetInternal500Counter(context.Background(), "[test]", 1)
})
})
t.Run("ResetInternal500Count 返回 error → 不 panic(仅日志)", func(t *testing.T) {
cache := &mockInternal500Cache{
resetErr: errors.New("redis timeout"),
}
svc := &AntigravityGatewayService{
internal500Cache: cache,
}
require.NotPanics(t, func() {
svc.resetInternal500Counter(context.Background(), "[test]", 42)
})
require.Len(t, cache.resetCalls, 1)
require.Equal(t, int64(42), cache.resetCalls[0])
})
t.Run("正常调用 → 调用 ResetInternal500Count", func(t *testing.T) {
cache := &mockInternal500Cache{}
svc := &AntigravityGatewayService{
internal500Cache: cache,
}
svc.resetInternal500Counter(context.Background(), "[test]", 99)
require.Len(t, cache.resetCalls, 1)
require.Equal(t, int64(99), cache.resetCalls[0])
})
}
...@@ -12,6 +12,7 @@ import ( ...@@ -12,6 +12,7 @@ import (
"log/slog" "log/slog"
mathrand "math/rand" mathrand "math/rand"
"net/http" "net/http"
"net/url"
"os" "os"
"path/filepath" "path/filepath"
"regexp" "regexp"
...@@ -368,6 +369,8 @@ var allowedHeaders = map[string]bool{ ...@@ -368,6 +369,8 @@ var allowedHeaders = map[string]bool{
"user-agent": true, "user-agent": true,
"content-type": true, "content-type": true,
"accept-encoding": true, "accept-encoding": true,
"x-claude-code-session-id": true,
"x-client-request-id": true,
} }
// GatewayCache 定义网关服务的缓存操作接口。 // GatewayCache 定义网关服务的缓存操作接口。
...@@ -4150,10 +4153,12 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -4150,10 +4153,12 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return nil, err return nil, err
} }
// 获取代理URL // 获取代理URL(自定义 base URL 模式下,proxy 通过 buildCustomRelayURL 作为查询参数传递)
proxyURL := "" proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil { if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL() if !account.IsCustomBaseURLEnabled() || account.GetCustomBaseURL() == "" {
proxyURL = account.Proxy.URL()
}
} }
// 解析 TLS 指纹 profile(同一请求生命周期内不变,避免重试循环中重复解析) // 解析 TLS 指纹 profile(同一请求生命周期内不变,避免重试循环中重复解析)
...@@ -5628,6 +5633,16 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex ...@@ -5628,6 +5633,16 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
} }
targetURL = validatedURL + "/v1/messages?beta=true" targetURL = validatedURL + "/v1/messages?beta=true"
} }
} else if account.IsCustomBaseURLEnabled() {
customURL := account.GetCustomBaseURL()
if customURL == "" {
return nil, fmt.Errorf("custom_base_url is enabled but not configured for account %d", account.ID)
}
validatedURL, err := s.validateUpstreamBaseURL(customURL)
if err != nil {
return nil, err
}
targetURL = s.buildCustomRelayURL(validatedURL, "/v1/messages", account)
} }
clientHeaders := http.Header{} clientHeaders := http.Header{}
...@@ -5743,6 +5758,15 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex ...@@ -5743,6 +5758,15 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
} }
} }
// 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖
if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" {
if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
if parsed := ParseMetadataUserID(uid); parsed != nil {
setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID)
}
}
}
// === DEBUG: 打印上游转发请求(headers + body 摘要),与 CLIENT_ORIGINAL 对比 === // === DEBUG: 打印上游转发请求(headers + body 摘要),与 CLIENT_ORIGINAL 对比 ===
s.debugLogGatewaySnapshot("UPSTREAM_FORWARD", req.Header, body, map[string]string{ s.debugLogGatewaySnapshot("UPSTREAM_FORWARD", req.Header, body, map[string]string{
"url": req.URL.String(), "url": req.URL.String(),
...@@ -8063,10 +8087,12 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, ...@@ -8063,10 +8087,12 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
return err return err
} }
// 获取代理URL // 获取代理URL(自定义 base URL 模式下,proxy 通过 buildCustomRelayURL 作为查询参数传递)
proxyURL := "" proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil { if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL() if !account.IsCustomBaseURLEnabled() || account.GetCustomBaseURL() == "" {
proxyURL = account.Proxy.URL()
}
} }
// 发送请求 // 发送请求
...@@ -8345,6 +8371,16 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con ...@@ -8345,6 +8371,16 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
} }
targetURL = validatedURL + "/v1/messages/count_tokens?beta=true" targetURL = validatedURL + "/v1/messages/count_tokens?beta=true"
} }
} else if account.IsCustomBaseURLEnabled() {
customURL := account.GetCustomBaseURL()
if customURL == "" {
return nil, fmt.Errorf("custom_base_url is enabled but not configured for account %d", account.ID)
}
validatedURL, err := s.validateUpstreamBaseURL(customURL)
if err != nil {
return nil, err
}
targetURL = s.buildCustomRelayURL(validatedURL, "/v1/messages/count_tokens", account)
} }
clientHeaders := http.Header{} clientHeaders := http.Header{}
...@@ -8450,6 +8486,15 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con ...@@ -8450,6 +8486,15 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
} }
} }
// 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖
if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" {
if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
if parsed := ParseMetadataUserID(uid); parsed != nil {
setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID)
}
}
}
if c != nil && tokenType == "oauth" { if c != nil && tokenType == "oauth" {
c.Set(claudeMimicDebugInfoKey, buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode)) c.Set(claudeMimicDebugInfoKey, buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode))
} }
...@@ -8471,6 +8516,19 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m ...@@ -8471,6 +8516,19 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m
}) })
} }
// buildCustomRelayURL 构建自定义中继转发 URL
// 在 path 后附加 beta=true 和可选的 proxy 查询参数
func (s *GatewayService) buildCustomRelayURL(baseURL, path string, account *Account) string {
u := strings.TrimRight(baseURL, "/") + path + "?beta=true"
if account.ProxyID != nil && account.Proxy != nil {
proxyURL := account.Proxy.URL()
if proxyURL != "" {
u += "&proxy=" + url.QueryEscape(proxyURL)
}
}
return u
}
func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) { func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled { if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP) normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
......
...@@ -36,6 +36,11 @@ var headerWireCasing = map[string]string{ ...@@ -36,6 +36,11 @@ var headerWireCasing = map[string]string{
"sec-fetch-mode": "sec-fetch-mode", "sec-fetch-mode": "sec-fetch-mode",
"accept-encoding": "accept-encoding", "accept-encoding": "accept-encoding",
"authorization": "authorization", "authorization": "authorization",
// Claude Code 2.1.87+ 新增 header
"x-claude-code-session-id": "X-Claude-Code-Session-Id",
"x-client-request-id": "x-client-request-id",
"content-length": "content-length",
} }
// headerWireOrder 定义真实 Claude CLI 发送 header 的顺序(基于抓包)。 // headerWireOrder 定义真实 Claude CLI 发送 header 的顺序(基于抓包)。
...@@ -55,11 +60,14 @@ var headerWireOrder = []string{ ...@@ -55,11 +60,14 @@ var headerWireOrder = []string{
"authorization", "authorization",
"x-app", "x-app",
"User-Agent", "User-Agent",
"X-Claude-Code-Session-Id",
"content-type", "content-type",
"anthropic-beta", "anthropic-beta",
"x-client-request-id",
"accept-language", "accept-language",
"sec-fetch-mode", "sec-fetch-mode",
"accept-encoding", "accept-encoding",
"content-length",
"x-stainless-helper-method", "x-stainless-helper-method",
} }
......
package service
import "context"
// Internal500CounterCache 追踪 Antigravity 账号连续 INTERNAL 500 失败轮数
type Internal500CounterCache interface {
// IncrementInternal500Count 原子递增计数并返回当前值
IncrementInternal500Count(ctx context.Context, accountID int64) (int64, error)
// ResetInternal500Count 清零计数器(成功响应时调用)
ResetInternal500Count(ctx context.Context, accountID int64) error
}
package service
import (
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
)
func NormalizeOpenAICompatRequestedModel(model string) string {
trimmed := strings.TrimSpace(model)
if trimmed == "" {
return ""
}
normalized, _, ok := splitOpenAICompatReasoningModel(trimmed)
if !ok || normalized == "" {
return trimmed
}
return normalized
}
func applyOpenAICompatModelNormalization(req *apicompat.AnthropicRequest) {
if req == nil {
return
}
originalModel := strings.TrimSpace(req.Model)
if originalModel == "" {
return
}
normalizedModel, derivedEffort, hasReasoningSuffix := splitOpenAICompatReasoningModel(originalModel)
if hasReasoningSuffix && normalizedModel != "" {
req.Model = normalizedModel
}
if req.OutputConfig != nil && strings.TrimSpace(req.OutputConfig.Effort) != "" {
return
}
claudeEffort := openAIReasoningEffortToClaudeOutputEffort(derivedEffort)
if claudeEffort == "" {
return
}
if req.OutputConfig == nil {
req.OutputConfig = &apicompat.AnthropicOutputConfig{}
}
req.OutputConfig.Effort = claudeEffort
}
func splitOpenAICompatReasoningModel(model string) (normalizedModel string, reasoningEffort string, ok bool) {
trimmed := strings.TrimSpace(model)
if trimmed == "" {
return "", "", false
}
modelID := trimmed
if strings.Contains(modelID, "/") {
parts := strings.Split(modelID, "/")
modelID = parts[len(parts)-1]
}
modelID = strings.TrimSpace(modelID)
if !strings.HasPrefix(strings.ToLower(modelID), "gpt-") {
return trimmed, "", false
}
parts := strings.FieldsFunc(strings.ToLower(modelID), func(r rune) bool {
switch r {
case '-', '_', ' ':
return true
default:
return false
}
})
if len(parts) == 0 {
return trimmed, "", false
}
last := strings.NewReplacer("-", "", "_", "", " ", "").Replace(parts[len(parts)-1])
switch last {
case "none", "minimal":
case "low", "medium", "high":
reasoningEffort = last
case "xhigh", "extrahigh":
reasoningEffort = "xhigh"
default:
return trimmed, "", false
}
return normalizeCodexModel(modelID), reasoningEffort, true
}
func openAIReasoningEffortToClaudeOutputEffort(effort string) string {
switch strings.TrimSpace(effort) {
case "low", "medium", "high":
return effort
case "xhigh":
return "max"
default:
return ""
}
}
package service
import (
"bytes"
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestNormalizeOpenAICompatRequestedModel(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
want string
}{
{name: "gpt reasoning alias strips xhigh", input: "gpt-5.4-xhigh", want: "gpt-5.4"},
{name: "gpt reasoning alias strips none", input: "gpt-5.4-none", want: "gpt-5.4"},
{name: "codex max model stays intact", input: "gpt-5.1-codex-max", want: "gpt-5.1-codex-max"},
{name: "non openai model unchanged", input: "claude-opus-4-6", want: "claude-opus-4-6"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, NormalizeOpenAICompatRequestedModel(tt.input))
})
}
}
func TestApplyOpenAICompatModelNormalization(t *testing.T) {
t.Parallel()
t.Run("derives xhigh from model suffix when output config missing", func(t *testing.T) {
req := &apicompat.AnthropicRequest{Model: "gpt-5.4-xhigh"}
applyOpenAICompatModelNormalization(req)
require.Equal(t, "gpt-5.4", req.Model)
require.NotNil(t, req.OutputConfig)
require.Equal(t, "max", req.OutputConfig.Effort)
})
t.Run("explicit output config wins over model suffix", func(t *testing.T) {
req := &apicompat.AnthropicRequest{
Model: "gpt-5.4-xhigh",
OutputConfig: &apicompat.AnthropicOutputConfig{Effort: "low"},
}
applyOpenAICompatModelNormalization(req)
require.Equal(t, "gpt-5.4", req.Model)
require.NotNil(t, req.OutputConfig)
require.Equal(t, "low", req.OutputConfig.Effort)
})
t.Run("non openai model is untouched", func(t *testing.T) {
req := &apicompat.AnthropicRequest{Model: "claude-opus-4-6"}
applyOpenAICompatModelNormalization(req)
require.Equal(t, "claude-opus-4-6", req.Model)
require.Nil(t, req.OutputConfig)
})
}
func TestForwardAsAnthropic_NormalizesRoutingAndEffortForGpt54XHigh(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{"model":"gpt-5.4-xhigh","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":false}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := strings.Join([]string{
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`,
"",
"data: [DONE]",
"",
}, "\n")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_compat"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
"model_mapping": map[string]any{
"gpt-5.4": "gpt-5.4",
},
},
}
result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "gpt-5.4-xhigh", result.Model)
require.Equal(t, "gpt-5.4", result.UpstreamModel)
require.Equal(t, "gpt-5.4", result.BillingModel)
require.NotNil(t, result.ReasoningEffort)
require.Equal(t, "xhigh", *result.ReasoningEffort)
require.Equal(t, "gpt-5.4", gjson.GetBytes(upstream.lastBody, "model").String())
require.Equal(t, "xhigh", gjson.GetBytes(upstream.lastBody, "reasoning.effort").String())
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "gpt-5.4-xhigh", gjson.GetBytes(rec.Body.Bytes(), "model").String())
require.Equal(t, "ok", gjson.GetBytes(rec.Body.Bytes(), "content.0.text").String())
t.Logf("upstream body: %s", string(upstream.lastBody))
t.Logf("response body: %s", rec.Body.String())
}
...@@ -40,6 +40,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( ...@@ -40,6 +40,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
return nil, fmt.Errorf("parse anthropic request: %w", err) return nil, fmt.Errorf("parse anthropic request: %w", err)
} }
originalModel := anthropicReq.Model originalModel := anthropicReq.Model
applyOpenAICompatModelNormalization(&anthropicReq)
clientStream := anthropicReq.Stream // client's original stream preference clientStream := anthropicReq.Stream // client's original stream preference
// 2. Convert Anthropic → Responses // 2. Convert Anthropic → Responses
......
...@@ -895,14 +895,16 @@ func TestOpenAIGatewayServiceRecordUsage_UsesRequestedModelAndUpstreamModelMetad ...@@ -895,14 +895,16 @@ func TestOpenAIGatewayServiceRecordUsage_UsesRequestedModelAndUpstreamModelMetad
require.Equal(t, 1, userRepo.deductCalls) require.Equal(t, 1, userRepo.deductCalls)
} }
func TestOpenAIGatewayServiceRecordUsage_BillsMappedRequestsUsingUpstreamModelFallback(t *testing.T) { func TestOpenAIGatewayServiceRecordUsage_BillsMappedRequestsUsingRequestedModel(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{} userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{} subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10} usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10}
expectedCost, err := svc.billingService.CalculateCost("gpt-5.1-codex", UsageTokens{ // Billing should use the requested model ("gpt-5.1"), not the upstream mapped model ("gpt-5.1-codex").
// This ensures pricing is always based on the model the user requested.
expectedCost, err := svc.billingService.CalculateCost("gpt-5.1", UsageTokens{
InputTokens: 20, InputTokens: 20,
OutputTokens: 10, OutputTokens: 10,
}, 1.1) }, 1.1)
......
...@@ -4153,9 +4153,6 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec ...@@ -4153,9 +4153,6 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
} }
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
if result.BillingModel != "" {
billingModel = strings.TrimSpace(result.BillingModel)
}
serviceTier := "" serviceTier := ""
if result.ServiceTier != nil { if result.ServiceTier != nil {
serviceTier = strings.TrimSpace(*result.ServiceTier) serviceTier = strings.TrimSpace(*result.ServiceTier)
......
...@@ -502,6 +502,25 @@ func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *A ...@@ -502,6 +502,25 @@ func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *A
refreshToken := account.GetCredential("refresh_token") refreshToken := account.GetCredential("refresh_token")
if refreshToken == "" { if refreshToken == "" {
accessToken := account.GetCredential("access_token")
if accessToken != "" {
tokenInfo := &OpenAITokenInfo{
AccessToken: accessToken,
RefreshToken: "",
IDToken: account.GetCredential("id_token"),
ClientID: account.GetCredential("client_id"),
Email: account.GetCredential("email"),
ChatGPTAccountID: account.GetCredential("chatgpt_account_id"),
ChatGPTUserID: account.GetCredential("chatgpt_user_id"),
OrganizationID: account.GetCredential("organization_id"),
PlanType: account.GetCredential("plan_type"),
}
if expiresAt := account.GetCredentialAsTime("expires_at"); expiresAt != nil {
tokenInfo.ExpiresAt = expiresAt.Unix()
tokenInfo.ExpiresIn = int64(time.Until(*expiresAt).Seconds())
}
return tokenInfo, nil
}
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available") return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available")
} }
......
package service
import (
"context"
"errors"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/stretchr/testify/require"
)
type openaiOAuthClientRefreshStub struct {
refreshCalls int32
}
func (s *openaiOAuthClientRefreshStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) {
return nil, errors.New("not implemented")
}
func (s *openaiOAuthClientRefreshStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
atomic.AddInt32(&s.refreshCalls, 1)
return nil, errors.New("not implemented")
}
func (s *openaiOAuthClientRefreshStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
atomic.AddInt32(&s.refreshCalls, 1)
return nil, errors.New("not implemented")
}
func TestOpenAIOAuthService_RefreshAccountToken_NoRefreshTokenUsesExistingAccessToken(t *testing.T) {
client := &openaiOAuthClientRefreshStub{}
svc := NewOpenAIOAuthService(nil, client)
expiresAt := time.Now().Add(30 * time.Minute).UTC().Format(time.RFC3339)
account := &Account{
ID: 77,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "existing-access-token",
"expires_at": expiresAt,
"client_id": "client-id-1",
},
}
info, err := svc.RefreshAccountToken(context.Background(), account)
require.NoError(t, err)
require.NotNil(t, info)
require.Equal(t, "existing-access-token", info.AccessToken)
require.Equal(t, "client-id-1", info.ClientID)
require.Zero(t, atomic.LoadInt32(&client.refreshCalls), "existing access token should be reused without calling refresh")
}
...@@ -189,10 +189,38 @@ func (s *PricingService) checkAndUpdatePricing() error { ...@@ -189,10 +189,38 @@ func (s *PricingService) checkAndUpdatePricing() error {
return s.downloadPricingData() return s.downloadPricingData()
} }
// 检查文件是否过期 // 先加载本地文件(确保服务可用),再检查是否需要更新
if err := s.loadPricingData(pricingFile); err != nil {
logger.LegacyPrintf("service.pricing", "[Pricing] Failed to load local file, downloading: %v", err)
return s.downloadPricingData()
}
// 如果配置了哈希URL,通过远程哈希检查是否有更新
if s.cfg.Pricing.HashURL != "" {
remoteHash, err := s.fetchRemoteHash()
if err != nil {
logger.LegacyPrintf("service.pricing", "[Pricing] Failed to fetch remote hash on startup: %v", err)
return nil // 已加载本地文件,哈希获取失败不影响启动
}
s.mu.RLock()
localHash := s.localHash
s.mu.RUnlock()
if localHash == "" || remoteHash != localHash {
logger.LegacyPrintf("service.pricing", "[Pricing] Remote hash differs on startup (local=%s remote=%s), downloading...",
localHash[:min(8, len(localHash))], remoteHash[:min(8, len(remoteHash))])
if err := s.downloadPricingData(); err != nil {
logger.LegacyPrintf("service.pricing", "[Pricing] Download failed, using existing file: %v", err)
}
}
return nil
}
// 没有哈希URL时,基于文件年龄检查
info, err := os.Stat(pricingFile) info, err := os.Stat(pricingFile)
if err != nil { if err != nil {
return s.downloadPricingData() return nil // 已加载本地文件
} }
fileAge := time.Since(info.ModTime()) fileAge := time.Since(info.ModTime())
...@@ -205,21 +233,11 @@ func (s *PricingService) checkAndUpdatePricing() error { ...@@ -205,21 +233,11 @@ func (s *PricingService) checkAndUpdatePricing() error {
} }
} }
// 加载本地文件 return nil
return s.loadPricingData(pricingFile)
} }
// syncWithRemote 与远程同步(基于哈希校验) // syncWithRemote 与远程同步(基于哈希校验)
func (s *PricingService) syncWithRemote() error { func (s *PricingService) syncWithRemote() error {
pricingFile := s.getPricingFilePath()
// 计算本地文件哈希
localHash, err := s.computeFileHash(pricingFile)
if err != nil {
logger.LegacyPrintf("service.pricing", "[Pricing] Failed to compute local hash: %v", err)
return s.downloadPricingData()
}
// 如果配置了哈希URL,从远程获取哈希进行比对 // 如果配置了哈希URL,从远程获取哈希进行比对
if s.cfg.Pricing.HashURL != "" { if s.cfg.Pricing.HashURL != "" {
remoteHash, err := s.fetchRemoteHash() remoteHash, err := s.fetchRemoteHash()
...@@ -228,8 +246,13 @@ func (s *PricingService) syncWithRemote() error { ...@@ -228,8 +246,13 @@ func (s *PricingService) syncWithRemote() error {
return nil // 哈希获取失败不影响正常使用 return nil // 哈希获取失败不影响正常使用
} }
if remoteHash != localHash { s.mu.RLock()
logger.LegacyPrintf("service.pricing", "%s", "[Pricing] Remote hash differs, downloading new version...") localHash := s.localHash
s.mu.RUnlock()
if localHash == "" || remoteHash != localHash {
logger.LegacyPrintf("service.pricing", "[Pricing] Remote hash differs (local=%s remote=%s), downloading new version...",
localHash[:min(8, len(localHash))], remoteHash[:min(8, len(remoteHash))])
return s.downloadPricingData() return s.downloadPricingData()
} }
logger.LegacyPrintf("service.pricing", "%s", "[Pricing] Hash check passed, no update needed") logger.LegacyPrintf("service.pricing", "%s", "[Pricing] Hash check passed, no update needed")
...@@ -237,6 +260,7 @@ func (s *PricingService) syncWithRemote() error { ...@@ -237,6 +260,7 @@ func (s *PricingService) syncWithRemote() error {
} }
// 没有哈希URL时,基于时间检查 // 没有哈希URL时,基于时间检查
pricingFile := s.getPricingFilePath()
info, err := os.Stat(pricingFile) info, err := os.Stat(pricingFile)
if err != nil { if err != nil {
return s.downloadPricingData() return s.downloadPricingData()
...@@ -264,11 +288,12 @@ func (s *PricingService) downloadPricingData() error { ...@@ -264,11 +288,12 @@ func (s *PricingService) downloadPricingData() error {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel() defer cancel()
var expectedHash string // 获取远程哈希(用于同步锚点,不作为完整性校验)
var remoteHash string
if strings.TrimSpace(s.cfg.Pricing.HashURL) != "" { if strings.TrimSpace(s.cfg.Pricing.HashURL) != "" {
expectedHash, err = s.fetchRemoteHash() remoteHash, err = s.fetchRemoteHash()
if err != nil { if err != nil {
return fmt.Errorf("fetch remote hash: %w", err) logger.LegacyPrintf("service.pricing", "[Pricing] Failed to fetch remote hash (continuing): %v", err)
} }
} }
...@@ -277,11 +302,13 @@ func (s *PricingService) downloadPricingData() error { ...@@ -277,11 +302,13 @@ func (s *PricingService) downloadPricingData() error {
return fmt.Errorf("download failed: %w", err) return fmt.Errorf("download failed: %w", err)
} }
if expectedHash != "" { // 哈希校验:不匹配时仅告警,不阻止更新
actualHash := sha256.Sum256(body) // 远程哈希文件可能与数据文件不同步(如维护者更新了数据但未更新哈希文件)
if !strings.EqualFold(expectedHash, hex.EncodeToString(actualHash[:])) { dataHash := sha256.Sum256(body)
return fmt.Errorf("pricing hash mismatch") dataHashStr := hex.EncodeToString(dataHash[:])
} if remoteHash != "" && !strings.EqualFold(remoteHash, dataHashStr) {
logger.LegacyPrintf("service.pricing", "[Pricing] Hash mismatch warning: remote=%s data=%s (hash file may be out of sync)",
remoteHash[:min(8, len(remoteHash))], dataHashStr[:8])
} }
// 解析JSON数据(使用灵活的解析方式) // 解析JSON数据(使用灵活的解析方式)
...@@ -296,11 +323,14 @@ func (s *PricingService) downloadPricingData() error { ...@@ -296,11 +323,14 @@ func (s *PricingService) downloadPricingData() error {
logger.LegacyPrintf("service.pricing", "[Pricing] Failed to save file: %v", err) logger.LegacyPrintf("service.pricing", "[Pricing] Failed to save file: %v", err)
} }
// 保存哈希 // 使用远程哈希作为同步锚点,防止重复下载
hash := sha256.Sum256(body) // 当远程哈希不可用时,回退到数据本身的哈希
hashStr := hex.EncodeToString(hash[:]) syncHash := dataHashStr
if remoteHash != "" {
syncHash = remoteHash
}
hashFile := s.getHashFilePath() hashFile := s.getHashFilePath()
if err := os.WriteFile(hashFile, []byte(hashStr+"\n"), 0644); err != nil { if err := os.WriteFile(hashFile, []byte(syncHash+"\n"), 0644); err != nil {
logger.LegacyPrintf("service.pricing", "[Pricing] Failed to save hash: %v", err) logger.LegacyPrintf("service.pricing", "[Pricing] Failed to save hash: %v", err)
} }
...@@ -308,7 +338,7 @@ func (s *PricingService) downloadPricingData() error { ...@@ -308,7 +338,7 @@ func (s *PricingService) downloadPricingData() error {
s.mu.Lock() s.mu.Lock()
s.pricingData = data s.pricingData = data
s.lastUpdated = time.Now() s.lastUpdated = time.Now()
s.localHash = hashStr s.localHash = syncHash
s.mu.Unlock() s.mu.Unlock()
logger.LegacyPrintf("service.pricing", "[Pricing] Downloaded %d models successfully", len(data)) logger.LegacyPrintf("service.pricing", "[Pricing] Downloaded %d models successfully", len(data))
...@@ -486,16 +516,6 @@ func (s *PricingService) validatePricingURL(raw string) (string, error) { ...@@ -486,16 +516,6 @@ func (s *PricingService) validatePricingURL(raw string) (string, error) {
return normalized, nil return normalized, nil
} }
// computeFileHash 计算文件哈希
func (s *PricingService) computeFileHash(filePath string) (string, error) {
data, err := os.ReadFile(filePath)
if err != nil {
return "", err
}
hash := sha256.Sum256(data)
return hex.EncodeToString(hash[:]), nil
}
// GetModelPricing 获取模型价格(带模糊匹配) // GetModelPricing 获取模型价格(带模糊匹配)
func (s *PricingService) GetModelPricing(modelName string) *LiteLLMModelPricing { func (s *PricingService) GetModelPricing(modelName string) *LiteLLMModelPricing {
s.mu.RLock() s.mu.RLock()
......
...@@ -32,8 +32,9 @@ type TokenRefreshService struct { ...@@ -32,8 +32,9 @@ type TokenRefreshService struct {
privacyClientFactory PrivacyClientFactory privacyClientFactory PrivacyClientFactory
proxyRepo ProxyRepository proxyRepo ProxyRepository
stopCh chan struct{} stopCh chan struct{}
wg sync.WaitGroup stopOnce sync.Once
wg sync.WaitGroup
} }
// NewTokenRefreshService 创建token刷新服务 // NewTokenRefreshService 创建token刷新服务
...@@ -130,7 +131,9 @@ func (s *TokenRefreshService) Start() { ...@@ -130,7 +131,9 @@ func (s *TokenRefreshService) Start() {
// Stop 停止刷新服务(可安全多次调用) // Stop 停止刷新服务(可安全多次调用)
func (s *TokenRefreshService) Stop() { func (s *TokenRefreshService) Stop() {
close(s.stopCh) s.stopOnce.Do(func() {
close(s.stopCh)
})
s.wg.Wait() s.wg.Wait()
slog.Info("token_refresh.service_stopped") slog.Info("token_refresh.service_stopped")
} }
...@@ -430,6 +433,7 @@ func isNonRetryableRefreshError(err error) bool { ...@@ -430,6 +433,7 @@ func isNonRetryableRefreshError(err error) bool {
"unauthorized_client", // 客户端未授权 "unauthorized_client", // 客户端未授权
"access_denied", // 访问被拒绝 "access_denied", // 访问被拒绝
"missing_project_id", // 缺少 project_id "missing_project_id", // 缺少 project_id
"no refresh token available",
} }
for _, needle := range nonRetryable { for _, needle := range nonRetryable {
if strings.Contains(msg, needle) { if strings.Contains(msg, needle) {
......
...@@ -19,6 +19,7 @@ type tokenRefreshAccountRepo struct { ...@@ -19,6 +19,7 @@ type tokenRefreshAccountRepo struct {
updateCredentialsCalls int updateCredentialsCalls int
setErrorCalls int setErrorCalls int
clearTempCalls int clearTempCalls int
setTempUnschedCalls int
lastAccount *Account lastAccount *Account
updateErr error updateErr error
} }
...@@ -58,6 +59,11 @@ func (r *tokenRefreshAccountRepo) ClearTempUnschedulable(ctx context.Context, id ...@@ -58,6 +59,11 @@ func (r *tokenRefreshAccountRepo) ClearTempUnschedulable(ctx context.Context, id
return nil return nil
} }
func (r *tokenRefreshAccountRepo) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
r.setTempUnschedCalls++
return nil
}
type tokenCacheInvalidatorStub struct { type tokenCacheInvalidatorStub struct {
calls int calls int
err error err error
...@@ -490,6 +496,31 @@ func TestTokenRefreshService_RefreshWithRetry_NonRetryableErrorAllPlatforms(t *t ...@@ -490,6 +496,31 @@ func TestTokenRefreshService_RefreshWithRetry_NonRetryableErrorAllPlatforms(t *t
} }
} }
func TestTokenRefreshService_RefreshWithRetry_NoRefreshTokenDoesNotTempUnschedule(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 2,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil)
account := &Account{
ID: 18,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
err: errors.New("no refresh token available"),
}
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
require.Error(t, err)
require.Equal(t, 0, repo.updateCalls)
require.Equal(t, 0, repo.setTempUnschedCalls, "missing refresh token should not mark the account temp unschedulable")
require.Equal(t, 1, repo.setErrorCalls, "missing refresh token should be treated as a non-retryable credential state")
}
// TestIsNonRetryableRefreshError 测试不可重试错误判断 // TestIsNonRetryableRefreshError 测试不可重试错误判断
func TestIsNonRetryableRefreshError(t *testing.T) { func TestIsNonRetryableRefreshError(t *testing.T) {
tests := []struct { tests := []struct {
...@@ -503,6 +534,7 @@ func TestIsNonRetryableRefreshError(t *testing.T) { ...@@ -503,6 +534,7 @@ func TestIsNonRetryableRefreshError(t *testing.T) {
{name: "invalid_client", err: errors.New("invalid_client"), expected: true}, {name: "invalid_client", err: errors.New("invalid_client"), expected: true},
{name: "unauthorized_client", err: errors.New("unauthorized_client"), expected: true}, {name: "unauthorized_client", err: errors.New("unauthorized_client"), expected: true},
{name: "access_denied", err: errors.New("access_denied"), expected: true}, {name: "access_denied", err: errors.New("access_denied"), expected: true},
{name: "no_refresh_token", err: errors.New("no refresh token available"), expected: true},
{name: "invalid_grant_with_desc", err: errors.New("Error: invalid_grant - token revoked"), expected: true}, {name: "invalid_grant_with_desc", err: errors.New("Error: invalid_grant - token revoked"), expected: true},
{name: "case_insensitive", err: errors.New("INVALID_GRANT"), expected: true}, {name: "case_insensitive", err: errors.New("INVALID_GRANT"), expected: true},
} }
......
...@@ -21,8 +21,8 @@ func optionalNonEqualStringPtr(value, compare string) *string { ...@@ -21,8 +21,8 @@ func optionalNonEqualStringPtr(value, compare string) *string {
} }
func forwardResultBillingModel(requestedModel, upstreamModel string) string { func forwardResultBillingModel(requestedModel, upstreamModel string) string {
if trimmedUpstream := strings.TrimSpace(upstreamModel); trimmedUpstream != "" { if trimmed := strings.TrimSpace(requestedModel); trimmed != "" {
return trimmedUpstream return trimmed
} }
return strings.TrimSpace(requestedModel) return strings.TrimSpace(upstreamModel)
} }
...@@ -865,10 +865,10 @@ rate_limit: ...@@ -865,10 +865,10 @@ rate_limit:
pricing: pricing:
# URL to fetch model pricing data (default: pinned model-price-repo commit) # URL to fetch model pricing data (default: pinned model-price-repo commit)
# 获取模型定价数据的 URL(默认:固定 commit 的 model-price-repo) # 获取模型定价数据的 URL(默认:固定 commit 的 model-price-repo)
remote_url: "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.json" remote_url: "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/refs/heads/main//model_prices_and_context_window.json"
# Hash verification URL (optional) # Hash verification URL (optional)
# 哈希校验 URL(可选) # 哈希校验 URL(可选)
hash_url: "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.sha256" hash_url: "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/refs/heads/main//model_prices_and_context_window.sha256"
# Local data directory for caching # Local data directory for caching
# 本地数据缓存目录 # 本地数据缓存目录
data_dir: "./data" data_dir: "./data"
......
...@@ -2245,6 +2245,41 @@ ...@@ -2245,6 +2245,41 @@
</p> </p>
</div> </div>
</div> </div>
<!-- Custom Base URL Relay -->
<div class="rounded-lg border border-gray-200 p-4 dark:border-dark-600">
<div class="flex items-center justify-between">
<div>
<label class="input-label mb-0">{{ t('admin.accounts.quotaControl.customBaseUrl.label') }}</label>
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.quotaControl.customBaseUrl.hint') }}
</p>
</div>
<button
type="button"
@click="customBaseUrlEnabled = !customBaseUrlEnabled"
:class="[
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
customBaseUrlEnabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
]"
>
<span
:class="[
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
customBaseUrlEnabled ? 'translate-x-5' : 'translate-x-0'
]"
/>
</button>
</div>
<div v-if="customBaseUrlEnabled" class="mt-3">
<input
v-model="customBaseUrl"
type="text"
class="input"
:placeholder="t('admin.accounts.quotaControl.customBaseUrl.urlHint')"
/>
</div>
</div>
</div> </div>
<div> <div>
...@@ -3095,6 +3130,8 @@ const tlsFingerprintProfiles = ref<{ id: number; name: string }[]>([]) ...@@ -3095,6 +3130,8 @@ const tlsFingerprintProfiles = ref<{ id: number; name: string }[]>([])
const sessionIdMaskingEnabled = ref(false) const sessionIdMaskingEnabled = ref(false)
const cacheTTLOverrideEnabled = ref(false) const cacheTTLOverrideEnabled = ref(false)
const cacheTTLOverrideTarget = ref<string>('5m') const cacheTTLOverrideTarget = ref<string>('5m')
const customBaseUrlEnabled = ref(false)
const customBaseUrl = ref('')
// Gemini tier selection (used as fallback when auto-detection is unavailable/fails) // Gemini tier selection (used as fallback when auto-detection is unavailable/fails)
const geminiTierGoogleOne = ref<'google_one_free' | 'google_ai_pro' | 'google_ai_ultra'>('google_one_free') const geminiTierGoogleOne = ref<'google_one_free' | 'google_ai_pro' | 'google_ai_ultra'>('google_one_free')
...@@ -3765,6 +3802,8 @@ const resetForm = () => { ...@@ -3765,6 +3802,8 @@ const resetForm = () => {
sessionIdMaskingEnabled.value = false sessionIdMaskingEnabled.value = false
cacheTTLOverrideEnabled.value = false cacheTTLOverrideEnabled.value = false
cacheTTLOverrideTarget.value = '5m' cacheTTLOverrideTarget.value = '5m'
customBaseUrlEnabled.value = false
customBaseUrl.value = ''
allowOverages.value = false allowOverages.value = false
antigravityAccountType.value = 'oauth' antigravityAccountType.value = 'oauth'
upstreamBaseUrl.value = '' upstreamBaseUrl.value = ''
...@@ -4856,6 +4895,12 @@ const handleAnthropicExchange = async (authCode: string) => { ...@@ -4856,6 +4895,12 @@ const handleAnthropicExchange = async (authCode: string) => {
extra.cache_ttl_override_target = cacheTTLOverrideTarget.value extra.cache_ttl_override_target = cacheTTLOverrideTarget.value
} }
// Add custom base URL settings
if (customBaseUrlEnabled.value && customBaseUrl.value.trim()) {
extra.custom_base_url_enabled = true
extra.custom_base_url = customBaseUrl.value.trim()
}
const credentials: Record<string, unknown> = { ...tokenInfo } const credentials: Record<string, unknown> = { ...tokenInfo }
applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create') applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create')
await createAccountAndFinish(form.platform, addMethod.value as AccountType, credentials, extra) await createAccountAndFinish(form.platform, addMethod.value as AccountType, credentials, extra)
...@@ -4974,6 +5019,12 @@ const handleCookieAuth = async (sessionKey: string) => { ...@@ -4974,6 +5019,12 @@ const handleCookieAuth = async (sessionKey: string) => {
extra.cache_ttl_override_target = cacheTTLOverrideTarget.value extra.cache_ttl_override_target = cacheTTLOverrideTarget.value
} }
// Add custom base URL settings
if (customBaseUrlEnabled.value && customBaseUrl.value.trim()) {
extra.custom_base_url_enabled = true
extra.custom_base_url = customBaseUrl.value.trim()
}
const accountName = keys.length > 1 ? `${form.name} #${i + 1}` : form.name const accountName = keys.length > 1 ? `${form.name} #${i + 1}` : form.name
const credentials: Record<string, unknown> = { ...tokenInfo } const credentials: Record<string, unknown> = { ...tokenInfo }
......
...@@ -1580,6 +1580,41 @@ ...@@ -1580,6 +1580,41 @@
</p> </p>
</div> </div>
</div> </div>
<!-- Custom Base URL Relay -->
<div class="rounded-lg border border-gray-200 p-4 dark:border-dark-600">
<div class="flex items-center justify-between">
<div>
<label class="input-label mb-0">{{ t('admin.accounts.quotaControl.customBaseUrl.label') }}</label>
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.quotaControl.customBaseUrl.hint') }}
</p>
</div>
<button
type="button"
@click="customBaseUrlEnabled = !customBaseUrlEnabled"
:class="[
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
customBaseUrlEnabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
]"
>
<span
:class="[
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
customBaseUrlEnabled ? 'translate-x-5' : 'translate-x-0'
]"
/>
</button>
</div>
<div v-if="customBaseUrlEnabled" class="mt-3">
<input
v-model="customBaseUrl"
type="text"
class="input"
:placeholder="t('admin.accounts.quotaControl.customBaseUrl.urlHint')"
/>
</div>
</div>
</div> </div>
<div class="border-t border-gray-200 pt-4 dark:border-dark-600"> <div class="border-t border-gray-200 pt-4 dark:border-dark-600">
...@@ -1854,6 +1889,8 @@ const tlsFingerprintProfiles = ref<{ id: number; name: string }[]>([]) ...@@ -1854,6 +1889,8 @@ const tlsFingerprintProfiles = ref<{ id: number; name: string }[]>([])
const sessionIdMaskingEnabled = ref(false) const sessionIdMaskingEnabled = ref(false)
const cacheTTLOverrideEnabled = ref(false) const cacheTTLOverrideEnabled = ref(false)
const cacheTTLOverrideTarget = ref<string>('5m') const cacheTTLOverrideTarget = ref<string>('5m')
const customBaseUrlEnabled = ref(false)
const customBaseUrl = ref('')
// OpenAI 自动透传开关(OAuth/API Key) // OpenAI 自动透传开关(OAuth/API Key)
const openaiPassthroughEnabled = ref(false) const openaiPassthroughEnabled = ref(false)
...@@ -2482,6 +2519,8 @@ function loadQuotaControlSettings(account: Account) { ...@@ -2482,6 +2519,8 @@ function loadQuotaControlSettings(account: Account) {
sessionIdMaskingEnabled.value = false sessionIdMaskingEnabled.value = false
cacheTTLOverrideEnabled.value = false cacheTTLOverrideEnabled.value = false
cacheTTLOverrideTarget.value = '5m' cacheTTLOverrideTarget.value = '5m'
customBaseUrlEnabled.value = false
customBaseUrl.value = ''
// Only applies to Anthropic OAuth/SetupToken accounts // Only applies to Anthropic OAuth/SetupToken accounts
if (account.platform !== 'anthropic' || (account.type !== 'oauth' && account.type !== 'setup-token')) { if (account.platform !== 'anthropic' || (account.type !== 'oauth' && account.type !== 'setup-token')) {
...@@ -2528,6 +2567,12 @@ function loadQuotaControlSettings(account: Account) { ...@@ -2528,6 +2567,12 @@ function loadQuotaControlSettings(account: Account) {
cacheTTLOverrideEnabled.value = true cacheTTLOverrideEnabled.value = true
cacheTTLOverrideTarget.value = account.cache_ttl_override_target || '5m' cacheTTLOverrideTarget.value = account.cache_ttl_override_target || '5m'
} }
// Load custom base URL setting
if (account.custom_base_url_enabled === true) {
customBaseUrlEnabled.value = true
customBaseUrl.value = account.custom_base_url || ''
}
} }
function formatTempUnschedKeywords(value: unknown) { function formatTempUnschedKeywords(value: unknown) {
...@@ -2980,6 +3025,15 @@ const handleSubmit = async () => { ...@@ -2980,6 +3025,15 @@ const handleSubmit = async () => {
delete newExtra.cache_ttl_override_target delete newExtra.cache_ttl_override_target
} }
// Custom base URL relay setting
if (customBaseUrlEnabled.value && customBaseUrl.value.trim()) {
newExtra.custom_base_url_enabled = true
newExtra.custom_base_url = customBaseUrl.value.trim()
} else {
delete newExtra.custom_base_url_enabled
delete newExtra.custom_base_url
}
updatePayload.extra = newExtra updatePayload.extra = newExtra
} }
......
...@@ -64,7 +64,8 @@ const chartColors = computed(() => ({ ...@@ -64,7 +64,8 @@ const chartColors = computed(() => ({
input: '#3b82f6', input: '#3b82f6',
output: '#10b981', output: '#10b981',
cacheCreation: '#f59e0b', cacheCreation: '#f59e0b',
cacheRead: '#06b6d4' cacheRead: '#06b6d4',
cacheHitRate: '#8b5cf6'
})) }))
const chartData = computed(() => { const chartData = computed(() => {
...@@ -104,6 +105,19 @@ const chartData = computed(() => { ...@@ -104,6 +105,19 @@ const chartData = computed(() => {
backgroundColor: `${chartColors.value.cacheRead}20`, backgroundColor: `${chartColors.value.cacheRead}20`,
fill: true, fill: true,
tension: 0.3 tension: 0.3
},
{
label: 'Cache Hit Rate',
data: props.trendData.map((d) => {
const total = d.cache_read_tokens + d.cache_creation_tokens
return total > 0 ? (d.cache_read_tokens / total) * 100 : 0
}),
borderColor: chartColors.value.cacheHitRate,
backgroundColor: `${chartColors.value.cacheHitRate}20`,
borderDash: [5, 5],
fill: false,
tension: 0.3,
yAxisID: 'yPercent'
} }
] ]
} }
...@@ -132,6 +146,9 @@ const lineOptions = computed(() => ({ ...@@ -132,6 +146,9 @@ const lineOptions = computed(() => ({
tooltip: { tooltip: {
callbacks: { callbacks: {
label: (context: any) => { label: (context: any) => {
if (context.dataset.yAxisID === 'yPercent') {
return `${context.dataset.label}: ${context.raw.toFixed(1)}%`
}
return `${context.dataset.label}: ${formatTokens(context.raw)}` return `${context.dataset.label}: ${formatTokens(context.raw)}`
}, },
footer: (tooltipItems: any) => { footer: (tooltipItems: any) => {
...@@ -168,6 +185,21 @@ const lineOptions = computed(() => ({ ...@@ -168,6 +185,21 @@ const lineOptions = computed(() => ({
}, },
callback: (value: string | number) => formatTokens(Number(value)) callback: (value: string | number) => formatTokens(Number(value))
} }
},
yPercent: {
position: 'right' as const,
min: 0,
max: 100,
grid: {
drawOnChartArea: false
},
ticks: {
color: chartColors.value.cacheHitRate,
font: {
size: 10
},
callback: (value: string | number) => `${value}%`
}
} }
} }
})) }))
......
...@@ -2318,6 +2318,11 @@ export default { ...@@ -2318,6 +2318,11 @@ export default {
target: 'Target TTL', target: 'Target TTL',
targetHint: 'Select the TTL tier for billing' targetHint: 'Select the TTL tier for billing'
}, },
customBaseUrl: {
label: 'Custom Relay URL',
hint: 'Forward requests to a custom relay service. Proxy URL will be passed as a query parameter.',
urlHint: 'Relay service URL (e.g., https://relay.example.com)',
},
clientAffinity: { clientAffinity: {
label: 'Client Affinity Scheduling', label: 'Client Affinity Scheduling',
hint: 'When enabled, new sessions prefer accounts previously used by this client to reduce account switching' hint: 'When enabled, new sessions prefer accounts previously used by this client to reduce account switching'
...@@ -4378,6 +4383,7 @@ export default { ...@@ -4378,6 +4383,7 @@ export default {
provider: 'Type', provider: 'Type',
active: 'Active', active: 'Active',
endpoint: 'Endpoint', endpoint: 'Endpoint',
bucket: 'Bucket',
storagePath: 'Storage Path', storagePath: 'Storage Path',
capacityUsage: 'Capacity / Used', capacityUsage: 'Capacity / Used',
capacityUnlimited: 'Unlimited', capacityUnlimited: 'Unlimited',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment