Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
ca8692c7
Commit
ca8692c7
authored
Mar 31, 2026
by
InCerry
Browse files
Merge remote-tracking branch 'upstream/main'
# Conflicts: # backend/internal/service/openai_gateway_messages.go
parents
b6d46fd5
318aa5e0
Changes
43
Hide whitespace changes
Inline
Side-by-side
backend/internal/service/antigravity_internal500_penalty_test.go
0 → 100644
View file @
ca8692c7
//go:build unit
package
service
import
(
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// --- mock: Internal500CounterCache ---
type
mockInternal500Cache
struct
{
incrementCount
int64
incrementErr
error
resetErr
error
incrementCalls
[]
int64
// 记录 IncrementInternal500Count 被调用时的 accountID
resetCalls
[]
int64
// 记录 ResetInternal500Count 被调用时的 accountID
}
func
(
m
*
mockInternal500Cache
)
IncrementInternal500Count
(
_
context
.
Context
,
accountID
int64
)
(
int64
,
error
)
{
m
.
incrementCalls
=
append
(
m
.
incrementCalls
,
accountID
)
return
m
.
incrementCount
,
m
.
incrementErr
}
func
(
m
*
mockInternal500Cache
)
ResetInternal500Count
(
_
context
.
Context
,
accountID
int64
)
error
{
m
.
resetCalls
=
append
(
m
.
resetCalls
,
accountID
)
return
m
.
resetErr
}
// --- mock: 专用于 internal500 惩罚测试的 AccountRepository ---
type
internal500AccountRepoStub
struct
{
AccountRepository
// 嵌入接口,未实现的方法会 panic(不应被调用)
tempUnschedCalls
[]
tempUnschedCall
setErrorCalls
[]
setErrorCall
}
type
tempUnschedCall
struct
{
accountID
int64
until
time
.
Time
reason
string
}
type
setErrorCall
struct
{
accountID
int64
reason
string
}
func
(
r
*
internal500AccountRepoStub
)
SetTempUnschedulable
(
_
context
.
Context
,
id
int64
,
until
time
.
Time
,
reason
string
)
error
{
r
.
tempUnschedCalls
=
append
(
r
.
tempUnschedCalls
,
tempUnschedCall
{
accountID
:
id
,
until
:
until
,
reason
:
reason
})
return
nil
}
func
(
r
*
internal500AccountRepoStub
)
SetError
(
_
context
.
Context
,
id
int64
,
errorMsg
string
)
error
{
r
.
setErrorCalls
=
append
(
r
.
setErrorCalls
,
setErrorCall
{
accountID
:
id
,
reason
:
errorMsg
})
return
nil
}
// =============================================================================
// TestIsAntigravityInternalServerError
// =============================================================================
func
TestIsAntigravityInternalServerError
(
t
*
testing
.
T
)
{
t
.
Run
(
"匹配完整的 INTERNAL 500 body"
,
func
(
t
*
testing
.
T
)
{
body
:=
[]
byte
(
`{"error":{"code":500,"message":"Internal error encountered.","status":"INTERNAL"}}`
)
require
.
True
(
t
,
isAntigravityInternalServerError
(
500
,
body
))
})
t
.
Run
(
"statusCode 不是 500"
,
func
(
t
*
testing
.
T
)
{
body
:=
[]
byte
(
`{"error":{"code":500,"message":"Internal error encountered.","status":"INTERNAL"}}`
)
require
.
False
(
t
,
isAntigravityInternalServerError
(
429
,
body
))
require
.
False
(
t
,
isAntigravityInternalServerError
(
503
,
body
))
require
.
False
(
t
,
isAntigravityInternalServerError
(
200
,
body
))
})
t
.
Run
(
"body 中 message 不匹配"
,
func
(
t
*
testing
.
T
)
{
body
:=
[]
byte
(
`{"error":{"code":500,"message":"Some other error","status":"INTERNAL"}}`
)
require
.
False
(
t
,
isAntigravityInternalServerError
(
500
,
body
))
})
t
.
Run
(
"body 中 status 不匹配"
,
func
(
t
*
testing
.
T
)
{
body
:=
[]
byte
(
`{"error":{"code":500,"message":"Internal error encountered.","status":"UNAVAILABLE"}}`
)
require
.
False
(
t
,
isAntigravityInternalServerError
(
500
,
body
))
})
t
.
Run
(
"body 中 code 不匹配"
,
func
(
t
*
testing
.
T
)
{
body
:=
[]
byte
(
`{"error":{"code":503,"message":"Internal error encountered.","status":"INTERNAL"}}`
)
require
.
False
(
t
,
isAntigravityInternalServerError
(
500
,
body
))
})
t
.
Run
(
"空 body"
,
func
(
t
*
testing
.
T
)
{
require
.
False
(
t
,
isAntigravityInternalServerError
(
500
,
[]
byte
{}))
require
.
False
(
t
,
isAntigravityInternalServerError
(
500
,
nil
))
})
t
.
Run
(
"其他 500 错误格式(纯文本)"
,
func
(
t
*
testing
.
T
)
{
body
:=
[]
byte
(
`Internal Server Error`
)
require
.
False
(
t
,
isAntigravityInternalServerError
(
500
,
body
))
})
t
.
Run
(
"其他 500 错误格式(不同 JSON 结构)"
,
func
(
t
*
testing
.
T
)
{
body
:=
[]
byte
(
`{"message":"Internal Server Error","statusCode":500}`
)
require
.
False
(
t
,
isAntigravityInternalServerError
(
500
,
body
))
})
}
// =============================================================================
// TestApplyInternal500Penalty
// =============================================================================
func
TestApplyInternal500Penalty
(
t
*
testing
.
T
)
{
t
.
Run
(
"count=1 → SetTempUnschedulable 10 分钟"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
internal500AccountRepoStub
{}
svc
:=
&
AntigravityGatewayService
{
accountRepo
:
repo
}
account
:=
&
Account
{
ID
:
1
,
Name
:
"acc-1"
}
before
:=
time
.
Now
()
svc
.
applyInternal500Penalty
(
context
.
Background
(),
"[test]"
,
account
,
1
)
after
:=
time
.
Now
()
require
.
Len
(
t
,
repo
.
tempUnschedCalls
,
1
)
require
.
Empty
(
t
,
repo
.
setErrorCalls
)
call
:=
repo
.
tempUnschedCalls
[
0
]
require
.
Equal
(
t
,
int64
(
1
),
call
.
accountID
)
require
.
Contains
(
t
,
call
.
reason
,
"INTERNAL 500"
)
// until 应在 [before+10m, after+10m] 范围内
require
.
True
(
t
,
call
.
until
.
After
(
before
.
Add
(
internal500PenaltyTier1Duration
)
.
Add
(
-
time
.
Second
)))
require
.
True
(
t
,
call
.
until
.
Before
(
after
.
Add
(
internal500PenaltyTier1Duration
)
.
Add
(
time
.
Second
)))
})
t
.
Run
(
"count=2 → SetTempUnschedulable 10 小时"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
internal500AccountRepoStub
{}
svc
:=
&
AntigravityGatewayService
{
accountRepo
:
repo
}
account
:=
&
Account
{
ID
:
2
,
Name
:
"acc-2"
}
before
:=
time
.
Now
()
svc
.
applyInternal500Penalty
(
context
.
Background
(),
"[test]"
,
account
,
2
)
after
:=
time
.
Now
()
require
.
Len
(
t
,
repo
.
tempUnschedCalls
,
1
)
require
.
Empty
(
t
,
repo
.
setErrorCalls
)
call
:=
repo
.
tempUnschedCalls
[
0
]
require
.
Equal
(
t
,
int64
(
2
),
call
.
accountID
)
require
.
Contains
(
t
,
call
.
reason
,
"INTERNAL 500"
)
require
.
True
(
t
,
call
.
until
.
After
(
before
.
Add
(
internal500PenaltyTier2Duration
)
.
Add
(
-
time
.
Second
)))
require
.
True
(
t
,
call
.
until
.
Before
(
after
.
Add
(
internal500PenaltyTier2Duration
)
.
Add
(
time
.
Second
)))
})
t
.
Run
(
"count=3 → SetError 永久禁用"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
internal500AccountRepoStub
{}
svc
:=
&
AntigravityGatewayService
{
accountRepo
:
repo
}
account
:=
&
Account
{
ID
:
3
,
Name
:
"acc-3"
}
svc
.
applyInternal500Penalty
(
context
.
Background
(),
"[test]"
,
account
,
3
)
require
.
Empty
(
t
,
repo
.
tempUnschedCalls
)
require
.
Len
(
t
,
repo
.
setErrorCalls
,
1
)
call
:=
repo
.
setErrorCalls
[
0
]
require
.
Equal
(
t
,
int64
(
3
),
call
.
accountID
)
require
.
Contains
(
t
,
call
.
reason
,
"INTERNAL 500 consecutive failures: 3"
)
})
t
.
Run
(
"count=5 → SetError 永久禁用(>=3 都走永久禁用)"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
internal500AccountRepoStub
{}
svc
:=
&
AntigravityGatewayService
{
accountRepo
:
repo
}
account
:=
&
Account
{
ID
:
5
,
Name
:
"acc-5"
}
svc
.
applyInternal500Penalty
(
context
.
Background
(),
"[test]"
,
account
,
5
)
require
.
Empty
(
t
,
repo
.
tempUnschedCalls
)
require
.
Len
(
t
,
repo
.
setErrorCalls
,
1
)
call
:=
repo
.
setErrorCalls
[
0
]
require
.
Equal
(
t
,
int64
(
5
),
call
.
accountID
)
require
.
Contains
(
t
,
call
.
reason
,
"INTERNAL 500 consecutive failures: 5"
)
})
t
.
Run
(
"count=0 → 不调用任何方法"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
internal500AccountRepoStub
{}
svc
:=
&
AntigravityGatewayService
{
accountRepo
:
repo
}
account
:=
&
Account
{
ID
:
10
,
Name
:
"acc-10"
}
svc
.
applyInternal500Penalty
(
context
.
Background
(),
"[test]"
,
account
,
0
)
require
.
Empty
(
t
,
repo
.
tempUnschedCalls
)
require
.
Empty
(
t
,
repo
.
setErrorCalls
)
})
}
// =============================================================================
// TestHandleInternal500RetryExhausted
// =============================================================================
func
TestHandleInternal500RetryExhausted
(
t
*
testing
.
T
)
{
t
.
Run
(
"internal500Cache 为 nil → 不 panic,不调用任何方法"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
internal500AccountRepoStub
{}
svc
:=
&
AntigravityGatewayService
{
accountRepo
:
repo
,
internal500Cache
:
nil
,
}
account
:=
&
Account
{
ID
:
1
,
Name
:
"acc-1"
}
// 不应 panic
require
.
NotPanics
(
t
,
func
()
{
svc
.
handleInternal500RetryExhausted
(
context
.
Background
(),
"[test]"
,
account
)
})
require
.
Empty
(
t
,
repo
.
tempUnschedCalls
)
require
.
Empty
(
t
,
repo
.
setErrorCalls
)
})
t
.
Run
(
"IncrementInternal500Count 返回 error → 不调用惩罚方法"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
internal500AccountRepoStub
{}
cache
:=
&
mockInternal500Cache
{
incrementErr
:
errors
.
New
(
"redis connection error"
),
}
svc
:=
&
AntigravityGatewayService
{
accountRepo
:
repo
,
internal500Cache
:
cache
,
}
account
:=
&
Account
{
ID
:
2
,
Name
:
"acc-2"
}
svc
.
handleInternal500RetryExhausted
(
context
.
Background
(),
"[test]"
,
account
)
require
.
Len
(
t
,
cache
.
incrementCalls
,
1
)
require
.
Equal
(
t
,
int64
(
2
),
cache
.
incrementCalls
[
0
])
require
.
Empty
(
t
,
repo
.
tempUnschedCalls
)
require
.
Empty
(
t
,
repo
.
setErrorCalls
)
})
t
.
Run
(
"IncrementInternal500Count 返回 count=1 → 触发 tier1 惩罚"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
internal500AccountRepoStub
{}
cache
:=
&
mockInternal500Cache
{
incrementCount
:
1
,
}
svc
:=
&
AntigravityGatewayService
{
accountRepo
:
repo
,
internal500Cache
:
cache
,
}
account
:=
&
Account
{
ID
:
3
,
Name
:
"acc-3"
}
svc
.
handleInternal500RetryExhausted
(
context
.
Background
(),
"[test]"
,
account
)
require
.
Len
(
t
,
cache
.
incrementCalls
,
1
)
require
.
Equal
(
t
,
int64
(
3
),
cache
.
incrementCalls
[
0
])
// tier1: SetTempUnschedulable
require
.
Len
(
t
,
repo
.
tempUnschedCalls
,
1
)
require
.
Equal
(
t
,
int64
(
3
),
repo
.
tempUnschedCalls
[
0
]
.
accountID
)
require
.
Empty
(
t
,
repo
.
setErrorCalls
)
})
t
.
Run
(
"IncrementInternal500Count 返回 count=3 → 触发 tier3 永久禁用"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
internal500AccountRepoStub
{}
cache
:=
&
mockInternal500Cache
{
incrementCount
:
3
,
}
svc
:=
&
AntigravityGatewayService
{
accountRepo
:
repo
,
internal500Cache
:
cache
,
}
account
:=
&
Account
{
ID
:
4
,
Name
:
"acc-4"
}
svc
.
handleInternal500RetryExhausted
(
context
.
Background
(),
"[test]"
,
account
)
require
.
Len
(
t
,
cache
.
incrementCalls
,
1
)
require
.
Empty
(
t
,
repo
.
tempUnschedCalls
)
require
.
Len
(
t
,
repo
.
setErrorCalls
,
1
)
require
.
Equal
(
t
,
int64
(
4
),
repo
.
setErrorCalls
[
0
]
.
accountID
)
})
}
// =============================================================================
// TestResetInternal500Counter
// =============================================================================
func
TestResetInternal500Counter
(
t
*
testing
.
T
)
{
t
.
Run
(
"internal500Cache 为 nil → 不 panic"
,
func
(
t
*
testing
.
T
)
{
svc
:=
&
AntigravityGatewayService
{
internal500Cache
:
nil
,
}
require
.
NotPanics
(
t
,
func
()
{
svc
.
resetInternal500Counter
(
context
.
Background
(),
"[test]"
,
1
)
})
})
t
.
Run
(
"ResetInternal500Count 返回 error → 不 panic(仅日志)"
,
func
(
t
*
testing
.
T
)
{
cache
:=
&
mockInternal500Cache
{
resetErr
:
errors
.
New
(
"redis timeout"
),
}
svc
:=
&
AntigravityGatewayService
{
internal500Cache
:
cache
,
}
require
.
NotPanics
(
t
,
func
()
{
svc
.
resetInternal500Counter
(
context
.
Background
(),
"[test]"
,
42
)
})
require
.
Len
(
t
,
cache
.
resetCalls
,
1
)
require
.
Equal
(
t
,
int64
(
42
),
cache
.
resetCalls
[
0
])
})
t
.
Run
(
"正常调用 → 调用 ResetInternal500Count"
,
func
(
t
*
testing
.
T
)
{
cache
:=
&
mockInternal500Cache
{}
svc
:=
&
AntigravityGatewayService
{
internal500Cache
:
cache
,
}
svc
.
resetInternal500Counter
(
context
.
Background
(),
"[test]"
,
99
)
require
.
Len
(
t
,
cache
.
resetCalls
,
1
)
require
.
Equal
(
t
,
int64
(
99
),
cache
.
resetCalls
[
0
])
})
}
backend/internal/service/gateway_service.go
View file @
ca8692c7
...
...
@@ -12,6 +12,7 @@ import (
"log/slog"
mathrand
"math/rand"
"net/http"
"net/url"
"os"
"path/filepath"
"regexp"
...
...
@@ -368,6 +369,8 @@ var allowedHeaders = map[string]bool{
"user-agent"
:
true
,
"content-type"
:
true
,
"accept-encoding"
:
true
,
"x-claude-code-session-id"
:
true
,
"x-client-request-id"
:
true
,
}
// GatewayCache 定义网关服务的缓存操作接口。
...
...
@@ -4150,10 +4153,12 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return
nil
,
err
}
// 获取代理URL
// 获取代理URL
(自定义 base URL 模式下,proxy 通过 buildCustomRelayURL 作为查询参数传递)
proxyURL
:=
""
if
account
.
ProxyID
!=
nil
&&
account
.
Proxy
!=
nil
{
proxyURL
=
account
.
Proxy
.
URL
()
if
!
account
.
IsCustomBaseURLEnabled
()
||
account
.
GetCustomBaseURL
()
==
""
{
proxyURL
=
account
.
Proxy
.
URL
()
}
}
// 解析 TLS 指纹 profile(同一请求生命周期内不变,避免重试循环中重复解析)
...
...
@@ -5628,6 +5633,16 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
targetURL
=
validatedURL
+
"/v1/messages?beta=true"
}
}
else
if
account
.
IsCustomBaseURLEnabled
()
{
customURL
:=
account
.
GetCustomBaseURL
()
if
customURL
==
""
{
return
nil
,
fmt
.
Errorf
(
"custom_base_url is enabled but not configured for account %d"
,
account
.
ID
)
}
validatedURL
,
err
:=
s
.
validateUpstreamBaseURL
(
customURL
)
if
err
!=
nil
{
return
nil
,
err
}
targetURL
=
s
.
buildCustomRelayURL
(
validatedURL
,
"/v1/messages"
,
account
)
}
clientHeaders
:=
http
.
Header
{}
...
...
@@ -5743,6 +5758,15 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
}
// 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖
if
sessionHeader
:=
getHeaderRaw
(
req
.
Header
,
"X-Claude-Code-Session-Id"
);
sessionHeader
!=
""
{
if
uid
:=
gjson
.
GetBytes
(
body
,
"metadata.user_id"
)
.
String
();
uid
!=
""
{
if
parsed
:=
ParseMetadataUserID
(
uid
);
parsed
!=
nil
{
setHeaderRaw
(
req
.
Header
,
"X-Claude-Code-Session-Id"
,
parsed
.
SessionID
)
}
}
}
// === DEBUG: 打印上游转发请求(headers + body 摘要),与 CLIENT_ORIGINAL 对比 ===
s
.
debugLogGatewaySnapshot
(
"UPSTREAM_FORWARD"
,
req
.
Header
,
body
,
map
[
string
]
string
{
"url"
:
req
.
URL
.
String
(),
...
...
@@ -8063,10 +8087,12 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
return
err
}
// 获取代理URL
// 获取代理URL
(自定义 base URL 模式下,proxy 通过 buildCustomRelayURL 作为查询参数传递)
proxyURL
:=
""
if
account
.
ProxyID
!=
nil
&&
account
.
Proxy
!=
nil
{
proxyURL
=
account
.
Proxy
.
URL
()
if
!
account
.
IsCustomBaseURLEnabled
()
||
account
.
GetCustomBaseURL
()
==
""
{
proxyURL
=
account
.
Proxy
.
URL
()
}
}
// 发送请求
...
...
@@ -8345,6 +8371,16 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
targetURL
=
validatedURL
+
"/v1/messages/count_tokens?beta=true"
}
}
else
if
account
.
IsCustomBaseURLEnabled
()
{
customURL
:=
account
.
GetCustomBaseURL
()
if
customURL
==
""
{
return
nil
,
fmt
.
Errorf
(
"custom_base_url is enabled but not configured for account %d"
,
account
.
ID
)
}
validatedURL
,
err
:=
s
.
validateUpstreamBaseURL
(
customURL
)
if
err
!=
nil
{
return
nil
,
err
}
targetURL
=
s
.
buildCustomRelayURL
(
validatedURL
,
"/v1/messages/count_tokens"
,
account
)
}
clientHeaders
:=
http
.
Header
{}
...
...
@@ -8450,6 +8486,15 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
}
// 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖
if
sessionHeader
:=
getHeaderRaw
(
req
.
Header
,
"X-Claude-Code-Session-Id"
);
sessionHeader
!=
""
{
if
uid
:=
gjson
.
GetBytes
(
body
,
"metadata.user_id"
)
.
String
();
uid
!=
""
{
if
parsed
:=
ParseMetadataUserID
(
uid
);
parsed
!=
nil
{
setHeaderRaw
(
req
.
Header
,
"X-Claude-Code-Session-Id"
,
parsed
.
SessionID
)
}
}
}
if
c
!=
nil
&&
tokenType
==
"oauth"
{
c
.
Set
(
claudeMimicDebugInfoKey
,
buildClaudeMimicDebugLine
(
req
,
body
,
account
,
tokenType
,
mimicClaudeCode
))
}
...
...
@@ -8471,6 +8516,19 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m
})
}
// buildCustomRelayURL 构建自定义中继转发 URL
// 在 path 后附加 beta=true 和可选的 proxy 查询参数
func
(
s
*
GatewayService
)
buildCustomRelayURL
(
baseURL
,
path
string
,
account
*
Account
)
string
{
u
:=
strings
.
TrimRight
(
baseURL
,
"/"
)
+
path
+
"?beta=true"
if
account
.
ProxyID
!=
nil
&&
account
.
Proxy
!=
nil
{
proxyURL
:=
account
.
Proxy
.
URL
()
if
proxyURL
!=
""
{
u
+=
"&proxy="
+
url
.
QueryEscape
(
proxyURL
)
}
}
return
u
}
func
(
s
*
GatewayService
)
validateUpstreamBaseURL
(
raw
string
)
(
string
,
error
)
{
if
s
.
cfg
!=
nil
&&
!
s
.
cfg
.
Security
.
URLAllowlist
.
Enabled
{
normalized
,
err
:=
urlvalidator
.
ValidateURLFormat
(
raw
,
s
.
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
)
...
...
backend/internal/service/header_util.go
View file @
ca8692c7
...
...
@@ -36,6 +36,11 @@ var headerWireCasing = map[string]string{
"sec-fetch-mode"
:
"sec-fetch-mode"
,
"accept-encoding"
:
"accept-encoding"
,
"authorization"
:
"authorization"
,
// Claude Code 2.1.87+ 新增 header
"x-claude-code-session-id"
:
"X-Claude-Code-Session-Id"
,
"x-client-request-id"
:
"x-client-request-id"
,
"content-length"
:
"content-length"
,
}
// headerWireOrder 定义真实 Claude CLI 发送 header 的顺序(基于抓包)。
...
...
@@ -55,11 +60,14 @@ var headerWireOrder = []string{
"authorization"
,
"x-app"
,
"User-Agent"
,
"X-Claude-Code-Session-Id"
,
"content-type"
,
"anthropic-beta"
,
"x-client-request-id"
,
"accept-language"
,
"sec-fetch-mode"
,
"accept-encoding"
,
"content-length"
,
"x-stainless-helper-method"
,
}
...
...
backend/internal/service/internal500_counter.go
0 → 100644
View file @
ca8692c7
package
service
import
"context"
// Internal500CounterCache 追踪 Antigravity 账号连续 INTERNAL 500 失败轮数
type
Internal500CounterCache
interface
{
// IncrementInternal500Count 原子递增计数并返回当前值
IncrementInternal500Count
(
ctx
context
.
Context
,
accountID
int64
)
(
int64
,
error
)
// ResetInternal500Count 清零计数器(成功响应时调用)
ResetInternal500Count
(
ctx
context
.
Context
,
accountID
int64
)
error
}
backend/internal/service/openai_compat_model.go
0 → 100644
View file @
ca8692c7
package
service
import
(
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
)
func
NormalizeOpenAICompatRequestedModel
(
model
string
)
string
{
trimmed
:=
strings
.
TrimSpace
(
model
)
if
trimmed
==
""
{
return
""
}
normalized
,
_
,
ok
:=
splitOpenAICompatReasoningModel
(
trimmed
)
if
!
ok
||
normalized
==
""
{
return
trimmed
}
return
normalized
}
func
applyOpenAICompatModelNormalization
(
req
*
apicompat
.
AnthropicRequest
)
{
if
req
==
nil
{
return
}
originalModel
:=
strings
.
TrimSpace
(
req
.
Model
)
if
originalModel
==
""
{
return
}
normalizedModel
,
derivedEffort
,
hasReasoningSuffix
:=
splitOpenAICompatReasoningModel
(
originalModel
)
if
hasReasoningSuffix
&&
normalizedModel
!=
""
{
req
.
Model
=
normalizedModel
}
if
req
.
OutputConfig
!=
nil
&&
strings
.
TrimSpace
(
req
.
OutputConfig
.
Effort
)
!=
""
{
return
}
claudeEffort
:=
openAIReasoningEffortToClaudeOutputEffort
(
derivedEffort
)
if
claudeEffort
==
""
{
return
}
if
req
.
OutputConfig
==
nil
{
req
.
OutputConfig
=
&
apicompat
.
AnthropicOutputConfig
{}
}
req
.
OutputConfig
.
Effort
=
claudeEffort
}
func
splitOpenAICompatReasoningModel
(
model
string
)
(
normalizedModel
string
,
reasoningEffort
string
,
ok
bool
)
{
trimmed
:=
strings
.
TrimSpace
(
model
)
if
trimmed
==
""
{
return
""
,
""
,
false
}
modelID
:=
trimmed
if
strings
.
Contains
(
modelID
,
"/"
)
{
parts
:=
strings
.
Split
(
modelID
,
"/"
)
modelID
=
parts
[
len
(
parts
)
-
1
]
}
modelID
=
strings
.
TrimSpace
(
modelID
)
if
!
strings
.
HasPrefix
(
strings
.
ToLower
(
modelID
),
"gpt-"
)
{
return
trimmed
,
""
,
false
}
parts
:=
strings
.
FieldsFunc
(
strings
.
ToLower
(
modelID
),
func
(
r
rune
)
bool
{
switch
r
{
case
'-'
,
'_'
,
' '
:
return
true
default
:
return
false
}
})
if
len
(
parts
)
==
0
{
return
trimmed
,
""
,
false
}
last
:=
strings
.
NewReplacer
(
"-"
,
""
,
"_"
,
""
,
" "
,
""
)
.
Replace
(
parts
[
len
(
parts
)
-
1
])
switch
last
{
case
"none"
,
"minimal"
:
case
"low"
,
"medium"
,
"high"
:
reasoningEffort
=
last
case
"xhigh"
,
"extrahigh"
:
reasoningEffort
=
"xhigh"
default
:
return
trimmed
,
""
,
false
}
return
normalizeCodexModel
(
modelID
),
reasoningEffort
,
true
}
func
openAIReasoningEffortToClaudeOutputEffort
(
effort
string
)
string
{
switch
strings
.
TrimSpace
(
effort
)
{
case
"low"
,
"medium"
,
"high"
:
return
effort
case
"xhigh"
:
return
"max"
default
:
return
""
}
}
backend/internal/service/openai_compat_model_test.go
0 → 100644
View file @
ca8692c7
package
service
import
(
"bytes"
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func
TestNormalizeOpenAICompatRequestedModel
(
t
*
testing
.
T
)
{
t
.
Parallel
()
tests
:=
[]
struct
{
name
string
input
string
want
string
}{
{
name
:
"gpt reasoning alias strips xhigh"
,
input
:
"gpt-5.4-xhigh"
,
want
:
"gpt-5.4"
},
{
name
:
"gpt reasoning alias strips none"
,
input
:
"gpt-5.4-none"
,
want
:
"gpt-5.4"
},
{
name
:
"codex max model stays intact"
,
input
:
"gpt-5.1-codex-max"
,
want
:
"gpt-5.1-codex-max"
},
{
name
:
"non openai model unchanged"
,
input
:
"claude-opus-4-6"
,
want
:
"claude-opus-4-6"
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
tt
.
want
,
NormalizeOpenAICompatRequestedModel
(
tt
.
input
))
})
}
}
func
TestApplyOpenAICompatModelNormalization
(
t
*
testing
.
T
)
{
t
.
Parallel
()
t
.
Run
(
"derives xhigh from model suffix when output config missing"
,
func
(
t
*
testing
.
T
)
{
req
:=
&
apicompat
.
AnthropicRequest
{
Model
:
"gpt-5.4-xhigh"
}
applyOpenAICompatModelNormalization
(
req
)
require
.
Equal
(
t
,
"gpt-5.4"
,
req
.
Model
)
require
.
NotNil
(
t
,
req
.
OutputConfig
)
require
.
Equal
(
t
,
"max"
,
req
.
OutputConfig
.
Effort
)
})
t
.
Run
(
"explicit output config wins over model suffix"
,
func
(
t
*
testing
.
T
)
{
req
:=
&
apicompat
.
AnthropicRequest
{
Model
:
"gpt-5.4-xhigh"
,
OutputConfig
:
&
apicompat
.
AnthropicOutputConfig
{
Effort
:
"low"
},
}
applyOpenAICompatModelNormalization
(
req
)
require
.
Equal
(
t
,
"gpt-5.4"
,
req
.
Model
)
require
.
NotNil
(
t
,
req
.
OutputConfig
)
require
.
Equal
(
t
,
"low"
,
req
.
OutputConfig
.
Effort
)
})
t
.
Run
(
"non openai model is untouched"
,
func
(
t
*
testing
.
T
)
{
req
:=
&
apicompat
.
AnthropicRequest
{
Model
:
"claude-opus-4-6"
}
applyOpenAICompatModelNormalization
(
req
)
require
.
Equal
(
t
,
"claude-opus-4-6"
,
req
.
Model
)
require
.
Nil
(
t
,
req
.
OutputConfig
)
})
}
func
TestForwardAsAnthropic_NormalizesRoutingAndEffortForGpt54XHigh
(
t
*
testing
.
T
)
{
t
.
Parallel
()
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
body
:=
[]
byte
(
`{"model":"gpt-5.4-xhigh","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":false}`
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/v1/messages"
,
bytes
.
NewReader
(
body
))
c
.
Request
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
upstreamBody
:=
strings
.
Join
([]
string
{
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`
,
""
,
"data: [DONE]"
,
""
,
},
"
\n
"
)
upstream
:=
&
httpUpstreamRecorder
{
resp
:
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"text/event-stream"
},
"x-request-id"
:
[]
string
{
"rid_compat"
}},
Body
:
io
.
NopCloser
(
strings
.
NewReader
(
upstreamBody
)),
}}
svc
:=
&
OpenAIGatewayService
{
httpUpstream
:
upstream
}
account
:=
&
Account
{
ID
:
1
,
Name
:
"openai-oauth"
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Concurrency
:
1
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"oauth-token"
,
"chatgpt_account_id"
:
"chatgpt-acc"
,
"model_mapping"
:
map
[
string
]
any
{
"gpt-5.4"
:
"gpt-5.4"
,
},
},
}
result
,
err
:=
svc
.
ForwardAsAnthropic
(
context
.
Background
(),
c
,
account
,
body
,
""
,
"gpt-5.1"
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
"gpt-5.4-xhigh"
,
result
.
Model
)
require
.
Equal
(
t
,
"gpt-5.4"
,
result
.
UpstreamModel
)
require
.
Equal
(
t
,
"gpt-5.4"
,
result
.
BillingModel
)
require
.
NotNil
(
t
,
result
.
ReasoningEffort
)
require
.
Equal
(
t
,
"xhigh"
,
*
result
.
ReasoningEffort
)
require
.
Equal
(
t
,
"gpt-5.4"
,
gjson
.
GetBytes
(
upstream
.
lastBody
,
"model"
)
.
String
())
require
.
Equal
(
t
,
"xhigh"
,
gjson
.
GetBytes
(
upstream
.
lastBody
,
"reasoning.effort"
)
.
String
())
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
require
.
Equal
(
t
,
"gpt-5.4-xhigh"
,
gjson
.
GetBytes
(
rec
.
Body
.
Bytes
(),
"model"
)
.
String
())
require
.
Equal
(
t
,
"ok"
,
gjson
.
GetBytes
(
rec
.
Body
.
Bytes
(),
"content.0.text"
)
.
String
())
t
.
Logf
(
"upstream body: %s"
,
string
(
upstream
.
lastBody
))
t
.
Logf
(
"response body: %s"
,
rec
.
Body
.
String
())
}
backend/internal/service/openai_gateway_messages.go
View file @
ca8692c7
...
...
@@ -40,6 +40,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
return
nil
,
fmt
.
Errorf
(
"parse anthropic request: %w"
,
err
)
}
originalModel
:=
anthropicReq
.
Model
applyOpenAICompatModelNormalization
(
&
anthropicReq
)
clientStream
:=
anthropicReq
.
Stream
// client's original stream preference
// 2. Convert Anthropic → Responses
...
...
backend/internal/service/openai_gateway_record_usage_test.go
View file @
ca8692c7
...
...
@@ -895,14 +895,16 @@ func TestOpenAIGatewayServiceRecordUsage_UsesRequestedModelAndUpstreamModelMetad
require
.
Equal
(
t
,
1
,
userRepo
.
deductCalls
)
}
func
TestOpenAIGatewayServiceRecordUsage_BillsMappedRequestsUsing
UpstreamModelFallback
(
t
*
testing
.
T
)
{
func
TestOpenAIGatewayServiceRecordUsage_BillsMappedRequestsUsing
RequestedModel
(
t
*
testing
.
T
)
{
usageRepo
:=
&
openAIRecordUsageLogRepoStub
{
inserted
:
true
}
userRepo
:=
&
openAIRecordUsageUserRepoStub
{}
subRepo
:=
&
openAIRecordUsageSubRepoStub
{}
svc
:=
newOpenAIRecordUsageServiceForTest
(
usageRepo
,
userRepo
,
subRepo
,
nil
)
usage
:=
OpenAIUsage
{
InputTokens
:
20
,
OutputTokens
:
10
}
expectedCost
,
err
:=
svc
.
billingService
.
CalculateCost
(
"gpt-5.1-codex"
,
UsageTokens
{
// Billing should use the requested model ("gpt-5.1"), not the upstream mapped model ("gpt-5.1-codex").
// This ensures pricing is always based on the model the user requested.
expectedCost
,
err
:=
svc
.
billingService
.
CalculateCost
(
"gpt-5.1"
,
UsageTokens
{
InputTokens
:
20
,
OutputTokens
:
10
,
},
1.1
)
...
...
backend/internal/service/openai_gateway_service.go
View file @
ca8692c7
...
...
@@ -4153,9 +4153,6 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
}
billingModel
:=
forwardResultBillingModel
(
result
.
Model
,
result
.
UpstreamModel
)
if
result
.
BillingModel
!=
""
{
billingModel
=
strings
.
TrimSpace
(
result
.
BillingModel
)
}
serviceTier
:=
""
if
result
.
ServiceTier
!=
nil
{
serviceTier
=
strings
.
TrimSpace
(
*
result
.
ServiceTier
)
...
...
backend/internal/service/openai_oauth_service.go
View file @
ca8692c7
...
...
@@ -502,6 +502,25 @@ func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *A
refreshToken
:=
account
.
GetCredential
(
"refresh_token"
)
if
refreshToken
==
""
{
accessToken
:=
account
.
GetCredential
(
"access_token"
)
if
accessToken
!=
""
{
tokenInfo
:=
&
OpenAITokenInfo
{
AccessToken
:
accessToken
,
RefreshToken
:
""
,
IDToken
:
account
.
GetCredential
(
"id_token"
),
ClientID
:
account
.
GetCredential
(
"client_id"
),
Email
:
account
.
GetCredential
(
"email"
),
ChatGPTAccountID
:
account
.
GetCredential
(
"chatgpt_account_id"
),
ChatGPTUserID
:
account
.
GetCredential
(
"chatgpt_user_id"
),
OrganizationID
:
account
.
GetCredential
(
"organization_id"
),
PlanType
:
account
.
GetCredential
(
"plan_type"
),
}
if
expiresAt
:=
account
.
GetCredentialAsTime
(
"expires_at"
);
expiresAt
!=
nil
{
tokenInfo
.
ExpiresAt
=
expiresAt
.
Unix
()
tokenInfo
.
ExpiresIn
=
int64
(
time
.
Until
(
*
expiresAt
)
.
Seconds
())
}
return
tokenInfo
,
nil
}
return
nil
,
infraerrors
.
New
(
http
.
StatusBadRequest
,
"OPENAI_OAUTH_NO_REFRESH_TOKEN"
,
"no refresh token available"
)
}
...
...
backend/internal/service/openai_oauth_service_refresh_test.go
0 → 100644
View file @
ca8692c7
package
service
import
(
"context"
"errors"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/stretchr/testify/require"
)
type
openaiOAuthClientRefreshStub
struct
{
refreshCalls
int32
}
func
(
s
*
openaiOAuthClientRefreshStub
)
ExchangeCode
(
ctx
context
.
Context
,
code
,
codeVerifier
,
redirectURI
,
proxyURL
,
clientID
string
)
(
*
openai
.
TokenResponse
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
s
*
openaiOAuthClientRefreshStub
)
RefreshToken
(
ctx
context
.
Context
,
refreshToken
,
proxyURL
string
)
(
*
openai
.
TokenResponse
,
error
)
{
atomic
.
AddInt32
(
&
s
.
refreshCalls
,
1
)
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
s
*
openaiOAuthClientRefreshStub
)
RefreshTokenWithClientID
(
ctx
context
.
Context
,
refreshToken
,
proxyURL
string
,
clientID
string
)
(
*
openai
.
TokenResponse
,
error
)
{
atomic
.
AddInt32
(
&
s
.
refreshCalls
,
1
)
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
TestOpenAIOAuthService_RefreshAccountToken_NoRefreshTokenUsesExistingAccessToken
(
t
*
testing
.
T
)
{
client
:=
&
openaiOAuthClientRefreshStub
{}
svc
:=
NewOpenAIOAuthService
(
nil
,
client
)
expiresAt
:=
time
.
Now
()
.
Add
(
30
*
time
.
Minute
)
.
UTC
()
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
77
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"existing-access-token"
,
"expires_at"
:
expiresAt
,
"client_id"
:
"client-id-1"
,
},
}
info
,
err
:=
svc
.
RefreshAccountToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
info
)
require
.
Equal
(
t
,
"existing-access-token"
,
info
.
AccessToken
)
require
.
Equal
(
t
,
"client-id-1"
,
info
.
ClientID
)
require
.
Zero
(
t
,
atomic
.
LoadInt32
(
&
client
.
refreshCalls
),
"existing access token should be reused without calling refresh"
)
}
backend/internal/service/pricing_service.go
View file @
ca8692c7
...
...
@@ -189,10 +189,38 @@ func (s *PricingService) checkAndUpdatePricing() error {
return
s
.
downloadPricingData
()
}
// 检查文件是否过期
// 先加载本地文件(确保服务可用),再检查是否需要更新
if
err
:=
s
.
loadPricingData
(
pricingFile
);
err
!=
nil
{
logger
.
LegacyPrintf
(
"service.pricing"
,
"[Pricing] Failed to load local file, downloading: %v"
,
err
)
return
s
.
downloadPricingData
()
}
// 如果配置了哈希URL,通过远程哈希检查是否有更新
if
s
.
cfg
.
Pricing
.
HashURL
!=
""
{
remoteHash
,
err
:=
s
.
fetchRemoteHash
()
if
err
!=
nil
{
logger
.
LegacyPrintf
(
"service.pricing"
,
"[Pricing] Failed to fetch remote hash on startup: %v"
,
err
)
return
nil
// 已加载本地文件,哈希获取失败不影响启动
}
s
.
mu
.
RLock
()
localHash
:=
s
.
localHash
s
.
mu
.
RUnlock
()
if
localHash
==
""
||
remoteHash
!=
localHash
{
logger
.
LegacyPrintf
(
"service.pricing"
,
"[Pricing] Remote hash differs on startup (local=%s remote=%s), downloading..."
,
localHash
[
:
min
(
8
,
len
(
localHash
))],
remoteHash
[
:
min
(
8
,
len
(
remoteHash
))])
if
err
:=
s
.
downloadPricingData
();
err
!=
nil
{
logger
.
LegacyPrintf
(
"service.pricing"
,
"[Pricing] Download failed, using existing file: %v"
,
err
)
}
}
return
nil
}
// 没有哈希URL时,基于文件年龄检查
info
,
err
:=
os
.
Stat
(
pricingFile
)
if
err
!=
nil
{
return
s
.
downloadPricingData
()
return
nil
// 已加载本地文件
}
fileAge
:=
time
.
Since
(
info
.
ModTime
())
...
...
@@ -205,21 +233,11 @@ func (s *PricingService) checkAndUpdatePricing() error {
}
}
// 加载本地文件
return
s
.
loadPricingData
(
pricingFile
)
return
nil
}
// syncWithRemote 与远程同步(基于哈希校验)
func
(
s
*
PricingService
)
syncWithRemote
()
error
{
pricingFile
:=
s
.
getPricingFilePath
()
// 计算本地文件哈希
localHash
,
err
:=
s
.
computeFileHash
(
pricingFile
)
if
err
!=
nil
{
logger
.
LegacyPrintf
(
"service.pricing"
,
"[Pricing] Failed to compute local hash: %v"
,
err
)
return
s
.
downloadPricingData
()
}
// 如果配置了哈希URL,从远程获取哈希进行比对
if
s
.
cfg
.
Pricing
.
HashURL
!=
""
{
remoteHash
,
err
:=
s
.
fetchRemoteHash
()
...
...
@@ -228,8 +246,13 @@ func (s *PricingService) syncWithRemote() error {
return
nil
// 哈希获取失败不影响正常使用
}
if
remoteHash
!=
localHash
{
logger
.
LegacyPrintf
(
"service.pricing"
,
"%s"
,
"[Pricing] Remote hash differs, downloading new version..."
)
s
.
mu
.
RLock
()
localHash
:=
s
.
localHash
s
.
mu
.
RUnlock
()
if
localHash
==
""
||
remoteHash
!=
localHash
{
logger
.
LegacyPrintf
(
"service.pricing"
,
"[Pricing] Remote hash differs (local=%s remote=%s), downloading new version..."
,
localHash
[
:
min
(
8
,
len
(
localHash
))],
remoteHash
[
:
min
(
8
,
len
(
remoteHash
))])
return
s
.
downloadPricingData
()
}
logger
.
LegacyPrintf
(
"service.pricing"
,
"%s"
,
"[Pricing] Hash check passed, no update needed"
)
...
...
@@ -237,6 +260,7 @@ func (s *PricingService) syncWithRemote() error {
}
// 没有哈希URL时,基于时间检查
pricingFile
:=
s
.
getPricingFilePath
()
info
,
err
:=
os
.
Stat
(
pricingFile
)
if
err
!=
nil
{
return
s
.
downloadPricingData
()
...
...
@@ -264,11 +288,12 @@ func (s *PricingService) downloadPricingData() error {
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
30
*
time
.
Second
)
defer
cancel
()
var
expectedHash
string
// 获取远程哈希(用于同步锚点,不作为完整性校验)
var
remoteHash
string
if
strings
.
TrimSpace
(
s
.
cfg
.
Pricing
.
HashURL
)
!=
""
{
expec
te
d
Hash
,
err
=
s
.
fetchRemoteHash
()
remo
teHash
,
err
=
s
.
fetchRemoteHash
()
if
err
!=
nil
{
return
fmt
.
Errorf
(
"
fetch remote hash: %
w
"
,
err
)
logger
.
LegacyPrintf
(
"service.pricing"
,
"[Pricing] Failed to
fetch remote hash
(continuing)
: %
v
"
,
err
)
}
}
...
...
@@ -277,11 +302,13 @@ func (s *PricingService) downloadPricingData() error {
return
fmt
.
Errorf
(
"download failed: %w"
,
err
)
}
if
expectedHash
!=
""
{
actualHash
:=
sha256
.
Sum256
(
body
)
if
!
strings
.
EqualFold
(
expectedHash
,
hex
.
EncodeToString
(
actualHash
[
:
]))
{
return
fmt
.
Errorf
(
"pricing hash mismatch"
)
}
// 哈希校验:不匹配时仅告警,不阻止更新
// 远程哈希文件可能与数据文件不同步(如维护者更新了数据但未更新哈希文件)
dataHash
:=
sha256
.
Sum256
(
body
)
dataHashStr
:=
hex
.
EncodeToString
(
dataHash
[
:
])
if
remoteHash
!=
""
&&
!
strings
.
EqualFold
(
remoteHash
,
dataHashStr
)
{
logger
.
LegacyPrintf
(
"service.pricing"
,
"[Pricing] Hash mismatch warning: remote=%s data=%s (hash file may be out of sync)"
,
remoteHash
[
:
min
(
8
,
len
(
remoteHash
))],
dataHashStr
[
:
8
])
}
// 解析JSON数据(使用灵活的解析方式)
...
...
@@ -296,11 +323,14 @@ func (s *PricingService) downloadPricingData() error {
logger
.
LegacyPrintf
(
"service.pricing"
,
"[Pricing] Failed to save file: %v"
,
err
)
}
// 保存哈希
hash
:=
sha256
.
Sum256
(
body
)
hashStr
:=
hex
.
EncodeToString
(
hash
[
:
])
// 使用远程哈希作为同步锚点,防止重复下载
// 当远程哈希不可用时,回退到数据本身的哈希
syncHash
:=
dataHashStr
if
remoteHash
!=
""
{
syncHash
=
remoteHash
}
hashFile
:=
s
.
getHashFilePath
()
if
err
:=
os
.
WriteFile
(
hashFile
,
[]
byte
(
hashStr
+
"
\n
"
),
0644
);
err
!=
nil
{
if
err
:=
os
.
WriteFile
(
hashFile
,
[]
byte
(
syncHash
+
"
\n
"
),
0644
);
err
!=
nil
{
logger
.
LegacyPrintf
(
"service.pricing"
,
"[Pricing] Failed to save hash: %v"
,
err
)
}
...
...
@@ -308,7 +338,7 @@ func (s *PricingService) downloadPricingData() error {
s
.
mu
.
Lock
()
s
.
pricingData
=
data
s
.
lastUpdated
=
time
.
Now
()
s
.
localHash
=
hashStr
s
.
localHash
=
syncHash
s
.
mu
.
Unlock
()
logger
.
LegacyPrintf
(
"service.pricing"
,
"[Pricing] Downloaded %d models successfully"
,
len
(
data
))
...
...
@@ -486,16 +516,6 @@ func (s *PricingService) validatePricingURL(raw string) (string, error) {
return
normalized
,
nil
}
// computeFileHash 计算文件哈希
func
(
s
*
PricingService
)
computeFileHash
(
filePath
string
)
(
string
,
error
)
{
data
,
err
:=
os
.
ReadFile
(
filePath
)
if
err
!=
nil
{
return
""
,
err
}
hash
:=
sha256
.
Sum256
(
data
)
return
hex
.
EncodeToString
(
hash
[
:
]),
nil
}
// GetModelPricing 获取模型价格(带模糊匹配)
func
(
s
*
PricingService
)
GetModelPricing
(
modelName
string
)
*
LiteLLMModelPricing
{
s
.
mu
.
RLock
()
...
...
backend/internal/service/token_refresh_service.go
View file @
ca8692c7
...
...
@@ -32,8 +32,9 @@ type TokenRefreshService struct {
privacyClientFactory
PrivacyClientFactory
proxyRepo
ProxyRepository
stopCh
chan
struct
{}
wg
sync
.
WaitGroup
stopCh
chan
struct
{}
stopOnce
sync
.
Once
wg
sync
.
WaitGroup
}
// NewTokenRefreshService 创建token刷新服务
...
...
@@ -130,7 +131,9 @@ func (s *TokenRefreshService) Start() {
// Stop 停止刷新服务(可安全多次调用)
func
(
s
*
TokenRefreshService
)
Stop
()
{
close
(
s
.
stopCh
)
s
.
stopOnce
.
Do
(
func
()
{
close
(
s
.
stopCh
)
})
s
.
wg
.
Wait
()
slog
.
Info
(
"token_refresh.service_stopped"
)
}
...
...
@@ -430,6 +433,7 @@ func isNonRetryableRefreshError(err error) bool {
"unauthorized_client"
,
// 客户端未授权
"access_denied"
,
// 访问被拒绝
"missing_project_id"
,
// 缺少 project_id
"no refresh token available"
,
}
for
_
,
needle
:=
range
nonRetryable
{
if
strings
.
Contains
(
msg
,
needle
)
{
...
...
backend/internal/service/token_refresh_service_test.go
View file @
ca8692c7
...
...
@@ -19,6 +19,7 @@ type tokenRefreshAccountRepo struct {
updateCredentialsCalls
int
setErrorCalls
int
clearTempCalls
int
setTempUnschedCalls
int
lastAccount
*
Account
updateErr
error
}
...
...
@@ -58,6 +59,11 @@ func (r *tokenRefreshAccountRepo) ClearTempUnschedulable(ctx context.Context, id
return
nil
}
func
(
r
*
tokenRefreshAccountRepo
)
SetTempUnschedulable
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
,
reason
string
)
error
{
r
.
setTempUnschedCalls
++
return
nil
}
type
tokenCacheInvalidatorStub
struct
{
calls
int
err
error
...
...
@@ -490,6 +496,31 @@ func TestTokenRefreshService_RefreshWithRetry_NonRetryableErrorAllPlatforms(t *t
}
}
func
TestTokenRefreshService_RefreshWithRetry_NoRefreshTokenDoesNotTempUnschedule
(
t
*
testing
.
T
)
{
repo
:=
&
tokenRefreshAccountRepo
{}
cfg
:=
&
config
.
Config
{
TokenRefresh
:
config
.
TokenRefreshConfig
{
MaxRetries
:
2
,
RetryBackoffSeconds
:
0
,
},
}
service
:=
NewTokenRefreshService
(
repo
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
cfg
,
nil
)
account
:=
&
Account
{
ID
:
18
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
}
refresher
:=
&
tokenRefresherStub
{
err
:
errors
.
New
(
"no refresh token available"
),
}
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
,
refresher
,
time
.
Hour
)
require
.
Error
(
t
,
err
)
require
.
Equal
(
t
,
0
,
repo
.
updateCalls
)
require
.
Equal
(
t
,
0
,
repo
.
setTempUnschedCalls
,
"missing refresh token should not mark the account temp unschedulable"
)
require
.
Equal
(
t
,
1
,
repo
.
setErrorCalls
,
"missing refresh token should be treated as a non-retryable credential state"
)
}
// TestIsNonRetryableRefreshError 测试不可重试错误判断
func
TestIsNonRetryableRefreshError
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
...
...
@@ -503,6 +534,7 @@ func TestIsNonRetryableRefreshError(t *testing.T) {
{
name
:
"invalid_client"
,
err
:
errors
.
New
(
"invalid_client"
),
expected
:
true
},
{
name
:
"unauthorized_client"
,
err
:
errors
.
New
(
"unauthorized_client"
),
expected
:
true
},
{
name
:
"access_denied"
,
err
:
errors
.
New
(
"access_denied"
),
expected
:
true
},
{
name
:
"no_refresh_token"
,
err
:
errors
.
New
(
"no refresh token available"
),
expected
:
true
},
{
name
:
"invalid_grant_with_desc"
,
err
:
errors
.
New
(
"Error: invalid_grant - token revoked"
),
expected
:
true
},
{
name
:
"case_insensitive"
,
err
:
errors
.
New
(
"INVALID_GRANT"
),
expected
:
true
},
}
...
...
backend/internal/service/usage_log_helpers.go
View file @
ca8692c7
...
...
@@ -21,8 +21,8 @@ func optionalNonEqualStringPtr(value, compare string) *string {
}
func
forwardResultBillingModel
(
requestedModel
,
upstreamModel
string
)
string
{
if
trimmed
Upstream
:=
strings
.
TrimSpace
(
upstream
Model
);
trimmed
Upstream
!=
""
{
return
trimmed
Upstream
if
trimmed
:=
strings
.
TrimSpace
(
requested
Model
);
trimmed
!=
""
{
return
trimmed
}
return
strings
.
TrimSpace
(
requested
Model
)
return
strings
.
TrimSpace
(
upstream
Model
)
}
deploy/config.example.yaml
View file @
ca8692c7
...
...
@@ -865,10 +865,10 @@ rate_limit:
pricing
:
# URL to fetch model pricing data (default: pinned model-price-repo commit)
# 获取模型定价数据的 URL(默认:固定 commit 的 model-price-repo)
remote_url
:
"
https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/
c7947e9871687e664180bc971d4837f1fc2784a9
/model_prices_and_context_window.json"
remote_url
:
"
https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/
refs/heads/main/
/model_prices_and_context_window.json"
# Hash verification URL (optional)
# 哈希校验 URL(可选)
hash_url
:
"
https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/
c7947e9871687e664180bc971d4837f1fc2784a9
/model_prices_and_context_window.sha256"
hash_url
:
"
https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/
refs/heads/main/
/model_prices_and_context_window.sha256"
# Local data directory for caching
# 本地数据缓存目录
data_dir
:
"
./data"
...
...
frontend/src/components/account/CreateAccountModal.vue
View file @
ca8692c7
...
...
@@ -2245,6 +2245,41 @@
<
/p
>
<
/div
>
<
/div
>
<!--
Custom
Base
URL
Relay
-->
<
div
class
=
"
rounded-lg border border-gray-200 p-4 dark:border-dark-600
"
>
<
div
class
=
"
flex items-center justify-between
"
>
<
div
>
<
label
class
=
"
input-label mb-0
"
>
{{
t
(
'
admin.accounts.quotaControl.customBaseUrl.label
'
)
}}
<
/label
>
<
p
class
=
"
mt-1 text-xs text-gray-500 dark:text-gray-400
"
>
{{
t
(
'
admin.accounts.quotaControl.customBaseUrl.hint
'
)
}}
<
/p
>
<
/div
>
<
button
type
=
"
button
"
@
click
=
"
customBaseUrlEnabled = !customBaseUrlEnabled
"
:
class
=
"
[
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
customBaseUrlEnabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
]
"
>
<
span
:
class
=
"
[
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
customBaseUrlEnabled ? 'translate-x-5' : 'translate-x-0'
]
"
/>
<
/button
>
<
/div
>
<
div
v
-
if
=
"
customBaseUrlEnabled
"
class
=
"
mt-3
"
>
<
input
v
-
model
=
"
customBaseUrl
"
type
=
"
text
"
class
=
"
input
"
:
placeholder
=
"
t('admin.accounts.quotaControl.customBaseUrl.urlHint')
"
/>
<
/div
>
<
/div
>
<
/div
>
<
div
>
...
...
@@ -3095,6 +3130,8 @@ const tlsFingerprintProfiles = ref<{ id: number; name: string }[]>([])
const
sessionIdMaskingEnabled
=
ref
(
false
)
const
cacheTTLOverrideEnabled
=
ref
(
false
)
const
cacheTTLOverrideTarget
=
ref
<
string
>
(
'
5m
'
)
const
customBaseUrlEnabled
=
ref
(
false
)
const
customBaseUrl
=
ref
(
''
)
// Gemini tier selection (used as fallback when auto-detection is unavailable/fails)
const
geminiTierGoogleOne
=
ref
<
'
google_one_free
'
|
'
google_ai_pro
'
|
'
google_ai_ultra
'
>
(
'
google_one_free
'
)
...
...
@@ -3765,6 +3802,8 @@ const resetForm = () => {
sessionIdMaskingEnabled
.
value
=
false
cacheTTLOverrideEnabled
.
value
=
false
cacheTTLOverrideTarget
.
value
=
'
5m
'
customBaseUrlEnabled
.
value
=
false
customBaseUrl
.
value
=
''
allowOverages
.
value
=
false
antigravityAccountType
.
value
=
'
oauth
'
upstreamBaseUrl
.
value
=
''
...
...
@@ -4856,6 +4895,12 @@ const handleAnthropicExchange = async (authCode: string) => {
extra
.
cache_ttl_override_target
=
cacheTTLOverrideTarget
.
value
}
// Add custom base URL settings
if
(
customBaseUrlEnabled
.
value
&&
customBaseUrl
.
value
.
trim
())
{
extra
.
custom_base_url_enabled
=
true
extra
.
custom_base_url
=
customBaseUrl
.
value
.
trim
()
}
const
credentials
:
Record
<
string
,
unknown
>
=
{
...
tokenInfo
}
applyInterceptWarmup
(
credentials
,
interceptWarmupRequests
.
value
,
'
create
'
)
await
createAccountAndFinish
(
form
.
platform
,
addMethod
.
value
as
AccountType
,
credentials
,
extra
)
...
...
@@ -4974,6 +5019,12 @@ const handleCookieAuth = async (sessionKey: string) => {
extra
.
cache_ttl_override_target
=
cacheTTLOverrideTarget
.
value
}
// Add custom base URL settings
if
(
customBaseUrlEnabled
.
value
&&
customBaseUrl
.
value
.
trim
())
{
extra
.
custom_base_url_enabled
=
true
extra
.
custom_base_url
=
customBaseUrl
.
value
.
trim
()
}
const
accountName
=
keys
.
length
>
1
?
`${form.name
}
#${i + 1
}
`
:
form
.
name
const
credentials
:
Record
<
string
,
unknown
>
=
{
...
tokenInfo
}
...
...
frontend/src/components/account/EditAccountModal.vue
View file @
ca8692c7
...
...
@@ -1580,6 +1580,41 @@
<
/p
>
<
/div
>
<
/div
>
<!--
Custom
Base
URL
Relay
-->
<
div
class
=
"
rounded-lg border border-gray-200 p-4 dark:border-dark-600
"
>
<
div
class
=
"
flex items-center justify-between
"
>
<
div
>
<
label
class
=
"
input-label mb-0
"
>
{{
t
(
'
admin.accounts.quotaControl.customBaseUrl.label
'
)
}}
<
/label
>
<
p
class
=
"
mt-1 text-xs text-gray-500 dark:text-gray-400
"
>
{{
t
(
'
admin.accounts.quotaControl.customBaseUrl.hint
'
)
}}
<
/p
>
<
/div
>
<
button
type
=
"
button
"
@
click
=
"
customBaseUrlEnabled = !customBaseUrlEnabled
"
:
class
=
"
[
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
customBaseUrlEnabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
]
"
>
<
span
:
class
=
"
[
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
customBaseUrlEnabled ? 'translate-x-5' : 'translate-x-0'
]
"
/>
<
/button
>
<
/div
>
<
div
v
-
if
=
"
customBaseUrlEnabled
"
class
=
"
mt-3
"
>
<
input
v
-
model
=
"
customBaseUrl
"
type
=
"
text
"
class
=
"
input
"
:
placeholder
=
"
t('admin.accounts.quotaControl.customBaseUrl.urlHint')
"
/>
<
/div
>
<
/div
>
<
/div
>
<
div
class
=
"
border-t border-gray-200 pt-4 dark:border-dark-600
"
>
...
...
@@ -1854,6 +1889,8 @@ const tlsFingerprintProfiles = ref<{ id: number; name: string }[]>([])
const
sessionIdMaskingEnabled
=
ref
(
false
)
const
cacheTTLOverrideEnabled
=
ref
(
false
)
const
cacheTTLOverrideTarget
=
ref
<
string
>
(
'
5m
'
)
const
customBaseUrlEnabled
=
ref
(
false
)
const
customBaseUrl
=
ref
(
''
)
// OpenAI 自动透传开关(OAuth/API Key)
const
openaiPassthroughEnabled
=
ref
(
false
)
...
...
@@ -2482,6 +2519,8 @@ function loadQuotaControlSettings(account: Account) {
sessionIdMaskingEnabled
.
value
=
false
cacheTTLOverrideEnabled
.
value
=
false
cacheTTLOverrideTarget
.
value
=
'
5m
'
customBaseUrlEnabled
.
value
=
false
customBaseUrl
.
value
=
''
// Only applies to Anthropic OAuth/SetupToken accounts
if
(
account
.
platform
!==
'
anthropic
'
||
(
account
.
type
!==
'
oauth
'
&&
account
.
type
!==
'
setup-token
'
))
{
...
...
@@ -2528,6 +2567,12 @@ function loadQuotaControlSettings(account: Account) {
cacheTTLOverrideEnabled
.
value
=
true
cacheTTLOverrideTarget
.
value
=
account
.
cache_ttl_override_target
||
'
5m
'
}
// Load custom base URL setting
if
(
account
.
custom_base_url_enabled
===
true
)
{
customBaseUrlEnabled
.
value
=
true
customBaseUrl
.
value
=
account
.
custom_base_url
||
''
}
}
function
formatTempUnschedKeywords
(
value
:
unknown
)
{
...
...
@@ -2980,6 +3025,15 @@ const handleSubmit = async () => {
delete
newExtra
.
cache_ttl_override_target
}
// Custom base URL relay setting
if
(
customBaseUrlEnabled
.
value
&&
customBaseUrl
.
value
.
trim
())
{
newExtra
.
custom_base_url_enabled
=
true
newExtra
.
custom_base_url
=
customBaseUrl
.
value
.
trim
()
}
else
{
delete
newExtra
.
custom_base_url_enabled
delete
newExtra
.
custom_base_url
}
updatePayload
.
extra
=
newExtra
}
...
...
frontend/src/components/charts/TokenUsageTrend.vue
View file @
ca8692c7
...
...
@@ -64,7 +64,8 @@ const chartColors = computed(() => ({
input
:
'
#3b82f6
'
,
output
:
'
#10b981
'
,
cacheCreation
:
'
#f59e0b
'
,
cacheRead
:
'
#06b6d4
'
cacheRead
:
'
#06b6d4
'
,
cacheHitRate
:
'
#8b5cf6
'
}))
const
chartData
=
computed
(()
=>
{
...
...
@@ -104,6 +105,19 @@ const chartData = computed(() => {
backgroundColor
:
`
${
chartColors
.
value
.
cacheRead
}
20`
,
fill
:
true
,
tension
:
0.3
},
{
label
:
'
Cache Hit Rate
'
,
data
:
props
.
trendData
.
map
((
d
)
=>
{
const
total
=
d
.
cache_read_tokens
+
d
.
cache_creation_tokens
return
total
>
0
?
(
d
.
cache_read_tokens
/
total
)
*
100
:
0
}),
borderColor
:
chartColors
.
value
.
cacheHitRate
,
backgroundColor
:
`
${
chartColors
.
value
.
cacheHitRate
}
20`
,
borderDash
:
[
5
,
5
],
fill
:
false
,
tension
:
0.3
,
yAxisID
:
'
yPercent
'
}
]
}
...
...
@@ -132,6 +146,9 @@ const lineOptions = computed(() => ({
tooltip
:
{
callbacks
:
{
label
:
(
context
:
any
)
=>
{
if
(
context
.
dataset
.
yAxisID
===
'
yPercent
'
)
{
return
`
${
context
.
dataset
.
label
}
:
${
context
.
raw
.
toFixed
(
1
)}
%`
}
return
`
${
context
.
dataset
.
label
}
:
${
formatTokens
(
context
.
raw
)}
`
},
footer
:
(
tooltipItems
:
any
)
=>
{
...
...
@@ -168,6 +185,21 @@ const lineOptions = computed(() => ({
},
callback
:
(
value
:
string
|
number
)
=>
formatTokens
(
Number
(
value
))
}
},
yPercent
:
{
position
:
'
right
'
as
const
,
min
:
0
,
max
:
100
,
grid
:
{
drawOnChartArea
:
false
},
ticks
:
{
color
:
chartColors
.
value
.
cacheHitRate
,
font
:
{
size
:
10
},
callback
:
(
value
:
string
|
number
)
=>
`
${
value
}
%`
}
}
}
}))
...
...
frontend/src/i18n/locales/en.ts
View file @
ca8692c7
...
...
@@ -2318,6 +2318,11 @@ export default {
target
:
'
Target TTL
'
,
targetHint
:
'
Select the TTL tier for billing
'
},
customBaseUrl
:
{
label
:
'
Custom Relay URL
'
,
hint
:
'
Forward requests to a custom relay service. Proxy URL will be passed as a query parameter.
'
,
urlHint
:
'
Relay service URL (e.g., https://relay.example.com)
'
,
},
clientAffinity
:
{
label
:
'
Client Affinity Scheduling
'
,
hint
:
'
When enabled, new sessions prefer accounts previously used by this client to reduce account switching
'
...
...
@@ -4378,6 +4383,7 @@ export default {
provider
:
'
Type
'
,
active
:
'
Active
'
,
endpoint
:
'
Endpoint
'
,
bucket
:
'
Bucket
'
,
storagePath
:
'
Storage Path
'
,
capacityUsage
:
'
Capacity / Used
'
,
capacityUnlimited
:
'
Unlimited
'
,
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment