Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
de7ff902
Commit
de7ff902
authored
Feb 04, 2026
by
yangjianbo
Browse files
Merge branch 'main' into test
parents
317f26f0
dd96ada3
Changes
90
Show whitespace changes
Inline
Side-by-side
backend/internal/handler/admin/account_handler.go
View file @
de7ff902
...
...
@@ -84,7 +84,7 @@ type CreateAccountRequest struct {
Name
string
`json:"name" binding:"required"`
Notes
*
string
`json:"notes"`
Platform
string
`json:"platform" binding:"required"`
Type
string
`json:"type" binding:"required,oneof=oauth setup-token apikey"`
Type
string
`json:"type" binding:"required,oneof=oauth setup-token apikey
upstream
"`
Credentials
map
[
string
]
any
`json:"credentials" binding:"required"`
Extra
map
[
string
]
any
`json:"extra"`
ProxyID
*
int64
`json:"proxy_id"`
...
...
@@ -102,7 +102,7 @@ type CreateAccountRequest struct {
type
UpdateAccountRequest
struct
{
Name
string
`json:"name"`
Notes
*
string
`json:"notes"`
Type
string
`json:"type" binding:"omitempty,oneof=oauth setup-token apikey"`
Type
string
`json:"type" binding:"omitempty,oneof=oauth setup-token apikey
upstream
"`
Credentials
map
[
string
]
any
`json:"credentials"`
Extra
map
[
string
]
any
`json:"extra"`
ProxyID
*
int64
`json:"proxy_id"`
...
...
backend/internal/handler/admin/group_handler.go
View file @
de7ff902
...
...
@@ -44,9 +44,13 @@ type CreateGroupRequest struct {
SoraVideoPricePerRequestHD
*
float64
`json:"sora_video_price_per_request_hd"`
ClaudeCodeOnly
bool
`json:"claude_code_only"`
FallbackGroupID
*
int64
`json:"fallback_group_id"`
FallbackGroupIDOnInvalidRequest
*
int64
`json:"fallback_group_id_on_invalid_request"`
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting
map
[
string
][]
int64
`json:"model_routing"`
ModelRoutingEnabled
bool
`json:"model_routing_enabled"`
MCPXMLInject
*
bool
`json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes
[]
string
`json:"supported_model_scopes"`
// 从指定分组复制账号(创建后自动绑定)
CopyAccountsFromGroupIDs
[]
int64
`json:"copy_accounts_from_group_ids"`
}
...
...
@@ -73,9 +77,13 @@ type UpdateGroupRequest struct {
SoraVideoPricePerRequestHD
*
float64
`json:"sora_video_price_per_request_hd"`
ClaudeCodeOnly
*
bool
`json:"claude_code_only"`
FallbackGroupID
*
int64
`json:"fallback_group_id"`
FallbackGroupIDOnInvalidRequest
*
int64
`json:"fallback_group_id_on_invalid_request"`
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting
map
[
string
][]
int64
`json:"model_routing"`
ModelRoutingEnabled
*
bool
`json:"model_routing_enabled"`
MCPXMLInject
*
bool
`json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes
*
[]
string
`json:"supported_model_scopes"`
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs
[]
int64
`json:"copy_accounts_from_group_ids"`
}
...
...
@@ -185,8 +193,11 @@ func (h *GroupHandler) Create(c *gin.Context) {
SoraVideoPricePerRequestHD
:
req
.
SoraVideoPricePerRequestHD
,
ClaudeCodeOnly
:
req
.
ClaudeCodeOnly
,
FallbackGroupID
:
req
.
FallbackGroupID
,
FallbackGroupIDOnInvalidRequest
:
req
.
FallbackGroupIDOnInvalidRequest
,
ModelRouting
:
req
.
ModelRouting
,
ModelRoutingEnabled
:
req
.
ModelRoutingEnabled
,
MCPXMLInject
:
req
.
MCPXMLInject
,
SupportedModelScopes
:
req
.
SupportedModelScopes
,
CopyAccountsFromGroupIDs
:
req
.
CopyAccountsFromGroupIDs
,
})
if
err
!=
nil
{
...
...
@@ -232,8 +243,11 @@ func (h *GroupHandler) Update(c *gin.Context) {
SoraVideoPricePerRequestHD
:
req
.
SoraVideoPricePerRequestHD
,
ClaudeCodeOnly
:
req
.
ClaudeCodeOnly
,
FallbackGroupID
:
req
.
FallbackGroupID
,
FallbackGroupIDOnInvalidRequest
:
req
.
FallbackGroupIDOnInvalidRequest
,
ModelRouting
:
req
.
ModelRouting
,
ModelRoutingEnabled
:
req
.
ModelRoutingEnabled
,
MCPXMLInject
:
req
.
MCPXMLInject
,
SupportedModelScopes
:
req
.
SupportedModelScopes
,
CopyAccountsFromGroupIDs
:
req
.
CopyAccountsFromGroupIDs
,
})
if
err
!=
nil
{
...
...
backend/internal/handler/api_key_handler.go
View file @
de7ff902
...
...
@@ -3,6 +3,7 @@ package handler
import
(
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
...
...
@@ -32,6 +33,8 @@ type CreateAPIKeyRequest struct {
CustomKey
*
string
`json:"custom_key"`
// 可选的自定义key
IPWhitelist
[]
string
`json:"ip_whitelist"`
// IP 白名单
IPBlacklist
[]
string
`json:"ip_blacklist"`
// IP 黑名单
Quota
*
float64
`json:"quota"`
// 配额限制 (USD)
ExpiresInDays
*
int
`json:"expires_in_days"`
// 过期天数
}
// UpdateAPIKeyRequest represents the update API key request payload
...
...
@@ -41,6 +44,9 @@ type UpdateAPIKeyRequest struct {
Status
string
`json:"status" binding:"omitempty,oneof=active inactive"`
IPWhitelist
[]
string
`json:"ip_whitelist"`
// IP 白名单
IPBlacklist
[]
string
`json:"ip_blacklist"`
// IP 黑名单
Quota
*
float64
`json:"quota"`
// 配额限制 (USD), 0=无限制
ExpiresAt
*
string
`json:"expires_at"`
// 过期时间 (ISO 8601)
ResetQuota
*
bool
`json:"reset_quota"`
// 重置已用配额
}
// List handles listing user's API keys with pagination
...
...
@@ -119,6 +125,10 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
CustomKey
:
req
.
CustomKey
,
IPWhitelist
:
req
.
IPWhitelist
,
IPBlacklist
:
req
.
IPBlacklist
,
ExpiresInDays
:
req
.
ExpiresInDays
,
}
if
req
.
Quota
!=
nil
{
svcReq
.
Quota
=
*
req
.
Quota
}
key
,
err
:=
h
.
apiKeyService
.
Create
(
c
.
Request
.
Context
(),
subject
.
UserID
,
svcReq
)
if
err
!=
nil
{
...
...
@@ -153,6 +163,8 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
svcReq
:=
service
.
UpdateAPIKeyRequest
{
IPWhitelist
:
req
.
IPWhitelist
,
IPBlacklist
:
req
.
IPBlacklist
,
Quota
:
req
.
Quota
,
ResetQuota
:
req
.
ResetQuota
,
}
if
req
.
Name
!=
""
{
svcReq
.
Name
=
&
req
.
Name
...
...
@@ -161,6 +173,21 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
if
req
.
Status
!=
""
{
svcReq
.
Status
=
&
req
.
Status
}
// Parse expires_at if provided
if
req
.
ExpiresAt
!=
nil
{
if
*
req
.
ExpiresAt
==
""
{
// Empty string means clear expiration
svcReq
.
ExpiresAt
=
nil
svcReq
.
ClearExpiration
=
true
}
else
{
t
,
err
:=
time
.
Parse
(
time
.
RFC3339
,
*
req
.
ExpiresAt
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid expires_at format: "
+
err
.
Error
())
return
}
svcReq
.
ExpiresAt
=
&
t
}
}
key
,
err
:=
h
.
apiKeyService
.
Update
(
c
.
Request
.
Context
(),
keyID
,
subject
.
UserID
,
svcReq
)
if
err
!=
nil
{
...
...
backend/internal/handler/dto/mappers.go
View file @
de7ff902
...
...
@@ -76,6 +76,9 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
Status
:
k
.
Status
,
IPWhitelist
:
k
.
IPWhitelist
,
IPBlacklist
:
k
.
IPBlacklist
,
Quota
:
k
.
Quota
,
QuotaUsed
:
k
.
QuotaUsed
,
ExpiresAt
:
k
.
ExpiresAt
,
CreatedAt
:
k
.
CreatedAt
,
UpdatedAt
:
k
.
UpdatedAt
,
User
:
UserFromServiceShallow
(
k
.
User
),
...
...
@@ -108,6 +111,8 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
Group
:
groupFromServiceBase
(
g
),
ModelRouting
:
g
.
ModelRouting
,
ModelRoutingEnabled
:
g
.
ModelRoutingEnabled
,
MCPXMLInject
:
g
.
MCPXMLInject
,
SupportedModelScopes
:
g
.
SupportedModelScopes
,
AccountCount
:
g
.
AccountCount
,
}
if
len
(
g
.
AccountGroups
)
>
0
{
...
...
@@ -142,6 +147,7 @@ func groupFromServiceBase(g *service.Group) Group {
SoraVideoPricePerRequestHD
:
g
.
SoraVideoPricePerRequestHD
,
ClaudeCodeOnly
:
g
.
ClaudeCodeOnly
,
FallbackGroupID
:
g
.
FallbackGroupID
,
FallbackGroupIDOnInvalidRequest
:
g
.
FallbackGroupIDOnInvalidRequest
,
CreatedAt
:
g
.
CreatedAt
,
UpdatedAt
:
g
.
UpdatedAt
,
}
...
...
backend/internal/handler/dto/types.go
View file @
de7ff902
...
...
@@ -40,6 +40,9 @@ type APIKey struct {
Status
string
`json:"status"`
IPWhitelist
[]
string
`json:"ip_whitelist"`
IPBlacklist
[]
string
`json:"ip_blacklist"`
Quota
float64
`json:"quota"`
// Quota limit in USD (0 = unlimited)
QuotaUsed
float64
`json:"quota_used"`
// Used quota amount in USD
ExpiresAt
*
time
.
Time
`json:"expires_at"`
// Expiration time (nil = never expires)
CreatedAt
time
.
Time
`json:"created_at"`
UpdatedAt
time
.
Time
`json:"updated_at"`
...
...
@@ -75,6 +78,8 @@ type Group struct {
// Claude Code 客户端限制
ClaudeCodeOnly
bool
`json:"claude_code_only"`
FallbackGroupID
*
int64
`json:"fallback_group_id"`
// 无效请求兜底分组
FallbackGroupIDOnInvalidRequest
*
int64
`json:"fallback_group_id_on_invalid_request"`
CreatedAt
time
.
Time
`json:"created_at"`
UpdatedAt
time
.
Time
`json:"updated_at"`
...
...
@@ -89,6 +94,11 @@ type AdminGroup struct {
ModelRouting
map
[
string
][]
int64
`json:"model_routing"`
ModelRoutingEnabled
bool
`json:"model_routing_enabled"`
// MCP XML 协议注入(仅 antigravity 平台使用)
MCPXMLInject
bool
`json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes
[]
string
`json:"supported_model_scopes"`
AccountGroups
[]
AccountGroup
`json:"account_groups,omitempty"`
AccountCount
int64
`json:"account_count,omitempty"`
}
...
...
backend/internal/handler/gateway_handler.go
View file @
de7ff902
...
...
@@ -14,6 +14,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
pkgerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
...
...
@@ -31,6 +32,7 @@ type GatewayHandler struct {
userService
*
service
.
UserService
billingCacheService
*
service
.
BillingCacheService
usageService
*
service
.
UsageService
apiKeyService
*
service
.
APIKeyService
concurrencyHelper
*
ConcurrencyHelper
maxAccountSwitches
int
maxAccountSwitchesGemini
int
...
...
@@ -46,6 +48,7 @@ func NewGatewayHandler(
concurrencyService
*
service
.
ConcurrencyService
,
billingCacheService
*
service
.
BillingCacheService
,
usageService
*
service
.
UsageService
,
apiKeyService
*
service
.
APIKeyService
,
cfg
*
config
.
Config
,
)
*
GatewayHandler
{
pingInterval
:=
time
.
Duration
(
0
)
...
...
@@ -67,6 +70,7 @@ func NewGatewayHandler(
userService
:
userService
,
billingCacheService
:
billingCacheService
,
usageService
:
usageService
,
apiKeyService
:
apiKeyService
,
concurrencyHelper
:
NewConcurrencyHelper
(
concurrencyService
,
SSEPingFormatClaude
,
pingInterval
),
maxAccountSwitches
:
maxAccountSwitches
,
maxAccountSwitchesGemini
:
maxAccountSwitchesGemini
,
...
...
@@ -283,10 +287,14 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 转发请求 - 根据账号平台分流
var
result
*
service
.
ForwardResult
requestCtx
:=
c
.
Request
.
Context
()
if
switchCount
>
0
{
requestCtx
=
context
.
WithValue
(
requestCtx
,
ctxkey
.
AccountSwitchCount
,
switchCount
)
}
if
account
.
Platform
==
service
.
PlatformAntigravity
{
result
,
err
=
h
.
antigravityGatewayService
.
ForwardGemini
(
c
.
R
equest
.
Context
()
,
c
,
account
,
reqModel
,
"generateContent"
,
reqStream
,
body
)
result
,
err
=
h
.
antigravityGatewayService
.
ForwardGemini
(
r
equest
Ctx
,
c
,
account
,
reqModel
,
"generateContent"
,
reqStream
,
body
)
}
else
{
result
,
err
=
h
.
geminiCompatService
.
Forward
(
c
.
R
equest
.
Context
()
,
c
,
account
,
body
)
result
,
err
=
h
.
geminiCompatService
.
Forward
(
r
equest
Ctx
,
c
,
account
,
body
)
}
if
accountReleaseFunc
!=
nil
{
accountReleaseFunc
()
...
...
@@ -325,6 +333,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
Subscription
:
subscription
,
UserAgent
:
ua
,
IPAddress
:
clientIP
,
APIKeyService
:
h
.
apiKeyService
,
});
err
!=
nil
{
log
.
Printf
(
"Record usage failed: %v"
,
err
)
}
...
...
@@ -333,14 +342,24 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
}
currentAPIKey
:=
apiKey
currentSubscription
:=
subscription
var
fallbackGroupID
*
int64
if
apiKey
.
Group
!=
nil
{
fallbackGroupID
=
apiKey
.
Group
.
FallbackGroupIDOnInvalidRequest
}
fallbackUsed
:=
false
for
{
maxAccountSwitches
:=
h
.
maxAccountSwitches
switchCount
:=
0
failedAccountIDs
:=
make
(
map
[
int64
]
struct
{})
lastFailoverStatus
:=
0
retryWithFallback
:=
false
for
{
// 选择支持该模型的账号
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
api
Key
.
GroupID
,
sessionKey
,
reqModel
,
failedAccountIDs
,
parsedReq
.
MetadataUserID
)
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
currentAPI
Key
.
GroupID
,
sessionKey
,
reqModel
,
failedAccountIDs
,
parsedReq
.
MetadataUserID
)
if
err
!=
nil
{
if
len
(
failedAccountIDs
)
==
0
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts: "
+
err
.
Error
(),
streamStarted
)
...
...
@@ -410,7 +429,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h
.
concurrencyHelper
.
DecrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
)
accountWaitCounted
=
false
}
if
err
:=
h
.
gatewayService
.
BindStickySession
(
c
.
Request
.
Context
(),
api
Key
.
GroupID
,
sessionKey
,
account
.
ID
);
err
!=
nil
{
if
err
:=
h
.
gatewayService
.
BindStickySession
(
c
.
Request
.
Context
(),
currentAPI
Key
.
GroupID
,
sessionKey
,
account
.
ID
);
err
!=
nil
{
log
.
Printf
(
"Bind sticky session failed: %v"
,
err
)
}
}
...
...
@@ -419,15 +438,54 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 转发请求 - 根据账号平台分流
var
result
*
service
.
ForwardResult
requestCtx
:=
c
.
Request
.
Context
()
if
switchCount
>
0
{
requestCtx
=
context
.
WithValue
(
requestCtx
,
ctxkey
.
AccountSwitchCount
,
switchCount
)
}
if
account
.
Platform
==
service
.
PlatformAntigravity
{
result
,
err
=
h
.
antigravityGatewayService
.
Forward
(
c
.
R
equest
.
Context
()
,
c
,
account
,
body
)
result
,
err
=
h
.
antigravityGatewayService
.
Forward
(
r
equest
Ctx
,
c
,
account
,
body
)
}
else
{
result
,
err
=
h
.
gatewayService
.
Forward
(
c
.
R
equest
.
Context
()
,
c
,
account
,
parsedReq
)
result
,
err
=
h
.
gatewayService
.
Forward
(
r
equest
Ctx
,
c
,
account
,
parsedReq
)
}
if
accountReleaseFunc
!=
nil
{
accountReleaseFunc
()
}
if
err
!=
nil
{
var
promptTooLongErr
*
service
.
PromptTooLongError
if
errors
.
As
(
err
,
&
promptTooLongErr
)
{
log
.
Printf
(
"Prompt too long from antigravity: group=%d fallback_group_id=%v fallback_used=%v"
,
currentAPIKey
.
GroupID
,
fallbackGroupID
,
fallbackUsed
)
if
!
fallbackUsed
&&
fallbackGroupID
!=
nil
&&
*
fallbackGroupID
>
0
{
fallbackGroup
,
err
:=
h
.
gatewayService
.
ResolveGroupByID
(
c
.
Request
.
Context
(),
*
fallbackGroupID
)
if
err
!=
nil
{
log
.
Printf
(
"Resolve fallback group failed: %v"
,
err
)
_
=
h
.
antigravityGatewayService
.
WriteMappedClaudeError
(
c
,
account
,
promptTooLongErr
.
StatusCode
,
promptTooLongErr
.
RequestID
,
promptTooLongErr
.
Body
)
return
}
if
fallbackGroup
.
Platform
!=
service
.
PlatformAnthropic
||
fallbackGroup
.
SubscriptionType
==
service
.
SubscriptionTypeSubscription
||
fallbackGroup
.
FallbackGroupIDOnInvalidRequest
!=
nil
{
log
.
Printf
(
"Fallback group invalid: group=%d platform=%s subscription=%s"
,
fallbackGroup
.
ID
,
fallbackGroup
.
Platform
,
fallbackGroup
.
SubscriptionType
)
_
=
h
.
antigravityGatewayService
.
WriteMappedClaudeError
(
c
,
account
,
promptTooLongErr
.
StatusCode
,
promptTooLongErr
.
RequestID
,
promptTooLongErr
.
Body
)
return
}
fallbackAPIKey
:=
cloneAPIKeyWithGroup
(
apiKey
,
fallbackGroup
)
if
err
:=
h
.
billingCacheService
.
CheckBillingEligibility
(
c
.
Request
.
Context
(),
fallbackAPIKey
.
User
,
fallbackAPIKey
,
fallbackGroup
,
nil
);
err
!=
nil
{
status
,
code
,
message
:=
billingErrorDetails
(
err
)
h
.
handleStreamingAwareError
(
c
,
status
,
code
,
message
,
streamStarted
)
return
}
// 兜底重试按“直接请求兜底分组”处理:清除强制平台,允许按分组平台调度
ctx
:=
context
.
WithValue
(
c
.
Request
.
Context
(),
ctxkey
.
ForcePlatform
,
""
)
c
.
Request
=
c
.
Request
.
WithContext
(
ctx
)
currentAPIKey
=
fallbackAPIKey
currentSubscription
=
nil
fallbackUsed
=
true
retryWithFallback
=
true
break
}
_
=
h
.
antigravityGatewayService
.
WriteMappedClaudeError
(
c
,
account
,
promptTooLongErr
.
StatusCode
,
promptTooLongErr
.
RequestID
,
promptTooLongErr
.
Body
)
return
}
var
failoverErr
*
service
.
UpstreamFailoverError
if
errors
.
As
(
err
,
&
failoverErr
)
{
failedAccountIDs
[
account
.
ID
]
=
struct
{}{}
...
...
@@ -455,18 +513,23 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
defer
cancel
()
if
err
:=
h
.
gatewayService
.
RecordUsage
(
ctx
,
&
service
.
RecordUsageInput
{
Result
:
result
,
APIKey
:
api
Key
,
User
:
api
Key
.
User
,
APIKey
:
currentAPI
Key
,
User
:
currentAPI
Key
.
User
,
Account
:
usedAccount
,
Subscription
:
s
ubscription
,
Subscription
:
currentS
ubscription
,
UserAgent
:
ua
,
IPAddress
:
clientIP
,
APIKeyService
:
h
.
apiKeyService
,
});
err
!=
nil
{
log
.
Printf
(
"Record usage failed: %v"
,
err
)
}
}(
result
,
account
,
userAgent
,
clientIP
)
return
}
if
!
retryWithFallback
{
return
}
}
}
// Models handles listing available models
...
...
@@ -540,6 +603,17 @@ func (h *GatewayHandler) AntigravityModels(c *gin.Context) {
})
}
func
cloneAPIKeyWithGroup
(
apiKey
*
service
.
APIKey
,
group
*
service
.
Group
)
*
service
.
APIKey
{
if
apiKey
==
nil
||
group
==
nil
{
return
apiKey
}
cloned
:=
*
apiKey
groupID
:=
group
.
ID
cloned
.
GroupID
=
&
groupID
cloned
.
Group
=
group
return
&
cloned
}
// Usage handles getting account balance and usage statistics for CC Switch integration
// GET /v1/usage
func
(
h
*
GatewayHandler
)
Usage
(
c
*
gin
.
Context
)
{
...
...
backend/internal/handler/gemini_v1beta_handler.go
View file @
de7ff902
...
...
@@ -14,6 +14,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
...
...
@@ -335,10 +336,14 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 5) forward (根据平台分流)
var
result
*
service
.
ForwardResult
requestCtx
:=
c
.
Request
.
Context
()
if
switchCount
>
0
{
requestCtx
=
context
.
WithValue
(
requestCtx
,
ctxkey
.
AccountSwitchCount
,
switchCount
)
}
if
account
.
Platform
==
service
.
PlatformAntigravity
{
result
,
err
=
h
.
antigravityGatewayService
.
ForwardGemini
(
c
.
R
equest
.
Context
()
,
c
,
account
,
modelName
,
action
,
stream
,
body
)
result
,
err
=
h
.
antigravityGatewayService
.
ForwardGemini
(
r
equest
Ctx
,
c
,
account
,
modelName
,
action
,
stream
,
body
)
}
else
{
result
,
err
=
h
.
geminiCompatService
.
ForwardNative
(
c
.
R
equest
.
Context
()
,
c
,
account
,
modelName
,
action
,
stream
,
body
)
result
,
err
=
h
.
geminiCompatService
.
ForwardNative
(
r
equest
Ctx
,
c
,
account
,
modelName
,
action
,
stream
,
body
)
}
if
accountReleaseFunc
!=
nil
{
accountReleaseFunc
()
...
...
@@ -381,6 +386,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
IPAddress
:
ip
,
LongContextThreshold
:
200000
,
// Gemini 200K 阈值
LongContextMultiplier
:
2.0
,
// 超出部分双倍计费
APIKeyService
:
h
.
apiKeyService
,
});
err
!=
nil
{
log
.
Printf
(
"Record usage failed: %v"
,
err
)
}
...
...
backend/internal/handler/openai_gateway_handler.go
View file @
de7ff902
...
...
@@ -24,6 +24,7 @@ import (
type
OpenAIGatewayHandler
struct
{
gatewayService
*
service
.
OpenAIGatewayService
billingCacheService
*
service
.
BillingCacheService
apiKeyService
*
service
.
APIKeyService
concurrencyHelper
*
ConcurrencyHelper
maxAccountSwitches
int
}
...
...
@@ -33,6 +34,7 @@ func NewOpenAIGatewayHandler(
gatewayService
*
service
.
OpenAIGatewayService
,
concurrencyService
*
service
.
ConcurrencyService
,
billingCacheService
*
service
.
BillingCacheService
,
apiKeyService
*
service
.
APIKeyService
,
cfg
*
config
.
Config
,
)
*
OpenAIGatewayHandler
{
pingInterval
:=
time
.
Duration
(
0
)
...
...
@@ -46,6 +48,7 @@ func NewOpenAIGatewayHandler(
return
&
OpenAIGatewayHandler
{
gatewayService
:
gatewayService
,
billingCacheService
:
billingCacheService
,
apiKeyService
:
apiKeyService
,
concurrencyHelper
:
NewConcurrencyHelper
(
concurrencyService
,
SSEPingFormatComment
,
pingInterval
),
maxAccountSwitches
:
maxAccountSwitches
,
}
...
...
@@ -306,6 +309,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
Subscription
:
subscription
,
UserAgent
:
ua
,
IPAddress
:
ip
,
APIKeyService
:
h
.
apiKeyService
,
});
err
!=
nil
{
log
.
Printf
(
"Record usage failed: %v"
,
err
)
}
...
...
backend/internal/pkg/antigravity/oauth.go
View file @
de7ff902
...
...
@@ -40,17 +40,48 @@ const (
// URL 可用性 TTL(不可用 URL 的恢复时间)
URLAvailabilityTTL
=
5
*
time
.
Minute
// Antigravity API 端点
antigravityProdBaseURL
=
"https://cloudcode-pa.googleapis.com"
antigravityDailyBaseURL
=
"https://daily-cloudcode-pa.sandbox.googleapis.com"
)
// BaseURLs 定义 Antigravity API 端点(与 Antigravity-Manager 保持一致)
var
BaseURLs
=
[]
string
{
"https://cloudcode-pa.googleapis.com"
,
// prod (优先)
"https://daily-cloudcode-pa.sandbox.googleapis.com"
,
// daily sandbox (备用)
antigravityProdBaseURL
,
// prod (优先)
antigravityDailyBaseURL
,
// daily sandbox (备用)
}
// BaseURL 默认 URL(保持向后兼容)
var
BaseURL
=
BaseURLs
[
0
]
// ForwardBaseURLs 返回 API 转发用的 URL 顺序(daily 优先)
func
ForwardBaseURLs
()
[]
string
{
if
len
(
BaseURLs
)
==
0
{
return
nil
}
urls
:=
append
([]
string
(
nil
),
BaseURLs
...
)
dailyIndex
:=
-
1
for
i
,
url
:=
range
urls
{
if
url
==
antigravityDailyBaseURL
{
dailyIndex
=
i
break
}
}
if
dailyIndex
<=
0
{
return
urls
}
reordered
:=
make
([]
string
,
0
,
len
(
urls
))
reordered
=
append
(
reordered
,
urls
[
dailyIndex
])
for
i
,
url
:=
range
urls
{
if
i
==
dailyIndex
{
continue
}
reordered
=
append
(
reordered
,
url
)
}
return
reordered
}
// URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复和动态优先级)
type
URLAvailability
struct
{
mu
sync
.
RWMutex
...
...
@@ -100,22 +131,37 @@ func (u *URLAvailability) IsAvailable(url string) bool {
// GetAvailableURLs 返回可用的 URL 列表
// 最近成功的 URL 优先,其他按默认顺序
func
(
u
*
URLAvailability
)
GetAvailableURLs
()
[]
string
{
return
u
.
GetAvailableURLsWithBase
(
BaseURLs
)
}
// GetAvailableURLsWithBase 返回可用的 URL 列表(使用自定义顺序)
// 最近成功的 URL 优先,其他按传入顺序
func
(
u
*
URLAvailability
)
GetAvailableURLsWithBase
(
baseURLs
[]
string
)
[]
string
{
u
.
mu
.
RLock
()
defer
u
.
mu
.
RUnlock
()
now
:=
time
.
Now
()
result
:=
make
([]
string
,
0
,
len
(
B
aseURLs
))
result
:=
make
([]
string
,
0
,
len
(
b
aseURLs
))
// 如果有最近成功的 URL 且可用,放在最前面
if
u
.
lastSuccess
!=
""
{
found
:=
false
for
_
,
url
:=
range
baseURLs
{
if
url
==
u
.
lastSuccess
{
found
=
true
break
}
}
if
found
{
expiry
,
exists
:=
u
.
unavailable
[
u
.
lastSuccess
]
if
!
exists
||
now
.
After
(
expiry
)
{
result
=
append
(
result
,
u
.
lastSuccess
)
}
}
}
// 添加其他可用的 URL(按
默认
顺序)
for
_
,
url
:=
range
B
aseURLs
{
// 添加其他可用的 URL(按
传入
顺序)
for
_
,
url
:=
range
b
aseURLs
{
// 跳过已添加的 lastSuccess
if
url
==
u
.
lastSuccess
{
continue
...
...
backend/internal/pkg/antigravity/request_transformer.go
View file @
de7ff902
...
...
@@ -44,11 +44,13 @@ type TransformOptions struct {
// IdentityPatch 可选:自定义注入到 systemInstruction 开头的身份防护提示词;
// 为空时使用默认模板(包含 [IDENTITY_PATCH] 及 SYSTEM_PROMPT_BEGIN 标记)。
IdentityPatch
string
EnableMCPXML
bool
}
func
DefaultTransformOptions
()
TransformOptions
{
return
TransformOptions
{
EnableIdentityPatch
:
true
,
EnableMCPXML
:
true
,
}
}
...
...
@@ -257,8 +259,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
// 添加用户的 system prompt
parts
=
append
(
parts
,
userSystemParts
...
)
// 检测是否有 MCP 工具,如有则注入 XML 调用协议
if
hasMCPTools
(
tools
)
{
// 检测是否有 MCP 工具,如有
且启用了 MCP XML 注入
则注入 XML 调用协议
if
opts
.
EnableMCPXML
&&
hasMCPTools
(
tools
)
{
parts
=
append
(
parts
,
GeminiPart
{
Text
:
mcpXMLProtocol
})
}
...
...
@@ -312,7 +314,7 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
parts
=
append
([]
GeminiPart
{{
Text
:
"Thinking..."
,
Thought
:
true
,
ThoughtSignature
:
d
ummyThoughtSignature
,
ThoughtSignature
:
D
ummyThoughtSignature
,
}},
parts
...
)
}
}
...
...
@@ -330,9 +332,10 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
return
contents
,
strippedThinking
,
nil
}
//
d
ummyThoughtSignature 用于跳过 Gemini 3 thought_signature 验证
//
D
ummyThoughtSignature 用于跳过 Gemini 3 thought_signature 验证
// 参考: https://ai.google.dev/gemini-api/docs/thought-signatures
const
dummyThoughtSignature
=
"skip_thought_signature_validator"
// 导出供跨包使用(如 gemini_native_signature_cleaner 跨账号修复)
const
DummyThoughtSignature
=
"skip_thought_signature_validator"
// buildParts 构建消息的 parts
// allowDummyThought: 只有 Gemini 模型支持 dummy thought signature
...
...
@@ -370,7 +373,7 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
// signature 处理:
// - Claude 模型(allowDummyThought=false):必须是上游返回的真实 signature(dummy 视为缺失)
// - Gemini 模型(allowDummyThought=true):优先透传真实 signature,缺失时使用 dummy signature
if
block
.
Signature
!=
""
&&
(
allowDummyThought
||
block
.
Signature
!=
d
ummyThoughtSignature
)
{
if
block
.
Signature
!=
""
&&
(
allowDummyThought
||
block
.
Signature
!=
D
ummyThoughtSignature
)
{
part
.
ThoughtSignature
=
block
.
Signature
}
else
if
!
allowDummyThought
{
// Claude 模型需要有效 signature;在缺失时降级为普通文本,并在上层禁用 thinking mode。
...
...
@@ -381,7 +384,7 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
continue
}
else
{
// Gemini 模型使用 dummy signature
part
.
ThoughtSignature
=
d
ummyThoughtSignature
part
.
ThoughtSignature
=
D
ummyThoughtSignature
}
parts
=
append
(
parts
,
part
)
...
...
@@ -411,10 +414,10 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
// tool_use 的 signature 处理:
// - Claude 模型(allowDummyThought=false):必须是上游返回的真实 signature(dummy 视为缺失)
// - Gemini 模型(allowDummyThought=true):优先透传真实 signature,缺失时使用 dummy signature
if
block
.
Signature
!=
""
&&
(
allowDummyThought
||
block
.
Signature
!=
d
ummyThoughtSignature
)
{
if
block
.
Signature
!=
""
&&
(
allowDummyThought
||
block
.
Signature
!=
D
ummyThoughtSignature
)
{
part
.
ThoughtSignature
=
block
.
Signature
}
else
if
allowDummyThought
{
part
.
ThoughtSignature
=
d
ummyThoughtSignature
part
.
ThoughtSignature
=
D
ummyThoughtSignature
}
parts
=
append
(
parts
,
part
)
...
...
@@ -492,9 +495,23 @@ func parseToolResultContent(content json.RawMessage, isError bool) string {
}
// buildGenerationConfig 构建 generationConfig
const
(
defaultMaxOutputTokens
=
64000
maxOutputTokensUpperBound
=
65000
maxOutputTokensClaude
=
64000
)
func
maxOutputTokensLimit
(
model
string
)
int
{
if
strings
.
HasPrefix
(
model
,
"claude-"
)
{
return
maxOutputTokensClaude
}
return
maxOutputTokensUpperBound
}
func
buildGenerationConfig
(
req
*
ClaudeRequest
)
*
GeminiGenerationConfig
{
maxLimit
:=
maxOutputTokensLimit
(
req
.
Model
)
config
:=
&
GeminiGenerationConfig
{
MaxOutputTokens
:
64000
,
// 默认最大输出
MaxOutputTokens
:
defaultMaxOutputTokens
,
// 默认最大输出
StopSequences
:
DefaultStopSequences
,
}
...
...
@@ -518,6 +535,10 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
}
}
if
config
.
MaxOutputTokens
>
maxLimit
{
config
.
MaxOutputTokens
=
maxLimit
}
// 其他参数
if
req
.
Temperature
!=
nil
{
config
.
Temperature
=
req
.
Temperature
...
...
backend/internal/pkg/antigravity/request_transformer_test.go
View file @
de7ff902
...
...
@@ -86,7 +86,7 @@ func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) {
if
len
(
parts
)
!=
3
{
t
.
Fatalf
(
"expected 3 parts, got %d"
,
len
(
parts
))
}
if
!
parts
[
1
]
.
Thought
||
parts
[
1
]
.
ThoughtSignature
!=
d
ummyThoughtSignature
{
if
!
parts
[
1
]
.
Thought
||
parts
[
1
]
.
ThoughtSignature
!=
D
ummyThoughtSignature
{
t
.
Fatalf
(
"expected dummy thought signature, got thought=%v signature=%q"
,
parts
[
1
]
.
Thought
,
parts
[
1
]
.
ThoughtSignature
)
}
...
...
@@ -126,8 +126,8 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
if
len
(
parts
)
!=
1
||
parts
[
0
]
.
FunctionCall
==
nil
{
t
.
Fatalf
(
"expected 1 functionCall part, got %+v"
,
parts
)
}
if
parts
[
0
]
.
ThoughtSignature
!=
d
ummyThoughtSignature
{
t
.
Fatalf
(
"expected dummy tool signature %q, got %q"
,
d
ummyThoughtSignature
,
parts
[
0
]
.
ThoughtSignature
)
if
parts
[
0
]
.
ThoughtSignature
!=
D
ummyThoughtSignature
{
t
.
Fatalf
(
"expected dummy tool signature %q, got %q"
,
D
ummyThoughtSignature
,
parts
[
0
]
.
ThoughtSignature
)
}
})
...
...
backend/internal/pkg/ctxkey/ctxkey.go
View file @
de7ff902
...
...
@@ -14,6 +14,9 @@ const (
// RetryCount 表示当前请求在网关层的重试次数(用于 Ops 记录与排障)。
RetryCount
Key
=
"ctx_retry_count"
// AccountSwitchCount 表示请求过程中发生的账号切换次数
AccountSwitchCount
Key
=
"ctx_account_switch_count"
// IsClaudeCodeClient 标识当前请求是否来自 Claude Code 客户端
IsClaudeCodeClient
Key
=
"ctx_is_claude_code_client"
// Group 认证后的分组信息,由 API Key 认证中间件设置
...
...
backend/internal/repository/api_key_repo.go
View file @
de7ff902
...
...
@@ -33,7 +33,10 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) erro
SetKey
(
key
.
Key
)
.
SetName
(
key
.
Name
)
.
SetStatus
(
key
.
Status
)
.
SetNillableGroupID
(
key
.
GroupID
)
SetNillableGroupID
(
key
.
GroupID
)
.
SetQuota
(
key
.
Quota
)
.
SetQuotaUsed
(
key
.
QuotaUsed
)
.
SetNillableExpiresAt
(
key
.
ExpiresAt
)
if
len
(
key
.
IPWhitelist
)
>
0
{
builder
.
SetIPWhitelist
(
key
.
IPWhitelist
)
...
...
@@ -110,6 +113,9 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
apikey
.
FieldStatus
,
apikey
.
FieldIPWhitelist
,
apikey
.
FieldIPBlacklist
,
apikey
.
FieldQuota
,
apikey
.
FieldQuotaUsed
,
apikey
.
FieldExpiresAt
,
)
.
WithUser
(
func
(
q
*
dbent
.
UserQuery
)
{
q
.
Select
(
...
...
@@ -140,8 +146,11 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
group
.
FieldSoraVideoPricePerRequestHd
,
group
.
FieldClaudeCodeOnly
,
group
.
FieldFallbackGroupID
,
group
.
FieldFallbackGroupIDOnInvalidRequest
,
group
.
FieldModelRoutingEnabled
,
group
.
FieldModelRouting
,
group
.
FieldMcpXMLInject
,
group
.
FieldSupportedModelScopes
,
)
})
.
Only
(
ctx
)
...
...
@@ -165,6 +174,8 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
Where
(
apikey
.
IDEQ
(
key
.
ID
),
apikey
.
DeletedAtIsNil
())
.
SetName
(
key
.
Name
)
.
SetStatus
(
key
.
Status
)
.
SetQuota
(
key
.
Quota
)
.
SetQuotaUsed
(
key
.
QuotaUsed
)
.
SetUpdatedAt
(
now
)
if
key
.
GroupID
!=
nil
{
builder
.
SetGroupID
(
*
key
.
GroupID
)
...
...
@@ -172,6 +183,13 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
builder
.
ClearGroupID
()
}
// Expiration time
if
key
.
ExpiresAt
!=
nil
{
builder
.
SetExpiresAt
(
*
key
.
ExpiresAt
)
}
else
{
builder
.
ClearExpiresAt
()
}
// IP 限制字段
if
len
(
key
.
IPWhitelist
)
>
0
{
builder
.
SetIPWhitelist
(
key
.
IPWhitelist
)
...
...
@@ -361,6 +379,38 @@ func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64)
return
keys
,
nil
}
// IncrementQuotaUsed atomically increments the quota_used field and returns the new value
func
(
r
*
apiKeyRepository
)
IncrementQuotaUsed
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
(
float64
,
error
)
{
// Use raw SQL for atomic increment to avoid race conditions
// First get current value
m
,
err
:=
r
.
activeQuery
()
.
Where
(
apikey
.
IDEQ
(
id
))
.
Select
(
apikey
.
FieldQuotaUsed
)
.
Only
(
ctx
)
if
err
!=
nil
{
if
dbent
.
IsNotFound
(
err
)
{
return
0
,
service
.
ErrAPIKeyNotFound
}
return
0
,
err
}
newValue
:=
m
.
QuotaUsed
+
amount
// Update with new value
affected
,
err
:=
r
.
client
.
APIKey
.
Update
()
.
Where
(
apikey
.
IDEQ
(
id
),
apikey
.
DeletedAtIsNil
())
.
SetQuotaUsed
(
newValue
)
.
Save
(
ctx
)
if
err
!=
nil
{
return
0
,
err
}
if
affected
==
0
{
return
0
,
service
.
ErrAPIKeyNotFound
}
return
newValue
,
nil
}
func
apiKeyEntityToService
(
m
*
dbent
.
APIKey
)
*
service
.
APIKey
{
if
m
==
nil
{
return
nil
...
...
@@ -376,6 +426,9 @@ func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
CreatedAt
:
m
.
CreatedAt
,
UpdatedAt
:
m
.
UpdatedAt
,
GroupID
:
m
.
GroupID
,
Quota
:
m
.
Quota
,
QuotaUsed
:
m
.
QuotaUsed
,
ExpiresAt
:
m
.
ExpiresAt
,
}
if
m
.
Edges
.
User
!=
nil
{
out
.
User
=
userEntityToService
(
m
.
Edges
.
User
)
...
...
@@ -435,8 +488,11 @@ func groupEntityToService(g *dbent.Group) *service.Group {
DefaultValidityDays
:
g
.
DefaultValidityDays
,
ClaudeCodeOnly
:
g
.
ClaudeCodeOnly
,
FallbackGroupID
:
g
.
FallbackGroupID
,
FallbackGroupIDOnInvalidRequest
:
g
.
FallbackGroupIDOnInvalidRequest
,
ModelRouting
:
g
.
ModelRouting
,
ModelRoutingEnabled
:
g
.
ModelRoutingEnabled
,
MCPXMLInject
:
g
.
McpXMLInject
,
SupportedModelScopes
:
g
.
SupportedModelScopes
,
CreatedAt
:
g
.
CreatedAt
,
UpdatedAt
:
g
.
UpdatedAt
,
}
...
...
backend/internal/repository/group_repo.go
View file @
de7ff902
...
...
@@ -54,13 +54,18 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetDefaultValidityDays
(
groupIn
.
DefaultValidityDays
)
.
SetClaudeCodeOnly
(
groupIn
.
ClaudeCodeOnly
)
.
SetNillableFallbackGroupID
(
groupIn
.
FallbackGroupID
)
.
SetModelRoutingEnabled
(
groupIn
.
ModelRoutingEnabled
)
SetNillableFallbackGroupIDOnInvalidRequest
(
groupIn
.
FallbackGroupIDOnInvalidRequest
)
.
SetModelRoutingEnabled
(
groupIn
.
ModelRoutingEnabled
)
.
SetMcpXMLInject
(
groupIn
.
MCPXMLInject
)
// 设置模型路由配置
if
groupIn
.
ModelRouting
!=
nil
{
builder
=
builder
.
SetModelRouting
(
groupIn
.
ModelRouting
)
}
// 设置支持的模型系列(始终设置,空数组表示不限制)
builder
=
builder
.
SetSupportedModelScopes
(
groupIn
.
SupportedModelScopes
)
created
,
err
:=
builder
.
Save
(
ctx
)
if
err
==
nil
{
groupIn
.
ID
=
created
.
ID
...
...
@@ -91,7 +96,6 @@ func (r *groupRepository) GetByIDLite(ctx context.Context, id int64) (*service.G
if
err
!=
nil
{
return
nil
,
translatePersistenceError
(
err
,
service
.
ErrGroupNotFound
,
nil
)
}
return
groupEntityToService
(
m
),
nil
}
...
...
@@ -116,7 +120,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetNillableSoraVideoPricePerRequestHd
(
groupIn
.
SoraVideoPricePerRequestHD
)
.
SetDefaultValidityDays
(
groupIn
.
DefaultValidityDays
)
.
SetClaudeCodeOnly
(
groupIn
.
ClaudeCodeOnly
)
.
SetModelRoutingEnabled
(
groupIn
.
ModelRoutingEnabled
)
SetModelRoutingEnabled
(
groupIn
.
ModelRoutingEnabled
)
.
SetMcpXMLInject
(
groupIn
.
MCPXMLInject
)
// 处理 FallbackGroupID:nil 时清除,否则设置
if
groupIn
.
FallbackGroupID
!=
nil
{
...
...
@@ -124,6 +129,12 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
}
else
{
builder
=
builder
.
ClearFallbackGroupID
()
}
// 处理 FallbackGroupIDOnInvalidRequest:nil 时清除,否则设置
if
groupIn
.
FallbackGroupIDOnInvalidRequest
!=
nil
{
builder
=
builder
.
SetFallbackGroupIDOnInvalidRequest
(
*
groupIn
.
FallbackGroupIDOnInvalidRequest
)
}
else
{
builder
=
builder
.
ClearFallbackGroupIDOnInvalidRequest
()
}
// 处理 ModelRouting:nil 时清除,否则设置
if
groupIn
.
ModelRouting
!=
nil
{
...
...
@@ -132,6 +143,9 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
builder
=
builder
.
ClearModelRouting
()
}
// 处理 SupportedModelScopes(始终设置,空数组表示不限制)
builder
=
builder
.
SetSupportedModelScopes
(
groupIn
.
SupportedModelScopes
)
updated
,
err
:=
builder
.
Save
(
ctx
)
if
err
!=
nil
{
return
translatePersistenceError
(
err
,
service
.
ErrGroupNotFound
,
service
.
ErrGroupExists
)
...
...
backend/internal/repository/ops_repo_metrics.go
View file @
de7ff902
...
...
@@ -43,6 +43,7 @@ INSERT INTO ops_system_metrics (
upstream_529_count,
token_consumed,
account_switch_count,
qps,
tps,
...
...
@@ -81,14 +82,14 @@ INSERT INTO ops_system_metrics (
$1,$2,$3,$4,
$5,$6,$7,$8,
$9,$10,$11,
$12,$13,$14,
$15,
$16,$17,$18,$19,$20,
$21,
$22,$23,$24,$25,$26,
$27,
$28,$29,$30,
$3
1
,$3
2
,
$3
3
,$3
4
,
$35,
$36,$37,
$3
8
,$
39
$12,$13,$14,
$15,
$16,$17,$18,$19,$20,
$21,
$22,$23,$24,$25,$26,
$27,
$28,$29,$30,
$31,
$3
2
,$3
3
,
$3
4
,$3
5
,
$36,$37,
$38,
$3
9
,$
40
)`
_
,
err
:=
r
.
db
.
ExecContext
(
...
...
@@ -109,6 +110,7 @@ INSERT INTO ops_system_metrics (
input
.
Upstream529Count
,
input
.
TokenConsumed
,
input
.
AccountSwitchCount
,
opsNullFloat64
(
input
.
QPS
),
opsNullFloat64
(
input
.
TPS
),
...
...
@@ -177,7 +179,8 @@ SELECT
db_conn_waiting,
goroutine_count,
concurrency_queue_depth
concurrency_queue_depth,
account_switch_count
FROM ops_system_metrics
WHERE window_minutes = $1
AND platform IS NULL
...
...
@@ -199,6 +202,7 @@ LIMIT 1`
var
dbWaiting
sql
.
NullInt64
var
goroutines
sql
.
NullInt64
var
queueDepth
sql
.
NullInt64
var
accountSwitchCount
sql
.
NullInt64
if
err
:=
r
.
db
.
QueryRowContext
(
ctx
,
q
,
windowMinutes
)
.
Scan
(
&
out
.
ID
,
...
...
@@ -217,6 +221,7 @@ LIMIT 1`
&
dbWaiting
,
&
goroutines
,
&
queueDepth
,
&
accountSwitchCount
,
);
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -273,6 +278,10 @@ LIMIT 1`
v
:=
int
(
queueDepth
.
Int64
)
out
.
ConcurrencyQueueDepth
=
&
v
}
if
accountSwitchCount
.
Valid
{
v
:=
accountSwitchCount
.
Int64
out
.
AccountSwitchCount
=
&
v
}
return
&
out
,
nil
}
...
...
backend/internal/repository/ops_repo_trends.go
View file @
de7ff902
...
...
@@ -56,18 +56,44 @@ error_buckets AS (
AND COALESCE(status_code, 0) >= 400
GROUP BY 1
),
switch_buckets AS (
SELECT `
+
errorBucketExpr
+
` AS bucket,
COALESCE(SUM(CASE
WHEN split_part(ev->>'kind', ':', 1) IN ('failover', 'retry_exhausted_failover', 'failover_on_400') THEN 1
ELSE 0
END), 0) AS switch_count
FROM ops_error_logs
CROSS JOIN LATERAL jsonb_array_elements(
COALESCE(NULLIF(upstream_errors, 'null'::jsonb), '[]'::jsonb)
) AS ev
`
+
errorWhere
+
`
AND upstream_errors IS NOT NULL
GROUP BY 1
),
combined AS (
SELECT COALESCE(u.bucket, e.bucket) AS bucket,
COALESCE(u.success_count, 0) AS success_count,
COALESCE(e.error_count, 0) AS error_count,
COALESCE(u.token_consumed, 0) AS token_consumed
FROM usage_buckets u
FULL OUTER JOIN error_buckets e ON u.bucket = e.bucket
SELECT
bucket,
SUM(success_count) AS success_count,
SUM(error_count) AS error_count,
SUM(token_consumed) AS token_consumed,
SUM(switch_count) AS switch_count
FROM (
SELECT bucket, success_count, 0 AS error_count, token_consumed, 0 AS switch_count
FROM usage_buckets
UNION ALL
SELECT bucket, 0, error_count, 0, 0
FROM error_buckets
UNION ALL
SELECT bucket, 0, 0, 0, switch_count
FROM switch_buckets
) t
GROUP BY bucket
)
SELECT
bucket,
(success_count + error_count) AS request_count,
token_consumed
token_consumed,
switch_count
FROM combined
ORDER BY bucket ASC`
...
...
@@ -84,13 +110,18 @@ ORDER BY bucket ASC`
var
bucket
time
.
Time
var
requests
int64
var
tokens
sql
.
NullInt64
if
err
:=
rows
.
Scan
(
&
bucket
,
&
requests
,
&
tokens
);
err
!=
nil
{
var
switches
sql
.
NullInt64
if
err
:=
rows
.
Scan
(
&
bucket
,
&
requests
,
&
tokens
,
&
switches
);
err
!=
nil
{
return
nil
,
err
}
tokenConsumed
:=
int64
(
0
)
if
tokens
.
Valid
{
tokenConsumed
=
tokens
.
Int64
}
switchCount
:=
int64
(
0
)
if
switches
.
Valid
{
switchCount
=
switches
.
Int64
}
denom
:=
float64
(
bucketSeconds
)
if
denom
<=
0
{
...
...
@@ -103,6 +134,7 @@ ORDER BY bucket ASC`
BucketStart
:
bucket
.
UTC
(),
RequestCount
:
requests
,
TokenConsumed
:
tokenConsumed
,
SwitchCount
:
switchCount
,
QPS
:
qps
,
TPS
:
tps
,
})
...
...
@@ -385,6 +417,7 @@ func fillOpsThroughputBuckets(start, end time.Time, bucketSeconds int, points []
BucketStart
:
cursor
,
RequestCount
:
0
,
TokenConsumed
:
0
,
SwitchCount
:
0
,
QPS
:
0
,
TPS
:
0
,
})
...
...
backend/internal/server/api_contract_test.go
View file @
de7ff902
...
...
@@ -83,6 +83,9 @@ func TestAPIContracts(t *testing.T) {
"status": "active",
"ip_whitelist": null,
"ip_blacklist": null,
"quota": 0,
"quota_used": 0,
"expires_at": null,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
}
...
...
@@ -119,6 +122,9 @@ func TestAPIContracts(t *testing.T) {
"status": "active",
"ip_whitelist": null,
"ip_blacklist": null,
"quota": 0,
"quota_used": 0,
"expires_at": null,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
}
...
...
@@ -184,6 +190,7 @@ func TestAPIContracts(t *testing.T) {
"sora_video_price_per_request_hd": null,
"claude_code_only": false,
"fallback_group_id": null,
"fallback_group_id_on_invalid_request": null,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
}
...
...
@@ -1451,6 +1458,10 @@ func (r *stubApiKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) (
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubApiKeyRepo
)
IncrementQuotaUsed
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
(
float64
,
error
)
{
return
0
,
errors
.
New
(
"not implemented"
)
}
type
stubUsageLogRepo
struct
{
userLogs
map
[
int64
][]
service
.
UsageLog
}
...
...
backend/internal/server/middleware/api_key_auth.go
View file @
de7ff902
...
...
@@ -70,7 +70,27 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
// 检查API key是否激活
if
!
apiKey
.
IsActive
()
{
// Provide more specific error message based on status
switch
apiKey
.
Status
{
case
service
.
StatusAPIKeyQuotaExhausted
:
AbortWithError
(
c
,
429
,
"API_KEY_QUOTA_EXHAUSTED"
,
"API key 额度已用完"
)
case
service
.
StatusAPIKeyExpired
:
AbortWithError
(
c
,
403
,
"API_KEY_EXPIRED"
,
"API key 已过期"
)
default
:
AbortWithError
(
c
,
401
,
"API_KEY_DISABLED"
,
"API key is disabled"
)
}
return
}
// 检查API Key是否过期(即使状态是active,也要检查时间)
if
apiKey
.
IsExpired
()
{
AbortWithError
(
c
,
403
,
"API_KEY_EXPIRED"
,
"API key 已过期"
)
return
}
// 检查API Key配额是否耗尽
if
apiKey
.
IsQuotaExhausted
()
{
AbortWithError
(
c
,
429
,
"API_KEY_QUOTA_EXHAUSTED"
,
"API key 额度已用完"
)
return
}
...
...
backend/internal/server/middleware/api_key_auth_google.go
View file @
de7ff902
...
...
@@ -26,7 +26,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
abortWithGoogleError
(
c
,
400
,
"Query parameter api_key is deprecated. Use Authorization header or key instead."
)
return
}
apiKeyString
:=
extractAPIKeyF
romRequest
(
c
)
apiKeyString
:=
extractAPIKeyF
orGoogle
(
c
)
if
apiKeyString
==
""
{
abortWithGoogleError
(
c
,
401
,
"API key is required"
)
return
...
...
@@ -108,25 +108,38 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
}
}
func
extractAPIKeyFromRequest
(
c
*
gin
.
Context
)
string
{
authHeader
:=
c
.
GetHeader
(
"Authorization"
)
if
authHeader
!=
""
{
parts
:=
strings
.
SplitN
(
authHeader
,
" "
,
2
)
if
len
(
parts
)
==
2
&&
parts
[
0
]
==
"Bearer"
&&
strings
.
TrimSpace
(
parts
[
1
])
!=
""
{
return
strings
.
TrimSpace
(
parts
[
1
])
// extractAPIKeyForGoogle extracts API key for Google/Gemini endpoints.
// Priority: x-goog-api-key > Authorization: Bearer > x-api-key > query key
// This allows OpenClaw and other clients using Bearer auth to work with Gemini endpoints.
func
extractAPIKeyForGoogle
(
c
*
gin
.
Context
)
string
{
// 1) preferred: Gemini native header
if
k
:=
strings
.
TrimSpace
(
c
.
GetHeader
(
"x-goog-api-key"
));
k
!=
""
{
return
k
}
// 2) fallback: Authorization: Bearer <key>
auth
:=
strings
.
TrimSpace
(
c
.
GetHeader
(
"Authorization"
))
if
auth
!=
""
{
parts
:=
strings
.
SplitN
(
auth
,
" "
,
2
)
if
len
(
parts
)
==
2
&&
strings
.
EqualFold
(
parts
[
0
],
"Bearer"
)
{
if
k
:=
strings
.
TrimSpace
(
parts
[
1
]);
k
!=
""
{
return
k
}
if
v
:=
strings
.
TrimSpace
(
c
.
GetHeader
(
"x-api-key"
));
v
!=
""
{
return
v
}
if
v
:=
strings
.
TrimSpace
(
c
.
GetHeader
(
"x-goog-api-key"
));
v
!=
""
{
return
v
}
// 3) x-api-key header (backward compatibility)
if
k
:=
strings
.
TrimSpace
(
c
.
GetHeader
(
"x-api-key"
));
k
!=
""
{
return
k
}
// 4) query parameter key (for specific paths)
if
allowGoogleQueryKey
(
c
.
Request
.
URL
.
Path
)
{
if
v
:=
strings
.
TrimSpace
(
c
.
Query
(
"key"
));
v
!=
""
{
return
v
}
}
return
""
}
...
...
backend/internal/server/middleware/api_key_auth_google_test.go
View file @
de7ff902
...
...
@@ -75,6 +75,9 @@ func (f fakeAPIKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]s
func
(
f
fakeAPIKeyRepo
)
ListKeysByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
([]
string
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
f
fakeAPIKeyRepo
)
IncrementQuotaUsed
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
(
float64
,
error
)
{
return
0
,
errors
.
New
(
"not implemented"
)
}
type
googleErrorResponse
struct
{
Error
struct
{
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment