Unverified Commit dd96ada3 authored by 程序猿MT's avatar 程序猿MT Committed by GitHub
Browse files

Merge branch 'Wei-Shaw:main' into main

parents 31fe0178 8f397548
...@@ -84,7 +84,7 @@ type CreateAccountRequest struct { ...@@ -84,7 +84,7 @@ type CreateAccountRequest struct {
Name string `json:"name" binding:"required"` Name string `json:"name" binding:"required"`
Notes *string `json:"notes"` Notes *string `json:"notes"`
Platform string `json:"platform" binding:"required"` Platform string `json:"platform" binding:"required"`
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey"` Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream"`
Credentials map[string]any `json:"credentials" binding:"required"` Credentials map[string]any `json:"credentials" binding:"required"`
Extra map[string]any `json:"extra"` Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"` ProxyID *int64 `json:"proxy_id"`
...@@ -102,7 +102,7 @@ type CreateAccountRequest struct { ...@@ -102,7 +102,7 @@ type CreateAccountRequest struct {
type UpdateAccountRequest struct { type UpdateAccountRequest struct {
Name string `json:"name"` Name string `json:"name"`
Notes *string `json:"notes"` Notes *string `json:"notes"`
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey"` Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream"`
Credentials map[string]any `json:"credentials"` Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"` Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"` ProxyID *int64 `json:"proxy_id"`
......
...@@ -40,9 +40,13 @@ type CreateGroupRequest struct { ...@@ -40,9 +40,13 @@ type CreateGroupRequest struct {
ImagePrice4K *float64 `json:"image_price_4k"` ImagePrice4K *float64 `json:"image_price_4k"`
ClaudeCodeOnly bool `json:"claude_code_only"` ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"` FallbackGroupID *int64 `json:"fallback_group_id"`
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
// 模型路由配置(仅 anthropic 平台使用) // 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 `json:"model_routing"` ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled bool `json:"model_routing_enabled"` ModelRoutingEnabled bool `json:"model_routing_enabled"`
MCPXMLInject *bool `json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string `json:"supported_model_scopes"`
// 从指定分组复制账号(创建后自动绑定) // 从指定分组复制账号(创建后自动绑定)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
} }
...@@ -65,9 +69,13 @@ type UpdateGroupRequest struct { ...@@ -65,9 +69,13 @@ type UpdateGroupRequest struct {
ImagePrice4K *float64 `json:"image_price_4k"` ImagePrice4K *float64 `json:"image_price_4k"`
ClaudeCodeOnly *bool `json:"claude_code_only"` ClaudeCodeOnly *bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"` FallbackGroupID *int64 `json:"fallback_group_id"`
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
// 模型路由配置(仅 anthropic 平台使用) // 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 `json:"model_routing"` ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled *bool `json:"model_routing_enabled"` ModelRoutingEnabled *bool `json:"model_routing_enabled"`
MCPXMLInject *bool `json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string `json:"supported_model_scopes"`
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
} }
...@@ -173,8 +181,11 @@ func (h *GroupHandler) Create(c *gin.Context) { ...@@ -173,8 +181,11 @@ func (h *GroupHandler) Create(c *gin.Context) {
ImagePrice4K: req.ImagePrice4K, ImagePrice4K: req.ImagePrice4K,
ClaudeCodeOnly: req.ClaudeCodeOnly, ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID, FallbackGroupID: req.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest,
ModelRouting: req.ModelRouting, ModelRouting: req.ModelRouting,
ModelRoutingEnabled: req.ModelRoutingEnabled, ModelRoutingEnabled: req.ModelRoutingEnabled,
MCPXMLInject: req.MCPXMLInject,
SupportedModelScopes: req.SupportedModelScopes,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
}) })
if err != nil { if err != nil {
...@@ -216,8 +227,11 @@ func (h *GroupHandler) Update(c *gin.Context) { ...@@ -216,8 +227,11 @@ func (h *GroupHandler) Update(c *gin.Context) {
ImagePrice4K: req.ImagePrice4K, ImagePrice4K: req.ImagePrice4K,
ClaudeCodeOnly: req.ClaudeCodeOnly, ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID, FallbackGroupID: req.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest,
ModelRouting: req.ModelRouting, ModelRouting: req.ModelRouting,
ModelRoutingEnabled: req.ModelRoutingEnabled, ModelRoutingEnabled: req.ModelRoutingEnabled,
MCPXMLInject: req.MCPXMLInject,
SupportedModelScopes: req.SupportedModelScopes,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
}) })
if err != nil { if err != nil {
......
...@@ -3,6 +3,7 @@ package handler ...@@ -3,6 +3,7 @@ package handler
import ( import (
"strconv" "strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
...@@ -32,6 +33,8 @@ type CreateAPIKeyRequest struct { ...@@ -32,6 +33,8 @@ type CreateAPIKeyRequest struct {
CustomKey *string `json:"custom_key"` // 可选的自定义key CustomKey *string `json:"custom_key"` // 可选的自定义key
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单 IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单 IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
Quota *float64 `json:"quota"` // 配额限制 (USD)
ExpiresInDays *int `json:"expires_in_days"` // 过期天数
} }
// UpdateAPIKeyRequest represents the update API key request payload // UpdateAPIKeyRequest represents the update API key request payload
...@@ -41,6 +44,9 @@ type UpdateAPIKeyRequest struct { ...@@ -41,6 +44,9 @@ type UpdateAPIKeyRequest struct {
Status string `json:"status" binding:"omitempty,oneof=active inactive"` Status string `json:"status" binding:"omitempty,oneof=active inactive"`
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单 IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单 IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
Quota *float64 `json:"quota"` // 配额限制 (USD), 0=无限制
ExpiresAt *string `json:"expires_at"` // 过期时间 (ISO 8601)
ResetQuota *bool `json:"reset_quota"` // 重置已用配额
} }
// List handles listing user's API keys with pagination // List handles listing user's API keys with pagination
...@@ -119,6 +125,10 @@ func (h *APIKeyHandler) Create(c *gin.Context) { ...@@ -119,6 +125,10 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
CustomKey: req.CustomKey, CustomKey: req.CustomKey,
IPWhitelist: req.IPWhitelist, IPWhitelist: req.IPWhitelist,
IPBlacklist: req.IPBlacklist, IPBlacklist: req.IPBlacklist,
ExpiresInDays: req.ExpiresInDays,
}
if req.Quota != nil {
svcReq.Quota = *req.Quota
} }
key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq) key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq)
if err != nil { if err != nil {
...@@ -153,6 +163,8 @@ func (h *APIKeyHandler) Update(c *gin.Context) { ...@@ -153,6 +163,8 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
svcReq := service.UpdateAPIKeyRequest{ svcReq := service.UpdateAPIKeyRequest{
IPWhitelist: req.IPWhitelist, IPWhitelist: req.IPWhitelist,
IPBlacklist: req.IPBlacklist, IPBlacklist: req.IPBlacklist,
Quota: req.Quota,
ResetQuota: req.ResetQuota,
} }
if req.Name != "" { if req.Name != "" {
svcReq.Name = &req.Name svcReq.Name = &req.Name
...@@ -161,6 +173,21 @@ func (h *APIKeyHandler) Update(c *gin.Context) { ...@@ -161,6 +173,21 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
if req.Status != "" { if req.Status != "" {
svcReq.Status = &req.Status svcReq.Status = &req.Status
} }
// Parse expires_at if provided
if req.ExpiresAt != nil {
if *req.ExpiresAt == "" {
// Empty string means clear expiration
svcReq.ExpiresAt = nil
svcReq.ClearExpiration = true
} else {
t, err := time.Parse(time.RFC3339, *req.ExpiresAt)
if err != nil {
response.BadRequest(c, "Invalid expires_at format: "+err.Error())
return
}
svcReq.ExpiresAt = &t
}
}
key, err := h.apiKeyService.Update(c.Request.Context(), keyID, subject.UserID, svcReq) key, err := h.apiKeyService.Update(c.Request.Context(), keyID, subject.UserID, svcReq)
if err != nil { if err != nil {
......
...@@ -76,6 +76,9 @@ func APIKeyFromService(k *service.APIKey) *APIKey { ...@@ -76,6 +76,9 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
Status: k.Status, Status: k.Status,
IPWhitelist: k.IPWhitelist, IPWhitelist: k.IPWhitelist,
IPBlacklist: k.IPBlacklist, IPBlacklist: k.IPBlacklist,
Quota: k.Quota,
QuotaUsed: k.QuotaUsed,
ExpiresAt: k.ExpiresAt,
CreatedAt: k.CreatedAt, CreatedAt: k.CreatedAt,
UpdatedAt: k.UpdatedAt, UpdatedAt: k.UpdatedAt,
User: UserFromServiceShallow(k.User), User: UserFromServiceShallow(k.User),
...@@ -108,6 +111,8 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup { ...@@ -108,6 +111,8 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
Group: groupFromServiceBase(g), Group: groupFromServiceBase(g),
ModelRouting: g.ModelRouting, ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled, ModelRoutingEnabled: g.ModelRoutingEnabled,
MCPXMLInject: g.MCPXMLInject,
SupportedModelScopes: g.SupportedModelScopes,
AccountCount: g.AccountCount, AccountCount: g.AccountCount,
} }
if len(g.AccountGroups) > 0 { if len(g.AccountGroups) > 0 {
...@@ -138,6 +143,8 @@ func groupFromServiceBase(g *service.Group) Group { ...@@ -138,6 +143,8 @@ func groupFromServiceBase(g *service.Group) Group {
ImagePrice4K: g.ImagePrice4K, ImagePrice4K: g.ImagePrice4K,
ClaudeCodeOnly: g.ClaudeCodeOnly, ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID, FallbackGroupID: g.FallbackGroupID,
// 无效请求兜底分组
FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest,
CreatedAt: g.CreatedAt, CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt, UpdatedAt: g.UpdatedAt,
} }
......
...@@ -40,6 +40,9 @@ type APIKey struct { ...@@ -40,6 +40,9 @@ type APIKey struct {
Status string `json:"status"` Status string `json:"status"`
IPWhitelist []string `json:"ip_whitelist"` IPWhitelist []string `json:"ip_whitelist"`
IPBlacklist []string `json:"ip_blacklist"` IPBlacklist []string `json:"ip_blacklist"`
Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited)
QuotaUsed float64 `json:"quota_used"` // Used quota amount in USD
ExpiresAt *time.Time `json:"expires_at"` // Expiration time (nil = never expires)
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
...@@ -69,6 +72,8 @@ type Group struct { ...@@ -69,6 +72,8 @@ type Group struct {
// Claude Code 客户端限制 // Claude Code 客户端限制
ClaudeCodeOnly bool `json:"claude_code_only"` ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"` FallbackGroupID *int64 `json:"fallback_group_id"`
// 无效请求兜底分组
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
...@@ -83,6 +88,11 @@ type AdminGroup struct { ...@@ -83,6 +88,11 @@ type AdminGroup struct {
ModelRouting map[string][]int64 `json:"model_routing"` ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled bool `json:"model_routing_enabled"` ModelRoutingEnabled bool `json:"model_routing_enabled"`
// MCP XML 协议注入(仅 antigravity 平台使用)
MCPXMLInject bool `json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string `json:"supported_model_scopes"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"` AccountGroups []AccountGroup `json:"account_groups,omitempty"`
AccountCount int64 `json:"account_count,omitempty"` AccountCount int64 `json:"account_count,omitempty"`
} }
......
...@@ -14,6 +14,7 @@ import ( ...@@ -14,6 +14,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
...@@ -31,6 +32,7 @@ type GatewayHandler struct { ...@@ -31,6 +32,7 @@ type GatewayHandler struct {
userService *service.UserService userService *service.UserService
billingCacheService *service.BillingCacheService billingCacheService *service.BillingCacheService
usageService *service.UsageService usageService *service.UsageService
apiKeyService *service.APIKeyService
concurrencyHelper *ConcurrencyHelper concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int maxAccountSwitches int
maxAccountSwitchesGemini int maxAccountSwitchesGemini int
...@@ -45,6 +47,7 @@ func NewGatewayHandler( ...@@ -45,6 +47,7 @@ func NewGatewayHandler(
concurrencyService *service.ConcurrencyService, concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService, billingCacheService *service.BillingCacheService,
usageService *service.UsageService, usageService *service.UsageService,
apiKeyService *service.APIKeyService,
cfg *config.Config, cfg *config.Config,
) *GatewayHandler { ) *GatewayHandler {
pingInterval := time.Duration(0) pingInterval := time.Duration(0)
...@@ -66,6 +69,7 @@ func NewGatewayHandler( ...@@ -66,6 +69,7 @@ func NewGatewayHandler(
userService: userService, userService: userService,
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
usageService: usageService, usageService: usageService,
apiKeyService: apiKeyService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval), concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
maxAccountSwitches: maxAccountSwitches, maxAccountSwitches: maxAccountSwitches,
maxAccountSwitchesGemini: maxAccountSwitchesGemini, maxAccountSwitchesGemini: maxAccountSwitchesGemini,
...@@ -281,10 +285,14 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -281,10 +285,14 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 转发请求 - 根据账号平台分流 // 转发请求 - 根据账号平台分流
var result *service.ForwardResult var result *service.ForwardResult
requestCtx := c.Request.Context()
if switchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
}
if account.Platform == service.PlatformAntigravity { if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, reqModel, "generateContent", reqStream, body) result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body)
} else { } else {
result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body) result, err = h.geminiCompatService.Forward(requestCtx, c, account, body)
} }
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
accountReleaseFunc() accountReleaseFunc()
...@@ -323,6 +331,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -323,6 +331,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
Subscription: subscription, Subscription: subscription,
UserAgent: ua, UserAgent: ua,
IPAddress: clientIP, IPAddress: clientIP,
APIKeyService: h.apiKeyService,
}); err != nil { }); err != nil {
log.Printf("Record usage failed: %v", err) log.Printf("Record usage failed: %v", err)
} }
...@@ -331,14 +340,24 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -331,14 +340,24 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} }
} }
currentAPIKey := apiKey
currentSubscription := subscription
var fallbackGroupID *int64
if apiKey.Group != nil {
fallbackGroupID = apiKey.Group.FallbackGroupIDOnInvalidRequest
}
fallbackUsed := false
for {
maxAccountSwitches := h.maxAccountSwitches maxAccountSwitches := h.maxAccountSwitches
switchCount := 0 switchCount := 0
failedAccountIDs := make(map[int64]struct{}) failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0 lastFailoverStatus := 0
retryWithFallback := false
for { for {
// 选择支持该模型的账号 // 选择支持该模型的账号
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID) selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID)
if err != nil { if err != nil {
if len(failedAccountIDs) == 0 { if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
...@@ -408,7 +427,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -408,7 +427,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false accountWaitCounted = false
} }
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil { if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err) log.Printf("Bind sticky session failed: %v", err)
} }
} }
...@@ -417,15 +436,54 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -417,15 +436,54 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 转发请求 - 根据账号平台分流 // 转发请求 - 根据账号平台分流
var result *service.ForwardResult var result *service.ForwardResult
requestCtx := c.Request.Context()
if switchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
}
if account.Platform == service.PlatformAntigravity { if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, body) result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body)
} else { } else {
result, err = h.gatewayService.Forward(c.Request.Context(), c, account, parsedReq) result, err = h.gatewayService.Forward(requestCtx, c, account, parsedReq)
} }
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
accountReleaseFunc() accountReleaseFunc()
} }
if err != nil { if err != nil {
var promptTooLongErr *service.PromptTooLongError
if errors.As(err, &promptTooLongErr) {
log.Printf("Prompt too long from antigravity: group=%d fallback_group_id=%v fallback_used=%v", currentAPIKey.GroupID, fallbackGroupID, fallbackUsed)
if !fallbackUsed && fallbackGroupID != nil && *fallbackGroupID > 0 {
fallbackGroup, err := h.gatewayService.ResolveGroupByID(c.Request.Context(), *fallbackGroupID)
if err != nil {
log.Printf("Resolve fallback group failed: %v", err)
_ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body)
return
}
if fallbackGroup.Platform != service.PlatformAnthropic ||
fallbackGroup.SubscriptionType == service.SubscriptionTypeSubscription ||
fallbackGroup.FallbackGroupIDOnInvalidRequest != nil {
log.Printf("Fallback group invalid: group=%d platform=%s subscription=%s", fallbackGroup.ID, fallbackGroup.Platform, fallbackGroup.SubscriptionType)
_ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body)
return
}
fallbackAPIKey := cloneAPIKeyWithGroup(apiKey, fallbackGroup)
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), fallbackAPIKey.User, fallbackAPIKey, fallbackGroup, nil); err != nil {
status, code, message := billingErrorDetails(err)
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
// 兜底重试按“直接请求兜底分组”处理:清除强制平台,允许按分组平台调度
ctx := context.WithValue(c.Request.Context(), ctxkey.ForcePlatform, "")
c.Request = c.Request.WithContext(ctx)
currentAPIKey = fallbackAPIKey
currentSubscription = nil
fallbackUsed = true
retryWithFallback = true
break
}
_ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body)
return
}
var failoverErr *service.UpstreamFailoverError var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) { if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{} failedAccountIDs[account.ID] = struct{}{}
...@@ -453,18 +511,23 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -453,18 +511,23 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
defer cancel() defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result, Result: result,
APIKey: apiKey, APIKey: currentAPIKey,
User: apiKey.User, User: currentAPIKey.User,
Account: usedAccount, Account: usedAccount,
Subscription: subscription, Subscription: currentSubscription,
UserAgent: ua, UserAgent: ua,
IPAddress: clientIP, IPAddress: clientIP,
APIKeyService: h.apiKeyService,
}); err != nil { }); err != nil {
log.Printf("Record usage failed: %v", err) log.Printf("Record usage failed: %v", err)
} }
}(result, account, userAgent, clientIP) }(result, account, userAgent, clientIP)
return return
} }
if !retryWithFallback {
return
}
}
} }
// Models handles listing available models // Models handles listing available models
...@@ -527,6 +590,17 @@ func (h *GatewayHandler) AntigravityModels(c *gin.Context) { ...@@ -527,6 +590,17 @@ func (h *GatewayHandler) AntigravityModels(c *gin.Context) {
}) })
} }
func cloneAPIKeyWithGroup(apiKey *service.APIKey, group *service.Group) *service.APIKey {
if apiKey == nil || group == nil {
return apiKey
}
cloned := *apiKey
groupID := group.ID
cloned.GroupID = &groupID
cloned.Group = group
return &cloned
}
// Usage handles getting account balance and usage statistics for CC Switch integration // Usage handles getting account balance and usage statistics for CC Switch integration
// GET /v1/usage // GET /v1/usage
func (h *GatewayHandler) Usage(c *gin.Context) { func (h *GatewayHandler) Usage(c *gin.Context) {
......
...@@ -14,6 +14,7 @@ import ( ...@@ -14,6 +14,7 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini" "github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/ip"
...@@ -335,10 +336,14 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -335,10 +336,14 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 5) forward (根据平台分流) // 5) forward (根据平台分流)
var result *service.ForwardResult var result *service.ForwardResult
requestCtx := c.Request.Context()
if switchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
}
if account.Platform == service.PlatformAntigravity { if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, modelName, action, stream, body) result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body)
} else { } else {
result, err = h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body) result, err = h.geminiCompatService.ForwardNative(requestCtx, c, account, modelName, action, stream, body)
} }
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
accountReleaseFunc() accountReleaseFunc()
...@@ -381,6 +386,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -381,6 +386,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
IPAddress: ip, IPAddress: ip,
LongContextThreshold: 200000, // Gemini 200K 阈值 LongContextThreshold: 200000, // Gemini 200K 阈值
LongContextMultiplier: 2.0, // 超出部分双倍计费 LongContextMultiplier: 2.0, // 超出部分双倍计费
APIKeyService: h.apiKeyService,
}); err != nil { }); err != nil {
log.Printf("Record usage failed: %v", err) log.Printf("Record usage failed: %v", err)
} }
......
...@@ -24,6 +24,7 @@ import ( ...@@ -24,6 +24,7 @@ import (
type OpenAIGatewayHandler struct { type OpenAIGatewayHandler struct {
gatewayService *service.OpenAIGatewayService gatewayService *service.OpenAIGatewayService
billingCacheService *service.BillingCacheService billingCacheService *service.BillingCacheService
apiKeyService *service.APIKeyService
concurrencyHelper *ConcurrencyHelper concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int maxAccountSwitches int
} }
...@@ -33,6 +34,7 @@ func NewOpenAIGatewayHandler( ...@@ -33,6 +34,7 @@ func NewOpenAIGatewayHandler(
gatewayService *service.OpenAIGatewayService, gatewayService *service.OpenAIGatewayService,
concurrencyService *service.ConcurrencyService, concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService, billingCacheService *service.BillingCacheService,
apiKeyService *service.APIKeyService,
cfg *config.Config, cfg *config.Config,
) *OpenAIGatewayHandler { ) *OpenAIGatewayHandler {
pingInterval := time.Duration(0) pingInterval := time.Duration(0)
...@@ -46,6 +48,7 @@ func NewOpenAIGatewayHandler( ...@@ -46,6 +48,7 @@ func NewOpenAIGatewayHandler(
return &OpenAIGatewayHandler{ return &OpenAIGatewayHandler{
gatewayService: gatewayService, gatewayService: gatewayService,
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
apiKeyService: apiKeyService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
maxAccountSwitches: maxAccountSwitches, maxAccountSwitches: maxAccountSwitches,
} }
...@@ -306,6 +309,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -306,6 +309,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
Subscription: subscription, Subscription: subscription,
UserAgent: ua, UserAgent: ua,
IPAddress: ip, IPAddress: ip,
APIKeyService: h.apiKeyService,
}); err != nil { }); err != nil {
log.Printf("Record usage failed: %v", err) log.Printf("Record usage failed: %v", err)
} }
......
...@@ -40,17 +40,48 @@ const ( ...@@ -40,17 +40,48 @@ const (
// URL 可用性 TTL(不可用 URL 的恢复时间) // URL 可用性 TTL(不可用 URL 的恢复时间)
URLAvailabilityTTL = 5 * time.Minute URLAvailabilityTTL = 5 * time.Minute
// Antigravity API 端点
antigravityProdBaseURL = "https://cloudcode-pa.googleapis.com"
antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
) )
// BaseURLs 定义 Antigravity API 端点(与 Antigravity-Manager 保持一致) // BaseURLs 定义 Antigravity API 端点(与 Antigravity-Manager 保持一致)
var BaseURLs = []string{ var BaseURLs = []string{
"https://cloudcode-pa.googleapis.com", // prod (优先) antigravityProdBaseURL, // prod (优先)
"https://daily-cloudcode-pa.sandbox.googleapis.com", // daily sandbox (备用) antigravityDailyBaseURL, // daily sandbox (备用)
} }
// BaseURL 默认 URL(保持向后兼容) // BaseURL 默认 URL(保持向后兼容)
var BaseURL = BaseURLs[0] var BaseURL = BaseURLs[0]
// ForwardBaseURLs 返回 API 转发用的 URL 顺序(daily 优先)
func ForwardBaseURLs() []string {
if len(BaseURLs) == 0 {
return nil
}
urls := append([]string(nil), BaseURLs...)
dailyIndex := -1
for i, url := range urls {
if url == antigravityDailyBaseURL {
dailyIndex = i
break
}
}
if dailyIndex <= 0 {
return urls
}
reordered := make([]string, 0, len(urls))
reordered = append(reordered, urls[dailyIndex])
for i, url := range urls {
if i == dailyIndex {
continue
}
reordered = append(reordered, url)
}
return reordered
}
// URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复和动态优先级) // URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复和动态优先级)
type URLAvailability struct { type URLAvailability struct {
mu sync.RWMutex mu sync.RWMutex
...@@ -100,22 +131,37 @@ func (u *URLAvailability) IsAvailable(url string) bool { ...@@ -100,22 +131,37 @@ func (u *URLAvailability) IsAvailable(url string) bool {
// GetAvailableURLs 返回可用的 URL 列表 // GetAvailableURLs 返回可用的 URL 列表
// 最近成功的 URL 优先,其他按默认顺序 // 最近成功的 URL 优先,其他按默认顺序
func (u *URLAvailability) GetAvailableURLs() []string { func (u *URLAvailability) GetAvailableURLs() []string {
return u.GetAvailableURLsWithBase(BaseURLs)
}
// GetAvailableURLsWithBase 返回可用的 URL 列表(使用自定义顺序)
// 最近成功的 URL 优先,其他按传入顺序
func (u *URLAvailability) GetAvailableURLsWithBase(baseURLs []string) []string {
u.mu.RLock() u.mu.RLock()
defer u.mu.RUnlock() defer u.mu.RUnlock()
now := time.Now() now := time.Now()
result := make([]string, 0, len(BaseURLs)) result := make([]string, 0, len(baseURLs))
// 如果有最近成功的 URL 且可用,放在最前面 // 如果有最近成功的 URL 且可用,放在最前面
if u.lastSuccess != "" { if u.lastSuccess != "" {
found := false
for _, url := range baseURLs {
if url == u.lastSuccess {
found = true
break
}
}
if found {
expiry, exists := u.unavailable[u.lastSuccess] expiry, exists := u.unavailable[u.lastSuccess]
if !exists || now.After(expiry) { if !exists || now.After(expiry) {
result = append(result, u.lastSuccess) result = append(result, u.lastSuccess)
} }
} }
}
// 添加其他可用的 URL(按默认顺序) // 添加其他可用的 URL(按传入顺序)
for _, url := range BaseURLs { for _, url := range baseURLs {
// 跳过已添加的 lastSuccess // 跳过已添加的 lastSuccess
if url == u.lastSuccess { if url == u.lastSuccess {
continue continue
......
...@@ -44,11 +44,13 @@ type TransformOptions struct { ...@@ -44,11 +44,13 @@ type TransformOptions struct {
// IdentityPatch 可选:自定义注入到 systemInstruction 开头的身份防护提示词; // IdentityPatch 可选:自定义注入到 systemInstruction 开头的身份防护提示词;
// 为空时使用默认模板(包含 [IDENTITY_PATCH] 及 SYSTEM_PROMPT_BEGIN 标记)。 // 为空时使用默认模板(包含 [IDENTITY_PATCH] 及 SYSTEM_PROMPT_BEGIN 标记)。
IdentityPatch string IdentityPatch string
EnableMCPXML bool
} }
func DefaultTransformOptions() TransformOptions { func DefaultTransformOptions() TransformOptions {
return TransformOptions{ return TransformOptions{
EnableIdentityPatch: true, EnableIdentityPatch: true,
EnableMCPXML: true,
} }
} }
...@@ -257,8 +259,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans ...@@ -257,8 +259,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
// 添加用户的 system prompt // 添加用户的 system prompt
parts = append(parts, userSystemParts...) parts = append(parts, userSystemParts...)
// 检测是否有 MCP 工具,如有则注入 XML 调用协议 // 检测是否有 MCP 工具,如有且启用了 MCP XML 注入则注入 XML 调用协议
if hasMCPTools(tools) { if opts.EnableMCPXML && hasMCPTools(tools) {
parts = append(parts, GeminiPart{Text: mcpXMLProtocol}) parts = append(parts, GeminiPart{Text: mcpXMLProtocol})
} }
...@@ -312,7 +314,7 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT ...@@ -312,7 +314,7 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
parts = append([]GeminiPart{{ parts = append([]GeminiPart{{
Text: "Thinking...", Text: "Thinking...",
Thought: true, Thought: true,
ThoughtSignature: dummyThoughtSignature, ThoughtSignature: DummyThoughtSignature,
}}, parts...) }}, parts...)
} }
} }
...@@ -330,9 +332,10 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT ...@@ -330,9 +332,10 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
return contents, strippedThinking, nil return contents, strippedThinking, nil
} }
// dummyThoughtSignature 用于跳过 Gemini 3 thought_signature 验证 // DummyThoughtSignature 用于跳过 Gemini 3 thought_signature 验证
// 参考: https://ai.google.dev/gemini-api/docs/thought-signatures // 参考: https://ai.google.dev/gemini-api/docs/thought-signatures
const dummyThoughtSignature = "skip_thought_signature_validator" // 导出供跨包使用(如 gemini_native_signature_cleaner 跨账号修复)
const DummyThoughtSignature = "skip_thought_signature_validator"
// buildParts 构建消息的 parts // buildParts 构建消息的 parts
// allowDummyThought: 只有 Gemini 模型支持 dummy thought signature // allowDummyThought: 只有 Gemini 模型支持 dummy thought signature
...@@ -370,7 +373,7 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu ...@@ -370,7 +373,7 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
// signature 处理: // signature 处理:
// - Claude 模型(allowDummyThought=false):必须是上游返回的真实 signature(dummy 视为缺失) // - Claude 模型(allowDummyThought=false):必须是上游返回的真实 signature(dummy 视为缺失)
// - Gemini 模型(allowDummyThought=true):优先透传真实 signature,缺失时使用 dummy signature // - Gemini 模型(allowDummyThought=true):优先透传真实 signature,缺失时使用 dummy signature
if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) { if block.Signature != "" && (allowDummyThought || block.Signature != DummyThoughtSignature) {
part.ThoughtSignature = block.Signature part.ThoughtSignature = block.Signature
} else if !allowDummyThought { } else if !allowDummyThought {
// Claude 模型需要有效 signature;在缺失时降级为普通文本,并在上层禁用 thinking mode。 // Claude 模型需要有效 signature;在缺失时降级为普通文本,并在上层禁用 thinking mode。
...@@ -381,7 +384,7 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu ...@@ -381,7 +384,7 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
continue continue
} else { } else {
// Gemini 模型使用 dummy signature // Gemini 模型使用 dummy signature
part.ThoughtSignature = dummyThoughtSignature part.ThoughtSignature = DummyThoughtSignature
} }
parts = append(parts, part) parts = append(parts, part)
...@@ -411,10 +414,10 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu ...@@ -411,10 +414,10 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
// tool_use 的 signature 处理: // tool_use 的 signature 处理:
// - Claude 模型(allowDummyThought=false):必须是上游返回的真实 signature(dummy 视为缺失) // - Claude 模型(allowDummyThought=false):必须是上游返回的真实 signature(dummy 视为缺失)
// - Gemini 模型(allowDummyThought=true):优先透传真实 signature,缺失时使用 dummy signature // - Gemini 模型(allowDummyThought=true):优先透传真实 signature,缺失时使用 dummy signature
if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) { if block.Signature != "" && (allowDummyThought || block.Signature != DummyThoughtSignature) {
part.ThoughtSignature = block.Signature part.ThoughtSignature = block.Signature
} else if allowDummyThought { } else if allowDummyThought {
part.ThoughtSignature = dummyThoughtSignature part.ThoughtSignature = DummyThoughtSignature
} }
parts = append(parts, part) parts = append(parts, part)
...@@ -492,9 +495,23 @@ func parseToolResultContent(content json.RawMessage, isError bool) string { ...@@ -492,9 +495,23 @@ func parseToolResultContent(content json.RawMessage, isError bool) string {
} }
// buildGenerationConfig 构建 generationConfig // buildGenerationConfig 构建 generationConfig
const (
defaultMaxOutputTokens = 64000
maxOutputTokensUpperBound = 65000
maxOutputTokensClaude = 64000
)
func maxOutputTokensLimit(model string) int {
if strings.HasPrefix(model, "claude-") {
return maxOutputTokensClaude
}
return maxOutputTokensUpperBound
}
func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig { func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
maxLimit := maxOutputTokensLimit(req.Model)
config := &GeminiGenerationConfig{ config := &GeminiGenerationConfig{
MaxOutputTokens: 64000, // 默认最大输出 MaxOutputTokens: defaultMaxOutputTokens, // 默认最大输出
StopSequences: DefaultStopSequences, StopSequences: DefaultStopSequences,
} }
...@@ -518,6 +535,10 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig { ...@@ -518,6 +535,10 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
} }
} }
if config.MaxOutputTokens > maxLimit {
config.MaxOutputTokens = maxLimit
}
// 其他参数 // 其他参数
if req.Temperature != nil { if req.Temperature != nil {
config.Temperature = req.Temperature config.Temperature = req.Temperature
......
...@@ -86,7 +86,7 @@ func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) { ...@@ -86,7 +86,7 @@ func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) {
if len(parts) != 3 { if len(parts) != 3 {
t.Fatalf("expected 3 parts, got %d", len(parts)) t.Fatalf("expected 3 parts, got %d", len(parts))
} }
if !parts[1].Thought || parts[1].ThoughtSignature != dummyThoughtSignature { if !parts[1].Thought || parts[1].ThoughtSignature != DummyThoughtSignature {
t.Fatalf("expected dummy thought signature, got thought=%v signature=%q", t.Fatalf("expected dummy thought signature, got thought=%v signature=%q",
parts[1].Thought, parts[1].ThoughtSignature) parts[1].Thought, parts[1].ThoughtSignature)
} }
...@@ -126,8 +126,8 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) { ...@@ -126,8 +126,8 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
if len(parts) != 1 || parts[0].FunctionCall == nil { if len(parts) != 1 || parts[0].FunctionCall == nil {
t.Fatalf("expected 1 functionCall part, got %+v", parts) t.Fatalf("expected 1 functionCall part, got %+v", parts)
} }
if parts[0].ThoughtSignature != dummyThoughtSignature { if parts[0].ThoughtSignature != DummyThoughtSignature {
t.Fatalf("expected dummy tool signature %q, got %q", dummyThoughtSignature, parts[0].ThoughtSignature) t.Fatalf("expected dummy tool signature %q, got %q", DummyThoughtSignature, parts[0].ThoughtSignature)
} }
}) })
......
...@@ -14,6 +14,9 @@ const ( ...@@ -14,6 +14,9 @@ const (
// RetryCount 表示当前请求在网关层的重试次数(用于 Ops 记录与排障)。 // RetryCount 表示当前请求在网关层的重试次数(用于 Ops 记录与排障)。
RetryCount Key = "ctx_retry_count" RetryCount Key = "ctx_retry_count"
// AccountSwitchCount 表示请求过程中发生的账号切换次数
AccountSwitchCount Key = "ctx_account_switch_count"
// IsClaudeCodeClient 标识当前请求是否来自 Claude Code 客户端 // IsClaudeCodeClient 标识当前请求是否来自 Claude Code 客户端
IsClaudeCodeClient Key = "ctx_is_claude_code_client" IsClaudeCodeClient Key = "ctx_is_claude_code_client"
// Group 认证后的分组信息,由 API Key 认证中间件设置 // Group 认证后的分组信息,由 API Key 认证中间件设置
......
...@@ -33,7 +33,10 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) erro ...@@ -33,7 +33,10 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) erro
SetKey(key.Key). SetKey(key.Key).
SetName(key.Name). SetName(key.Name).
SetStatus(key.Status). SetStatus(key.Status).
SetNillableGroupID(key.GroupID) SetNillableGroupID(key.GroupID).
SetQuota(key.Quota).
SetQuotaUsed(key.QuotaUsed).
SetNillableExpiresAt(key.ExpiresAt)
if len(key.IPWhitelist) > 0 { if len(key.IPWhitelist) > 0 {
builder.SetIPWhitelist(key.IPWhitelist) builder.SetIPWhitelist(key.IPWhitelist)
...@@ -110,6 +113,9 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se ...@@ -110,6 +113,9 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
apikey.FieldStatus, apikey.FieldStatus,
apikey.FieldIPWhitelist, apikey.FieldIPWhitelist,
apikey.FieldIPBlacklist, apikey.FieldIPBlacklist,
apikey.FieldQuota,
apikey.FieldQuotaUsed,
apikey.FieldExpiresAt,
). ).
WithUser(func(q *dbent.UserQuery) { WithUser(func(q *dbent.UserQuery) {
q.Select( q.Select(
...@@ -136,8 +142,11 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se ...@@ -136,8 +142,11 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
group.FieldImagePrice4k, group.FieldImagePrice4k,
group.FieldClaudeCodeOnly, group.FieldClaudeCodeOnly,
group.FieldFallbackGroupID, group.FieldFallbackGroupID,
group.FieldFallbackGroupIDOnInvalidRequest,
group.FieldModelRoutingEnabled, group.FieldModelRoutingEnabled,
group.FieldModelRouting, group.FieldModelRouting,
group.FieldMcpXMLInject,
group.FieldSupportedModelScopes,
) )
}). }).
Only(ctx) Only(ctx)
...@@ -161,6 +170,8 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro ...@@ -161,6 +170,8 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
Where(apikey.IDEQ(key.ID), apikey.DeletedAtIsNil()). Where(apikey.IDEQ(key.ID), apikey.DeletedAtIsNil()).
SetName(key.Name). SetName(key.Name).
SetStatus(key.Status). SetStatus(key.Status).
SetQuota(key.Quota).
SetQuotaUsed(key.QuotaUsed).
SetUpdatedAt(now) SetUpdatedAt(now)
if key.GroupID != nil { if key.GroupID != nil {
builder.SetGroupID(*key.GroupID) builder.SetGroupID(*key.GroupID)
...@@ -168,6 +179,13 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro ...@@ -168,6 +179,13 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
builder.ClearGroupID() builder.ClearGroupID()
} }
// Expiration time
if key.ExpiresAt != nil {
builder.SetExpiresAt(*key.ExpiresAt)
} else {
builder.ClearExpiresAt()
}
// IP 限制字段 // IP 限制字段
if len(key.IPWhitelist) > 0 { if len(key.IPWhitelist) > 0 {
builder.SetIPWhitelist(key.IPWhitelist) builder.SetIPWhitelist(key.IPWhitelist)
...@@ -357,6 +375,38 @@ func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64) ...@@ -357,6 +375,38 @@ func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64)
return keys, nil return keys, nil
} }
// IncrementQuotaUsed atomically increments the quota_used field and returns the new value
func (r *apiKeyRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) {
// Use raw SQL for atomic increment to avoid race conditions
// First get current value
m, err := r.activeQuery().
Where(apikey.IDEQ(id)).
Select(apikey.FieldQuotaUsed).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return 0, service.ErrAPIKeyNotFound
}
return 0, err
}
newValue := m.QuotaUsed + amount
// Update with new value
affected, err := r.client.APIKey.Update().
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
SetQuotaUsed(newValue).
Save(ctx)
if err != nil {
return 0, err
}
if affected == 0 {
return 0, service.ErrAPIKeyNotFound
}
return newValue, nil
}
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey { func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
if m == nil { if m == nil {
return nil return nil
...@@ -372,6 +422,9 @@ func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey { ...@@ -372,6 +422,9 @@ func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
CreatedAt: m.CreatedAt, CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt, UpdatedAt: m.UpdatedAt,
GroupID: m.GroupID, GroupID: m.GroupID,
Quota: m.Quota,
QuotaUsed: m.QuotaUsed,
ExpiresAt: m.ExpiresAt,
} }
if m.Edges.User != nil { if m.Edges.User != nil {
out.User = userEntityToService(m.Edges.User) out.User = userEntityToService(m.Edges.User)
...@@ -427,8 +480,11 @@ func groupEntityToService(g *dbent.Group) *service.Group { ...@@ -427,8 +480,11 @@ func groupEntityToService(g *dbent.Group) *service.Group {
DefaultValidityDays: g.DefaultValidityDays, DefaultValidityDays: g.DefaultValidityDays,
ClaudeCodeOnly: g.ClaudeCodeOnly, ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID, FallbackGroupID: g.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest,
ModelRouting: g.ModelRouting, ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled, ModelRoutingEnabled: g.ModelRoutingEnabled,
MCPXMLInject: g.McpXMLInject,
SupportedModelScopes: g.SupportedModelScopes,
CreatedAt: g.CreatedAt, CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt, UpdatedAt: g.UpdatedAt,
} }
......
...@@ -50,13 +50,18 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er ...@@ -50,13 +50,18 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetDefaultValidityDays(groupIn.DefaultValidityDays). SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetNillableFallbackGroupID(groupIn.FallbackGroupID). SetNillableFallbackGroupID(groupIn.FallbackGroupID).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled) SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
SetMcpXMLInject(groupIn.MCPXMLInject)
// 设置模型路由配置 // 设置模型路由配置
if groupIn.ModelRouting != nil { if groupIn.ModelRouting != nil {
builder = builder.SetModelRouting(groupIn.ModelRouting) builder = builder.SetModelRouting(groupIn.ModelRouting)
} }
// 设置支持的模型系列(始终设置,空数组表示不限制)
builder = builder.SetSupportedModelScopes(groupIn.SupportedModelScopes)
created, err := builder.Save(ctx) created, err := builder.Save(ctx)
if err == nil { if err == nil {
groupIn.ID = created.ID groupIn.ID = created.ID
...@@ -87,7 +92,6 @@ func (r *groupRepository) GetByIDLite(ctx context.Context, id int64) (*service.G ...@@ -87,7 +92,6 @@ func (r *groupRepository) GetByIDLite(ctx context.Context, id int64) (*service.G
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil) return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil)
} }
return groupEntityToService(m), nil return groupEntityToService(m), nil
} }
...@@ -108,7 +112,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er ...@@ -108,7 +112,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice4k(groupIn.ImagePrice4K). SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetDefaultValidityDays(groupIn.DefaultValidityDays). SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled) SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
SetMcpXMLInject(groupIn.MCPXMLInject)
// 处理 FallbackGroupID:nil 时清除,否则设置 // 处理 FallbackGroupID:nil 时清除,否则设置
if groupIn.FallbackGroupID != nil { if groupIn.FallbackGroupID != nil {
...@@ -116,6 +121,12 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er ...@@ -116,6 +121,12 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
} else { } else {
builder = builder.ClearFallbackGroupID() builder = builder.ClearFallbackGroupID()
} }
// 处理 FallbackGroupIDOnInvalidRequest:nil 时清除,否则设置
if groupIn.FallbackGroupIDOnInvalidRequest != nil {
builder = builder.SetFallbackGroupIDOnInvalidRequest(*groupIn.FallbackGroupIDOnInvalidRequest)
} else {
builder = builder.ClearFallbackGroupIDOnInvalidRequest()
}
// 处理 ModelRouting:nil 时清除,否则设置 // 处理 ModelRouting:nil 时清除,否则设置
if groupIn.ModelRouting != nil { if groupIn.ModelRouting != nil {
...@@ -124,6 +135,9 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er ...@@ -124,6 +135,9 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
builder = builder.ClearModelRouting() builder = builder.ClearModelRouting()
} }
// 处理 SupportedModelScopes(始终设置,空数组表示不限制)
builder = builder.SetSupportedModelScopes(groupIn.SupportedModelScopes)
updated, err := builder.Save(ctx) updated, err := builder.Save(ctx)
if err != nil { if err != nil {
return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists) return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists)
......
...@@ -43,6 +43,7 @@ INSERT INTO ops_system_metrics ( ...@@ -43,6 +43,7 @@ INSERT INTO ops_system_metrics (
upstream_529_count, upstream_529_count,
token_consumed, token_consumed,
account_switch_count,
qps, qps,
tps, tps,
...@@ -81,14 +82,14 @@ INSERT INTO ops_system_metrics ( ...@@ -81,14 +82,14 @@ INSERT INTO ops_system_metrics (
$1,$2,$3,$4, $1,$2,$3,$4,
$5,$6,$7,$8, $5,$6,$7,$8,
$9,$10,$11, $9,$10,$11,
$12,$13,$14, $12,$13,$14,$15,
$15,$16,$17,$18,$19,$20, $16,$17,$18,$19,$20,$21,
$21,$22,$23,$24,$25,$26, $22,$23,$24,$25,$26,$27,
$27,$28,$29,$30, $28,$29,$30,$31,
$31,$32, $32,$33,
$33,$34, $34,$35,
$35,$36,$37, $36,$37,$38,
$38,$39 $39,$40
)` )`
_, err := r.db.ExecContext( _, err := r.db.ExecContext(
...@@ -109,6 +110,7 @@ INSERT INTO ops_system_metrics ( ...@@ -109,6 +110,7 @@ INSERT INTO ops_system_metrics (
input.Upstream529Count, input.Upstream529Count,
input.TokenConsumed, input.TokenConsumed,
input.AccountSwitchCount,
opsNullFloat64(input.QPS), opsNullFloat64(input.QPS),
opsNullFloat64(input.TPS), opsNullFloat64(input.TPS),
...@@ -177,7 +179,8 @@ SELECT ...@@ -177,7 +179,8 @@ SELECT
db_conn_waiting, db_conn_waiting,
goroutine_count, goroutine_count,
concurrency_queue_depth concurrency_queue_depth,
account_switch_count
FROM ops_system_metrics FROM ops_system_metrics
WHERE window_minutes = $1 WHERE window_minutes = $1
AND platform IS NULL AND platform IS NULL
...@@ -199,6 +202,7 @@ LIMIT 1` ...@@ -199,6 +202,7 @@ LIMIT 1`
var dbWaiting sql.NullInt64 var dbWaiting sql.NullInt64
var goroutines sql.NullInt64 var goroutines sql.NullInt64
var queueDepth sql.NullInt64 var queueDepth sql.NullInt64
var accountSwitchCount sql.NullInt64
if err := r.db.QueryRowContext(ctx, q, windowMinutes).Scan( if err := r.db.QueryRowContext(ctx, q, windowMinutes).Scan(
&out.ID, &out.ID,
...@@ -217,6 +221,7 @@ LIMIT 1` ...@@ -217,6 +221,7 @@ LIMIT 1`
&dbWaiting, &dbWaiting,
&goroutines, &goroutines,
&queueDepth, &queueDepth,
&accountSwitchCount,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
...@@ -273,6 +278,10 @@ LIMIT 1` ...@@ -273,6 +278,10 @@ LIMIT 1`
v := int(queueDepth.Int64) v := int(queueDepth.Int64)
out.ConcurrencyQueueDepth = &v out.ConcurrencyQueueDepth = &v
} }
if accountSwitchCount.Valid {
v := accountSwitchCount.Int64
out.AccountSwitchCount = &v
}
return &out, nil return &out, nil
} }
......
...@@ -56,18 +56,44 @@ error_buckets AS ( ...@@ -56,18 +56,44 @@ error_buckets AS (
AND COALESCE(status_code, 0) >= 400 AND COALESCE(status_code, 0) >= 400
GROUP BY 1 GROUP BY 1
), ),
switch_buckets AS (
SELECT ` + errorBucketExpr + ` AS bucket,
COALESCE(SUM(CASE
WHEN split_part(ev->>'kind', ':', 1) IN ('failover', 'retry_exhausted_failover', 'failover_on_400') THEN 1
ELSE 0
END), 0) AS switch_count
FROM ops_error_logs
CROSS JOIN LATERAL jsonb_array_elements(
COALESCE(NULLIF(upstream_errors, 'null'::jsonb), '[]'::jsonb)
) AS ev
` + errorWhere + `
AND upstream_errors IS NOT NULL
GROUP BY 1
),
combined AS ( combined AS (
SELECT COALESCE(u.bucket, e.bucket) AS bucket, SELECT
COALESCE(u.success_count, 0) AS success_count, bucket,
COALESCE(e.error_count, 0) AS error_count, SUM(success_count) AS success_count,
COALESCE(u.token_consumed, 0) AS token_consumed SUM(error_count) AS error_count,
FROM usage_buckets u SUM(token_consumed) AS token_consumed,
FULL OUTER JOIN error_buckets e ON u.bucket = e.bucket SUM(switch_count) AS switch_count
FROM (
SELECT bucket, success_count, 0 AS error_count, token_consumed, 0 AS switch_count
FROM usage_buckets
UNION ALL
SELECT bucket, 0, error_count, 0, 0
FROM error_buckets
UNION ALL
SELECT bucket, 0, 0, 0, switch_count
FROM switch_buckets
) t
GROUP BY bucket
) )
SELECT SELECT
bucket, bucket,
(success_count + error_count) AS request_count, (success_count + error_count) AS request_count,
token_consumed token_consumed,
switch_count
FROM combined FROM combined
ORDER BY bucket ASC` ORDER BY bucket ASC`
...@@ -84,13 +110,18 @@ ORDER BY bucket ASC` ...@@ -84,13 +110,18 @@ ORDER BY bucket ASC`
var bucket time.Time var bucket time.Time
var requests int64 var requests int64
var tokens sql.NullInt64 var tokens sql.NullInt64
if err := rows.Scan(&bucket, &requests, &tokens); err != nil { var switches sql.NullInt64
if err := rows.Scan(&bucket, &requests, &tokens, &switches); err != nil {
return nil, err return nil, err
} }
tokenConsumed := int64(0) tokenConsumed := int64(0)
if tokens.Valid { if tokens.Valid {
tokenConsumed = tokens.Int64 tokenConsumed = tokens.Int64
} }
switchCount := int64(0)
if switches.Valid {
switchCount = switches.Int64
}
denom := float64(bucketSeconds) denom := float64(bucketSeconds)
if denom <= 0 { if denom <= 0 {
...@@ -103,6 +134,7 @@ ORDER BY bucket ASC` ...@@ -103,6 +134,7 @@ ORDER BY bucket ASC`
BucketStart: bucket.UTC(), BucketStart: bucket.UTC(),
RequestCount: requests, RequestCount: requests,
TokenConsumed: tokenConsumed, TokenConsumed: tokenConsumed,
SwitchCount: switchCount,
QPS: qps, QPS: qps,
TPS: tps, TPS: tps,
}) })
...@@ -385,6 +417,7 @@ func fillOpsThroughputBuckets(start, end time.Time, bucketSeconds int, points [] ...@@ -385,6 +417,7 @@ func fillOpsThroughputBuckets(start, end time.Time, bucketSeconds int, points []
BucketStart: cursor, BucketStart: cursor,
RequestCount: 0, RequestCount: 0,
TokenConsumed: 0, TokenConsumed: 0,
SwitchCount: 0,
QPS: 0, QPS: 0,
TPS: 0, TPS: 0,
}) })
......
...@@ -83,6 +83,9 @@ func TestAPIContracts(t *testing.T) { ...@@ -83,6 +83,9 @@ func TestAPIContracts(t *testing.T) {
"status": "active", "status": "active",
"ip_whitelist": null, "ip_whitelist": null,
"ip_blacklist": null, "ip_blacklist": null,
"quota": 0,
"quota_used": 0,
"expires_at": null,
"created_at": "2025-01-02T03:04:05Z", "created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z" "updated_at": "2025-01-02T03:04:05Z"
} }
...@@ -119,6 +122,9 @@ func TestAPIContracts(t *testing.T) { ...@@ -119,6 +122,9 @@ func TestAPIContracts(t *testing.T) {
"status": "active", "status": "active",
"ip_whitelist": null, "ip_whitelist": null,
"ip_blacklist": null, "ip_blacklist": null,
"quota": 0,
"quota_used": 0,
"expires_at": null,
"created_at": "2025-01-02T03:04:05Z", "created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z" "updated_at": "2025-01-02T03:04:05Z"
} }
...@@ -180,6 +186,7 @@ func TestAPIContracts(t *testing.T) { ...@@ -180,6 +186,7 @@ func TestAPIContracts(t *testing.T) {
"image_price_4k": null, "image_price_4k": null,
"claude_code_only": false, "claude_code_only": false,
"fallback_group_id": null, "fallback_group_id": null,
"fallback_group_id_on_invalid_request": null,
"created_at": "2025-01-02T03:04:05Z", "created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z" "updated_at": "2025-01-02T03:04:05Z"
} }
...@@ -601,7 +608,7 @@ func newContractDeps(t *testing.T) *contractDeps { ...@@ -601,7 +608,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingService := service.NewSettingService(settingRepo, cfg) settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil) adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, nil, nil) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil) adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil)
...@@ -1442,6 +1449,10 @@ func (r *stubApiKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ( ...@@ -1442,6 +1449,10 @@ func (r *stubApiKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) (
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubApiKeyRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) {
return 0, errors.New("not implemented")
}
type stubUsageLogRepo struct { type stubUsageLogRepo struct {
userLogs map[int64][]service.UsageLog userLogs map[int64][]service.UsageLog
} }
......
...@@ -70,7 +70,27 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti ...@@ -70,7 +70,27 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
// 检查API key是否激活 // 检查API key是否激活
if !apiKey.IsActive() { if !apiKey.IsActive() {
// Provide more specific error message based on status
switch apiKey.Status {
case service.StatusAPIKeyQuotaExhausted:
AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完")
case service.StatusAPIKeyExpired:
AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期")
default:
AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled") AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled")
}
return
}
// 检查API Key是否过期(即使状态是active,也要检查时间)
if apiKey.IsExpired() {
AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期")
return
}
// 检查API Key配额是否耗尽
if apiKey.IsQuotaExhausted() {
AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完")
return return
} }
......
...@@ -26,7 +26,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs ...@@ -26,7 +26,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
abortWithGoogleError(c, 400, "Query parameter api_key is deprecated. Use Authorization header or key instead.") abortWithGoogleError(c, 400, "Query parameter api_key is deprecated. Use Authorization header or key instead.")
return return
} }
apiKeyString := extractAPIKeyFromRequest(c) apiKeyString := extractAPIKeyForGoogle(c)
if apiKeyString == "" { if apiKeyString == "" {
abortWithGoogleError(c, 401, "API key is required") abortWithGoogleError(c, 401, "API key is required")
return return
...@@ -108,25 +108,38 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs ...@@ -108,25 +108,38 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
} }
} }
func extractAPIKeyFromRequest(c *gin.Context) string { // extractAPIKeyForGoogle extracts API key for Google/Gemini endpoints.
authHeader := c.GetHeader("Authorization") // Priority: x-goog-api-key > Authorization: Bearer > x-api-key > query key
if authHeader != "" { // This allows OpenClaw and other clients using Bearer auth to work with Gemini endpoints.
parts := strings.SplitN(authHeader, " ", 2) func extractAPIKeyForGoogle(c *gin.Context) string {
if len(parts) == 2 && parts[0] == "Bearer" && strings.TrimSpace(parts[1]) != "" { // 1) preferred: Gemini native header
return strings.TrimSpace(parts[1]) if k := strings.TrimSpace(c.GetHeader("x-goog-api-key")); k != "" {
return k
} }
// 2) fallback: Authorization: Bearer <key>
auth := strings.TrimSpace(c.GetHeader("Authorization"))
if auth != "" {
parts := strings.SplitN(auth, " ", 2)
if len(parts) == 2 && strings.EqualFold(parts[0], "Bearer") {
if k := strings.TrimSpace(parts[1]); k != "" {
return k
} }
if v := strings.TrimSpace(c.GetHeader("x-api-key")); v != "" {
return v
} }
if v := strings.TrimSpace(c.GetHeader("x-goog-api-key")); v != "" {
return v
} }
// 3) x-api-key header (backward compatibility)
if k := strings.TrimSpace(c.GetHeader("x-api-key")); k != "" {
return k
}
// 4) query parameter key (for specific paths)
if allowGoogleQueryKey(c.Request.URL.Path) { if allowGoogleQueryKey(c.Request.URL.Path) {
if v := strings.TrimSpace(c.Query("key")); v != "" { if v := strings.TrimSpace(c.Query("key")); v != "" {
return v return v
} }
} }
return "" return ""
} }
......
...@@ -75,6 +75,9 @@ func (f fakeAPIKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]s ...@@ -75,6 +75,9 @@ func (f fakeAPIKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]s
func (f fakeAPIKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) { func (f fakeAPIKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (f fakeAPIKeyRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) {
return 0, errors.New("not implemented")
}
type googleErrorResponse struct { type googleErrorResponse struct {
Error struct { Error struct {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment