Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
d367d1cd
"frontend/vscode:/vscode.git/clone" did not exist on "668e1647932e3d414532ea6817da06562e7c4e83"
Commit
d367d1cd
authored
Feb 09, 2026
by
yangjianbo
Browse files
Merge branch 'main' into test-sora
parents
d7011163
3c46f7d2
Changes
104
Hide whitespace changes
Inline
Side-by-side
backend/internal/handler/dto/mappers.go
View file @
d367d1cd
...
@@ -115,6 +115,7 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
...
@@ -115,6 +115,7 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
MCPXMLInject
:
g
.
MCPXMLInject
,
MCPXMLInject
:
g
.
MCPXMLInject
,
SupportedModelScopes
:
g
.
SupportedModelScopes
,
SupportedModelScopes
:
g
.
SupportedModelScopes
,
AccountCount
:
g
.
AccountCount
,
AccountCount
:
g
.
AccountCount
,
SortOrder
:
g
.
SortOrder
,
}
}
if
len
(
g
.
AccountGroups
)
>
0
{
if
len
(
g
.
AccountGroups
)
>
0
{
out
.
AccountGroups
=
make
([]
AccountGroup
,
0
,
len
(
g
.
AccountGroups
))
out
.
AccountGroups
=
make
([]
AccountGroup
,
0
,
len
(
g
.
AccountGroups
))
...
...
backend/internal/handler/dto/types.go
View file @
d367d1cd
...
@@ -2,11 +2,6 @@ package dto
...
@@ -2,11 +2,6 @@ package dto
import
"time"
import
"time"
type
ScopeRateLimitInfo
struct
{
ResetAt
time
.
Time
`json:"reset_at"`
RemainingSec
int64
`json:"remaining_sec"`
}
type
User
struct
{
type
User
struct
{
ID
int64
`json:"id"`
ID
int64
`json:"id"`
Email
string
`json:"email"`
Email
string
`json:"email"`
...
@@ -104,6 +99,9 @@ type AdminGroup struct {
...
@@ -104,6 +99,9 @@ type AdminGroup struct {
SupportedModelScopes
[]
string
`json:"supported_model_scopes"`
SupportedModelScopes
[]
string
`json:"supported_model_scopes"`
AccountGroups
[]
AccountGroup
`json:"account_groups,omitempty"`
AccountGroups
[]
AccountGroup
`json:"account_groups,omitempty"`
AccountCount
int64
`json:"account_count,omitempty"`
AccountCount
int64
`json:"account_count,omitempty"`
// 分组排序
SortOrder
int
`json:"sort_order"`
}
}
type
Account
struct
{
type
Account
struct
{
...
@@ -132,9 +130,6 @@ type Account struct {
...
@@ -132,9 +130,6 @@ type Account struct {
RateLimitResetAt
*
time
.
Time
`json:"rate_limit_reset_at"`
RateLimitResetAt
*
time
.
Time
`json:"rate_limit_reset_at"`
OverloadUntil
*
time
.
Time
`json:"overload_until"`
OverloadUntil
*
time
.
Time
`json:"overload_until"`
// Antigravity scope 级限流状态(从 extra 提取)
ScopeRateLimits
map
[
string
]
ScopeRateLimitInfo
`json:"scope_rate_limits,omitempty"`
TempUnschedulableUntil
*
time
.
Time
`json:"temp_unschedulable_until"`
TempUnschedulableUntil
*
time
.
Time
`json:"temp_unschedulable_until"`
TempUnschedulableReason
string
`json:"temp_unschedulable_reason"`
TempUnschedulableReason
string
`json:"temp_unschedulable_reason"`
...
...
backend/internal/handler/gateway_handler.go
View file @
d367d1cd
...
@@ -13,6 +13,7 @@ import (
...
@@ -13,6 +13,7 @@ import (
"time"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
...
@@ -116,7 +117,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
...
@@ -116,7 +117,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
setOpsRequestContext
(
c
,
""
,
false
,
body
)
setOpsRequestContext
(
c
,
""
,
false
,
body
)
parsedReq
,
err
:=
service
.
ParseGatewayRequest
(
body
)
parsedReq
,
err
:=
service
.
ParseGatewayRequest
(
body
,
domain
.
PlatformAnthropic
)
if
err
!=
nil
{
if
err
!=
nil
{
h
.
errorResponse
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"Failed to parse request body"
)
h
.
errorResponse
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"Failed to parse request body"
)
return
return
...
@@ -205,6 +206,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
...
@@ -205,6 +206,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
}
// 计算粘性会话hash
// 计算粘性会话hash
parsedReq
.
SessionContext
=
&
service
.
SessionContext
{
ClientIP
:
ip
.
GetClientIP
(
c
),
UserAgent
:
c
.
GetHeader
(
"User-Agent"
),
APIKeyID
:
apiKey
.
ID
,
}
sessionHash
:=
h
.
gatewayService
.
GenerateSessionHash
(
parsedReq
)
sessionHash
:=
h
.
gatewayService
.
GenerateSessionHash
(
parsedReq
)
// 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则使用分组平台
// 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则使用分组平台
...
@@ -336,7 +342,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
...
@@ -336,7 +342,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if
errors
.
As
(
err
,
&
failoverErr
)
{
if
errors
.
As
(
err
,
&
failoverErr
)
{
failedAccountIDs
[
account
.
ID
]
=
struct
{}{}
failedAccountIDs
[
account
.
ID
]
=
struct
{}{}
lastFailoverErr
=
failoverErr
lastFailoverErr
=
failoverErr
if
failoverErr
.
ForceCacheBilling
{
if
need
ForceCacheBilling
(
hasBoundSession
,
failoverErr
)
{
forceCacheBilling
=
true
forceCacheBilling
=
true
}
}
if
switchCount
>=
maxAccountSwitches
{
if
switchCount
>=
maxAccountSwitches
{
...
@@ -345,6 +351,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
...
@@ -345,6 +351,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
}
switchCount
++
switchCount
++
log
.
Printf
(
"Account %d: upstream error %d, switching account %d/%d"
,
account
.
ID
,
failoverErr
.
StatusCode
,
switchCount
,
maxAccountSwitches
)
log
.
Printf
(
"Account %d: upstream error %d, switching account %d/%d"
,
account
.
ID
,
failoverErr
.
StatusCode
,
switchCount
,
maxAccountSwitches
)
if
account
.
Platform
==
service
.
PlatformAntigravity
{
if
!
sleepFailoverDelay
(
c
.
Request
.
Context
(),
switchCount
)
{
return
}
}
continue
continue
}
}
// 错误响应已在Forward中处理,这里只记录日志
// 错误响应已在Forward中处理,这里只记录日志
...
@@ -484,7 +495,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
...
@@ -484,7 +495,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if
switchCount
>
0
{
if
switchCount
>
0
{
requestCtx
=
context
.
WithValue
(
requestCtx
,
ctxkey
.
AccountSwitchCount
,
switchCount
)
requestCtx
=
context
.
WithValue
(
requestCtx
,
ctxkey
.
AccountSwitchCount
,
switchCount
)
}
}
if
account
.
Platform
==
service
.
PlatformAntigravity
{
if
account
.
Platform
==
service
.
PlatformAntigravity
&&
account
.
Type
!=
service
.
AccountTypeAPIKey
{
result
,
err
=
h
.
antigravityGatewayService
.
Forward
(
requestCtx
,
c
,
account
,
body
,
hasBoundSession
)
result
,
err
=
h
.
antigravityGatewayService
.
Forward
(
requestCtx
,
c
,
account
,
body
,
hasBoundSession
)
}
else
{
}
else
{
result
,
err
=
h
.
gatewayService
.
Forward
(
requestCtx
,
c
,
account
,
parsedReq
)
result
,
err
=
h
.
gatewayService
.
Forward
(
requestCtx
,
c
,
account
,
parsedReq
)
...
@@ -532,7 +543,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
...
@@ -532,7 +543,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if
errors
.
As
(
err
,
&
failoverErr
)
{
if
errors
.
As
(
err
,
&
failoverErr
)
{
failedAccountIDs
[
account
.
ID
]
=
struct
{}{}
failedAccountIDs
[
account
.
ID
]
=
struct
{}{}
lastFailoverErr
=
failoverErr
lastFailoverErr
=
failoverErr
if
failoverErr
.
ForceCacheBilling
{
if
need
ForceCacheBilling
(
hasBoundSession
,
failoverErr
)
{
forceCacheBilling
=
true
forceCacheBilling
=
true
}
}
if
switchCount
>=
maxAccountSwitches
{
if
switchCount
>=
maxAccountSwitches
{
...
@@ -541,6 +552,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
...
@@ -541,6 +552,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
}
switchCount
++
switchCount
++
log
.
Printf
(
"Account %d: upstream error %d, switching account %d/%d"
,
account
.
ID
,
failoverErr
.
StatusCode
,
switchCount
,
maxAccountSwitches
)
log
.
Printf
(
"Account %d: upstream error %d, switching account %d/%d"
,
account
.
ID
,
failoverErr
.
StatusCode
,
switchCount
,
maxAccountSwitches
)
if
account
.
Platform
==
service
.
PlatformAntigravity
{
if
!
sleepFailoverDelay
(
c
.
Request
.
Context
(),
switchCount
)
{
return
}
}
continue
continue
}
}
// 错误响应已在Forward中处理,这里只记录日志
// 错误响应已在Forward中处理,这里只记录日志
...
@@ -814,6 +830,27 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT
...
@@ -814,6 +830,27 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT
fmt
.
Sprintf
(
"Concurrency limit exceeded for %s, please retry later"
,
slotType
),
streamStarted
)
fmt
.
Sprintf
(
"Concurrency limit exceeded for %s, please retry later"
,
slotType
),
streamStarted
)
}
}
// needForceCacheBilling 判断 failover 时是否需要强制缓存计费
// 粘性会话切换账号、或上游明确标记时,将 input_tokens 转为 cache_read 计费
func
needForceCacheBilling
(
hasBoundSession
bool
,
failoverErr
*
service
.
UpstreamFailoverError
)
bool
{
return
hasBoundSession
||
(
failoverErr
!=
nil
&&
failoverErr
.
ForceCacheBilling
)
}
// sleepFailoverDelay 账号切换线性递增延时:第1次0s、第2次1s、第3次2s…
// 返回 false 表示 context 已取消。
func
sleepFailoverDelay
(
ctx
context
.
Context
,
switchCount
int
)
bool
{
delay
:=
time
.
Duration
(
switchCount
-
1
)
*
time
.
Second
if
delay
<=
0
{
return
true
}
select
{
case
<-
ctx
.
Done
()
:
return
false
case
<-
time
.
After
(
delay
)
:
return
true
}
}
func
(
h
*
GatewayHandler
)
handleFailoverExhausted
(
c
*
gin
.
Context
,
failoverErr
*
service
.
UpstreamFailoverError
,
platform
string
,
streamStarted
bool
)
{
func
(
h
*
GatewayHandler
)
handleFailoverExhausted
(
c
*
gin
.
Context
,
failoverErr
*
service
.
UpstreamFailoverError
,
platform
string
,
streamStarted
bool
)
{
statusCode
:=
failoverErr
.
StatusCode
statusCode
:=
failoverErr
.
StatusCode
responseBody
:=
failoverErr
.
ResponseBody
responseBody
:=
failoverErr
.
ResponseBody
...
@@ -947,7 +984,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
...
@@ -947,7 +984,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
setOpsRequestContext
(
c
,
""
,
false
,
body
)
setOpsRequestContext
(
c
,
""
,
false
,
body
)
parsedReq
,
err
:=
service
.
ParseGatewayRequest
(
body
)
parsedReq
,
err
:=
service
.
ParseGatewayRequest
(
body
,
domain
.
PlatformAnthropic
)
if
err
!=
nil
{
if
err
!=
nil
{
h
.
errorResponse
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"Failed to parse request body"
)
h
.
errorResponse
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"Failed to parse request body"
)
return
return
...
@@ -975,6 +1012,11 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
...
@@ -975,6 +1012,11 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
}
}
// 计算粘性会话 hash
// 计算粘性会话 hash
parsedReq
.
SessionContext
=
&
service
.
SessionContext
{
ClientIP
:
ip
.
GetClientIP
(
c
),
UserAgent
:
c
.
GetHeader
(
"User-Agent"
),
APIKeyID
:
apiKey
.
ID
,
}
sessionHash
:=
h
.
gatewayService
.
GenerateSessionHash
(
parsedReq
)
sessionHash
:=
h
.
gatewayService
.
GenerateSessionHash
(
parsedReq
)
// 选择支持该模型的账号
// 选择支持该模型的账号
...
...
backend/internal/handler/gemini_v1beta_handler.go
View file @
d367d1cd
...
@@ -14,6 +14,7 @@ import (
...
@@ -14,6 +14,7 @@ import (
"strings"
"strings"
"time"
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
...
@@ -30,13 +31,6 @@ import (
...
@@ -30,13 +31,6 @@ import (
// 匹配格式: /Users/xxx/.gemini/tmp/[64位十六进制哈希]
// 匹配格式: /Users/xxx/.gemini/tmp/[64位十六进制哈希]
var
geminiCLITmpDirRegex
=
regexp
.
MustCompile
(
`/\.gemini/tmp/([A-Fa-f0-9]{64})`
)
var
geminiCLITmpDirRegex
=
regexp
.
MustCompile
(
`/\.gemini/tmp/([A-Fa-f0-9]{64})`
)
func
isGeminiCLIRequest
(
c
*
gin
.
Context
,
body
[]
byte
)
bool
{
if
strings
.
TrimSpace
(
c
.
GetHeader
(
"x-gemini-api-privileged-user-id"
))
!=
""
{
return
true
}
return
geminiCLITmpDirRegex
.
Match
(
body
)
}
// GeminiV1BetaListModels proxies:
// GeminiV1BetaListModels proxies:
// GET /v1beta/models
// GET /v1beta/models
func
(
h
*
GatewayHandler
)
GeminiV1BetaListModels
(
c
*
gin
.
Context
)
{
func
(
h
*
GatewayHandler
)
GeminiV1BetaListModels
(
c
*
gin
.
Context
)
{
...
@@ -239,7 +233,14 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
...
@@ -239,7 +233,14 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
sessionHash
:=
extractGeminiCLISessionHash
(
c
,
body
)
sessionHash
:=
extractGeminiCLISessionHash
(
c
,
body
)
if
sessionHash
==
""
{
if
sessionHash
==
""
{
// Fallback: 使用通用的会话哈希生成逻辑(适用于其他客户端)
// Fallback: 使用通用的会话哈希生成逻辑(适用于其他客户端)
parsedReq
,
_
:=
service
.
ParseGatewayRequest
(
body
)
parsedReq
,
_
:=
service
.
ParseGatewayRequest
(
body
,
domain
.
PlatformGemini
)
if
parsedReq
!=
nil
{
parsedReq
.
SessionContext
=
&
service
.
SessionContext
{
ClientIP
:
ip
.
GetClientIP
(
c
),
UserAgent
:
c
.
GetHeader
(
"User-Agent"
),
APIKeyID
:
apiKey
.
ID
,
}
}
sessionHash
=
h
.
gatewayService
.
GenerateSessionHash
(
parsedReq
)
sessionHash
=
h
.
gatewayService
.
GenerateSessionHash
(
parsedReq
)
}
}
sessionKey
:=
sessionHash
sessionKey
:=
sessionHash
...
@@ -258,6 +259,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
...
@@ -258,6 +259,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
var
geminiDigestChain
string
var
geminiDigestChain
string
var
geminiPrefixHash
string
var
geminiPrefixHash
string
var
geminiSessionUUID
string
var
geminiSessionUUID
string
var
matchedDigestChain
string
useDigestFallback
:=
sessionBoundAccountID
==
0
useDigestFallback
:=
sessionBoundAccountID
==
0
if
useDigestFallback
{
if
useDigestFallback
{
...
@@ -284,13 +286,14 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
...
@@ -284,13 +286,14 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
)
)
// 查找会话
// 查找会话
foundUUID
,
foundAccountID
,
found
:=
h
.
gatewayService
.
FindGeminiSession
(
foundUUID
,
foundAccountID
,
foundMatchedChain
,
found
:=
h
.
gatewayService
.
FindGeminiSession
(
c
.
Request
.
Context
(),
c
.
Request
.
Context
(),
derefGroupID
(
apiKey
.
GroupID
),
derefGroupID
(
apiKey
.
GroupID
),
geminiPrefixHash
,
geminiPrefixHash
,
geminiDigestChain
,
geminiDigestChain
,
)
)
if
found
{
if
found
{
matchedDigestChain
=
foundMatchedChain
sessionBoundAccountID
=
foundAccountID
sessionBoundAccountID
=
foundAccountID
geminiSessionUUID
=
foundUUID
geminiSessionUUID
=
foundUUID
log
.
Printf
(
"[Gemini] Digest fallback matched: uuid=%s, accountID=%d, chain=%s"
,
log
.
Printf
(
"[Gemini] Digest fallback matched: uuid=%s, accountID=%d, chain=%s"
,
...
@@ -316,7 +319,6 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
...
@@ -316,7 +319,6 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号
// 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号
hasBoundSession
:=
sessionKey
!=
""
&&
sessionBoundAccountID
>
0
hasBoundSession
:=
sessionKey
!=
""
&&
sessionBoundAccountID
>
0
isCLI
:=
isGeminiCLIRequest
(
c
,
body
)
cleanedForUnknownBinding
:=
false
cleanedForUnknownBinding
:=
false
maxAccountSwitches
:=
h
.
maxAccountSwitchesGemini
maxAccountSwitches
:=
h
.
maxAccountSwitchesGemini
...
@@ -344,10 +346,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
...
@@ -344,10 +346,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
log
.
Printf
(
"[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature"
,
sessionBoundAccountID
,
account
.
ID
)
log
.
Printf
(
"[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature"
,
sessionBoundAccountID
,
account
.
ID
)
body
=
service
.
CleanGeminiNativeThoughtSignatures
(
body
)
body
=
service
.
CleanGeminiNativeThoughtSignatures
(
body
)
sessionBoundAccountID
=
account
.
ID
sessionBoundAccountID
=
account
.
ID
}
else
if
sessionKey
!=
""
&&
sessionBoundAccountID
==
0
&&
isCLI
&&
!
cleanedForUnknownBinding
&&
bytes
.
Contains
(
body
,
[]
byte
(
`"thoughtSignature"`
))
{
}
else
if
sessionKey
!=
""
&&
sessionBoundAccountID
==
0
&&
!
cleanedForUnknownBinding
&&
bytes
.
Contains
(
body
,
[]
byte
(
`"thoughtSignature"`
))
{
// 无缓存绑定但请求里已有 thoughtSignature:常见于缓存丢失/TTL 过期后,
CLI
继续携带旧签名。
// 无缓存绑定但请求里已有 thoughtSignature:常见于缓存丢失/TTL 过期后,
客户端
继续携带旧签名。
// 为避免第一次转发就 400,这里做一次确定性清理,让新账号重新生成签名链路。
// 为避免第一次转发就 400,这里做一次确定性清理,让新账号重新生成签名链路。
log
.
Printf
(
"[Gemini] Sticky session binding missing
for CLI request
, cleaning thoughtSignature proactively"
)
log
.
Printf
(
"[Gemini] Sticky session binding missing, cleaning thoughtSignature proactively"
)
body
=
service
.
CleanGeminiNativeThoughtSignatures
(
body
)
body
=
service
.
CleanGeminiNativeThoughtSignatures
(
body
)
cleanedForUnknownBinding
=
true
cleanedForUnknownBinding
=
true
sessionBoundAccountID
=
account
.
ID
sessionBoundAccountID
=
account
.
ID
...
@@ -410,7 +412,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
...
@@ -410,7 +412,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
if
switchCount
>
0
{
if
switchCount
>
0
{
requestCtx
=
context
.
WithValue
(
requestCtx
,
ctxkey
.
AccountSwitchCount
,
switchCount
)
requestCtx
=
context
.
WithValue
(
requestCtx
,
ctxkey
.
AccountSwitchCount
,
switchCount
)
}
}
if
account
.
Platform
==
service
.
PlatformAntigravity
{
if
account
.
Platform
==
service
.
PlatformAntigravity
&&
account
.
Type
!=
service
.
AccountTypeAPIKey
{
result
,
err
=
h
.
antigravityGatewayService
.
ForwardGemini
(
requestCtx
,
c
,
account
,
modelName
,
action
,
stream
,
body
,
hasBoundSession
)
result
,
err
=
h
.
antigravityGatewayService
.
ForwardGemini
(
requestCtx
,
c
,
account
,
modelName
,
action
,
stream
,
body
,
hasBoundSession
)
}
else
{
}
else
{
result
,
err
=
h
.
geminiCompatService
.
ForwardNative
(
requestCtx
,
c
,
account
,
modelName
,
action
,
stream
,
body
)
result
,
err
=
h
.
geminiCompatService
.
ForwardNative
(
requestCtx
,
c
,
account
,
modelName
,
action
,
stream
,
body
)
...
@@ -422,7 +424,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
...
@@ -422,7 +424,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
var
failoverErr
*
service
.
UpstreamFailoverError
var
failoverErr
*
service
.
UpstreamFailoverError
if
errors
.
As
(
err
,
&
failoverErr
)
{
if
errors
.
As
(
err
,
&
failoverErr
)
{
failedAccountIDs
[
account
.
ID
]
=
struct
{}{}
failedAccountIDs
[
account
.
ID
]
=
struct
{}{}
if
failoverErr
.
ForceCacheBilling
{
if
need
ForceCacheBilling
(
hasBoundSession
,
failoverErr
)
{
forceCacheBilling
=
true
forceCacheBilling
=
true
}
}
if
switchCount
>=
maxAccountSwitches
{
if
switchCount
>=
maxAccountSwitches
{
...
@@ -433,6 +435,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
...
@@ -433,6 +435,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
lastFailoverErr
=
failoverErr
lastFailoverErr
=
failoverErr
switchCount
++
switchCount
++
log
.
Printf
(
"Gemini account %d: upstream error %d, switching account %d/%d"
,
account
.
ID
,
failoverErr
.
StatusCode
,
switchCount
,
maxAccountSwitches
)
log
.
Printf
(
"Gemini account %d: upstream error %d, switching account %d/%d"
,
account
.
ID
,
failoverErr
.
StatusCode
,
switchCount
,
maxAccountSwitches
)
if
account
.
Platform
==
service
.
PlatformAntigravity
{
if
!
sleepFailoverDelay
(
c
.
Request
.
Context
(),
switchCount
)
{
return
}
}
continue
continue
}
}
// ForwardNative already wrote the response
// ForwardNative already wrote the response
...
@@ -453,6 +460,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
...
@@ -453,6 +460,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
geminiDigestChain
,
geminiDigestChain
,
geminiSessionUUID
,
geminiSessionUUID
,
account
.
ID
,
account
.
ID
,
matchedDigestChain
,
);
err
!=
nil
{
);
err
!=
nil
{
log
.
Printf
(
"[Gemini] Failed to save digest session: %v"
,
err
)
log
.
Printf
(
"[Gemini] Failed to save digest session: %v"
,
err
)
}
}
...
...
backend/internal/repository/account_repo.go
View file @
d367d1cd
...
@@ -282,6 +282,34 @@ func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID
...
@@ -282,6 +282,34 @@ func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID
return
&
accounts
[
0
],
nil
return
&
accounts
[
0
],
nil
}
}
func
(
r
*
accountRepository
)
ListCRSAccountIDs
(
ctx
context
.
Context
)
(
map
[
string
]
int64
,
error
)
{
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
`
SELECT id, extra->>'crs_account_id'
FROM accounts
WHERE deleted_at IS NULL
AND extra->>'crs_account_id' IS NOT NULL
AND extra->>'crs_account_id' != ''
`
)
if
err
!=
nil
{
return
nil
,
err
}
defer
func
()
{
_
=
rows
.
Close
()
}()
result
:=
make
(
map
[
string
]
int64
)
for
rows
.
Next
()
{
var
id
int64
var
crsID
string
if
err
:=
rows
.
Scan
(
&
id
,
&
crsID
);
err
!=
nil
{
return
nil
,
err
}
result
[
crsID
]
=
id
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
nil
,
err
}
return
result
,
nil
}
func
(
r
*
accountRepository
)
Update
(
ctx
context
.
Context
,
account
*
service
.
Account
)
error
{
func
(
r
*
accountRepository
)
Update
(
ctx
context
.
Context
,
account
*
service
.
Account
)
error
{
if
account
==
nil
{
if
account
==
nil
{
return
nil
return
nil
...
@@ -798,53 +826,6 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
...
@@ -798,53 +826,6 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
return
nil
return
nil
}
}
func
(
r
*
accountRepository
)
SetAntigravityQuotaScopeLimit
(
ctx
context
.
Context
,
id
int64
,
scope
service
.
AntigravityQuotaScope
,
resetAt
time
.
Time
)
error
{
now
:=
time
.
Now
()
.
UTC
()
payload
:=
map
[
string
]
string
{
"rate_limited_at"
:
now
.
Format
(
time
.
RFC3339
),
"rate_limit_reset_at"
:
resetAt
.
UTC
()
.
Format
(
time
.
RFC3339
),
}
raw
,
err
:=
json
.
Marshal
(
payload
)
if
err
!=
nil
{
return
err
}
scopeKey
:=
string
(
scope
)
client
:=
clientFromContext
(
ctx
,
r
.
client
)
result
,
err
:=
client
.
ExecContext
(
ctx
,
`UPDATE accounts SET
extra = jsonb_set(
jsonb_set(COALESCE(extra, '{}'::jsonb), '{antigravity_quota_scopes}'::text[], COALESCE(extra->'antigravity_quota_scopes', '{}'::jsonb), true),
ARRAY['antigravity_quota_scopes', $1]::text[],
$2::jsonb,
true
),
updated_at = NOW(),
last_used_at = NOW()
WHERE id = $3 AND deleted_at IS NULL`
,
scopeKey
,
raw
,
id
,
)
if
err
!=
nil
{
return
err
}
affected
,
err
:=
result
.
RowsAffected
()
if
err
!=
nil
{
return
err
}
if
affected
==
0
{
return
service
.
ErrAccountNotFound
}
if
err
:=
enqueueSchedulerOutbox
(
ctx
,
r
.
sql
,
service
.
SchedulerOutboxEventAccountChanged
,
&
id
,
nil
,
nil
);
err
!=
nil
{
log
.
Printf
(
"[SchedulerOutbox] enqueue quota scope failed: account=%d err=%v"
,
id
,
err
)
}
return
nil
}
func
(
r
*
accountRepository
)
SetModelRateLimit
(
ctx
context
.
Context
,
id
int64
,
scope
string
,
resetAt
time
.
Time
)
error
{
func
(
r
*
accountRepository
)
SetModelRateLimit
(
ctx
context
.
Context
,
id
int64
,
scope
string
,
resetAt
time
.
Time
)
error
{
if
scope
==
""
{
if
scope
==
""
{
return
nil
return
nil
...
...
backend/internal/repository/api_key_repo.go
View file @
d367d1cd
...
@@ -476,6 +476,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
...
@@ -476,6 +476,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
ModelRoutingEnabled
:
g
.
ModelRoutingEnabled
,
ModelRoutingEnabled
:
g
.
ModelRoutingEnabled
,
MCPXMLInject
:
g
.
McpXMLInject
,
MCPXMLInject
:
g
.
McpXMLInject
,
SupportedModelScopes
:
g
.
SupportedModelScopes
,
SupportedModelScopes
:
g
.
SupportedModelScopes
,
SortOrder
:
g
.
SortOrder
,
CreatedAt
:
g
.
CreatedAt
,
CreatedAt
:
g
.
CreatedAt
,
UpdatedAt
:
g
.
UpdatedAt
,
UpdatedAt
:
g
.
UpdatedAt
,
}
}
...
...
backend/internal/repository/gateway_cache.go
View file @
d367d1cd
...
@@ -11,63 +11,6 @@ import (
...
@@ -11,63 +11,6 @@ import (
const
stickySessionPrefix
=
"sticky_session:"
const
stickySessionPrefix
=
"sticky_session:"
// Gemini Trie Lua 脚本
const
(
// geminiTrieFindScript 查找最长前缀匹配的 Lua 脚本
// KEYS[1] = trie key
// ARGV[1] = digestChain (如 "u:a-m:b-u:c-m:d")
// ARGV[2] = TTL seconds (用于刷新)
// 返回: 最长匹配的 value (uuid:accountID) 或 nil
// 查找成功时自动刷新 TTL,防止活跃会话意外过期
geminiTrieFindScript
=
`
local chain = ARGV[1]
local ttl = tonumber(ARGV[2])
local lastMatch = nil
local path = ""
for part in string.gmatch(chain, "[^-]+") do
path = path == "" and part or path .. "-" .. part
local val = redis.call('HGET', KEYS[1], path)
if val and val ~= "" then
lastMatch = val
end
end
if lastMatch then
redis.call('EXPIRE', KEYS[1], ttl)
end
return lastMatch
`
// geminiTrieSaveScript 保存会话到 Trie 的 Lua 脚本
// KEYS[1] = trie key
// ARGV[1] = digestChain
// ARGV[2] = value (uuid:accountID)
// ARGV[3] = TTL seconds
geminiTrieSaveScript
=
`
local chain = ARGV[1]
local value = ARGV[2]
local ttl = tonumber(ARGV[3])
local path = ""
for part in string.gmatch(chain, "[^-]+") do
path = path == "" and part or path .. "-" .. part
end
redis.call('HSET', KEYS[1], path, value)
redis.call('EXPIRE', KEYS[1], ttl)
return "OK"
`
)
// 模型负载统计相关常量
const
(
modelLoadKeyPrefix
=
"ag:model_load:"
// 模型调用次数 key 前缀
modelLastUsedKeyPrefix
=
"ag:model_last_used:"
// 模型最后调度时间 key 前缀
modelLoadTTL
=
24
*
time
.
Hour
// 调用次数 TTL(24 小时无调用后清零)
modelLastUsedTTL
=
24
*
time
.
Hour
// 最后调度时间 TTL
)
type
gatewayCache
struct
{
type
gatewayCache
struct
{
rdb
*
redis
.
Client
rdb
*
redis
.
Client
}
}
...
@@ -108,133 +51,3 @@ func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64
...
@@ -108,133 +51,3 @@ func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64
key
:=
buildSessionKey
(
groupID
,
sessionHash
)
key
:=
buildSessionKey
(
groupID
,
sessionHash
)
return
c
.
rdb
.
Del
(
ctx
,
key
)
.
Err
()
return
c
.
rdb
.
Del
(
ctx
,
key
)
.
Err
()
}
}
// ============ Antigravity 模型负载统计方法 ============
// modelLoadKey 构建模型调用次数 key
// 格式: ag:model_load:{accountID}:{model}
func
modelLoadKey
(
accountID
int64
,
model
string
)
string
{
return
fmt
.
Sprintf
(
"%s%d:%s"
,
modelLoadKeyPrefix
,
accountID
,
model
)
}
// modelLastUsedKey 构建模型最后调度时间 key
// 格式: ag:model_last_used:{accountID}:{model}
func
modelLastUsedKey
(
accountID
int64
,
model
string
)
string
{
return
fmt
.
Sprintf
(
"%s%d:%s"
,
modelLastUsedKeyPrefix
,
accountID
,
model
)
}
// IncrModelCallCount 增加模型调用次数并更新最后调度时间
// 返回更新后的调用次数
func
(
c
*
gatewayCache
)
IncrModelCallCount
(
ctx
context
.
Context
,
accountID
int64
,
model
string
)
(
int64
,
error
)
{
loadKey
:=
modelLoadKey
(
accountID
,
model
)
lastUsedKey
:=
modelLastUsedKey
(
accountID
,
model
)
pipe
:=
c
.
rdb
.
Pipeline
()
incrCmd
:=
pipe
.
Incr
(
ctx
,
loadKey
)
pipe
.
Expire
(
ctx
,
loadKey
,
modelLoadTTL
)
// 每次调用刷新 TTL
pipe
.
Set
(
ctx
,
lastUsedKey
,
time
.
Now
()
.
Unix
(),
modelLastUsedTTL
)
if
_
,
err
:=
pipe
.
Exec
(
ctx
);
err
!=
nil
{
return
0
,
err
}
return
incrCmd
.
Val
(),
nil
}
// GetModelLoadBatch 批量获取账号的模型负载信息
func
(
c
*
gatewayCache
)
GetModelLoadBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
,
model
string
)
(
map
[
int64
]
*
service
.
ModelLoadInfo
,
error
)
{
if
len
(
accountIDs
)
==
0
{
return
make
(
map
[
int64
]
*
service
.
ModelLoadInfo
),
nil
}
loadCmds
,
lastUsedCmds
:=
c
.
pipelineModelLoadGet
(
ctx
,
accountIDs
,
model
)
return
c
.
parseModelLoadResults
(
accountIDs
,
loadCmds
,
lastUsedCmds
),
nil
}
// pipelineModelLoadGet 批量获取模型负载的 Pipeline 操作
func
(
c
*
gatewayCache
)
pipelineModelLoadGet
(
ctx
context
.
Context
,
accountIDs
[]
int64
,
model
string
,
)
(
map
[
int64
]
*
redis
.
StringCmd
,
map
[
int64
]
*
redis
.
StringCmd
)
{
pipe
:=
c
.
rdb
.
Pipeline
()
loadCmds
:=
make
(
map
[
int64
]
*
redis
.
StringCmd
,
len
(
accountIDs
))
lastUsedCmds
:=
make
(
map
[
int64
]
*
redis
.
StringCmd
,
len
(
accountIDs
))
for
_
,
id
:=
range
accountIDs
{
loadCmds
[
id
]
=
pipe
.
Get
(
ctx
,
modelLoadKey
(
id
,
model
))
lastUsedCmds
[
id
]
=
pipe
.
Get
(
ctx
,
modelLastUsedKey
(
id
,
model
))
}
_
,
_
=
pipe
.
Exec
(
ctx
)
// 忽略错误,key 不存在是正常的
return
loadCmds
,
lastUsedCmds
}
// parseModelLoadResults 解析 Pipeline 结果
func
(
c
*
gatewayCache
)
parseModelLoadResults
(
accountIDs
[]
int64
,
loadCmds
map
[
int64
]
*
redis
.
StringCmd
,
lastUsedCmds
map
[
int64
]
*
redis
.
StringCmd
,
)
map
[
int64
]
*
service
.
ModelLoadInfo
{
result
:=
make
(
map
[
int64
]
*
service
.
ModelLoadInfo
,
len
(
accountIDs
))
for
_
,
id
:=
range
accountIDs
{
result
[
id
]
=
&
service
.
ModelLoadInfo
{
CallCount
:
getInt64OrZero
(
loadCmds
[
id
]),
LastUsedAt
:
getTimeOrZero
(
lastUsedCmds
[
id
]),
}
}
return
result
}
// getInt64OrZero 从 StringCmd 获取 int64 值,失败返回 0
func
getInt64OrZero
(
cmd
*
redis
.
StringCmd
)
int64
{
val
,
_
:=
cmd
.
Int64
()
return
val
}
// getTimeOrZero 从 StringCmd 获取 time.Time,失败返回零值
func
getTimeOrZero
(
cmd
*
redis
.
StringCmd
)
time
.
Time
{
val
,
err
:=
cmd
.
Int64
()
if
err
!=
nil
{
return
time
.
Time
{}
}
return
time
.
Unix
(
val
,
0
)
}
// ============ Gemini 会话 Fallback 方法 (Trie 实现) ============
// FindGeminiSession 查找 Gemini 会话(使用 Trie + Lua 脚本实现 O(L) 查询)
// 返回最长匹配的会话信息,匹配成功时自动刷新 TTL
func
(
c
*
gatewayCache
)
FindGeminiSession
(
ctx
context
.
Context
,
groupID
int64
,
prefixHash
,
digestChain
string
)
(
uuid
string
,
accountID
int64
,
found
bool
)
{
if
digestChain
==
""
{
return
""
,
0
,
false
}
trieKey
:=
service
.
BuildGeminiTrieKey
(
groupID
,
prefixHash
)
ttlSeconds
:=
int
(
service
.
GeminiSessionTTL
()
.
Seconds
())
// 使用 Lua 脚本在 Redis 端执行 Trie 查找,O(L) 次 HGET,1 次网络往返
// 查找成功时自动刷新 TTL,防止活跃会话意外过期
result
,
err
:=
c
.
rdb
.
Eval
(
ctx
,
geminiTrieFindScript
,
[]
string
{
trieKey
},
digestChain
,
ttlSeconds
)
.
Result
()
if
err
!=
nil
||
result
==
nil
{
return
""
,
0
,
false
}
value
,
ok
:=
result
.
(
string
)
if
!
ok
||
value
==
""
{
return
""
,
0
,
false
}
uuid
,
accountID
,
ok
=
service
.
ParseGeminiSessionValue
(
value
)
return
uuid
,
accountID
,
ok
}
// SaveGeminiSession 保存 Gemini 会话(使用 Trie + Lua 脚本)
func
(
c
*
gatewayCache
)
SaveGeminiSession
(
ctx
context
.
Context
,
groupID
int64
,
prefixHash
,
digestChain
,
uuid
string
,
accountID
int64
)
error
{
if
digestChain
==
""
{
return
nil
}
trieKey
:=
service
.
BuildGeminiTrieKey
(
groupID
,
prefixHash
)
value
:=
service
.
FormatGeminiSessionValue
(
uuid
,
accountID
)
ttlSeconds
:=
int
(
service
.
GeminiSessionTTL
()
.
Seconds
())
return
c
.
rdb
.
Eval
(
ctx
,
geminiTrieSaveScript
,
[]
string
{
trieKey
},
digestChain
,
value
,
ttlSeconds
)
.
Err
()
}
backend/internal/repository/gateway_cache_integration_test.go
View file @
d367d1cd
...
@@ -104,157 +104,6 @@ func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
...
@@ -104,157 +104,6 @@ func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
require
.
False
(
s
.
T
(),
errors
.
Is
(
err
,
redis
.
Nil
),
"expected parsing error, not redis.Nil"
)
require
.
False
(
s
.
T
(),
errors
.
Is
(
err
,
redis
.
Nil
),
"expected parsing error, not redis.Nil"
)
}
}
// ============ Gemini Trie 会话测试 ============
func
(
s
*
GatewayCacheSuite
)
TestGeminiSessionTrie_SaveAndFind
()
{
groupID
:=
int64
(
1
)
prefixHash
:=
"testprefix"
digestChain
:=
"u:hash1-m:hash2-u:hash3"
uuid
:=
"test-uuid-123"
accountID
:=
int64
(
42
)
// 保存会话
err
:=
s
.
cache
.
SaveGeminiSession
(
s
.
ctx
,
groupID
,
prefixHash
,
digestChain
,
uuid
,
accountID
)
require
.
NoError
(
s
.
T
(),
err
,
"SaveGeminiSession"
)
// 精确匹配查找
foundUUID
,
foundAccountID
,
found
:=
s
.
cache
.
FindGeminiSession
(
s
.
ctx
,
groupID
,
prefixHash
,
digestChain
)
require
.
True
(
s
.
T
(),
found
,
"should find exact match"
)
require
.
Equal
(
s
.
T
(),
uuid
,
foundUUID
)
require
.
Equal
(
s
.
T
(),
accountID
,
foundAccountID
)
}
func
(
s
*
GatewayCacheSuite
)
TestGeminiSessionTrie_PrefixMatch
()
{
groupID
:=
int64
(
1
)
prefixHash
:=
"prefixmatch"
shortChain
:=
"u:a-m:b"
longChain
:=
"u:a-m:b-u:c-m:d"
uuid
:=
"uuid-prefix"
accountID
:=
int64
(
100
)
// 保存短链
err
:=
s
.
cache
.
SaveGeminiSession
(
s
.
ctx
,
groupID
,
prefixHash
,
shortChain
,
uuid
,
accountID
)
require
.
NoError
(
s
.
T
(),
err
)
// 用长链查找,应该匹配到短链(前缀匹配)
foundUUID
,
foundAccountID
,
found
:=
s
.
cache
.
FindGeminiSession
(
s
.
ctx
,
groupID
,
prefixHash
,
longChain
)
require
.
True
(
s
.
T
(),
found
,
"should find prefix match"
)
require
.
Equal
(
s
.
T
(),
uuid
,
foundUUID
)
require
.
Equal
(
s
.
T
(),
accountID
,
foundAccountID
)
}
func
(
s
*
GatewayCacheSuite
)
TestGeminiSessionTrie_LongestPrefixMatch
()
{
groupID
:=
int64
(
1
)
prefixHash
:=
"longestmatch"
// 保存多个不同长度的链
err
:=
s
.
cache
.
SaveGeminiSession
(
s
.
ctx
,
groupID
,
prefixHash
,
"u:a"
,
"uuid-short"
,
1
)
require
.
NoError
(
s
.
T
(),
err
)
err
=
s
.
cache
.
SaveGeminiSession
(
s
.
ctx
,
groupID
,
prefixHash
,
"u:a-m:b"
,
"uuid-medium"
,
2
)
require
.
NoError
(
s
.
T
(),
err
)
err
=
s
.
cache
.
SaveGeminiSession
(
s
.
ctx
,
groupID
,
prefixHash
,
"u:a-m:b-u:c"
,
"uuid-long"
,
3
)
require
.
NoError
(
s
.
T
(),
err
)
// 查找更长的链,应该匹配到最长的前缀
foundUUID
,
foundAccountID
,
found
:=
s
.
cache
.
FindGeminiSession
(
s
.
ctx
,
groupID
,
prefixHash
,
"u:a-m:b-u:c-m:d-u:e"
)
require
.
True
(
s
.
T
(),
found
,
"should find longest prefix match"
)
require
.
Equal
(
s
.
T
(),
"uuid-long"
,
foundUUID
)
require
.
Equal
(
s
.
T
(),
int64
(
3
),
foundAccountID
)
// 查找中等长度的链
foundUUID
,
foundAccountID
,
found
=
s
.
cache
.
FindGeminiSession
(
s
.
ctx
,
groupID
,
prefixHash
,
"u:a-m:b-u:x"
)
require
.
True
(
s
.
T
(),
found
)
require
.
Equal
(
s
.
T
(),
"uuid-medium"
,
foundUUID
)
require
.
Equal
(
s
.
T
(),
int64
(
2
),
foundAccountID
)
}
func
(
s
*
GatewayCacheSuite
)
TestGeminiSessionTrie_NoMatch
()
{
groupID
:=
int64
(
1
)
prefixHash
:=
"nomatch"
digestChain
:=
"u:a-m:b"
// 保存一个会话
err
:=
s
.
cache
.
SaveGeminiSession
(
s
.
ctx
,
groupID
,
prefixHash
,
digestChain
,
"uuid"
,
1
)
require
.
NoError
(
s
.
T
(),
err
)
// 用不同的链查找,应该找不到
_
,
_
,
found
:=
s
.
cache
.
FindGeminiSession
(
s
.
ctx
,
groupID
,
prefixHash
,
"u:x-m:y"
)
require
.
False
(
s
.
T
(),
found
,
"should not find non-matching chain"
)
}
func
(
s
*
GatewayCacheSuite
)
TestGeminiSessionTrie_DifferentPrefixHash
()
{
groupID
:=
int64
(
1
)
digestChain
:=
"u:a-m:b"
// 保存到 prefixHash1
err
:=
s
.
cache
.
SaveGeminiSession
(
s
.
ctx
,
groupID
,
"prefix1"
,
digestChain
,
"uuid1"
,
1
)
require
.
NoError
(
s
.
T
(),
err
)
// 用 prefixHash2 查找,应该找不到(不同用户/客户端隔离)
_
,
_
,
found
:=
s
.
cache
.
FindGeminiSession
(
s
.
ctx
,
groupID
,
"prefix2"
,
digestChain
)
require
.
False
(
s
.
T
(),
found
,
"different prefixHash should be isolated"
)
}
func
(
s
*
GatewayCacheSuite
)
TestGeminiSessionTrie_DifferentGroupID
()
{
prefixHash
:=
"sameprefix"
digestChain
:=
"u:a-m:b"
// 保存到 groupID 1
err
:=
s
.
cache
.
SaveGeminiSession
(
s
.
ctx
,
1
,
prefixHash
,
digestChain
,
"uuid1"
,
1
)
require
.
NoError
(
s
.
T
(),
err
)
// 用 groupID 2 查找,应该找不到(分组隔离)
_
,
_
,
found
:=
s
.
cache
.
FindGeminiSession
(
s
.
ctx
,
2
,
prefixHash
,
digestChain
)
require
.
False
(
s
.
T
(),
found
,
"different groupID should be isolated"
)
}
func
(
s
*
GatewayCacheSuite
)
TestGeminiSessionTrie_EmptyDigestChain
()
{
groupID
:=
int64
(
1
)
prefixHash
:=
"emptytest"
// 空链不应该保存
err
:=
s
.
cache
.
SaveGeminiSession
(
s
.
ctx
,
groupID
,
prefixHash
,
""
,
"uuid"
,
1
)
require
.
NoError
(
s
.
T
(),
err
,
"empty chain should not error"
)
// 空链查找应该返回 false
_
,
_
,
found
:=
s
.
cache
.
FindGeminiSession
(
s
.
ctx
,
groupID
,
prefixHash
,
""
)
require
.
False
(
s
.
T
(),
found
,
"empty chain should not match"
)
}
func
(
s
*
GatewayCacheSuite
)
TestGeminiSessionTrie_MultipleSessions
()
{
groupID
:=
int64
(
1
)
prefixHash
:=
"multisession"
// 保存多个不同会话(模拟 1000 个并发会话的场景)
sessions
:=
[]
struct
{
chain
string
uuid
string
accountID
int64
}{
{
"u:session1"
,
"uuid-1"
,
1
},
{
"u:session2-m:reply2"
,
"uuid-2"
,
2
},
{
"u:session3-m:reply3-u:msg3"
,
"uuid-3"
,
3
},
}
for
_
,
sess
:=
range
sessions
{
err
:=
s
.
cache
.
SaveGeminiSession
(
s
.
ctx
,
groupID
,
prefixHash
,
sess
.
chain
,
sess
.
uuid
,
sess
.
accountID
)
require
.
NoError
(
s
.
T
(),
err
)
}
// 验证每个会话都能正确查找
for
_
,
sess
:=
range
sessions
{
foundUUID
,
foundAccountID
,
found
:=
s
.
cache
.
FindGeminiSession
(
s
.
ctx
,
groupID
,
prefixHash
,
sess
.
chain
)
require
.
True
(
s
.
T
(),
found
,
"should find session: %s"
,
sess
.
chain
)
require
.
Equal
(
s
.
T
(),
sess
.
uuid
,
foundUUID
)
require
.
Equal
(
s
.
T
(),
sess
.
accountID
,
foundAccountID
)
}
// 验证继续对话的场景
foundUUID
,
foundAccountID
,
found
:=
s
.
cache
.
FindGeminiSession
(
s
.
ctx
,
groupID
,
prefixHash
,
"u:session2-m:reply2-u:newmsg"
)
require
.
True
(
s
.
T
(),
found
)
require
.
Equal
(
s
.
T
(),
"uuid-2"
,
foundUUID
)
require
.
Equal
(
s
.
T
(),
int64
(
2
),
foundAccountID
)
}
func
TestGatewayCacheSuite
(
t
*
testing
.
T
)
{
func
TestGatewayCacheSuite
(
t
*
testing
.
T
)
{
suite
.
Run
(
t
,
new
(
GatewayCacheSuite
))
suite
.
Run
(
t
,
new
(
GatewayCacheSuite
))
...
...
backend/internal/repository/gateway_cache_model_load_integration_test.go
deleted
100644 → 0
View file @
d7011163
//go:build integration
package
repository
import
(
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
// ============ Gateway Cache 模型负载统计集成测试 ============
type
GatewayCacheModelLoadSuite
struct
{
suite
.
Suite
}
func
TestGatewayCacheModelLoadSuite
(
t
*
testing
.
T
)
{
suite
.
Run
(
t
,
new
(
GatewayCacheModelLoadSuite
))
}
func
(
s
*
GatewayCacheModelLoadSuite
)
TestIncrModelCallCount_Basic
()
{
t
:=
s
.
T
()
rdb
:=
testRedis
(
t
)
cache
:=
&
gatewayCache
{
rdb
:
rdb
}
ctx
:=
context
.
Background
()
accountID
:=
int64
(
123
)
model
:=
"claude-sonnet-4-20250514"
// 首次调用应返回 1
count1
,
err
:=
cache
.
IncrModelCallCount
(
ctx
,
accountID
,
model
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
1
),
count1
)
// 第二次调用应返回 2
count2
,
err
:=
cache
.
IncrModelCallCount
(
ctx
,
accountID
,
model
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
2
),
count2
)
// 第三次调用应返回 3
count3
,
err
:=
cache
.
IncrModelCallCount
(
ctx
,
accountID
,
model
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
3
),
count3
)
}
func
(
s
*
GatewayCacheModelLoadSuite
)
TestIncrModelCallCount_DifferentModels
()
{
t
:=
s
.
T
()
rdb
:=
testRedis
(
t
)
cache
:=
&
gatewayCache
{
rdb
:
rdb
}
ctx
:=
context
.
Background
()
accountID
:=
int64
(
456
)
model1
:=
"claude-sonnet-4-20250514"
model2
:=
"claude-opus-4-5-20251101"
// 不同模型应该独立计数
count1
,
err
:=
cache
.
IncrModelCallCount
(
ctx
,
accountID
,
model1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
1
),
count1
)
count2
,
err
:=
cache
.
IncrModelCallCount
(
ctx
,
accountID
,
model2
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
1
),
count2
)
count1Again
,
err
:=
cache
.
IncrModelCallCount
(
ctx
,
accountID
,
model1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
2
),
count1Again
)
}
func
(
s
*
GatewayCacheModelLoadSuite
)
TestIncrModelCallCount_DifferentAccounts
()
{
t
:=
s
.
T
()
rdb
:=
testRedis
(
t
)
cache
:=
&
gatewayCache
{
rdb
:
rdb
}
ctx
:=
context
.
Background
()
account1
:=
int64
(
111
)
account2
:=
int64
(
222
)
model
:=
"gemini-2.5-pro"
// 不同账号应该独立计数
count1
,
err
:=
cache
.
IncrModelCallCount
(
ctx
,
account1
,
model
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
1
),
count1
)
count2
,
err
:=
cache
.
IncrModelCallCount
(
ctx
,
account2
,
model
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
1
),
count2
)
}
func
(
s
*
GatewayCacheModelLoadSuite
)
TestGetModelLoadBatch_Empty
()
{
t
:=
s
.
T
()
rdb
:=
testRedis
(
t
)
cache
:=
&
gatewayCache
{
rdb
:
rdb
}
ctx
:=
context
.
Background
()
result
,
err
:=
cache
.
GetModelLoadBatch
(
ctx
,
[]
int64
{},
"any-model"
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
Empty
(
t
,
result
)
}
func
(
s
*
GatewayCacheModelLoadSuite
)
TestGetModelLoadBatch_NonExistent
()
{
t
:=
s
.
T
()
rdb
:=
testRedis
(
t
)
cache
:=
&
gatewayCache
{
rdb
:
rdb
}
ctx
:=
context
.
Background
()
// 查询不存在的账号应返回零值
result
,
err
:=
cache
.
GetModelLoadBatch
(
ctx
,
[]
int64
{
9999
,
9998
},
"claude-sonnet-4-20250514"
)
require
.
NoError
(
t
,
err
)
require
.
Len
(
t
,
result
,
2
)
require
.
Equal
(
t
,
int64
(
0
),
result
[
9999
]
.
CallCount
)
require
.
True
(
t
,
result
[
9999
]
.
LastUsedAt
.
IsZero
())
require
.
Equal
(
t
,
int64
(
0
),
result
[
9998
]
.
CallCount
)
require
.
True
(
t
,
result
[
9998
]
.
LastUsedAt
.
IsZero
())
}
func
(
s
*
GatewayCacheModelLoadSuite
)
TestGetModelLoadBatch_AfterIncrement
()
{
t
:=
s
.
T
()
rdb
:=
testRedis
(
t
)
cache
:=
&
gatewayCache
{
rdb
:
rdb
}
ctx
:=
context
.
Background
()
accountID
:=
int64
(
789
)
model
:=
"claude-sonnet-4-20250514"
// 先增加调用次数
beforeIncr
:=
time
.
Now
()
_
,
err
:=
cache
.
IncrModelCallCount
(
ctx
,
accountID
,
model
)
require
.
NoError
(
t
,
err
)
_
,
err
=
cache
.
IncrModelCallCount
(
ctx
,
accountID
,
model
)
require
.
NoError
(
t
,
err
)
_
,
err
=
cache
.
IncrModelCallCount
(
ctx
,
accountID
,
model
)
require
.
NoError
(
t
,
err
)
afterIncr
:=
time
.
Now
()
// 获取负载信息
result
,
err
:=
cache
.
GetModelLoadBatch
(
ctx
,
[]
int64
{
accountID
},
model
)
require
.
NoError
(
t
,
err
)
require
.
Len
(
t
,
result
,
1
)
loadInfo
:=
result
[
accountID
]
require
.
NotNil
(
t
,
loadInfo
)
require
.
Equal
(
t
,
int64
(
3
),
loadInfo
.
CallCount
)
require
.
False
(
t
,
loadInfo
.
LastUsedAt
.
IsZero
())
// LastUsedAt 应该在 beforeIncr 和 afterIncr 之间
require
.
True
(
t
,
loadInfo
.
LastUsedAt
.
After
(
beforeIncr
.
Add
(
-
time
.
Second
))
||
loadInfo
.
LastUsedAt
.
Equal
(
beforeIncr
))
require
.
True
(
t
,
loadInfo
.
LastUsedAt
.
Before
(
afterIncr
.
Add
(
time
.
Second
))
||
loadInfo
.
LastUsedAt
.
Equal
(
afterIncr
))
}
func
(
s
*
GatewayCacheModelLoadSuite
)
TestGetModelLoadBatch_MultipleAccounts
()
{
t
:=
s
.
T
()
rdb
:=
testRedis
(
t
)
cache
:=
&
gatewayCache
{
rdb
:
rdb
}
ctx
:=
context
.
Background
()
model
:=
"claude-opus-4-5-20251101"
account1
:=
int64
(
1001
)
account2
:=
int64
(
1002
)
account3
:=
int64
(
1003
)
// 不调用
// account1 调用 2 次
_
,
err
:=
cache
.
IncrModelCallCount
(
ctx
,
account1
,
model
)
require
.
NoError
(
t
,
err
)
_
,
err
=
cache
.
IncrModelCallCount
(
ctx
,
account1
,
model
)
require
.
NoError
(
t
,
err
)
// account2 调用 5 次
for
i
:=
0
;
i
<
5
;
i
++
{
_
,
err
=
cache
.
IncrModelCallCount
(
ctx
,
account2
,
model
)
require
.
NoError
(
t
,
err
)
}
// 批量获取
result
,
err
:=
cache
.
GetModelLoadBatch
(
ctx
,
[]
int64
{
account1
,
account2
,
account3
},
model
)
require
.
NoError
(
t
,
err
)
require
.
Len
(
t
,
result
,
3
)
require
.
Equal
(
t
,
int64
(
2
),
result
[
account1
]
.
CallCount
)
require
.
False
(
t
,
result
[
account1
]
.
LastUsedAt
.
IsZero
())
require
.
Equal
(
t
,
int64
(
5
),
result
[
account2
]
.
CallCount
)
require
.
False
(
t
,
result
[
account2
]
.
LastUsedAt
.
IsZero
())
require
.
Equal
(
t
,
int64
(
0
),
result
[
account3
]
.
CallCount
)
require
.
True
(
t
,
result
[
account3
]
.
LastUsedAt
.
IsZero
())
}
func
(
s
*
GatewayCacheModelLoadSuite
)
TestGetModelLoadBatch_ModelIsolation
()
{
t
:=
s
.
T
()
rdb
:=
testRedis
(
t
)
cache
:=
&
gatewayCache
{
rdb
:
rdb
}
ctx
:=
context
.
Background
()
accountID
:=
int64
(
2001
)
model1
:=
"claude-sonnet-4-20250514"
model2
:=
"gemini-2.5-pro"
// 对 model1 调用 3 次
for
i
:=
0
;
i
<
3
;
i
++
{
_
,
err
:=
cache
.
IncrModelCallCount
(
ctx
,
accountID
,
model1
)
require
.
NoError
(
t
,
err
)
}
// 获取 model1 的负载
result1
,
err
:=
cache
.
GetModelLoadBatch
(
ctx
,
[]
int64
{
accountID
},
model1
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
3
),
result1
[
accountID
]
.
CallCount
)
// 获取 model2 的负载(应该为 0)
result2
,
err
:=
cache
.
GetModelLoadBatch
(
ctx
,
[]
int64
{
accountID
},
model2
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
0
),
result2
[
accountID
]
.
CallCount
)
}
// ============ 辅助函数测试 ============
func
(
s
*
GatewayCacheModelLoadSuite
)
TestModelLoadKey_Format
()
{
t
:=
s
.
T
()
key
:=
modelLoadKey
(
123
,
"claude-sonnet-4"
)
require
.
Equal
(
t
,
"ag:model_load:123:claude-sonnet-4"
,
key
)
}
func
(
s
*
GatewayCacheModelLoadSuite
)
TestModelLastUsedKey_Format
()
{
t
:=
s
.
T
()
key
:=
modelLastUsedKey
(
456
,
"gemini-2.5-pro"
)
require
.
Equal
(
t
,
"ag:model_last_used:456:gemini-2.5-pro"
,
key
)
}
backend/internal/repository/group_repo.go
View file @
d367d1cd
...
@@ -199,7 +199,7 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
...
@@ -199,7 +199,7 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
groups
,
err
:=
q
.
groups
,
err
:=
q
.
Offset
(
params
.
Offset
())
.
Offset
(
params
.
Offset
())
.
Limit
(
params
.
Limit
())
.
Limit
(
params
.
Limit
())
.
Order
(
dbent
.
Asc
(
group
.
FieldID
))
.
Order
(
dbent
.
Asc
(
group
.
FieldSortOrder
),
dbent
.
Asc
(
group
.
FieldID
))
.
All
(
ctx
)
All
(
ctx
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
nil
,
err
return
nil
,
nil
,
err
...
@@ -226,7 +226,7 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
...
@@ -226,7 +226,7 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
func
(
r
*
groupRepository
)
ListActive
(
ctx
context
.
Context
)
([]
service
.
Group
,
error
)
{
func
(
r
*
groupRepository
)
ListActive
(
ctx
context
.
Context
)
([]
service
.
Group
,
error
)
{
groups
,
err
:=
r
.
client
.
Group
.
Query
()
.
groups
,
err
:=
r
.
client
.
Group
.
Query
()
.
Where
(
group
.
StatusEQ
(
service
.
StatusActive
))
.
Where
(
group
.
StatusEQ
(
service
.
StatusActive
))
.
Order
(
dbent
.
Asc
(
group
.
FieldID
))
.
Order
(
dbent
.
Asc
(
group
.
FieldSortOrder
),
dbent
.
Asc
(
group
.
FieldID
))
.
All
(
ctx
)
All
(
ctx
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
...
@@ -253,7 +253,7 @@ func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, erro
...
@@ -253,7 +253,7 @@ func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, erro
func
(
r
*
groupRepository
)
ListActiveByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
service
.
Group
,
error
)
{
func
(
r
*
groupRepository
)
ListActiveByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
service
.
Group
,
error
)
{
groups
,
err
:=
r
.
client
.
Group
.
Query
()
.
groups
,
err
:=
r
.
client
.
Group
.
Query
()
.
Where
(
group
.
StatusEQ
(
service
.
StatusActive
),
group
.
PlatformEQ
(
platform
))
.
Where
(
group
.
StatusEQ
(
service
.
StatusActive
),
group
.
PlatformEQ
(
platform
))
.
Order
(
dbent
.
Asc
(
group
.
FieldID
))
.
Order
(
dbent
.
Asc
(
group
.
FieldSortOrder
),
dbent
.
Asc
(
group
.
FieldID
))
.
All
(
ctx
)
All
(
ctx
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
...
@@ -505,3 +505,29 @@ func (r *groupRepository) BindAccountsToGroup(ctx context.Context, groupID int64
...
@@ -505,3 +505,29 @@ func (r *groupRepository) BindAccountsToGroup(ctx context.Context, groupID int64
return
nil
return
nil
}
}
// UpdateSortOrders 批量更新分组排序
func
(
r
*
groupRepository
)
UpdateSortOrders
(
ctx
context
.
Context
,
updates
[]
service
.
GroupSortOrderUpdate
)
error
{
if
len
(
updates
)
==
0
{
return
nil
}
// 使用事务批量更新
tx
,
err
:=
r
.
client
.
Tx
(
ctx
)
if
err
!=
nil
{
return
err
}
defer
func
()
{
_
=
tx
.
Rollback
()
}()
for
_
,
u
:=
range
updates
{
if
_
,
err
:=
tx
.
Group
.
UpdateOneID
(
u
.
ID
)
.
SetSortOrder
(
u
.
SortOrder
)
.
Save
(
ctx
);
err
!=
nil
{
return
translatePersistenceError
(
err
,
service
.
ErrGroupNotFound
,
nil
)
}
}
if
err
:=
tx
.
Commit
();
err
!=
nil
{
return
err
}
return
nil
}
backend/internal/server/api_contract_test.go
View file @
d367d1cd
...
@@ -901,6 +901,10 @@ func (stubGroupRepo) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int
...
@@ -901,6 +901,10 @@ func (stubGroupRepo) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int
return
nil
,
errors
.
New
(
"not implemented"
)
return
nil
,
errors
.
New
(
"not implemented"
)
}
}
func
(
stubGroupRepo
)
UpdateSortOrders
(
ctx
context
.
Context
,
updates
[]
service
.
GroupSortOrderUpdate
)
error
{
return
nil
}
type
stubAccountRepo
struct
{
type
stubAccountRepo
struct
{
bulkUpdateIDs
[]
int64
bulkUpdateIDs
[]
int64
}
}
...
@@ -1013,10 +1017,6 @@ func (s *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt
...
@@ -1013,10 +1017,6 @@ func (s *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt
return
errors
.
New
(
"not implemented"
)
return
errors
.
New
(
"not implemented"
)
}
}
func
(
s
*
stubAccountRepo
)
SetAntigravityQuotaScopeLimit
(
ctx
context
.
Context
,
id
int64
,
scope
service
.
AntigravityQuotaScope
,
resetAt
time
.
Time
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
s
*
stubAccountRepo
)
SetModelRateLimit
(
ctx
context
.
Context
,
id
int64
,
scope
string
,
resetAt
time
.
Time
)
error
{
func
(
s
*
stubAccountRepo
)
SetModelRateLimit
(
ctx
context
.
Context
,
id
int64
,
scope
string
,
resetAt
time
.
Time
)
error
{
return
errors
.
New
(
"not implemented"
)
return
errors
.
New
(
"not implemented"
)
}
}
...
@@ -1058,6 +1058,10 @@ func (s *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates s
...
@@ -1058,6 +1058,10 @@ func (s *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates s
return
int64
(
len
(
ids
)),
nil
return
int64
(
len
(
ids
)),
nil
}
}
func
(
s
*
stubAccountRepo
)
ListCRSAccountIDs
(
ctx
context
.
Context
)
(
map
[
string
]
int64
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
type
stubProxyRepo
struct
{}
type
stubProxyRepo
struct
{}
func
(
stubProxyRepo
)
Create
(
ctx
context
.
Context
,
proxy
*
service
.
Proxy
)
error
{
func
(
stubProxyRepo
)
Create
(
ctx
context
.
Context
,
proxy
*
service
.
Proxy
)
error
{
...
...
backend/internal/server/routes/admin.go
View file @
d367d1cd
...
@@ -192,6 +192,7 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
...
@@ -192,6 +192,7 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
{
{
groups
.
GET
(
""
,
h
.
Admin
.
Group
.
List
)
groups
.
GET
(
""
,
h
.
Admin
.
Group
.
List
)
groups
.
GET
(
"/all"
,
h
.
Admin
.
Group
.
GetAll
)
groups
.
GET
(
"/all"
,
h
.
Admin
.
Group
.
GetAll
)
groups
.
PUT
(
"/sort-order"
,
h
.
Admin
.
Group
.
UpdateSortOrder
)
groups
.
GET
(
"/:id"
,
h
.
Admin
.
Group
.
GetByID
)
groups
.
GET
(
"/:id"
,
h
.
Admin
.
Group
.
GetByID
)
groups
.
POST
(
""
,
h
.
Admin
.
Group
.
Create
)
groups
.
POST
(
""
,
h
.
Admin
.
Group
.
Create
)
groups
.
PUT
(
"/:id"
,
h
.
Admin
.
Group
.
Update
)
groups
.
PUT
(
"/:id"
,
h
.
Admin
.
Group
.
Update
)
...
@@ -208,6 +209,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
...
@@ -208,6 +209,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts
.
GET
(
"/:id"
,
h
.
Admin
.
Account
.
GetByID
)
accounts
.
GET
(
"/:id"
,
h
.
Admin
.
Account
.
GetByID
)
accounts
.
POST
(
""
,
h
.
Admin
.
Account
.
Create
)
accounts
.
POST
(
""
,
h
.
Admin
.
Account
.
Create
)
accounts
.
POST
(
"/sync/crs"
,
h
.
Admin
.
Account
.
SyncFromCRS
)
accounts
.
POST
(
"/sync/crs"
,
h
.
Admin
.
Account
.
SyncFromCRS
)
accounts
.
POST
(
"/sync/crs/preview"
,
h
.
Admin
.
Account
.
PreviewFromCRS
)
accounts
.
PUT
(
"/:id"
,
h
.
Admin
.
Account
.
Update
)
accounts
.
PUT
(
"/:id"
,
h
.
Admin
.
Account
.
Update
)
accounts
.
DELETE
(
"/:id"
,
h
.
Admin
.
Account
.
Delete
)
accounts
.
DELETE
(
"/:id"
,
h
.
Admin
.
Account
.
Delete
)
accounts
.
POST
(
"/:id/test"
,
h
.
Admin
.
Account
.
Test
)
accounts
.
POST
(
"/:id/test"
,
h
.
Admin
.
Account
.
Test
)
...
...
backend/internal/service/account.go
View file @
d367d1cd
...
@@ -425,6 +425,22 @@ func (a *Account) GetBaseURL() string {
...
@@ -425,6 +425,22 @@ func (a *Account) GetBaseURL() string {
if
baseURL
==
""
{
if
baseURL
==
""
{
return
"https://api.anthropic.com"
return
"https://api.anthropic.com"
}
}
if
a
.
Platform
==
PlatformAntigravity
{
return
strings
.
TrimRight
(
baseURL
,
"/"
)
+
"/antigravity"
}
return
baseURL
}
// GetGeminiBaseURL 返回 Gemini 兼容端点的 base URL。
// Antigravity 平台的 APIKey 账号自动拼接 /antigravity。
func
(
a
*
Account
)
GetGeminiBaseURL
(
defaultBaseURL
string
)
string
{
baseURL
:=
strings
.
TrimSpace
(
a
.
GetCredential
(
"base_url"
))
if
baseURL
==
""
{
return
defaultBaseURL
}
if
a
.
Platform
==
PlatformAntigravity
&&
a
.
Type
==
AccountTypeAPIKey
{
return
strings
.
TrimRight
(
baseURL
,
"/"
)
+
"/antigravity"
}
return
baseURL
return
baseURL
}
}
...
...
backend/internal/service/account_base_url_test.go
0 → 100644
View file @
d367d1cd
//go:build unit
package
service
import
(
"testing"
)
func
TestGetBaseURL
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
account
Account
expected
string
}{
{
name
:
"non-apikey type returns empty"
,
account
:
Account
{
Type
:
AccountTypeOAuth
,
Platform
:
PlatformAnthropic
,
},
expected
:
""
,
},
{
name
:
"apikey without base_url returns default anthropic"
,
account
:
Account
{
Type
:
AccountTypeAPIKey
,
Platform
:
PlatformAnthropic
,
Credentials
:
map
[
string
]
any
{},
},
expected
:
"https://api.anthropic.com"
,
},
{
name
:
"apikey with custom base_url"
,
account
:
Account
{
Type
:
AccountTypeAPIKey
,
Platform
:
PlatformAnthropic
,
Credentials
:
map
[
string
]
any
{
"base_url"
:
"https://custom.example.com"
},
},
expected
:
"https://custom.example.com"
,
},
{
name
:
"antigravity apikey auto-appends /antigravity"
,
account
:
Account
{
Type
:
AccountTypeAPIKey
,
Platform
:
PlatformAntigravity
,
Credentials
:
map
[
string
]
any
{
"base_url"
:
"https://upstream.example.com"
},
},
expected
:
"https://upstream.example.com/antigravity"
,
},
{
name
:
"antigravity apikey trims trailing slash before appending"
,
account
:
Account
{
Type
:
AccountTypeAPIKey
,
Platform
:
PlatformAntigravity
,
Credentials
:
map
[
string
]
any
{
"base_url"
:
"https://upstream.example.com/"
},
},
expected
:
"https://upstream.example.com/antigravity"
,
},
{
name
:
"antigravity non-apikey returns empty"
,
account
:
Account
{
Type
:
AccountTypeOAuth
,
Platform
:
PlatformAntigravity
,
Credentials
:
map
[
string
]
any
{
"base_url"
:
"https://upstream.example.com"
},
},
expected
:
""
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
result
:=
tt
.
account
.
GetBaseURL
()
if
result
!=
tt
.
expected
{
t
.
Errorf
(
"GetBaseURL() = %q, want %q"
,
result
,
tt
.
expected
)
}
})
}
}
func
TestGetGeminiBaseURL
(
t
*
testing
.
T
)
{
const
defaultGeminiURL
=
"https://generativelanguage.googleapis.com"
tests
:=
[]
struct
{
name
string
account
Account
expected
string
}{
{
name
:
"apikey without base_url returns default"
,
account
:
Account
{
Type
:
AccountTypeAPIKey
,
Platform
:
PlatformGemini
,
Credentials
:
map
[
string
]
any
{},
},
expected
:
defaultGeminiURL
,
},
{
name
:
"apikey with custom base_url"
,
account
:
Account
{
Type
:
AccountTypeAPIKey
,
Platform
:
PlatformGemini
,
Credentials
:
map
[
string
]
any
{
"base_url"
:
"https://custom-gemini.example.com"
},
},
expected
:
"https://custom-gemini.example.com"
,
},
{
name
:
"antigravity apikey auto-appends /antigravity"
,
account
:
Account
{
Type
:
AccountTypeAPIKey
,
Platform
:
PlatformAntigravity
,
Credentials
:
map
[
string
]
any
{
"base_url"
:
"https://upstream.example.com"
},
},
expected
:
"https://upstream.example.com/antigravity"
,
},
{
name
:
"antigravity apikey trims trailing slash"
,
account
:
Account
{
Type
:
AccountTypeAPIKey
,
Platform
:
PlatformAntigravity
,
Credentials
:
map
[
string
]
any
{
"base_url"
:
"https://upstream.example.com/"
},
},
expected
:
"https://upstream.example.com/antigravity"
,
},
{
name
:
"antigravity oauth does NOT append /antigravity"
,
account
:
Account
{
Type
:
AccountTypeOAuth
,
Platform
:
PlatformAntigravity
,
Credentials
:
map
[
string
]
any
{
"base_url"
:
"https://upstream.example.com"
},
},
expected
:
"https://upstream.example.com"
,
},
{
name
:
"oauth without base_url returns default"
,
account
:
Account
{
Type
:
AccountTypeOAuth
,
Platform
:
PlatformAntigravity
,
Credentials
:
map
[
string
]
any
{},
},
expected
:
defaultGeminiURL
,
},
{
name
:
"nil credentials returns default"
,
account
:
Account
{
Type
:
AccountTypeAPIKey
,
Platform
:
PlatformGemini
,
},
expected
:
defaultGeminiURL
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
result
:=
tt
.
account
.
GetGeminiBaseURL
(
defaultGeminiURL
)
if
result
!=
tt
.
expected
{
t
.
Errorf
(
"GetGeminiBaseURL() = %q, want %q"
,
result
,
tt
.
expected
)
}
})
}
}
backend/internal/service/account_service.go
View file @
d367d1cd
...
@@ -28,6 +28,9 @@ type AccountRepository interface {
...
@@ -28,6 +28,9 @@ type AccountRepository interface {
// FindByExtraField 根据 extra 字段中的键值对查找账号(限定 platform='sora')
// FindByExtraField 根据 extra 字段中的键值对查找账号(限定 platform='sora')
// 用于查找通过 linked_openai_account_id 关联的 Sora 账号
// 用于查找通过 linked_openai_account_id 关联的 Sora 账号
FindByExtraField
(
ctx
context
.
Context
,
key
string
,
value
any
)
([]
Account
,
error
)
FindByExtraField
(
ctx
context
.
Context
,
key
string
,
value
any
)
([]
Account
,
error
)
// ListCRSAccountIDs returns a map of crs_account_id -> local account ID
// for all accounts that have been synced from CRS.
ListCRSAccountIDs
(
ctx
context
.
Context
)
(
map
[
string
]
int64
,
error
)
Update
(
ctx
context
.
Context
,
account
*
Account
)
error
Update
(
ctx
context
.
Context
,
account
*
Account
)
error
Delete
(
ctx
context
.
Context
,
id
int64
)
error
Delete
(
ctx
context
.
Context
,
id
int64
)
error
...
@@ -53,7 +56,6 @@ type AccountRepository interface {
...
@@ -53,7 +56,6 @@ type AccountRepository interface {
ListSchedulableByGroupIDAndPlatforms
(
ctx
context
.
Context
,
groupID
int64
,
platforms
[]
string
)
([]
Account
,
error
)
ListSchedulableByGroupIDAndPlatforms
(
ctx
context
.
Context
,
groupID
int64
,
platforms
[]
string
)
([]
Account
,
error
)
SetRateLimited
(
ctx
context
.
Context
,
id
int64
,
resetAt
time
.
Time
)
error
SetRateLimited
(
ctx
context
.
Context
,
id
int64
,
resetAt
time
.
Time
)
error
SetAntigravityQuotaScopeLimit
(
ctx
context
.
Context
,
id
int64
,
scope
AntigravityQuotaScope
,
resetAt
time
.
Time
)
error
SetModelRateLimit
(
ctx
context
.
Context
,
id
int64
,
scope
string
,
resetAt
time
.
Time
)
error
SetModelRateLimit
(
ctx
context
.
Context
,
id
int64
,
scope
string
,
resetAt
time
.
Time
)
error
SetOverloaded
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
)
error
SetOverloaded
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
)
error
SetTempUnschedulable
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
,
reason
string
)
error
SetTempUnschedulable
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
,
reason
string
)
error
...
...
backend/internal/service/account_service_delete_test.go
View file @
d367d1cd
...
@@ -54,10 +54,14 @@ func (s *accountRepoStub) GetByCRSAccountID(ctx context.Context, crsAccountID st
...
@@ -54,10 +54,14 @@ func (s *accountRepoStub) GetByCRSAccountID(ctx context.Context, crsAccountID st
panic
(
"unexpected GetByCRSAccountID call"
)
panic
(
"unexpected GetByCRSAccountID call"
)
}
}
func
(
s
*
accountRepoStub
)
FindByExtraField
(
ctx
context
.
Context
,
key
string
,
value
interface
{}
)
([]
Account
,
error
)
{
func
(
s
*
accountRepoStub
)
FindByExtraField
(
ctx
context
.
Context
,
key
string
,
value
any
)
([]
Account
,
error
)
{
panic
(
"unexpected FindByExtraField call"
)
panic
(
"unexpected FindByExtraField call"
)
}
}
func
(
s
*
accountRepoStub
)
ListCRSAccountIDs
(
ctx
context
.
Context
)
(
map
[
string
]
int64
,
error
)
{
panic
(
"unexpected ListCRSAccountIDs call"
)
}
func
(
s
*
accountRepoStub
)
Update
(
ctx
context
.
Context
,
account
*
Account
)
error
{
func
(
s
*
accountRepoStub
)
Update
(
ctx
context
.
Context
,
account
*
Account
)
error
{
panic
(
"unexpected Update call"
)
panic
(
"unexpected Update call"
)
}
}
...
@@ -147,10 +151,6 @@ func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt
...
@@ -147,10 +151,6 @@ func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt
panic
(
"unexpected SetRateLimited call"
)
panic
(
"unexpected SetRateLimited call"
)
}
}
func
(
s
*
accountRepoStub
)
SetAntigravityQuotaScopeLimit
(
ctx
context
.
Context
,
id
int64
,
scope
AntigravityQuotaScope
,
resetAt
time
.
Time
)
error
{
panic
(
"unexpected SetAntigravityQuotaScopeLimit call"
)
}
func
(
s
*
accountRepoStub
)
SetModelRateLimit
(
ctx
context
.
Context
,
id
int64
,
scope
string
,
resetAt
time
.
Time
)
error
{
func
(
s
*
accountRepoStub
)
SetModelRateLimit
(
ctx
context
.
Context
,
id
int64
,
scope
string
,
resetAt
time
.
Time
)
error
{
panic
(
"unexpected SetModelRateLimit call"
)
panic
(
"unexpected SetModelRateLimit call"
)
}
}
...
...
backend/internal/service/account_test_service.go
View file @
d367d1cd
...
@@ -250,7 +250,6 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
...
@@ -250,7 +250,6 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
// Set common headers
// Set common headers
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
Header
.
Set
(
"anthropic-version"
,
"2023-06-01"
)
req
.
Header
.
Set
(
"anthropic-version"
,
"2023-06-01"
)
req
.
Header
.
Set
(
"anthropic-beta"
,
claude
.
DefaultBetaHeader
)
// Apply Claude Code client headers
// Apply Claude Code client headers
for
key
,
value
:=
range
claude
.
DefaultHeaders
{
for
key
,
value
:=
range
claude
.
DefaultHeaders
{
...
@@ -259,8 +258,10 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
...
@@ -259,8 +258,10 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
// Set authentication header
// Set authentication header
if
useBearer
{
if
useBearer
{
req
.
Header
.
Set
(
"anthropic-beta"
,
claude
.
DefaultBetaHeader
)
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
authToken
)
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
authToken
)
}
else
{
}
else
{
req
.
Header
.
Set
(
"anthropic-beta"
,
claude
.
APIKeyBetaHeader
)
req
.
Header
.
Set
(
"x-api-key"
,
authToken
)
req
.
Header
.
Set
(
"x-api-key"
,
authToken
)
}
}
...
...
backend/internal/service/admin_service.go
View file @
d367d1cd
...
@@ -36,6 +36,7 @@ type AdminService interface {
...
@@ -36,6 +36,7 @@ type AdminService interface {
UpdateGroup
(
ctx
context
.
Context
,
id
int64
,
input
*
UpdateGroupInput
)
(
*
Group
,
error
)
UpdateGroup
(
ctx
context
.
Context
,
id
int64
,
input
*
UpdateGroupInput
)
(
*
Group
,
error
)
DeleteGroup
(
ctx
context
.
Context
,
id
int64
)
error
DeleteGroup
(
ctx
context
.
Context
,
id
int64
)
error
GetGroupAPIKeys
(
ctx
context
.
Context
,
groupID
int64
,
page
,
pageSize
int
)
([]
APIKey
,
int64
,
error
)
GetGroupAPIKeys
(
ctx
context
.
Context
,
groupID
int64
,
page
,
pageSize
int
)
([]
APIKey
,
int64
,
error
)
UpdateGroupSortOrders
(
ctx
context
.
Context
,
updates
[]
GroupSortOrderUpdate
)
error
// Account management
// Account management
ListAccounts
(
ctx
context
.
Context
,
page
,
pageSize
int
,
platform
,
accountType
,
status
,
search
string
)
([]
Account
,
int64
,
error
)
ListAccounts
(
ctx
context
.
Context
,
page
,
pageSize
int
,
platform
,
accountType
,
status
,
search
string
)
([]
Account
,
int64
,
error
)
...
@@ -1048,6 +1049,10 @@ func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, p
...
@@ -1048,6 +1049,10 @@ func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, p
return
keys
,
result
.
Total
,
nil
return
keys
,
result
.
Total
,
nil
}
}
func
(
s
*
adminServiceImpl
)
UpdateGroupSortOrders
(
ctx
context
.
Context
,
updates
[]
GroupSortOrderUpdate
)
error
{
return
s
.
groupRepo
.
UpdateSortOrders
(
ctx
,
updates
)
}
// Account management implementations
// Account management implementations
func
(
s
*
adminServiceImpl
)
ListAccounts
(
ctx
context
.
Context
,
page
,
pageSize
int
,
platform
,
accountType
,
status
,
search
string
)
([]
Account
,
int64
,
error
)
{
func
(
s
*
adminServiceImpl
)
ListAccounts
(
ctx
context
.
Context
,
page
,
pageSize
int
,
platform
,
accountType
,
status
,
search
string
)
([]
Account
,
int64
,
error
)
{
params
:=
pagination
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
}
params
:=
pagination
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
}
...
...
backend/internal/service/admin_service_delete_test.go
View file @
d367d1cd
...
@@ -172,6 +172,10 @@ func (s *groupRepoStub) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []
...
@@ -172,6 +172,10 @@ func (s *groupRepoStub) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []
panic
(
"unexpected GetAccountIDsByGroupIDs call"
)
panic
(
"unexpected GetAccountIDsByGroupIDs call"
)
}
}
func
(
s
*
groupRepoStub
)
UpdateSortOrders
(
ctx
context
.
Context
,
updates
[]
GroupSortOrderUpdate
)
error
{
return
nil
}
type
proxyRepoStub
struct
{
type
proxyRepoStub
struct
{
deleteErr
error
deleteErr
error
countErr
error
countErr
error
...
...
backend/internal/service/admin_service_group_test.go
View file @
d367d1cd
...
@@ -116,6 +116,10 @@ func (s *groupRepoStubForAdmin) GetAccountIDsByGroupIDs(_ context.Context, _ []i
...
@@ -116,6 +116,10 @@ func (s *groupRepoStubForAdmin) GetAccountIDsByGroupIDs(_ context.Context, _ []i
panic
(
"unexpected GetAccountIDsByGroupIDs call"
)
panic
(
"unexpected GetAccountIDsByGroupIDs call"
)
}
}
func
(
s
*
groupRepoStubForAdmin
)
UpdateSortOrders
(
_
context
.
Context
,
_
[]
GroupSortOrderUpdate
)
error
{
return
nil
}
// TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递
// TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递
func
TestAdminService_CreateGroup_WithImagePricing
(
t
*
testing
.
T
)
{
func
TestAdminService_CreateGroup_WithImagePricing
(
t
*
testing
.
T
)
{
repo
:=
&
groupRepoStubForAdmin
{}
repo
:=
&
groupRepoStubForAdmin
{}
...
@@ -395,6 +399,10 @@ func (s *groupRepoStubForFallbackCycle) GetAccountIDsByGroupIDs(_ context.Contex
...
@@ -395,6 +399,10 @@ func (s *groupRepoStubForFallbackCycle) GetAccountIDsByGroupIDs(_ context.Contex
panic
(
"unexpected GetAccountIDsByGroupIDs call"
)
panic
(
"unexpected GetAccountIDsByGroupIDs call"
)
}
}
func
(
s
*
groupRepoStubForFallbackCycle
)
UpdateSortOrders
(
_
context
.
Context
,
_
[]
GroupSortOrderUpdate
)
error
{
return
nil
}
type
groupRepoStubForInvalidRequestFallback
struct
{
type
groupRepoStubForInvalidRequestFallback
struct
{
groups
map
[
int64
]
*
Group
groups
map
[
int64
]
*
Group
created
*
Group
created
*
Group
...
@@ -466,6 +474,10 @@ func (s *groupRepoStubForInvalidRequestFallback) BindAccountsToGroup(_ context.C
...
@@ -466,6 +474,10 @@ func (s *groupRepoStubForInvalidRequestFallback) BindAccountsToGroup(_ context.C
panic
(
"unexpected BindAccountsToGroup call"
)
panic
(
"unexpected BindAccountsToGroup call"
)
}
}
func
(
s
*
groupRepoStubForInvalidRequestFallback
)
UpdateSortOrders
(
_
context
.
Context
,
_
[]
GroupSortOrderUpdate
)
error
{
return
nil
}
func
TestAdminService_CreateGroup_InvalidRequestFallbackRejectsUnsupportedPlatform
(
t
*
testing
.
T
)
{
func
TestAdminService_CreateGroup_InvalidRequestFallbackRejectsUnsupportedPlatform
(
t
*
testing
.
T
)
{
fallbackID
:=
int64
(
10
)
fallbackID
:=
int64
(
10
)
repo
:=
&
groupRepoStubForInvalidRequestFallback
{
repo
:=
&
groupRepoStubForInvalidRequestFallback
{
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment