Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
dd96ada3
Unverified
Commit
dd96ada3
authored
Feb 04, 2026
by
程序猿MT
Committed by
GitHub
Feb 04, 2026
Browse files
Merge branch 'Wei-Shaw:main' into main
parents
31fe0178
8f397548
Changes
90
Show whitespace changes
Inline
Side-by-side
backend/internal/handler/admin/account_handler.go
View file @
dd96ada3
...
...
@@ -84,7 +84,7 @@ type CreateAccountRequest struct {
Name
string
`json:"name" binding:"required"`
Notes
*
string
`json:"notes"`
Platform
string
`json:"platform" binding:"required"`
Type
string
`json:"type" binding:"required,oneof=oauth setup-token apikey"`
Type
string
`json:"type" binding:"required,oneof=oauth setup-token apikey
upstream
"`
Credentials
map
[
string
]
any
`json:"credentials" binding:"required"`
Extra
map
[
string
]
any
`json:"extra"`
ProxyID
*
int64
`json:"proxy_id"`
...
...
@@ -102,7 +102,7 @@ type CreateAccountRequest struct {
type
UpdateAccountRequest
struct
{
Name
string
`json:"name"`
Notes
*
string
`json:"notes"`
Type
string
`json:"type" binding:"omitempty,oneof=oauth setup-token apikey"`
Type
string
`json:"type" binding:"omitempty,oneof=oauth setup-token apikey
upstream
"`
Credentials
map
[
string
]
any
`json:"credentials"`
Extra
map
[
string
]
any
`json:"extra"`
ProxyID
*
int64
`json:"proxy_id"`
...
...
backend/internal/handler/admin/group_handler.go
View file @
dd96ada3
...
...
@@ -40,9 +40,13 @@ type CreateGroupRequest struct {
ImagePrice4K
*
float64
`json:"image_price_4k"`
ClaudeCodeOnly
bool
`json:"claude_code_only"`
FallbackGroupID
*
int64
`json:"fallback_group_id"`
FallbackGroupIDOnInvalidRequest
*
int64
`json:"fallback_group_id_on_invalid_request"`
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting
map
[
string
][]
int64
`json:"model_routing"`
ModelRoutingEnabled
bool
`json:"model_routing_enabled"`
MCPXMLInject
*
bool
`json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes
[]
string
`json:"supported_model_scopes"`
// 从指定分组复制账号(创建后自动绑定)
CopyAccountsFromGroupIDs
[]
int64
`json:"copy_accounts_from_group_ids"`
}
...
...
@@ -65,9 +69,13 @@ type UpdateGroupRequest struct {
ImagePrice4K
*
float64
`json:"image_price_4k"`
ClaudeCodeOnly
*
bool
`json:"claude_code_only"`
FallbackGroupID
*
int64
`json:"fallback_group_id"`
FallbackGroupIDOnInvalidRequest
*
int64
`json:"fallback_group_id_on_invalid_request"`
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting
map
[
string
][]
int64
`json:"model_routing"`
ModelRoutingEnabled
*
bool
`json:"model_routing_enabled"`
MCPXMLInject
*
bool
`json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes
*
[]
string
`json:"supported_model_scopes"`
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs
[]
int64
`json:"copy_accounts_from_group_ids"`
}
...
...
@@ -173,8 +181,11 @@ func (h *GroupHandler) Create(c *gin.Context) {
ImagePrice4K
:
req
.
ImagePrice4K
,
ClaudeCodeOnly
:
req
.
ClaudeCodeOnly
,
FallbackGroupID
:
req
.
FallbackGroupID
,
FallbackGroupIDOnInvalidRequest
:
req
.
FallbackGroupIDOnInvalidRequest
,
ModelRouting
:
req
.
ModelRouting
,
ModelRoutingEnabled
:
req
.
ModelRoutingEnabled
,
MCPXMLInject
:
req
.
MCPXMLInject
,
SupportedModelScopes
:
req
.
SupportedModelScopes
,
CopyAccountsFromGroupIDs
:
req
.
CopyAccountsFromGroupIDs
,
})
if
err
!=
nil
{
...
...
@@ -216,8 +227,11 @@ func (h *GroupHandler) Update(c *gin.Context) {
ImagePrice4K
:
req
.
ImagePrice4K
,
ClaudeCodeOnly
:
req
.
ClaudeCodeOnly
,
FallbackGroupID
:
req
.
FallbackGroupID
,
FallbackGroupIDOnInvalidRequest
:
req
.
FallbackGroupIDOnInvalidRequest
,
ModelRouting
:
req
.
ModelRouting
,
ModelRoutingEnabled
:
req
.
ModelRoutingEnabled
,
MCPXMLInject
:
req
.
MCPXMLInject
,
SupportedModelScopes
:
req
.
SupportedModelScopes
,
CopyAccountsFromGroupIDs
:
req
.
CopyAccountsFromGroupIDs
,
})
if
err
!=
nil
{
...
...
backend/internal/handler/api_key_handler.go
View file @
dd96ada3
...
...
@@ -3,6 +3,7 @@ package handler
import
(
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
...
...
@@ -32,6 +33,8 @@ type CreateAPIKeyRequest struct {
CustomKey
*
string
`json:"custom_key"`
// 可选的自定义key
IPWhitelist
[]
string
`json:"ip_whitelist"`
// IP 白名单
IPBlacklist
[]
string
`json:"ip_blacklist"`
// IP 黑名单
Quota
*
float64
`json:"quota"`
// 配额限制 (USD)
ExpiresInDays
*
int
`json:"expires_in_days"`
// 过期天数
}
// UpdateAPIKeyRequest represents the update API key request payload
...
...
@@ -41,6 +44,9 @@ type UpdateAPIKeyRequest struct {
Status
string
`json:"status" binding:"omitempty,oneof=active inactive"`
IPWhitelist
[]
string
`json:"ip_whitelist"`
// IP 白名单
IPBlacklist
[]
string
`json:"ip_blacklist"`
// IP 黑名单
Quota
*
float64
`json:"quota"`
// 配额限制 (USD), 0=无限制
ExpiresAt
*
string
`json:"expires_at"`
// 过期时间 (ISO 8601)
ResetQuota
*
bool
`json:"reset_quota"`
// 重置已用配额
}
// List handles listing user's API keys with pagination
...
...
@@ -119,6 +125,10 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
CustomKey
:
req
.
CustomKey
,
IPWhitelist
:
req
.
IPWhitelist
,
IPBlacklist
:
req
.
IPBlacklist
,
ExpiresInDays
:
req
.
ExpiresInDays
,
}
if
req
.
Quota
!=
nil
{
svcReq
.
Quota
=
*
req
.
Quota
}
key
,
err
:=
h
.
apiKeyService
.
Create
(
c
.
Request
.
Context
(),
subject
.
UserID
,
svcReq
)
if
err
!=
nil
{
...
...
@@ -153,6 +163,8 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
svcReq
:=
service
.
UpdateAPIKeyRequest
{
IPWhitelist
:
req
.
IPWhitelist
,
IPBlacklist
:
req
.
IPBlacklist
,
Quota
:
req
.
Quota
,
ResetQuota
:
req
.
ResetQuota
,
}
if
req
.
Name
!=
""
{
svcReq
.
Name
=
&
req
.
Name
...
...
@@ -161,6 +173,21 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
if
req
.
Status
!=
""
{
svcReq
.
Status
=
&
req
.
Status
}
// Parse expires_at if provided
if
req
.
ExpiresAt
!=
nil
{
if
*
req
.
ExpiresAt
==
""
{
// Empty string means clear expiration
svcReq
.
ExpiresAt
=
nil
svcReq
.
ClearExpiration
=
true
}
else
{
t
,
err
:=
time
.
Parse
(
time
.
RFC3339
,
*
req
.
ExpiresAt
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid expires_at format: "
+
err
.
Error
())
return
}
svcReq
.
ExpiresAt
=
&
t
}
}
key
,
err
:=
h
.
apiKeyService
.
Update
(
c
.
Request
.
Context
(),
keyID
,
subject
.
UserID
,
svcReq
)
if
err
!=
nil
{
...
...
backend/internal/handler/dto/mappers.go
View file @
dd96ada3
...
...
@@ -76,6 +76,9 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
Status
:
k
.
Status
,
IPWhitelist
:
k
.
IPWhitelist
,
IPBlacklist
:
k
.
IPBlacklist
,
Quota
:
k
.
Quota
,
QuotaUsed
:
k
.
QuotaUsed
,
ExpiresAt
:
k
.
ExpiresAt
,
CreatedAt
:
k
.
CreatedAt
,
UpdatedAt
:
k
.
UpdatedAt
,
User
:
UserFromServiceShallow
(
k
.
User
),
...
...
@@ -108,6 +111,8 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
Group
:
groupFromServiceBase
(
g
),
ModelRouting
:
g
.
ModelRouting
,
ModelRoutingEnabled
:
g
.
ModelRoutingEnabled
,
MCPXMLInject
:
g
.
MCPXMLInject
,
SupportedModelScopes
:
g
.
SupportedModelScopes
,
AccountCount
:
g
.
AccountCount
,
}
if
len
(
g
.
AccountGroups
)
>
0
{
...
...
@@ -138,6 +143,8 @@ func groupFromServiceBase(g *service.Group) Group {
ImagePrice4K
:
g
.
ImagePrice4K
,
ClaudeCodeOnly
:
g
.
ClaudeCodeOnly
,
FallbackGroupID
:
g
.
FallbackGroupID
,
// 无效请求兜底分组
FallbackGroupIDOnInvalidRequest
:
g
.
FallbackGroupIDOnInvalidRequest
,
CreatedAt
:
g
.
CreatedAt
,
UpdatedAt
:
g
.
UpdatedAt
,
}
...
...
backend/internal/handler/dto/types.go
View file @
dd96ada3
...
...
@@ -40,6 +40,9 @@ type APIKey struct {
Status
string
`json:"status"`
IPWhitelist
[]
string
`json:"ip_whitelist"`
IPBlacklist
[]
string
`json:"ip_blacklist"`
Quota
float64
`json:"quota"`
// Quota limit in USD (0 = unlimited)
QuotaUsed
float64
`json:"quota_used"`
// Used quota amount in USD
ExpiresAt
*
time
.
Time
`json:"expires_at"`
// Expiration time (nil = never expires)
CreatedAt
time
.
Time
`json:"created_at"`
UpdatedAt
time
.
Time
`json:"updated_at"`
...
...
@@ -69,6 +72,8 @@ type Group struct {
// Claude Code 客户端限制
ClaudeCodeOnly
bool
`json:"claude_code_only"`
FallbackGroupID
*
int64
`json:"fallback_group_id"`
// 无效请求兜底分组
FallbackGroupIDOnInvalidRequest
*
int64
`json:"fallback_group_id_on_invalid_request"`
CreatedAt
time
.
Time
`json:"created_at"`
UpdatedAt
time
.
Time
`json:"updated_at"`
...
...
@@ -83,6 +88,11 @@ type AdminGroup struct {
ModelRouting
map
[
string
][]
int64
`json:"model_routing"`
ModelRoutingEnabled
bool
`json:"model_routing_enabled"`
// MCP XML 协议注入(仅 antigravity 平台使用)
MCPXMLInject
bool
`json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes
[]
string
`json:"supported_model_scopes"`
AccountGroups
[]
AccountGroup
`json:"account_groups,omitempty"`
AccountCount
int64
`json:"account_count,omitempty"`
}
...
...
backend/internal/handler/gateway_handler.go
View file @
dd96ada3
...
...
@@ -14,6 +14,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
pkgerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
...
...
@@ -31,6 +32,7 @@ type GatewayHandler struct {
userService
*
service
.
UserService
billingCacheService
*
service
.
BillingCacheService
usageService
*
service
.
UsageService
apiKeyService
*
service
.
APIKeyService
concurrencyHelper
*
ConcurrencyHelper
maxAccountSwitches
int
maxAccountSwitchesGemini
int
...
...
@@ -45,6 +47,7 @@ func NewGatewayHandler(
concurrencyService
*
service
.
ConcurrencyService
,
billingCacheService
*
service
.
BillingCacheService
,
usageService
*
service
.
UsageService
,
apiKeyService
*
service
.
APIKeyService
,
cfg
*
config
.
Config
,
)
*
GatewayHandler
{
pingInterval
:=
time
.
Duration
(
0
)
...
...
@@ -66,6 +69,7 @@ func NewGatewayHandler(
userService
:
userService
,
billingCacheService
:
billingCacheService
,
usageService
:
usageService
,
apiKeyService
:
apiKeyService
,
concurrencyHelper
:
NewConcurrencyHelper
(
concurrencyService
,
SSEPingFormatClaude
,
pingInterval
),
maxAccountSwitches
:
maxAccountSwitches
,
maxAccountSwitchesGemini
:
maxAccountSwitchesGemini
,
...
...
@@ -281,10 +285,14 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 转发请求 - 根据账号平台分流
var
result
*
service
.
ForwardResult
requestCtx
:=
c
.
Request
.
Context
()
if
switchCount
>
0
{
requestCtx
=
context
.
WithValue
(
requestCtx
,
ctxkey
.
AccountSwitchCount
,
switchCount
)
}
if
account
.
Platform
==
service
.
PlatformAntigravity
{
result
,
err
=
h
.
antigravityGatewayService
.
ForwardGemini
(
c
.
R
equest
.
Context
()
,
c
,
account
,
reqModel
,
"generateContent"
,
reqStream
,
body
)
result
,
err
=
h
.
antigravityGatewayService
.
ForwardGemini
(
r
equest
Ctx
,
c
,
account
,
reqModel
,
"generateContent"
,
reqStream
,
body
)
}
else
{
result
,
err
=
h
.
geminiCompatService
.
Forward
(
c
.
R
equest
.
Context
()
,
c
,
account
,
body
)
result
,
err
=
h
.
geminiCompatService
.
Forward
(
r
equest
Ctx
,
c
,
account
,
body
)
}
if
accountReleaseFunc
!=
nil
{
accountReleaseFunc
()
...
...
@@ -323,6 +331,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
Subscription
:
subscription
,
UserAgent
:
ua
,
IPAddress
:
clientIP
,
APIKeyService
:
h
.
apiKeyService
,
});
err
!=
nil
{
log
.
Printf
(
"Record usage failed: %v"
,
err
)
}
...
...
@@ -331,14 +340,24 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
}
currentAPIKey
:=
apiKey
currentSubscription
:=
subscription
var
fallbackGroupID
*
int64
if
apiKey
.
Group
!=
nil
{
fallbackGroupID
=
apiKey
.
Group
.
FallbackGroupIDOnInvalidRequest
}
fallbackUsed
:=
false
for
{
maxAccountSwitches
:=
h
.
maxAccountSwitches
switchCount
:=
0
failedAccountIDs
:=
make
(
map
[
int64
]
struct
{})
lastFailoverStatus
:=
0
retryWithFallback
:=
false
for
{
// 选择支持该模型的账号
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
api
Key
.
GroupID
,
sessionKey
,
reqModel
,
failedAccountIDs
,
parsedReq
.
MetadataUserID
)
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
currentAPI
Key
.
GroupID
,
sessionKey
,
reqModel
,
failedAccountIDs
,
parsedReq
.
MetadataUserID
)
if
err
!=
nil
{
if
len
(
failedAccountIDs
)
==
0
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts: "
+
err
.
Error
(),
streamStarted
)
...
...
@@ -408,7 +427,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h
.
concurrencyHelper
.
DecrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
)
accountWaitCounted
=
false
}
if
err
:=
h
.
gatewayService
.
BindStickySession
(
c
.
Request
.
Context
(),
api
Key
.
GroupID
,
sessionKey
,
account
.
ID
);
err
!=
nil
{
if
err
:=
h
.
gatewayService
.
BindStickySession
(
c
.
Request
.
Context
(),
currentAPI
Key
.
GroupID
,
sessionKey
,
account
.
ID
);
err
!=
nil
{
log
.
Printf
(
"Bind sticky session failed: %v"
,
err
)
}
}
...
...
@@ -417,15 +436,54 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 转发请求 - 根据账号平台分流
var
result
*
service
.
ForwardResult
requestCtx
:=
c
.
Request
.
Context
()
if
switchCount
>
0
{
requestCtx
=
context
.
WithValue
(
requestCtx
,
ctxkey
.
AccountSwitchCount
,
switchCount
)
}
if
account
.
Platform
==
service
.
PlatformAntigravity
{
result
,
err
=
h
.
antigravityGatewayService
.
Forward
(
c
.
R
equest
.
Context
()
,
c
,
account
,
body
)
result
,
err
=
h
.
antigravityGatewayService
.
Forward
(
r
equest
Ctx
,
c
,
account
,
body
)
}
else
{
result
,
err
=
h
.
gatewayService
.
Forward
(
c
.
R
equest
.
Context
()
,
c
,
account
,
parsedReq
)
result
,
err
=
h
.
gatewayService
.
Forward
(
r
equest
Ctx
,
c
,
account
,
parsedReq
)
}
if
accountReleaseFunc
!=
nil
{
accountReleaseFunc
()
}
if
err
!=
nil
{
var
promptTooLongErr
*
service
.
PromptTooLongError
if
errors
.
As
(
err
,
&
promptTooLongErr
)
{
log
.
Printf
(
"Prompt too long from antigravity: group=%d fallback_group_id=%v fallback_used=%v"
,
currentAPIKey
.
GroupID
,
fallbackGroupID
,
fallbackUsed
)
if
!
fallbackUsed
&&
fallbackGroupID
!=
nil
&&
*
fallbackGroupID
>
0
{
fallbackGroup
,
err
:=
h
.
gatewayService
.
ResolveGroupByID
(
c
.
Request
.
Context
(),
*
fallbackGroupID
)
if
err
!=
nil
{
log
.
Printf
(
"Resolve fallback group failed: %v"
,
err
)
_
=
h
.
antigravityGatewayService
.
WriteMappedClaudeError
(
c
,
account
,
promptTooLongErr
.
StatusCode
,
promptTooLongErr
.
RequestID
,
promptTooLongErr
.
Body
)
return
}
if
fallbackGroup
.
Platform
!=
service
.
PlatformAnthropic
||
fallbackGroup
.
SubscriptionType
==
service
.
SubscriptionTypeSubscription
||
fallbackGroup
.
FallbackGroupIDOnInvalidRequest
!=
nil
{
log
.
Printf
(
"Fallback group invalid: group=%d platform=%s subscription=%s"
,
fallbackGroup
.
ID
,
fallbackGroup
.
Platform
,
fallbackGroup
.
SubscriptionType
)
_
=
h
.
antigravityGatewayService
.
WriteMappedClaudeError
(
c
,
account
,
promptTooLongErr
.
StatusCode
,
promptTooLongErr
.
RequestID
,
promptTooLongErr
.
Body
)
return
}
fallbackAPIKey
:=
cloneAPIKeyWithGroup
(
apiKey
,
fallbackGroup
)
if
err
:=
h
.
billingCacheService
.
CheckBillingEligibility
(
c
.
Request
.
Context
(),
fallbackAPIKey
.
User
,
fallbackAPIKey
,
fallbackGroup
,
nil
);
err
!=
nil
{
status
,
code
,
message
:=
billingErrorDetails
(
err
)
h
.
handleStreamingAwareError
(
c
,
status
,
code
,
message
,
streamStarted
)
return
}
// 兜底重试按“直接请求兜底分组”处理:清除强制平台,允许按分组平台调度
ctx
:=
context
.
WithValue
(
c
.
Request
.
Context
(),
ctxkey
.
ForcePlatform
,
""
)
c
.
Request
=
c
.
Request
.
WithContext
(
ctx
)
currentAPIKey
=
fallbackAPIKey
currentSubscription
=
nil
fallbackUsed
=
true
retryWithFallback
=
true
break
}
_
=
h
.
antigravityGatewayService
.
WriteMappedClaudeError
(
c
,
account
,
promptTooLongErr
.
StatusCode
,
promptTooLongErr
.
RequestID
,
promptTooLongErr
.
Body
)
return
}
var
failoverErr
*
service
.
UpstreamFailoverError
if
errors
.
As
(
err
,
&
failoverErr
)
{
failedAccountIDs
[
account
.
ID
]
=
struct
{}{}
...
...
@@ -453,18 +511,23 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
defer
cancel
()
if
err
:=
h
.
gatewayService
.
RecordUsage
(
ctx
,
&
service
.
RecordUsageInput
{
Result
:
result
,
APIKey
:
api
Key
,
User
:
api
Key
.
User
,
APIKey
:
currentAPI
Key
,
User
:
currentAPI
Key
.
User
,
Account
:
usedAccount
,
Subscription
:
s
ubscription
,
Subscription
:
currentS
ubscription
,
UserAgent
:
ua
,
IPAddress
:
clientIP
,
APIKeyService
:
h
.
apiKeyService
,
});
err
!=
nil
{
log
.
Printf
(
"Record usage failed: %v"
,
err
)
}
}(
result
,
account
,
userAgent
,
clientIP
)
return
}
if
!
retryWithFallback
{
return
}
}
}
// Models handles listing available models
...
...
@@ -527,6 +590,17 @@ func (h *GatewayHandler) AntigravityModels(c *gin.Context) {
})
}
func
cloneAPIKeyWithGroup
(
apiKey
*
service
.
APIKey
,
group
*
service
.
Group
)
*
service
.
APIKey
{
if
apiKey
==
nil
||
group
==
nil
{
return
apiKey
}
cloned
:=
*
apiKey
groupID
:=
group
.
ID
cloned
.
GroupID
=
&
groupID
cloned
.
Group
=
group
return
&
cloned
}
// Usage handles getting account balance and usage statistics for CC Switch integration
// GET /v1/usage
func
(
h
*
GatewayHandler
)
Usage
(
c
*
gin
.
Context
)
{
...
...
backend/internal/handler/gemini_v1beta_handler.go
View file @
dd96ada3
...
...
@@ -14,6 +14,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
...
...
@@ -335,10 +336,14 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 5) forward (根据平台分流)
var
result
*
service
.
ForwardResult
requestCtx
:=
c
.
Request
.
Context
()
if
switchCount
>
0
{
requestCtx
=
context
.
WithValue
(
requestCtx
,
ctxkey
.
AccountSwitchCount
,
switchCount
)
}
if
account
.
Platform
==
service
.
PlatformAntigravity
{
result
,
err
=
h
.
antigravityGatewayService
.
ForwardGemini
(
c
.
R
equest
.
Context
()
,
c
,
account
,
modelName
,
action
,
stream
,
body
)
result
,
err
=
h
.
antigravityGatewayService
.
ForwardGemini
(
r
equest
Ctx
,
c
,
account
,
modelName
,
action
,
stream
,
body
)
}
else
{
result
,
err
=
h
.
geminiCompatService
.
ForwardNative
(
c
.
R
equest
.
Context
()
,
c
,
account
,
modelName
,
action
,
stream
,
body
)
result
,
err
=
h
.
geminiCompatService
.
ForwardNative
(
r
equest
Ctx
,
c
,
account
,
modelName
,
action
,
stream
,
body
)
}
if
accountReleaseFunc
!=
nil
{
accountReleaseFunc
()
...
...
@@ -381,6 +386,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
IPAddress
:
ip
,
LongContextThreshold
:
200000
,
// Gemini 200K 阈值
LongContextMultiplier
:
2.0
,
// 超出部分双倍计费
APIKeyService
:
h
.
apiKeyService
,
});
err
!=
nil
{
log
.
Printf
(
"Record usage failed: %v"
,
err
)
}
...
...
backend/internal/handler/openai_gateway_handler.go
View file @
dd96ada3
...
...
@@ -24,6 +24,7 @@ import (
type
OpenAIGatewayHandler
struct
{
gatewayService
*
service
.
OpenAIGatewayService
billingCacheService
*
service
.
BillingCacheService
apiKeyService
*
service
.
APIKeyService
concurrencyHelper
*
ConcurrencyHelper
maxAccountSwitches
int
}
...
...
@@ -33,6 +34,7 @@ func NewOpenAIGatewayHandler(
gatewayService
*
service
.
OpenAIGatewayService
,
concurrencyService
*
service
.
ConcurrencyService
,
billingCacheService
*
service
.
BillingCacheService
,
apiKeyService
*
service
.
APIKeyService
,
cfg
*
config
.
Config
,
)
*
OpenAIGatewayHandler
{
pingInterval
:=
time
.
Duration
(
0
)
...
...
@@ -46,6 +48,7 @@ func NewOpenAIGatewayHandler(
return
&
OpenAIGatewayHandler
{
gatewayService
:
gatewayService
,
billingCacheService
:
billingCacheService
,
apiKeyService
:
apiKeyService
,
concurrencyHelper
:
NewConcurrencyHelper
(
concurrencyService
,
SSEPingFormatComment
,
pingInterval
),
maxAccountSwitches
:
maxAccountSwitches
,
}
...
...
@@ -306,6 +309,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
Subscription
:
subscription
,
UserAgent
:
ua
,
IPAddress
:
ip
,
APIKeyService
:
h
.
apiKeyService
,
});
err
!=
nil
{
log
.
Printf
(
"Record usage failed: %v"
,
err
)
}
...
...
backend/internal/pkg/antigravity/oauth.go
View file @
dd96ada3
...
...
@@ -40,17 +40,48 @@ const (
// URL 可用性 TTL(不可用 URL 的恢复时间)
URLAvailabilityTTL
=
5
*
time
.
Minute
// Antigravity API 端点
antigravityProdBaseURL
=
"https://cloudcode-pa.googleapis.com"
antigravityDailyBaseURL
=
"https://daily-cloudcode-pa.sandbox.googleapis.com"
)
// BaseURLs 定义 Antigravity API 端点(与 Antigravity-Manager 保持一致)
var
BaseURLs
=
[]
string
{
"https://cloudcode-pa.googleapis.com"
,
// prod (优先)
"https://daily-cloudcode-pa.sandbox.googleapis.com"
,
// daily sandbox (备用)
antigravityProdBaseURL
,
// prod (优先)
antigravityDailyBaseURL
,
// daily sandbox (备用)
}
// BaseURL 默认 URL(保持向后兼容)
var
BaseURL
=
BaseURLs
[
0
]
// ForwardBaseURLs 返回 API 转发用的 URL 顺序(daily 优先)
func
ForwardBaseURLs
()
[]
string
{
if
len
(
BaseURLs
)
==
0
{
return
nil
}
urls
:=
append
([]
string
(
nil
),
BaseURLs
...
)
dailyIndex
:=
-
1
for
i
,
url
:=
range
urls
{
if
url
==
antigravityDailyBaseURL
{
dailyIndex
=
i
break
}
}
if
dailyIndex
<=
0
{
return
urls
}
reordered
:=
make
([]
string
,
0
,
len
(
urls
))
reordered
=
append
(
reordered
,
urls
[
dailyIndex
])
for
i
,
url
:=
range
urls
{
if
i
==
dailyIndex
{
continue
}
reordered
=
append
(
reordered
,
url
)
}
return
reordered
}
// URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复和动态优先级)
type
URLAvailability
struct
{
mu
sync
.
RWMutex
...
...
@@ -100,22 +131,37 @@ func (u *URLAvailability) IsAvailable(url string) bool {
// GetAvailableURLs 返回可用的 URL 列表
// 最近成功的 URL 优先,其他按默认顺序
func
(
u
*
URLAvailability
)
GetAvailableURLs
()
[]
string
{
return
u
.
GetAvailableURLsWithBase
(
BaseURLs
)
}
// GetAvailableURLsWithBase 返回可用的 URL 列表(使用自定义顺序)
// 最近成功的 URL 优先,其他按传入顺序
func
(
u
*
URLAvailability
)
GetAvailableURLsWithBase
(
baseURLs
[]
string
)
[]
string
{
u
.
mu
.
RLock
()
defer
u
.
mu
.
RUnlock
()
now
:=
time
.
Now
()
result
:=
make
([]
string
,
0
,
len
(
B
aseURLs
))
result
:=
make
([]
string
,
0
,
len
(
b
aseURLs
))
// 如果有最近成功的 URL 且可用,放在最前面
if
u
.
lastSuccess
!=
""
{
found
:=
false
for
_
,
url
:=
range
baseURLs
{
if
url
==
u
.
lastSuccess
{
found
=
true
break
}
}
if
found
{
expiry
,
exists
:=
u
.
unavailable
[
u
.
lastSuccess
]
if
!
exists
||
now
.
After
(
expiry
)
{
result
=
append
(
result
,
u
.
lastSuccess
)
}
}
}
// 添加其他可用的 URL(按
默认
顺序)
for
_
,
url
:=
range
B
aseURLs
{
// 添加其他可用的 URL(按
传入
顺序)
for
_
,
url
:=
range
b
aseURLs
{
// 跳过已添加的 lastSuccess
if
url
==
u
.
lastSuccess
{
continue
...
...
backend/internal/pkg/antigravity/request_transformer.go
View file @
dd96ada3
...
...
@@ -44,11 +44,13 @@ type TransformOptions struct {
// IdentityPatch 可选:自定义注入到 systemInstruction 开头的身份防护提示词;
// 为空时使用默认模板(包含 [IDENTITY_PATCH] 及 SYSTEM_PROMPT_BEGIN 标记)。
IdentityPatch
string
EnableMCPXML
bool
}
func
DefaultTransformOptions
()
TransformOptions
{
return
TransformOptions
{
EnableIdentityPatch
:
true
,
EnableMCPXML
:
true
,
}
}
...
...
@@ -257,8 +259,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
// 添加用户的 system prompt
parts
=
append
(
parts
,
userSystemParts
...
)
// 检测是否有 MCP 工具,如有则注入 XML 调用协议
if
hasMCPTools
(
tools
)
{
// 检测是否有 MCP 工具,如有
且启用了 MCP XML 注入
则注入 XML 调用协议
if
opts
.
EnableMCPXML
&&
hasMCPTools
(
tools
)
{
parts
=
append
(
parts
,
GeminiPart
{
Text
:
mcpXMLProtocol
})
}
...
...
@@ -312,7 +314,7 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
parts
=
append
([]
GeminiPart
{{
Text
:
"Thinking..."
,
Thought
:
true
,
ThoughtSignature
:
d
ummyThoughtSignature
,
ThoughtSignature
:
D
ummyThoughtSignature
,
}},
parts
...
)
}
}
...
...
@@ -330,9 +332,10 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
return
contents
,
strippedThinking
,
nil
}
//
d
ummyThoughtSignature 用于跳过 Gemini 3 thought_signature 验证
//
D
ummyThoughtSignature 用于跳过 Gemini 3 thought_signature 验证
// 参考: https://ai.google.dev/gemini-api/docs/thought-signatures
const
dummyThoughtSignature
=
"skip_thought_signature_validator"
// 导出供跨包使用(如 gemini_native_signature_cleaner 跨账号修复)
const
DummyThoughtSignature
=
"skip_thought_signature_validator"
// buildParts 构建消息的 parts
// allowDummyThought: 只有 Gemini 模型支持 dummy thought signature
...
...
@@ -370,7 +373,7 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
// signature 处理:
// - Claude 模型(allowDummyThought=false):必须是上游返回的真实 signature(dummy 视为缺失)
// - Gemini 模型(allowDummyThought=true):优先透传真实 signature,缺失时使用 dummy signature
if
block
.
Signature
!=
""
&&
(
allowDummyThought
||
block
.
Signature
!=
d
ummyThoughtSignature
)
{
if
block
.
Signature
!=
""
&&
(
allowDummyThought
||
block
.
Signature
!=
D
ummyThoughtSignature
)
{
part
.
ThoughtSignature
=
block
.
Signature
}
else
if
!
allowDummyThought
{
// Claude 模型需要有效 signature;在缺失时降级为普通文本,并在上层禁用 thinking mode。
...
...
@@ -381,7 +384,7 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
continue
}
else
{
// Gemini 模型使用 dummy signature
part
.
ThoughtSignature
=
d
ummyThoughtSignature
part
.
ThoughtSignature
=
D
ummyThoughtSignature
}
parts
=
append
(
parts
,
part
)
...
...
@@ -411,10 +414,10 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
// tool_use 的 signature 处理:
// - Claude 模型(allowDummyThought=false):必须是上游返回的真实 signature(dummy 视为缺失)
// - Gemini 模型(allowDummyThought=true):优先透传真实 signature,缺失时使用 dummy signature
if
block
.
Signature
!=
""
&&
(
allowDummyThought
||
block
.
Signature
!=
d
ummyThoughtSignature
)
{
if
block
.
Signature
!=
""
&&
(
allowDummyThought
||
block
.
Signature
!=
D
ummyThoughtSignature
)
{
part
.
ThoughtSignature
=
block
.
Signature
}
else
if
allowDummyThought
{
part
.
ThoughtSignature
=
d
ummyThoughtSignature
part
.
ThoughtSignature
=
D
ummyThoughtSignature
}
parts
=
append
(
parts
,
part
)
...
...
@@ -492,9 +495,23 @@ func parseToolResultContent(content json.RawMessage, isError bool) string {
}
// buildGenerationConfig 构建 generationConfig
const
(
defaultMaxOutputTokens
=
64000
maxOutputTokensUpperBound
=
65000
maxOutputTokensClaude
=
64000
)
func
maxOutputTokensLimit
(
model
string
)
int
{
if
strings
.
HasPrefix
(
model
,
"claude-"
)
{
return
maxOutputTokensClaude
}
return
maxOutputTokensUpperBound
}
func
buildGenerationConfig
(
req
*
ClaudeRequest
)
*
GeminiGenerationConfig
{
maxLimit
:=
maxOutputTokensLimit
(
req
.
Model
)
config
:=
&
GeminiGenerationConfig
{
MaxOutputTokens
:
64000
,
// 默认最大输出
MaxOutputTokens
:
defaultMaxOutputTokens
,
// 默认最大输出
StopSequences
:
DefaultStopSequences
,
}
...
...
@@ -518,6 +535,10 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
}
}
if
config
.
MaxOutputTokens
>
maxLimit
{
config
.
MaxOutputTokens
=
maxLimit
}
// 其他参数
if
req
.
Temperature
!=
nil
{
config
.
Temperature
=
req
.
Temperature
...
...
backend/internal/pkg/antigravity/request_transformer_test.go
View file @
dd96ada3
...
...
@@ -86,7 +86,7 @@ func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) {
if
len
(
parts
)
!=
3
{
t
.
Fatalf
(
"expected 3 parts, got %d"
,
len
(
parts
))
}
if
!
parts
[
1
]
.
Thought
||
parts
[
1
]
.
ThoughtSignature
!=
d
ummyThoughtSignature
{
if
!
parts
[
1
]
.
Thought
||
parts
[
1
]
.
ThoughtSignature
!=
D
ummyThoughtSignature
{
t
.
Fatalf
(
"expected dummy thought signature, got thought=%v signature=%q"
,
parts
[
1
]
.
Thought
,
parts
[
1
]
.
ThoughtSignature
)
}
...
...
@@ -126,8 +126,8 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
if
len
(
parts
)
!=
1
||
parts
[
0
]
.
FunctionCall
==
nil
{
t
.
Fatalf
(
"expected 1 functionCall part, got %+v"
,
parts
)
}
if
parts
[
0
]
.
ThoughtSignature
!=
d
ummyThoughtSignature
{
t
.
Fatalf
(
"expected dummy tool signature %q, got %q"
,
d
ummyThoughtSignature
,
parts
[
0
]
.
ThoughtSignature
)
if
parts
[
0
]
.
ThoughtSignature
!=
D
ummyThoughtSignature
{
t
.
Fatalf
(
"expected dummy tool signature %q, got %q"
,
D
ummyThoughtSignature
,
parts
[
0
]
.
ThoughtSignature
)
}
})
...
...
backend/internal/pkg/ctxkey/ctxkey.go
View file @
dd96ada3
...
...
@@ -14,6 +14,9 @@ const (
// RetryCount 表示当前请求在网关层的重试次数(用于 Ops 记录与排障)。
RetryCount
Key
=
"ctx_retry_count"
// AccountSwitchCount 表示请求过程中发生的账号切换次数
AccountSwitchCount
Key
=
"ctx_account_switch_count"
// IsClaudeCodeClient 标识当前请求是否来自 Claude Code 客户端
IsClaudeCodeClient
Key
=
"ctx_is_claude_code_client"
// Group 认证后的分组信息,由 API Key 认证中间件设置
...
...
backend/internal/repository/api_key_repo.go
View file @
dd96ada3
...
...
@@ -33,7 +33,10 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) erro
SetKey
(
key
.
Key
)
.
SetName
(
key
.
Name
)
.
SetStatus
(
key
.
Status
)
.
SetNillableGroupID
(
key
.
GroupID
)
SetNillableGroupID
(
key
.
GroupID
)
.
SetQuota
(
key
.
Quota
)
.
SetQuotaUsed
(
key
.
QuotaUsed
)
.
SetNillableExpiresAt
(
key
.
ExpiresAt
)
if
len
(
key
.
IPWhitelist
)
>
0
{
builder
.
SetIPWhitelist
(
key
.
IPWhitelist
)
...
...
@@ -110,6 +113,9 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
apikey
.
FieldStatus
,
apikey
.
FieldIPWhitelist
,
apikey
.
FieldIPBlacklist
,
apikey
.
FieldQuota
,
apikey
.
FieldQuotaUsed
,
apikey
.
FieldExpiresAt
,
)
.
WithUser
(
func
(
q
*
dbent
.
UserQuery
)
{
q
.
Select
(
...
...
@@ -136,8 +142,11 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
group
.
FieldImagePrice4k
,
group
.
FieldClaudeCodeOnly
,
group
.
FieldFallbackGroupID
,
group
.
FieldFallbackGroupIDOnInvalidRequest
,
group
.
FieldModelRoutingEnabled
,
group
.
FieldModelRouting
,
group
.
FieldMcpXMLInject
,
group
.
FieldSupportedModelScopes
,
)
})
.
Only
(
ctx
)
...
...
@@ -161,6 +170,8 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
Where
(
apikey
.
IDEQ
(
key
.
ID
),
apikey
.
DeletedAtIsNil
())
.
SetName
(
key
.
Name
)
.
SetStatus
(
key
.
Status
)
.
SetQuota
(
key
.
Quota
)
.
SetQuotaUsed
(
key
.
QuotaUsed
)
.
SetUpdatedAt
(
now
)
if
key
.
GroupID
!=
nil
{
builder
.
SetGroupID
(
*
key
.
GroupID
)
...
...
@@ -168,6 +179,13 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
builder
.
ClearGroupID
()
}
// Expiration time
if
key
.
ExpiresAt
!=
nil
{
builder
.
SetExpiresAt
(
*
key
.
ExpiresAt
)
}
else
{
builder
.
ClearExpiresAt
()
}
// IP 限制字段
if
len
(
key
.
IPWhitelist
)
>
0
{
builder
.
SetIPWhitelist
(
key
.
IPWhitelist
)
...
...
@@ -357,6 +375,38 @@ func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64)
return
keys
,
nil
}
// IncrementQuotaUsed atomically increments the quota_used field and returns the new value
func
(
r
*
apiKeyRepository
)
IncrementQuotaUsed
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
(
float64
,
error
)
{
// Use raw SQL for atomic increment to avoid race conditions
// First get current value
m
,
err
:=
r
.
activeQuery
()
.
Where
(
apikey
.
IDEQ
(
id
))
.
Select
(
apikey
.
FieldQuotaUsed
)
.
Only
(
ctx
)
if
err
!=
nil
{
if
dbent
.
IsNotFound
(
err
)
{
return
0
,
service
.
ErrAPIKeyNotFound
}
return
0
,
err
}
newValue
:=
m
.
QuotaUsed
+
amount
// Update with new value
affected
,
err
:=
r
.
client
.
APIKey
.
Update
()
.
Where
(
apikey
.
IDEQ
(
id
),
apikey
.
DeletedAtIsNil
())
.
SetQuotaUsed
(
newValue
)
.
Save
(
ctx
)
if
err
!=
nil
{
return
0
,
err
}
if
affected
==
0
{
return
0
,
service
.
ErrAPIKeyNotFound
}
return
newValue
,
nil
}
func
apiKeyEntityToService
(
m
*
dbent
.
APIKey
)
*
service
.
APIKey
{
if
m
==
nil
{
return
nil
...
...
@@ -372,6 +422,9 @@ func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
CreatedAt
:
m
.
CreatedAt
,
UpdatedAt
:
m
.
UpdatedAt
,
GroupID
:
m
.
GroupID
,
Quota
:
m
.
Quota
,
QuotaUsed
:
m
.
QuotaUsed
,
ExpiresAt
:
m
.
ExpiresAt
,
}
if
m
.
Edges
.
User
!=
nil
{
out
.
User
=
userEntityToService
(
m
.
Edges
.
User
)
...
...
@@ -427,8 +480,11 @@ func groupEntityToService(g *dbent.Group) *service.Group {
DefaultValidityDays
:
g
.
DefaultValidityDays
,
ClaudeCodeOnly
:
g
.
ClaudeCodeOnly
,
FallbackGroupID
:
g
.
FallbackGroupID
,
FallbackGroupIDOnInvalidRequest
:
g
.
FallbackGroupIDOnInvalidRequest
,
ModelRouting
:
g
.
ModelRouting
,
ModelRoutingEnabled
:
g
.
ModelRoutingEnabled
,
MCPXMLInject
:
g
.
McpXMLInject
,
SupportedModelScopes
:
g
.
SupportedModelScopes
,
CreatedAt
:
g
.
CreatedAt
,
UpdatedAt
:
g
.
UpdatedAt
,
}
...
...
backend/internal/repository/group_repo.go
View file @
dd96ada3
...
...
@@ -50,13 +50,18 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetDefaultValidityDays
(
groupIn
.
DefaultValidityDays
)
.
SetClaudeCodeOnly
(
groupIn
.
ClaudeCodeOnly
)
.
SetNillableFallbackGroupID
(
groupIn
.
FallbackGroupID
)
.
SetModelRoutingEnabled
(
groupIn
.
ModelRoutingEnabled
)
SetNillableFallbackGroupIDOnInvalidRequest
(
groupIn
.
FallbackGroupIDOnInvalidRequest
)
.
SetModelRoutingEnabled
(
groupIn
.
ModelRoutingEnabled
)
.
SetMcpXMLInject
(
groupIn
.
MCPXMLInject
)
// 设置模型路由配置
if
groupIn
.
ModelRouting
!=
nil
{
builder
=
builder
.
SetModelRouting
(
groupIn
.
ModelRouting
)
}
// 设置支持的模型系列(始终设置,空数组表示不限制)
builder
=
builder
.
SetSupportedModelScopes
(
groupIn
.
SupportedModelScopes
)
created
,
err
:=
builder
.
Save
(
ctx
)
if
err
==
nil
{
groupIn
.
ID
=
created
.
ID
...
...
@@ -87,7 +92,6 @@ func (r *groupRepository) GetByIDLite(ctx context.Context, id int64) (*service.G
if
err
!=
nil
{
return
nil
,
translatePersistenceError
(
err
,
service
.
ErrGroupNotFound
,
nil
)
}
return
groupEntityToService
(
m
),
nil
}
...
...
@@ -108,7 +112,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice4k
(
groupIn
.
ImagePrice4K
)
.
SetDefaultValidityDays
(
groupIn
.
DefaultValidityDays
)
.
SetClaudeCodeOnly
(
groupIn
.
ClaudeCodeOnly
)
.
SetModelRoutingEnabled
(
groupIn
.
ModelRoutingEnabled
)
SetModelRoutingEnabled
(
groupIn
.
ModelRoutingEnabled
)
.
SetMcpXMLInject
(
groupIn
.
MCPXMLInject
)
// 处理 FallbackGroupID:nil 时清除,否则设置
if
groupIn
.
FallbackGroupID
!=
nil
{
...
...
@@ -116,6 +121,12 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
}
else
{
builder
=
builder
.
ClearFallbackGroupID
()
}
// 处理 FallbackGroupIDOnInvalidRequest:nil 时清除,否则设置
if
groupIn
.
FallbackGroupIDOnInvalidRequest
!=
nil
{
builder
=
builder
.
SetFallbackGroupIDOnInvalidRequest
(
*
groupIn
.
FallbackGroupIDOnInvalidRequest
)
}
else
{
builder
=
builder
.
ClearFallbackGroupIDOnInvalidRequest
()
}
// 处理 ModelRouting:nil 时清除,否则设置
if
groupIn
.
ModelRouting
!=
nil
{
...
...
@@ -124,6 +135,9 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
builder
=
builder
.
ClearModelRouting
()
}
// 处理 SupportedModelScopes(始终设置,空数组表示不限制)
builder
=
builder
.
SetSupportedModelScopes
(
groupIn
.
SupportedModelScopes
)
updated
,
err
:=
builder
.
Save
(
ctx
)
if
err
!=
nil
{
return
translatePersistenceError
(
err
,
service
.
ErrGroupNotFound
,
service
.
ErrGroupExists
)
...
...
backend/internal/repository/ops_repo_metrics.go
View file @
dd96ada3
...
...
@@ -43,6 +43,7 @@ INSERT INTO ops_system_metrics (
upstream_529_count,
token_consumed,
account_switch_count,
qps,
tps,
...
...
@@ -81,14 +82,14 @@ INSERT INTO ops_system_metrics (
$1,$2,$3,$4,
$5,$6,$7,$8,
$9,$10,$11,
$12,$13,$14,
$15,
$16,$17,$18,$19,$20,
$21,
$22,$23,$24,$25,$26,
$27,
$28,$29,$30,
$3
1
,$3
2
,
$3
3
,$3
4
,
$35,
$36,$37,
$3
8
,$
39
$12,$13,$14,
$15,
$16,$17,$18,$19,$20,
$21,
$22,$23,$24,$25,$26,
$27,
$28,$29,$30,
$31,
$3
2
,$3
3
,
$3
4
,$3
5
,
$36,$37,
$38,
$3
9
,$
40
)`
_
,
err
:=
r
.
db
.
ExecContext
(
...
...
@@ -109,6 +110,7 @@ INSERT INTO ops_system_metrics (
input
.
Upstream529Count
,
input
.
TokenConsumed
,
input
.
AccountSwitchCount
,
opsNullFloat64
(
input
.
QPS
),
opsNullFloat64
(
input
.
TPS
),
...
...
@@ -177,7 +179,8 @@ SELECT
db_conn_waiting,
goroutine_count,
concurrency_queue_depth
concurrency_queue_depth,
account_switch_count
FROM ops_system_metrics
WHERE window_minutes = $1
AND platform IS NULL
...
...
@@ -199,6 +202,7 @@ LIMIT 1`
var
dbWaiting
sql
.
NullInt64
var
goroutines
sql
.
NullInt64
var
queueDepth
sql
.
NullInt64
var
accountSwitchCount
sql
.
NullInt64
if
err
:=
r
.
db
.
QueryRowContext
(
ctx
,
q
,
windowMinutes
)
.
Scan
(
&
out
.
ID
,
...
...
@@ -217,6 +221,7 @@ LIMIT 1`
&
dbWaiting
,
&
goroutines
,
&
queueDepth
,
&
accountSwitchCount
,
);
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -273,6 +278,10 @@ LIMIT 1`
v
:=
int
(
queueDepth
.
Int64
)
out
.
ConcurrencyQueueDepth
=
&
v
}
if
accountSwitchCount
.
Valid
{
v
:=
accountSwitchCount
.
Int64
out
.
AccountSwitchCount
=
&
v
}
return
&
out
,
nil
}
...
...
backend/internal/repository/ops_repo_trends.go
View file @
dd96ada3
...
...
@@ -56,18 +56,44 @@ error_buckets AS (
AND COALESCE(status_code, 0) >= 400
GROUP BY 1
),
switch_buckets AS (
SELECT `
+
errorBucketExpr
+
` AS bucket,
COALESCE(SUM(CASE
WHEN split_part(ev->>'kind', ':', 1) IN ('failover', 'retry_exhausted_failover', 'failover_on_400') THEN 1
ELSE 0
END), 0) AS switch_count
FROM ops_error_logs
CROSS JOIN LATERAL jsonb_array_elements(
COALESCE(NULLIF(upstream_errors, 'null'::jsonb), '[]'::jsonb)
) AS ev
`
+
errorWhere
+
`
AND upstream_errors IS NOT NULL
GROUP BY 1
),
combined AS (
SELECT COALESCE(u.bucket, e.bucket) AS bucket,
COALESCE(u.success_count, 0) AS success_count,
COALESCE(e.error_count, 0) AS error_count,
COALESCE(u.token_consumed, 0) AS token_consumed
FROM usage_buckets u
FULL OUTER JOIN error_buckets e ON u.bucket = e.bucket
SELECT
bucket,
SUM(success_count) AS success_count,
SUM(error_count) AS error_count,
SUM(token_consumed) AS token_consumed,
SUM(switch_count) AS switch_count
FROM (
SELECT bucket, success_count, 0 AS error_count, token_consumed, 0 AS switch_count
FROM usage_buckets
UNION ALL
SELECT bucket, 0, error_count, 0, 0
FROM error_buckets
UNION ALL
SELECT bucket, 0, 0, 0, switch_count
FROM switch_buckets
) t
GROUP BY bucket
)
SELECT
bucket,
(success_count + error_count) AS request_count,
token_consumed
token_consumed,
switch_count
FROM combined
ORDER BY bucket ASC`
...
...
@@ -84,13 +110,18 @@ ORDER BY bucket ASC`
var
bucket
time
.
Time
var
requests
int64
var
tokens
sql
.
NullInt64
if
err
:=
rows
.
Scan
(
&
bucket
,
&
requests
,
&
tokens
);
err
!=
nil
{
var
switches
sql
.
NullInt64
if
err
:=
rows
.
Scan
(
&
bucket
,
&
requests
,
&
tokens
,
&
switches
);
err
!=
nil
{
return
nil
,
err
}
tokenConsumed
:=
int64
(
0
)
if
tokens
.
Valid
{
tokenConsumed
=
tokens
.
Int64
}
switchCount
:=
int64
(
0
)
if
switches
.
Valid
{
switchCount
=
switches
.
Int64
}
denom
:=
float64
(
bucketSeconds
)
if
denom
<=
0
{
...
...
@@ -103,6 +134,7 @@ ORDER BY bucket ASC`
BucketStart
:
bucket
.
UTC
(),
RequestCount
:
requests
,
TokenConsumed
:
tokenConsumed
,
SwitchCount
:
switchCount
,
QPS
:
qps
,
TPS
:
tps
,
})
...
...
@@ -385,6 +417,7 @@ func fillOpsThroughputBuckets(start, end time.Time, bucketSeconds int, points []
BucketStart
:
cursor
,
RequestCount
:
0
,
TokenConsumed
:
0
,
SwitchCount
:
0
,
QPS
:
0
,
TPS
:
0
,
})
...
...
backend/internal/server/api_contract_test.go
View file @
dd96ada3
...
...
@@ -83,6 +83,9 @@ func TestAPIContracts(t *testing.T) {
"status": "active",
"ip_whitelist": null,
"ip_blacklist": null,
"quota": 0,
"quota_used": 0,
"expires_at": null,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
}
...
...
@@ -119,6 +122,9 @@ func TestAPIContracts(t *testing.T) {
"status": "active",
"ip_whitelist": null,
"ip_blacklist": null,
"quota": 0,
"quota_used": 0,
"expires_at": null,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
}
...
...
@@ -180,6 +186,7 @@ func TestAPIContracts(t *testing.T) {
"image_price_4k": null,
"claude_code_only": false,
"fallback_group_id": null,
"fallback_group_id_on_invalid_request": null,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
}
...
...
@@ -601,7 +608,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingService
:=
service
.
NewSettingService
(
settingRepo
,
cfg
)
adminService
:=
service
.
NewAdminService
(
userRepo
,
groupRepo
,
&
accountRepo
,
proxyRepo
,
apiKeyRepo
,
redeemRepo
,
nil
,
nil
,
nil
,
nil
)
authHandler
:=
handler
.
NewAuthHandler
(
cfg
,
nil
,
userService
,
settingService
,
nil
,
nil
,
nil
)
authHandler
:=
handler
.
NewAuthHandler
(
cfg
,
nil
,
userService
,
settingService
,
nil
,
redeemService
,
nil
)
apiKeyHandler
:=
handler
.
NewAPIKeyHandler
(
apiKeyService
)
usageHandler
:=
handler
.
NewUsageHandler
(
usageService
,
apiKeyService
)
adminSettingHandler
:=
adminhandler
.
NewSettingHandler
(
settingService
,
nil
,
nil
,
nil
)
...
...
@@ -1442,6 +1449,10 @@ func (r *stubApiKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) (
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubApiKeyRepo
)
IncrementQuotaUsed
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
(
float64
,
error
)
{
return
0
,
errors
.
New
(
"not implemented"
)
}
type
stubUsageLogRepo
struct
{
userLogs
map
[
int64
][]
service
.
UsageLog
}
...
...
backend/internal/server/middleware/api_key_auth.go
View file @
dd96ada3
...
...
@@ -70,7 +70,27 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
// 检查API key是否激活
if
!
apiKey
.
IsActive
()
{
// Provide more specific error message based on status
switch
apiKey
.
Status
{
case
service
.
StatusAPIKeyQuotaExhausted
:
AbortWithError
(
c
,
429
,
"API_KEY_QUOTA_EXHAUSTED"
,
"API key 额度已用完"
)
case
service
.
StatusAPIKeyExpired
:
AbortWithError
(
c
,
403
,
"API_KEY_EXPIRED"
,
"API key 已过期"
)
default
:
AbortWithError
(
c
,
401
,
"API_KEY_DISABLED"
,
"API key is disabled"
)
}
return
}
// 检查API Key是否过期(即使状态是active,也要检查时间)
if
apiKey
.
IsExpired
()
{
AbortWithError
(
c
,
403
,
"API_KEY_EXPIRED"
,
"API key 已过期"
)
return
}
// 检查API Key配额是否耗尽
if
apiKey
.
IsQuotaExhausted
()
{
AbortWithError
(
c
,
429
,
"API_KEY_QUOTA_EXHAUSTED"
,
"API key 额度已用完"
)
return
}
...
...
backend/internal/server/middleware/api_key_auth_google.go
View file @
dd96ada3
...
...
@@ -26,7 +26,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
abortWithGoogleError
(
c
,
400
,
"Query parameter api_key is deprecated. Use Authorization header or key instead."
)
return
}
apiKeyString
:=
extractAPIKeyF
romRequest
(
c
)
apiKeyString
:=
extractAPIKeyF
orGoogle
(
c
)
if
apiKeyString
==
""
{
abortWithGoogleError
(
c
,
401
,
"API key is required"
)
return
...
...
@@ -108,25 +108,38 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
}
}
func
extractAPIKeyFromRequest
(
c
*
gin
.
Context
)
string
{
authHeader
:=
c
.
GetHeader
(
"Authorization"
)
if
authHeader
!=
""
{
parts
:=
strings
.
SplitN
(
authHeader
,
" "
,
2
)
if
len
(
parts
)
==
2
&&
parts
[
0
]
==
"Bearer"
&&
strings
.
TrimSpace
(
parts
[
1
])
!=
""
{
return
strings
.
TrimSpace
(
parts
[
1
])
// extractAPIKeyForGoogle extracts API key for Google/Gemini endpoints.
// Priority: x-goog-api-key > Authorization: Bearer > x-api-key > query key
// This allows OpenClaw and other clients using Bearer auth to work with Gemini endpoints.
func
extractAPIKeyForGoogle
(
c
*
gin
.
Context
)
string
{
// 1) preferred: Gemini native header
if
k
:=
strings
.
TrimSpace
(
c
.
GetHeader
(
"x-goog-api-key"
));
k
!=
""
{
return
k
}
// 2) fallback: Authorization: Bearer <key>
auth
:=
strings
.
TrimSpace
(
c
.
GetHeader
(
"Authorization"
))
if
auth
!=
""
{
parts
:=
strings
.
SplitN
(
auth
,
" "
,
2
)
if
len
(
parts
)
==
2
&&
strings
.
EqualFold
(
parts
[
0
],
"Bearer"
)
{
if
k
:=
strings
.
TrimSpace
(
parts
[
1
]);
k
!=
""
{
return
k
}
if
v
:=
strings
.
TrimSpace
(
c
.
GetHeader
(
"x-api-key"
));
v
!=
""
{
return
v
}
if
v
:=
strings
.
TrimSpace
(
c
.
GetHeader
(
"x-goog-api-key"
));
v
!=
""
{
return
v
}
// 3) x-api-key header (backward compatibility)
if
k
:=
strings
.
TrimSpace
(
c
.
GetHeader
(
"x-api-key"
));
k
!=
""
{
return
k
}
// 4) query parameter key (for specific paths)
if
allowGoogleQueryKey
(
c
.
Request
.
URL
.
Path
)
{
if
v
:=
strings
.
TrimSpace
(
c
.
Query
(
"key"
));
v
!=
""
{
return
v
}
}
return
""
}
...
...
backend/internal/server/middleware/api_key_auth_google_test.go
View file @
dd96ada3
...
...
@@ -75,6 +75,9 @@ func (f fakeAPIKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]s
func
(
f
fakeAPIKeyRepo
)
ListKeysByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
([]
string
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
f
fakeAPIKeyRepo
)
IncrementQuotaUsed
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
(
float64
,
error
)
{
return
0
,
errors
.
New
(
"not implemented"
)
}
type
googleErrorResponse
struct
{
Error
struct
{
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment