Commit 2220fd18 authored by song's avatar song
Browse files

merge upstream main

parents 11ff73b5 df4c0adf
...@@ -33,7 +33,7 @@ func main() { ...@@ -33,7 +33,7 @@ func main() {
}() }()
userRepo := repository.NewUserRepository(client, sqlDB) userRepo := repository.NewUserRepository(client, sqlDB)
authService := service.NewAuthService(userRepo, cfg, nil, nil, nil, nil, nil) authService := service.NewAuthService(userRepo, nil, cfg, nil, nil, nil, nil, nil)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
......
...@@ -43,6 +43,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -43,6 +43,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
return nil, err return nil, err
} }
userRepository := repository.NewUserRepository(client, db) userRepository := repository.NewUserRepository(client, db)
redeemCodeRepository := repository.NewRedeemCodeRepository(client)
settingRepository := repository.NewSettingRepository(client) settingRepository := repository.NewSettingRepository(client)
settingService := service.NewSettingService(settingRepository, configConfig) settingService := service.NewSettingService(settingRepository, configConfig)
redisClient := repository.ProvideRedis(configConfig) redisClient := repository.ProvideRedis(configConfig)
...@@ -61,24 +62,23 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -61,24 +62,23 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig) apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig)
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService) authService := service.NewAuthService(userRepository, redeemCodeRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator) userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
redeemCache := repository.NewRedeemCache(redisClient)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
secretEncryptor, err := repository.NewAESEncryptor(configConfig) secretEncryptor, err := repository.NewAESEncryptor(configConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
totpCache := repository.NewTotpCache(redisClient) totpCache := repository.NewTotpCache(redisClient)
totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService) totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService)
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, totpService) authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService)
userHandler := handler.NewUserHandler(userService) userHandler := handler.NewUserHandler(userService)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db) usageLogRepository := repository.NewUsageLogRepository(client, db)
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator) usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
redeemCodeRepository := repository.NewRedeemCodeRepository(client)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
redeemCache := repository.NewRedeemCache(redisClient)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
redeemHandler := handler.NewRedeemHandler(redeemService) redeemHandler := handler.NewRedeemHandler(redeemService)
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService) subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
announcementRepository := repository.NewAnnouncementRepository(client) announcementRepository := repository.NewAnnouncementRepository(client)
......
...@@ -37,6 +37,7 @@ const ( ...@@ -37,6 +37,7 @@ const (
RedeemTypeBalance = "balance" RedeemTypeBalance = "balance"
RedeemTypeConcurrency = "concurrency" RedeemTypeConcurrency = "concurrency"
RedeemTypeSubscription = "subscription" RedeemTypeSubscription = "subscription"
RedeemTypeInvitation = "invitation"
) )
// PromoCode status constants // PromoCode status constants
......
...@@ -47,6 +47,8 @@ type CreateGroupRequest struct { ...@@ -47,6 +47,8 @@ type CreateGroupRequest struct {
MCPXMLInject *bool `json:"mcp_xml_inject"` MCPXMLInject *bool `json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用) // 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string `json:"supported_model_scopes"` SupportedModelScopes []string `json:"supported_model_scopes"`
// 从指定分组复制账号(创建后自动绑定)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
} }
// UpdateGroupRequest represents update group request // UpdateGroupRequest represents update group request
...@@ -74,6 +76,8 @@ type UpdateGroupRequest struct { ...@@ -74,6 +76,8 @@ type UpdateGroupRequest struct {
MCPXMLInject *bool `json:"mcp_xml_inject"` MCPXMLInject *bool `json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用) // 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string `json:"supported_model_scopes"` SupportedModelScopes *[]string `json:"supported_model_scopes"`
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
} }
// List handles listing all groups with pagination // List handles listing all groups with pagination
...@@ -182,6 +186,7 @@ func (h *GroupHandler) Create(c *gin.Context) { ...@@ -182,6 +186,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
ModelRoutingEnabled: req.ModelRoutingEnabled, ModelRoutingEnabled: req.ModelRoutingEnabled,
MCPXMLInject: req.MCPXMLInject, MCPXMLInject: req.MCPXMLInject,
SupportedModelScopes: req.SupportedModelScopes, SupportedModelScopes: req.SupportedModelScopes,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
}) })
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
...@@ -227,6 +232,7 @@ func (h *GroupHandler) Update(c *gin.Context) { ...@@ -227,6 +232,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
ModelRoutingEnabled: req.ModelRoutingEnabled, ModelRoutingEnabled: req.ModelRoutingEnabled,
MCPXMLInject: req.MCPXMLInject, MCPXMLInject: req.MCPXMLInject,
SupportedModelScopes: req.SupportedModelScopes, SupportedModelScopes: req.SupportedModelScopes,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
}) })
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
......
...@@ -29,7 +29,7 @@ func NewRedeemHandler(adminService service.AdminService) *RedeemHandler { ...@@ -29,7 +29,7 @@ func NewRedeemHandler(adminService service.AdminService) *RedeemHandler {
// GenerateRedeemCodesRequest represents generate redeem codes request // GenerateRedeemCodesRequest represents generate redeem codes request
type GenerateRedeemCodesRequest struct { type GenerateRedeemCodesRequest struct {
Count int `json:"count" binding:"required,min=1,max=100"` Count int `json:"count" binding:"required,min=1,max=100"`
Type string `json:"type" binding:"required,oneof=balance concurrency subscription"` Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"`
Value float64 `json:"value" binding:"min=0"` Value float64 `json:"value" binding:"min=0"`
GroupID *int64 `json:"group_id"` // 订阅类型必填 GroupID *int64 `json:"group_id"` // 订阅类型必填
ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // 订阅类型使用,默认30天,最大100年 ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // 订阅类型使用,默认30天,最大100年
......
...@@ -49,6 +49,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { ...@@ -49,6 +49,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
EmailVerifyEnabled: settings.EmailVerifyEnabled, EmailVerifyEnabled: settings.EmailVerifyEnabled,
PromoCodeEnabled: settings.PromoCodeEnabled, PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled, PasswordResetEnabled: settings.PasswordResetEnabled,
InvitationCodeEnabled: settings.InvitationCodeEnabled,
TotpEnabled: settings.TotpEnabled, TotpEnabled: settings.TotpEnabled,
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(), TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
SMTPHost: settings.SMTPHost, SMTPHost: settings.SMTPHost,
...@@ -94,11 +95,12 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { ...@@ -94,11 +95,12 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
// UpdateSettingsRequest 更新设置请求 // UpdateSettingsRequest 更新设置请求
type UpdateSettingsRequest struct { type UpdateSettingsRequest struct {
// 注册设置 // 注册设置
RegistrationEnabled bool `json:"registration_enabled"` RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"` EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"` PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"` PasswordResetEnabled bool `json:"password_reset_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
// 邮件服务设置 // 邮件服务设置
SMTPHost string `json:"smtp_host"` SMTPHost string `json:"smtp_host"`
...@@ -291,6 +293,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -291,6 +293,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
EmailVerifyEnabled: req.EmailVerifyEnabled, EmailVerifyEnabled: req.EmailVerifyEnabled,
PromoCodeEnabled: req.PromoCodeEnabled, PromoCodeEnabled: req.PromoCodeEnabled,
PasswordResetEnabled: req.PasswordResetEnabled, PasswordResetEnabled: req.PasswordResetEnabled,
InvitationCodeEnabled: req.InvitationCodeEnabled,
TotpEnabled: req.TotpEnabled, TotpEnabled: req.TotpEnabled,
SMTPHost: req.SMTPHost, SMTPHost: req.SMTPHost,
SMTPPort: req.SMTPPort, SMTPPort: req.SMTPPort,
...@@ -370,6 +373,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -370,6 +373,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled, EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
PromoCodeEnabled: updatedSettings.PromoCodeEnabled, PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
PasswordResetEnabled: updatedSettings.PasswordResetEnabled, PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
TotpEnabled: updatedSettings.TotpEnabled, TotpEnabled: updatedSettings.TotpEnabled,
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(), TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
SMTPHost: updatedSettings.SMTPHost, SMTPHost: updatedSettings.SMTPHost,
......
...@@ -15,23 +15,25 @@ import ( ...@@ -15,23 +15,25 @@ import (
// AuthHandler handles authentication-related requests // AuthHandler handles authentication-related requests
type AuthHandler struct { type AuthHandler struct {
cfg *config.Config cfg *config.Config
authService *service.AuthService authService *service.AuthService
userService *service.UserService userService *service.UserService
settingSvc *service.SettingService settingSvc *service.SettingService
promoService *service.PromoService promoService *service.PromoService
totpService *service.TotpService redeemService *service.RedeemService
totpService *service.TotpService
} }
// NewAuthHandler creates a new AuthHandler // NewAuthHandler creates a new AuthHandler
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService, totpService *service.TotpService) *AuthHandler { func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService, redeemService *service.RedeemService, totpService *service.TotpService) *AuthHandler {
return &AuthHandler{ return &AuthHandler{
cfg: cfg, cfg: cfg,
authService: authService, authService: authService,
userService: userService, userService: userService,
settingSvc: settingService, settingSvc: settingService,
promoService: promoService, promoService: promoService,
totpService: totpService, redeemService: redeemService,
totpService: totpService,
} }
} }
...@@ -41,7 +43,8 @@ type RegisterRequest struct { ...@@ -41,7 +43,8 @@ type RegisterRequest struct {
Password string `json:"password" binding:"required,min=6"` Password string `json:"password" binding:"required,min=6"`
VerifyCode string `json:"verify_code"` VerifyCode string `json:"verify_code"`
TurnstileToken string `json:"turnstile_token"` TurnstileToken string `json:"turnstile_token"`
PromoCode string `json:"promo_code"` // 注册优惠码 PromoCode string `json:"promo_code"` // 注册优惠码
InvitationCode string `json:"invitation_code"` // 邀请码
} }
// SendVerifyCodeRequest 发送验证码请求 // SendVerifyCodeRequest 发送验证码请求
...@@ -87,7 +90,7 @@ func (h *AuthHandler) Register(c *gin.Context) { ...@@ -87,7 +90,7 @@ func (h *AuthHandler) Register(c *gin.Context) {
} }
} }
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode) token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
...@@ -346,6 +349,66 @@ func (h *AuthHandler) ValidatePromoCode(c *gin.Context) { ...@@ -346,6 +349,66 @@ func (h *AuthHandler) ValidatePromoCode(c *gin.Context) {
}) })
} }
// ValidateInvitationCodeRequest 验证邀请码请求
type ValidateInvitationCodeRequest struct {
Code string `json:"code" binding:"required"`
}
// ValidateInvitationCodeResponse 验证邀请码响应
type ValidateInvitationCodeResponse struct {
Valid bool `json:"valid"`
ErrorCode string `json:"error_code,omitempty"`
}
// ValidateInvitationCode 验证邀请码(公开接口,注册前调用)
// POST /api/v1/auth/validate-invitation-code
func (h *AuthHandler) ValidateInvitationCode(c *gin.Context) {
// 检查邀请码功能是否启用
if h.settingSvc == nil || !h.settingSvc.IsInvitationCodeEnabled(c.Request.Context()) {
response.Success(c, ValidateInvitationCodeResponse{
Valid: false,
ErrorCode: "INVITATION_CODE_DISABLED",
})
return
}
var req ValidateInvitationCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// 验证邀请码
redeemCode, err := h.redeemService.GetByCode(c.Request.Context(), req.Code)
if err != nil {
response.Success(c, ValidateInvitationCodeResponse{
Valid: false,
ErrorCode: "INVITATION_CODE_NOT_FOUND",
})
return
}
// 检查类型和状态
if redeemCode.Type != service.RedeemTypeInvitation {
response.Success(c, ValidateInvitationCodeResponse{
Valid: false,
ErrorCode: "INVITATION_CODE_INVALID",
})
return
}
if redeemCode.Status != service.StatusUnused {
response.Success(c, ValidateInvitationCodeResponse{
Valid: false,
ErrorCode: "INVITATION_CODE_USED",
})
return
}
response.Success(c, ValidateInvitationCodeResponse{
Valid: true,
})
}
// ForgotPasswordRequest 忘记密码请求 // ForgotPasswordRequest 忘记密码请求
type ForgotPasswordRequest struct { type ForgotPasswordRequest struct {
Email string `json:"email" binding:"required,email"` Email string `json:"email" binding:"required,email"`
......
...@@ -381,6 +381,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog { ...@@ -381,6 +381,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
AccountID: l.AccountID, AccountID: l.AccountID,
RequestID: l.RequestID, RequestID: l.RequestID,
Model: l.Model, Model: l.Model,
ReasoningEffort: l.ReasoningEffort,
GroupID: l.GroupID, GroupID: l.GroupID,
SubscriptionID: l.SubscriptionID, SubscriptionID: l.SubscriptionID,
InputTokens: l.InputTokens, InputTokens: l.InputTokens,
......
...@@ -6,6 +6,7 @@ type SystemSettings struct { ...@@ -6,6 +6,7 @@ type SystemSettings struct {
EmailVerifyEnabled bool `json:"email_verify_enabled"` EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"` PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"` PasswordResetEnabled bool `json:"password_reset_enabled"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置 TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
...@@ -63,6 +64,7 @@ type PublicSettings struct { ...@@ -63,6 +64,7 @@ type PublicSettings struct {
EmailVerifyEnabled bool `json:"email_verify_enabled"` EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"` PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"` PasswordResetEnabled bool `json:"password_reset_enabled"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
TurnstileEnabled bool `json:"turnstile_enabled"` TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"` TurnstileSiteKey string `json:"turnstile_site_key"`
......
...@@ -237,6 +237,9 @@ type UsageLog struct { ...@@ -237,6 +237,9 @@ type UsageLog struct {
AccountID int64 `json:"account_id"` AccountID int64 `json:"account_id"`
RequestID string `json:"request_id"` RequestID string `json:"request_id"`
Model string `json:"model"` Model string `json:"model"`
// ReasoningEffort is the request's reasoning effort level (OpenAI Responses API).
// nil means not provided / not applicable.
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
GroupID *int64 `json:"group_id"` GroupID *int64 `json:"group_id"`
SubscriptionID *int64 `json:"subscription_id"` SubscriptionID *int64 `json:"subscription_id"`
......
...@@ -596,7 +596,6 @@ func cloneAPIKeyWithGroup(apiKey *service.APIKey, group *service.Group) *service ...@@ -596,7 +596,6 @@ func cloneAPIKeyWithGroup(apiKey *service.APIKey, group *service.Group) *service
cloned.Group = group cloned.Group = group
return &cloned return &cloned
} }
// Usage handles getting account balance and usage statistics for CC Switch integration // Usage handles getting account balance and usage statistics for CC Switch integration
// GET /v1/usage // GET /v1/usage
func (h *GatewayHandler) Usage(c *gin.Context) { func (h *GatewayHandler) Usage(c *gin.Context) {
...@@ -849,6 +848,9 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { ...@@ -849,6 +848,9 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
return return
} }
// 检查是否为 Claude Code 客户端,设置到 context 中
SetClaudeCodeClientContext(c, body)
setOpsRequestContext(c, "", false, body) setOpsRequestContext(c, "", false, body)
parsedReq, err := service.ParseGatewayRequest(body) parsedReq, err := service.ParseGatewayRequest(body)
......
...@@ -371,18 +371,21 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -371,18 +371,21 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
userAgent := c.GetHeader("User-Agent") userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c) clientIP := ip.GetClientIP(c)
// 6) record usage async // 6) record usage async (Gemini 使用长上下文双倍计费)
go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string) { go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result, if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
APIKey: apiKey, Result: result,
User: apiKey.User, APIKey: apiKey,
Account: usedAccount, User: apiKey.User,
Subscription: subscription, Account: usedAccount,
UserAgent: ua, Subscription: subscription,
IPAddress: ip, UserAgent: ua,
IPAddress: ip,
LongContextThreshold: 200000, // Gemini 200K 阈值
LongContextMultiplier: 2.0, // 超出部分双倍计费
}); err != nil { }); err != nil {
log.Printf("Record usage failed: %v", err) log.Printf("Record usage failed: %v", err)
} }
......
...@@ -36,6 +36,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { ...@@ -36,6 +36,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
EmailVerifyEnabled: settings.EmailVerifyEnabled, EmailVerifyEnabled: settings.EmailVerifyEnabled,
PromoCodeEnabled: settings.PromoCodeEnabled, PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled, PasswordResetEnabled: settings.PasswordResetEnabled,
InvitationCodeEnabled: settings.InvitationCodeEnabled,
TotpEnabled: settings.TotpEnabled, TotpEnabled: settings.TotpEnabled,
TurnstileEnabled: settings.TurnstileEnabled, TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey, TurnstileSiteKey: settings.TurnstileSiteKey,
......
...@@ -81,7 +81,6 @@ func ForwardBaseURLs() []string { ...@@ -81,7 +81,6 @@ func ForwardBaseURLs() []string {
} }
return reordered return reordered
} }
// URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复和动态优先级) // URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复和动态优先级)
type URLAvailability struct { type URLAvailability struct {
mu sync.RWMutex mu sync.RWMutex
......
...@@ -9,11 +9,26 @@ const ( ...@@ -9,11 +9,26 @@ const (
BetaClaudeCode = "claude-code-20250219" BetaClaudeCode = "claude-code-20250219"
BetaInterleavedThinking = "interleaved-thinking-2025-05-14" BetaInterleavedThinking = "interleaved-thinking-2025-05-14"
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14" BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
BetaTokenCounting = "token-counting-2024-11-01"
) )
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header // DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
// MessageBetaHeaderNoTools /v1/messages 在无工具时的 beta header
//
// NOTE: Claude Code OAuth credentials are scoped to Claude Code. When we "mimic"
// Claude Code for non-Claude-Code clients, we must include the claude-code beta
// even if the request doesn't use tools, otherwise upstream may reject the
// request as a non-Claude-Code API request.
const MessageBetaHeaderNoTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking
// MessageBetaHeaderWithTools /v1/messages 在有工具时的 beta header
const MessageBetaHeaderWithTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking
// CountTokensBetaHeader count_tokens 请求使用的 anthropic-beta header
const CountTokensBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaTokenCounting
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta) // HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta)
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
...@@ -25,15 +40,17 @@ const APIKeyHaikuBetaHeader = BetaInterleavedThinking ...@@ -25,15 +40,17 @@ const APIKeyHaikuBetaHeader = BetaInterleavedThinking
// DefaultHeaders 是 Claude Code 客户端默认请求头。 // DefaultHeaders 是 Claude Code 客户端默认请求头。
var DefaultHeaders = map[string]string{ var DefaultHeaders = map[string]string{
"User-Agent": "claude-cli/2.0.62 (external, cli)", // Keep these in sync with recent Claude CLI traffic to reduce the chance
// that Claude Code-scoped OAuth credentials are rejected as "non-CLI" usage.
"User-Agent": "claude-cli/2.1.22 (external, cli)",
"X-Stainless-Lang": "js", "X-Stainless-Lang": "js",
"X-Stainless-Package-Version": "0.52.0", "X-Stainless-Package-Version": "0.70.0",
"X-Stainless-OS": "Linux", "X-Stainless-OS": "Linux",
"X-Stainless-Arch": "x64", "X-Stainless-Arch": "arm64",
"X-Stainless-Runtime": "node", "X-Stainless-Runtime": "node",
"X-Stainless-Runtime-Version": "v22.14.0", "X-Stainless-Runtime-Version": "v24.13.0",
"X-Stainless-Retry-Count": "0", "X-Stainless-Retry-Count": "0",
"X-Stainless-Timeout": "60", "X-Stainless-Timeout": "600",
"X-App": "cli", "X-App": "cli",
"Anthropic-Dangerous-Direct-Browser-Access": "true", "Anthropic-Dangerous-Direct-Browser-Access": "true",
} }
...@@ -79,3 +96,39 @@ func DefaultModelIDs() []string { ...@@ -79,3 +96,39 @@ func DefaultModelIDs() []string {
// DefaultTestModel 测试时使用的默认模型 // DefaultTestModel 测试时使用的默认模型
const DefaultTestModel = "claude-sonnet-4-5-20250929" const DefaultTestModel = "claude-sonnet-4-5-20250929"
// ModelIDOverrides Claude OAuth 请求需要的模型 ID 映射
var ModelIDOverrides = map[string]string{
"claude-sonnet-4-5": "claude-sonnet-4-5-20250929",
"claude-opus-4-5": "claude-opus-4-5-20251101",
"claude-haiku-4-5": "claude-haiku-4-5-20251001",
}
// ModelIDReverseOverrides 用于将上游模型 ID 还原为短名
var ModelIDReverseOverrides = map[string]string{
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
"claude-opus-4-5-20251101": "claude-opus-4-5",
"claude-haiku-4-5-20251001": "claude-haiku-4-5",
}
// NormalizeModelID 根据 Claude OAuth 规则映射模型
func NormalizeModelID(id string) string {
if id == "" {
return id
}
if mapped, ok := ModelIDOverrides[id]; ok {
return mapped
}
return id
}
// DenormalizeModelID 将上游模型 ID 转换为短名
func DenormalizeModelID(id string) string {
if id == "" {
return id
}
if mapped, ok := ModelIDReverseOverrides[id]; ok {
return mapped
}
return id
}
...@@ -439,3 +439,61 @@ func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int6 ...@@ -439,3 +439,61 @@ func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int6
return counts, nil return counts, nil
} }
// GetAccountIDsByGroupIDs 获取多个分组的所有账号 ID(去重)
func (r *groupRepository) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
if len(groupIDs) == 0 {
return nil, nil
}
rows, err := r.sql.QueryContext(
ctx,
"SELECT DISTINCT account_id FROM account_groups WHERE group_id = ANY($1) ORDER BY account_id",
pq.Array(groupIDs),
)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
var accountIDs []int64
for rows.Next() {
var accountID int64
if err := rows.Scan(&accountID); err != nil {
return nil, err
}
accountIDs = append(accountIDs, accountID)
}
if err := rows.Err(); err != nil {
return nil, err
}
return accountIDs, nil
}
// BindAccountsToGroup 将多个账号绑定到指定分组(批量插入,忽略已存在的绑定)
func (r *groupRepository) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
if len(accountIDs) == 0 {
return nil
}
// 使用 INSERT ... ON CONFLICT DO NOTHING 忽略已存在的绑定
_, err := r.sql.ExecContext(
ctx,
`INSERT INTO account_groups (account_id, group_id, priority, created_at)
SELECT unnest($1::bigint[]), $2, 50, NOW()
ON CONFLICT (account_id, group_id) DO NOTHING`,
pq.Array(accountIDs),
groupID,
)
if err != nil {
return err
}
// 发送调度器事件
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupID, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue bind accounts to group failed: group=%d err=%v", groupID, err)
}
return nil
}
...@@ -22,7 +22,7 @@ import ( ...@@ -22,7 +22,7 @@ import (
"github.com/lib/pq" "github.com/lib/pq"
) )
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, created_at" const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, reasoning_effort, created_at"
type usageLogRepository struct { type usageLogRepository struct {
client *dbent.Client client *dbent.Client
...@@ -111,21 +111,22 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ...@@ -111,21 +111,22 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
duration_ms, duration_ms,
first_token_ms, first_token_ms,
user_agent, user_agent,
ip_address, ip_address,
image_count, image_count,
image_size, image_size,
created_at reasoning_effort,
) VALUES ( created_at
$1, $2, $3, $4, $5, ) VALUES (
$6, $7, $1, $2, $3, $4, $5,
$8, $9, $10, $11, $6, $7,
$12, $13, $8, $9, $10, $11,
$14, $15, $16, $17, $18, $19, $12, $13,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30 $14, $15, $16, $17, $18, $19,
) $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31
ON CONFLICT (request_id, api_key_id) DO NOTHING )
RETURNING id, created_at ON CONFLICT (request_id, api_key_id) DO NOTHING
` RETURNING id, created_at
`
groupID := nullInt64(log.GroupID) groupID := nullInt64(log.GroupID)
subscriptionID := nullInt64(log.SubscriptionID) subscriptionID := nullInt64(log.SubscriptionID)
...@@ -134,6 +135,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ...@@ -134,6 +135,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
userAgent := nullString(log.UserAgent) userAgent := nullString(log.UserAgent)
ipAddress := nullString(log.IPAddress) ipAddress := nullString(log.IPAddress)
imageSize := nullString(log.ImageSize) imageSize := nullString(log.ImageSize)
reasoningEffort := nullString(log.ReasoningEffort)
var requestIDArg any var requestIDArg any
if requestID != "" { if requestID != "" {
...@@ -170,6 +172,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ...@@ -170,6 +172,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
ipAddress, ipAddress,
log.ImageCount, log.ImageCount,
imageSize, imageSize,
reasoningEffort,
createdAt, createdAt,
} }
if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil { if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil {
...@@ -2090,6 +2093,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -2090,6 +2093,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
ipAddress sql.NullString ipAddress sql.NullString
imageCount int imageCount int
imageSize sql.NullString imageSize sql.NullString
reasoningEffort sql.NullString
createdAt time.Time createdAt time.Time
) )
...@@ -2124,6 +2128,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -2124,6 +2128,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&ipAddress, &ipAddress,
&imageCount, &imageCount,
&imageSize, &imageSize,
&reasoningEffort,
&createdAt, &createdAt,
); err != nil { ); err != nil {
return nil, err return nil, err
...@@ -2183,6 +2188,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -2183,6 +2188,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if imageSize.Valid { if imageSize.Valid {
log.ImageSize = &imageSize.String log.ImageSize = &imageSize.String
} }
if reasoningEffort.Valid {
log.ReasoningEffort = &reasoningEffort.String
}
return log, nil return log, nil
} }
......
...@@ -488,6 +488,7 @@ func TestAPIContracts(t *testing.T) { ...@@ -488,6 +488,7 @@ func TestAPIContracts(t *testing.T) {
"fallback_model_openai": "gpt-4o", "fallback_model_openai": "gpt-4o",
"enable_identity_patch": true, "enable_identity_patch": true,
"identity_patch_prompt": "", "identity_patch_prompt": "",
"invitation_code_enabled": false,
"home_content": "", "home_content": "",
"hide_ccs_import_button": false, "hide_ccs_import_button": false,
"purchase_subscription_enabled": false, "purchase_subscription_enabled": false,
...@@ -599,8 +600,8 @@ func newContractDeps(t *testing.T) *contractDeps { ...@@ -599,8 +600,8 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo() settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg) settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil) adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, nil) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil) adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil)
...@@ -880,6 +881,14 @@ func (stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID i ...@@ -880,6 +881,14 @@ func (stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID i
return 0, errors.New("not implemented") return 0, errors.New("not implemented")
} }
func (stubGroupRepo) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
return errors.New("not implemented")
}
func (stubGroupRepo) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
return nil, errors.New("not implemented")
}
type stubAccountRepo struct { type stubAccountRepo struct {
bulkUpdateIDs []int64 bulkUpdateIDs []int64
} }
......
...@@ -32,6 +32,10 @@ func RegisterAuthRoutes( ...@@ -32,6 +32,10 @@ func RegisterAuthRoutes(
auth.POST("/validate-promo-code", rateLimiter.LimitWithOptions("validate-promo", 10, time.Minute, middleware.RateLimitOptions{ auth.POST("/validate-promo-code", rateLimiter.LimitWithOptions("validate-promo", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose, FailureMode: middleware.RateLimitFailClose,
}), h.Auth.ValidatePromoCode) }), h.Auth.ValidatePromoCode)
// 邀请码验证接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close)
auth.POST("/validate-invitation-code", rateLimiter.LimitWithOptions("validate-invitation", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.ValidateInvitationCode)
// 忘记密码接口添加速率限制:每分钟最多 5 次(Redis 故障时 fail-close) // 忘记密码接口添加速率限制:每分钟最多 5 次(Redis 故障时 fail-close)
auth.POST("/forgot-password", rateLimiter.LimitWithOptions("forgot-password", 5, time.Minute, middleware.RateLimitOptions{ auth.POST("/forgot-password", rateLimiter.LimitWithOptions("forgot-password", 5, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose, FailureMode: middleware.RateLimitFailClose,
......
...@@ -410,6 +410,22 @@ func (a *Account) GetExtraString(key string) string { ...@@ -410,6 +410,22 @@ func (a *Account) GetExtraString(key string) string {
return "" return ""
} }
func (a *Account) GetClaudeUserID() string {
if v := strings.TrimSpace(a.GetExtraString("claude_user_id")); v != "" {
return v
}
if v := strings.TrimSpace(a.GetExtraString("anthropic_user_id")); v != "" {
return v
}
if v := strings.TrimSpace(a.GetCredential("claude_user_id")); v != "" {
return v
}
if v := strings.TrimSpace(a.GetCredential("anthropic_user_id")); v != "" {
return v
}
return ""
}
func (a *Account) IsCustomErrorCodesEnabled() bool { func (a *Account) IsCustomErrorCodesEnabled() bool {
if a.Type != AccountTypeAPIKey || a.Credentials == nil { if a.Type != AccountTypeAPIKey || a.Credentials == nil {
return false return false
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment