"frontend/src/views/vscode:/vscode.git/clone" did not exist on "de2cab2f80da2d492cab450856a9bcd856a5f80f"
Commit a9880ee7 authored by erio's avatar erio
Browse files

fix: round-2 audit fixes — security, code quality, and UI improvements

Security (HIGH):
- Normalize all Redis cache keys to lowercase (verifyCode, passwordReset)
- Fix verify code TTL renewal on failed attempts: use remaining TTL via
  ExpiresAt field instead of resetting to full 15-minute window
- Add 3 missing fields to diffSettings audit log (promo_code, invitation_code,
  custom_endpoints)

Code quality (MEDIUM):
- Extract filterVerifiedEmails shared helper (balance_notify_service.go)
- Add Pricing array non-empty validation for channel pricing rules
- Add platform token semantics comment in gateway_service.go
- Complete validatePlanPatch test coverage (+10 test cases)
- Replace string types with QuotaThresholdType/QuotaResetMode across frontend
- Remove duplicate getPlatformTextColor/getRateBadgeClass in ChannelsView
- Return EMAIL_NOT_FOUND error on RemoveNotifyEmail miss

UI improvements:
- Reorder cost tooltip: user billing above separator, account billing below
- Add NaN guard to accountBilled function
- Move timezone selector inline into reset-mode row (no longer standalone)
parent 74f8a30f
...@@ -357,6 +357,11 @@ func (h *ChannelHandler) Create(c *gin.Context) { ...@@ -357,6 +357,11 @@ func (h *ChannelHandler) Create(c *gin.Context) {
fmt.Sprintf("pricing rule #%d must have at least one group or account", i+1))) fmt.Sprintf("pricing rule #%d must have at least one group or account", i+1)))
return return
} }
if len(r.Pricing) == 0 {
response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_PRICING",
fmt.Sprintf("pricing rule #%d must have at least one pricing entry", i+1)))
return
}
rule := accountStatsPricingRuleRequestToService(r) rule := accountStatsPricingRuleRequestToService(r)
rule.SortOrder = i rule.SortOrder = i
statsRules = append(statsRules, rule) statsRules = append(statsRules, rule)
...@@ -420,6 +425,11 @@ func (h *ChannelHandler) Update(c *gin.Context) { ...@@ -420,6 +425,11 @@ func (h *ChannelHandler) Update(c *gin.Context) {
fmt.Sprintf("pricing rule #%d must have at least one group or account", i+1))) fmt.Sprintf("pricing rule #%d must have at least one group or account", i+1)))
return return
} }
if len(r.Pricing) == 0 {
response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_PRICING",
fmt.Sprintf("pricing rule #%d must have at least one pricing entry", i+1)))
return
}
rule := accountStatsPricingRuleRequestToService(r) rule := accountStatsPricingRuleRequestToService(r)
rule.SortOrder = i rule.SortOrder = i
statsRules = append(statsRules, rule) statsRules = append(statsRules, rule)
......
...@@ -1138,6 +1138,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, ...@@ -1138,6 +1138,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if !equalStringSlice(before.RegistrationEmailSuffixWhitelist, after.RegistrationEmailSuffixWhitelist) { if !equalStringSlice(before.RegistrationEmailSuffixWhitelist, after.RegistrationEmailSuffixWhitelist) {
changed = append(changed, "registration_email_suffix_whitelist") changed = append(changed, "registration_email_suffix_whitelist")
} }
if before.PromoCodeEnabled != after.PromoCodeEnabled {
changed = append(changed, "promo_code_enabled")
}
if before.InvitationCodeEnabled != after.InvitationCodeEnabled {
changed = append(changed, "invitation_code_enabled")
}
if before.PasswordResetEnabled != after.PasswordResetEnabled { if before.PasswordResetEnabled != after.PasswordResetEnabled {
changed = append(changed, "password_reset_enabled") changed = append(changed, "password_reset_enabled")
} }
...@@ -1348,6 +1354,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, ...@@ -1348,6 +1354,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.CustomMenuItems != after.CustomMenuItems { if before.CustomMenuItems != after.CustomMenuItems {
changed = append(changed, "custom_menu_items") changed = append(changed, "custom_menu_items")
} }
if before.CustomEndpoints != after.CustomEndpoints {
changed = append(changed, "custom_endpoints")
}
if before.EnableFingerprintUnification != after.EnableFingerprintUnification { if before.EnableFingerprintUnification != after.EnableFingerprintUnification {
changed = append(changed, "enable_fingerprint_unification") changed = append(changed, "enable_fingerprint_unification")
} }
......
...@@ -20,8 +20,9 @@ const ( ...@@ -20,8 +20,9 @@ const (
) )
// verifyCodeKey generates the Redis key for email verification code. // verifyCodeKey generates the Redis key for email verification code.
// Email is lowercased for case-insensitive consistency.
func verifyCodeKey(email string) string { func verifyCodeKey(email string) string {
return verifyCodeKeyPrefix + email return verifyCodeKeyPrefix + strings.ToLower(email)
} }
// notifyVerifyKey generates the Redis key for notify email verification code. // notifyVerifyKey generates the Redis key for notify email verification code.
...@@ -33,12 +34,12 @@ func notifyVerifyKey(email string) string { ...@@ -33,12 +34,12 @@ func notifyVerifyKey(email string) string {
// passwordResetKey generates the Redis key for password reset token. // passwordResetKey generates the Redis key for password reset token.
func passwordResetKey(email string) string { func passwordResetKey(email string) string {
return passwordResetKeyPrefix + email return passwordResetKeyPrefix + strings.ToLower(email)
} }
// passwordResetSentAtKey generates the Redis key for password reset email sent timestamp. // passwordResetSentAtKey generates the Redis key for password reset email sent timestamp.
func passwordResetSentAtKey(email string) string { func passwordResetSentAtKey(email string) string {
return passwordResetSentAtKeyPrefix + email return passwordResetSentAtKeyPrefix + strings.ToLower(email)
} }
type emailCache struct { type emailCache struct {
......
...@@ -283,24 +283,7 @@ func (s *BalanceNotifyService) getAccountQuotaNotifyEmails(ctx context.Context) ...@@ -283,24 +283,7 @@ func (s *BalanceNotifyService) getAccountQuotaNotifyEmails(ctx context.Context)
return nil return nil
} }
var recipients []string return filterVerifiedEmails(entries)
seen := make(map[string]bool)
for _, entry := range entries {
if entry.Disabled || !entry.Verified {
continue
}
email := strings.TrimSpace(entry.Email)
if email == "" {
continue
}
lower := strings.ToLower(email)
if seen[lower] {
continue
}
seen[lower] = true
recipients = append(recipients, email)
}
return recipients
} }
// getSiteName reads site name from settings with fallback. // getSiteName reads site name from settings with fallback.
...@@ -312,13 +295,11 @@ func (s *BalanceNotifyService) getSiteName(ctx context.Context) string { ...@@ -312,13 +295,11 @@ func (s *BalanceNotifyService) getSiteName(ctx context.Context) string {
return name return name
} }
// collectBalanceNotifyRecipients returns verified, non-disabled email recipients. // filterVerifiedEmails returns deduplicated, non-disabled, verified emails.
// Only emails with verified=true and disabled=false are included. func filterVerifiedEmails(entries []NotifyEmailEntry) []string {
func (s *BalanceNotifyService) collectBalanceNotifyRecipients(user *User) []string {
var recipients []string var recipients []string
seen := make(map[string]bool) seen := make(map[string]bool)
for _, entry := range entries {
for _, entry := range user.BalanceNotifyExtraEmails {
if entry.Disabled || !entry.Verified { if entry.Disabled || !entry.Verified {
continue continue
} }
...@@ -333,10 +314,15 @@ func (s *BalanceNotifyService) collectBalanceNotifyRecipients(user *User) []stri ...@@ -333,10 +314,15 @@ func (s *BalanceNotifyService) collectBalanceNotifyRecipients(user *User) []stri
seen[lower] = true seen[lower] = true
recipients = append(recipients, email) recipients = append(recipients, email)
} }
return recipients return recipients
} }
// collectBalanceNotifyRecipients returns verified, non-disabled email recipients.
// Only emails with verified=true and disabled=false are included.
func (s *BalanceNotifyService) collectBalanceNotifyRecipients(user *User) []string {
return filterVerifiedEmails(user.BalanceNotifyExtraEmails)
}
// sendEmails sends an email to all recipients with shared timeout and error logging. // sendEmails sends an email to all recipients with shared timeout and error logging.
func (s *BalanceNotifyService) sendEmails(recipients []string, subject, body string, logAttrs ...any) { func (s *BalanceNotifyService) sendEmails(recipients []string, subject, body string, logAttrs ...any) {
if len(recipients) == 0 { if len(recipients) == 0 {
......
...@@ -55,6 +55,7 @@ type VerificationCodeData struct { ...@@ -55,6 +55,7 @@ type VerificationCodeData struct {
Code string Code string
Attempts int Attempts int
CreatedAt time.Time CreatedAt time.Time
ExpiresAt time.Time // absolute expiry; used to preserve remaining TTL when updating attempts
} }
// PasswordResetTokenData represents password reset token data // PasswordResetTokenData represents password reset token data
...@@ -263,6 +264,7 @@ func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName strin ...@@ -263,6 +264,7 @@ func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName strin
Code: code, Code: code,
Attempts: 0, Attempts: 0,
CreatedAt: time.Now(), CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(verifyCodeTTL),
} }
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil { if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
return fmt.Errorf("save verify code: %w", err) return fmt.Errorf("save verify code: %w", err)
...@@ -295,7 +297,11 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error ...@@ -295,7 +297,11 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
// 验证码不匹配 (constant-time comparison to prevent timing attacks) // 验证码不匹配 (constant-time comparison to prevent timing attacks)
if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 { if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 {
data.Attempts++ data.Attempts++
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil { remaining := time.Until(data.ExpiresAt)
if remaining <= 0 {
return ErrInvalidVerifyCode
}
if err := s.cache.SetVerificationCode(ctx, email, data, remaining); err != nil {
slog.Error("failed to update verification attempt count", "email", email, "error", err) slog.Error("failed to update verification attempt count", "email", email, "error", err)
} }
if data.Attempts >= maxVerifyCodeAttempts { if data.Attempts >= maxVerifyCodeAttempts {
......
...@@ -1194,12 +1194,20 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context ...@@ -1194,12 +1194,20 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
// anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户) // anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
// 注意:强制平台模式不走混合调度 // 注意:强制平台模式不走混合调度
if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform { if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform {
return s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) account, err := s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
if err != nil {
return nil, err
}
return s.hydrateSelectedAccount(ctx, account)
} }
// antigravity 分组、强制平台模式或无分组使用单平台选择 // antigravity 分组、强制平台模式或无分组使用单平台选择
// 注意:强制平台模式也必须遵守分组限制,不再回退到全平台查询 // 注意:强制平台模式也必须遵守分组限制,不再回退到全平台查询
return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) account, err := s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
if err != nil {
return nil, err
}
return s.hydrateSelectedAccount(ctx, account)
} }
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan. // SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
...@@ -1275,11 +1283,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1275,11 +1283,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
localExcluded[account.ID] = struct{}{} // 排除此账号 localExcluded[account.ID] = struct{}{} // 排除此账号
continue // 重新选择 continue // 重新选择
} }
return &AccountSelectionResult{ return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil)
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
} }
// 对于等待计划的情况,也需要先检查会话限制 // 对于等待计划的情况,也需要先检查会话限制
...@@ -1291,26 +1295,20 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1291,26 +1295,20 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil { if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID) waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
if waitingCount < cfg.StickySessionMaxWaiting { if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{ return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{
Account: account, AccountID: account.ID,
WaitPlan: &AccountWaitPlan{ MaxConcurrency: account.Concurrency,
AccountID: account.ID, Timeout: cfg.StickySessionWaitTimeout,
MaxConcurrency: account.Concurrency, MaxWaiting: cfg.StickySessionMaxWaiting,
Timeout: cfg.StickySessionWaitTimeout, })
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
} }
} }
return &AccountSelectionResult{ return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{
Account: account, AccountID: account.ID,
WaitPlan: &AccountWaitPlan{ MaxConcurrency: account.Concurrency,
AccountID: account.ID, Timeout: cfg.FallbackWaitTimeout,
MaxConcurrency: account.Concurrency, MaxWaiting: cfg.FallbackMaxWaiting,
Timeout: cfg.FallbackWaitTimeout, })
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
} }
} }
...@@ -1433,53 +1431,76 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1433,53 +1431,76 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) { if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) {
// 粘性账号在路由列表中,优先使用 // 粘性账号在路由列表中,优先使用
if stickyAccount, ok := accountByID[stickyAccountID]; ok { if stickyAccount, ok := accountByID[stickyAccountID]; ok {
if s.isAccountSchedulableForSelection(stickyAccount) && var stickyCacheMissReason string
gatePass := s.isAccountSchedulableForSelection(stickyAccount) &&
s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) && s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) &&
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) &&
s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) && s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) &&
s.isAccountSchedulableForQuota(stickyAccount) && s.isAccountSchedulableForQuota(stickyAccount) &&
s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) && s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true)
rpmPass := gatePass && s.isAccountSchedulableForRPM(ctx, stickyAccount, true)
s.isAccountSchedulableForRPM(ctx, stickyAccount, true) { // 粘性会话窗口费用+RPM 检查 if rpmPass { // 粘性会话窗口费用+RPM 检查
result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
// 会话数量限制检查 // 会话数量限制检查
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) { if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
result.ReleaseFunc() // 释放槽位 result.ReleaseFunc() // 释放槽位
stickyCacheMissReason = "session_limit"
// 继续到负载感知选择 // 继续到负载感知选择
} else { } else {
if s.debugModelRoutingEnabled() { if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID) logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID)
} }
return &AccountSelectionResult{ return s.newSelectionResult(ctx, stickyAccount, true, result.ReleaseFunc, nil)
Account: stickyAccount,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
} }
} }
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID) if stickyCacheMissReason == "" {
if waitingCount < cfg.StickySessionMaxWaiting { waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID)
// 会话数量限制检查(等待计划也需要占用会话配额) if waitingCount < cfg.StickySessionMaxWaiting {
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) { // 会话数量限制检查(等待计划也需要占用会话配额)
// 会话限制已满,继续到负载感知选择 if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
stickyCacheMissReason = "session_limit"
// 会话限制已满,继续到负载感知选择
} else {
return &AccountSelectionResult{
Account: stickyAccount,
WaitPlan: &AccountWaitPlan{
AccountID: stickyAccountID,
MaxConcurrency: stickyAccount.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
} else { } else {
return &AccountSelectionResult{ stickyCacheMissReason = "wait_queue_full"
Account: stickyAccount,
WaitPlan: &AccountWaitPlan{
AccountID: stickyAccountID,
MaxConcurrency: stickyAccount.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
} }
} }
// 粘性账号槽位满且等待队列已满,继续使用负载感知选择 // 粘性账号槽位满且等待队列已满,继续使用负载感知选择
} else if !gatePass {
stickyCacheMissReason = "gate_check"
} else {
stickyCacheMissReason = "rpm_red"
}
// 记录粘性缓存未命中的结构化日志
if stickyCacheMissReason != "" {
baseRPM := stickyAccount.GetBaseRPM()
var currentRPM int
if count, ok := rpmFromPrefetchContext(ctx, stickyAccount.ID); ok {
currentRPM = count
}
logger.LegacyPrintf("service.gateway", "[StickyCacheMiss] reason=%s account_id=%d session=%s current_rpm=%d base_rpm=%d",
stickyCacheMissReason, stickyAccountID, shortSessionHash(sessionHash), currentRPM, baseRPM)
} }
} else { } else {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
logger.LegacyPrintf("service.gateway", "[StickyCacheMiss] reason=account_cleared account_id=%d session=%s current_rpm=0 base_rpm=0",
stickyAccountID, shortSessionHash(sessionHash))
} }
} }
} }
...@@ -1544,11 +1565,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1544,11 +1565,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if s.debugModelRoutingEnabled() { if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID)
} }
return &AccountSelectionResult{ return s.newSelectionResult(ctx, item.account, true, result.ReleaseFunc, nil)
Account: item.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
} }
} }
...@@ -1561,15 +1578,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1561,15 +1578,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if s.debugModelRoutingEnabled() { if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID)
} }
return &AccountSelectionResult{ return s.newSelectionResult(ctx, item.account, false, nil, &AccountWaitPlan{
Account: item.account, AccountID: item.account.ID,
WaitPlan: &AccountWaitPlan{ MaxConcurrency: item.account.Concurrency,
AccountID: item.account.ID, Timeout: cfg.StickySessionWaitTimeout,
MaxConcurrency: item.account.Concurrency, MaxWaiting: cfg.StickySessionMaxWaiting,
Timeout: cfg.StickySessionWaitTimeout, })
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
} }
// 所有路由账号会话限制都已满,继续到 Layer 2 回退 // 所有路由账号会话限制都已满,继续到 Layer 2 回退
} }
...@@ -1603,11 +1617,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1603,11 +1617,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if !s.checkAndRegisterSession(ctx, account, sessionHash) { if !s.checkAndRegisterSession(ctx, account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续到 Layer 2 result.ReleaseFunc() // 释放槽位,继续到 Layer 2
} else { } else {
return &AccountSelectionResult{ if s.cache != nil {
Account: account, _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
Acquired: true, }
ReleaseFunc: result.ReleaseFunc, return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil)
}, nil
} }
} }
...@@ -1617,15 +1630,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1617,15 +1630,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if !s.checkAndRegisterSession(ctx, account, sessionHash) { if !s.checkAndRegisterSession(ctx, account, sessionHash) {
// 会话限制已满,继续到 Layer 2 // 会话限制已满,继续到 Layer 2
} else { } else {
return &AccountSelectionResult{ return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{
Account: account, AccountID: accountID,
WaitPlan: &AccountWaitPlan{ MaxConcurrency: account.Concurrency,
AccountID: accountID, Timeout: cfg.StickySessionWaitTimeout,
MaxConcurrency: account.Concurrency, MaxWaiting: cfg.StickySessionMaxWaiting,
Timeout: cfg.StickySessionWaitTimeout, })
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
} }
} }
} }
...@@ -1684,7 +1694,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1684,7 +1694,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
if err != nil { if err != nil {
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok { if result, ok, legacyErr := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); legacyErr != nil {
return nil, legacyErr
} else if ok {
return result, nil return result, nil
} }
} else { } else {
...@@ -1723,11 +1735,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1723,11 +1735,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if sessionHash != "" && s.cache != nil { if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL) _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
} }
return &AccountSelectionResult{ return s.newSelectionResult(ctx, selected.account, true, result.ReleaseFunc, nil)
Account: selected.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
} }
} }
...@@ -1750,20 +1758,17 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1750,20 +1758,17 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if !s.checkAndRegisterSession(ctx, acc, sessionHash) { if !s.checkAndRegisterSession(ctx, acc, sessionHash) {
continue // 会话限制已满,尝试下一个账号 continue // 会话限制已满,尝试下一个账号
} }
return &AccountSelectionResult{ return s.newSelectionResult(ctx, acc, false, nil, &AccountWaitPlan{
Account: acc, AccountID: acc.ID,
WaitPlan: &AccountWaitPlan{ MaxConcurrency: acc.Concurrency,
AccountID: acc.ID, Timeout: cfg.FallbackWaitTimeout,
MaxConcurrency: acc.Concurrency, MaxWaiting: cfg.FallbackMaxWaiting,
Timeout: cfg.FallbackWaitTimeout, })
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
} }
return nil, ErrNoAvailableAccounts return nil, ErrNoAvailableAccounts
} }
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) { func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool, error) {
ordered := append([]*Account(nil), candidates...) ordered := append([]*Account(nil), candidates...)
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth) sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
...@@ -1778,15 +1783,15 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates ...@@ -1778,15 +1783,15 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates
if sessionHash != "" && s.cache != nil { if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL) _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL)
} }
return &AccountSelectionResult{ selection, err := s.newSelectionResult(ctx, acc, true, result.ReleaseFunc, nil)
Account: acc, if err != nil {
Acquired: true, return nil, false, err
ReleaseFunc: result.ReleaseFunc, }
}, true return selection, true, nil
} }
} }
return nil, false return nil, false, nil
} }
func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig { func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig {
...@@ -2401,6 +2406,33 @@ func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID in ...@@ -2401,6 +2406,33 @@ func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID in
return s.accountRepo.GetByID(ctx, accountID) return s.accountRepo.GetByID(ctx, accountID)
} }
func (s *GatewayService) hydrateSelectedAccount(ctx context.Context, account *Account) (*Account, error) {
if account == nil || s.schedulerSnapshot == nil {
return account, nil
}
hydrated, err := s.schedulerSnapshot.GetAccount(ctx, account.ID)
if err != nil {
return nil, err
}
if hydrated == nil {
return nil, fmt.Errorf("selected gateway account %d not found during hydration", account.ID)
}
return hydrated, nil
}
func (s *GatewayService) newSelectionResult(ctx context.Context, account *Account, acquired bool, release func(), waitPlan *AccountWaitPlan) (*AccountSelectionResult, error) {
hydrated, err := s.hydrateSelectedAccount(ctx, account)
if err != nil {
return nil, err
}
return &AccountSelectionResult{
Account: hydrated,
Acquired: acquired,
ReleaseFunc: release,
WaitPlan: waitPlan,
}, nil
}
// filterByMinPriority 过滤出优先级最小的账号集合 // filterByMinPriority 过滤出优先级最小的账号集合
func filterByMinPriority(accounts []accountWithLoad) []accountWithLoad { func filterByMinPriority(accounts []accountWithLoad) []accountWithLoad {
if len(accounts) == 0 { if len(accounts) == 0 {
...@@ -2676,6 +2708,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, ...@@ -2676,6 +2708,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
preferOAuth := platform == PlatformGemini preferOAuth := platform == PlatformGemini
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform) routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform)
// require_privacy_set: 获取分组信息
var schedGroup *Group
if groupID != nil && s.groupRepo != nil {
schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID)
}
var accounts []Account var accounts []Account
accountsLoaded := false accountsLoaded := false
...@@ -2747,6 +2785,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, ...@@ -2747,6 +2785,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if !s.isAccountSchedulableForSelection(acc) { if !s.isAccountSchedulableForSelection(acc) {
continue continue
} }
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
_ = s.accountRepo.SetError(ctx, acc.ID,
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
continue
}
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue continue
} }
...@@ -2852,6 +2896,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, ...@@ -2852,6 +2896,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if !s.isAccountSchedulableForSelection(acc) { if !s.isAccountSchedulableForSelection(acc) {
continue continue
} }
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
_ = s.accountRepo.SetError(ctx, acc.ID,
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
continue
}
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue continue
} }
...@@ -2918,6 +2968,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g ...@@ -2918,6 +2968,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
preferOAuth := nativePlatform == PlatformGemini preferOAuth := nativePlatform == PlatformGemini
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, nativePlatform) routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, nativePlatform)
// require_privacy_set: 获取分组信息
var schedGroup *Group
if groupID != nil && s.groupRepo != nil {
schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID)
}
var accounts []Account var accounts []Account
accountsLoaded := false accountsLoaded := false
...@@ -2985,6 +3041,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g ...@@ -2985,6 +3041,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if !s.isAccountSchedulableForSelection(acc) { if !s.isAccountSchedulableForSelection(acc) {
continue continue
} }
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
_ = s.accountRepo.SetError(ctx, acc.ID,
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
continue
}
// 过滤:原生平台直接通过,antigravity 需要启用混合调度 // 过滤:原生平台直接通过,antigravity 需要启用混合调度
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue continue
...@@ -3078,6 +3140,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g ...@@ -3078,6 +3140,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
ctx = s.withRPMPrefetch(ctx, accounts) ctx = s.withRPMPrefetch(ctx, accounts)
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度) // 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
// needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查。
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID) needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
var selected *Account var selected *Account
for i := range accounts { for i := range accounts {
...@@ -3090,6 +3153,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g ...@@ -3090,6 +3153,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if !s.isAccountSchedulableForSelection(acc) { if !s.isAccountSchedulableForSelection(acc) {
continue continue
} }
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
_ = s.accountRepo.SetError(ctx, acc.ID,
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
continue
}
// 过滤:原生平台直接通过,antigravity 需要启用混合调度 // 过滤:原生平台直接通过,antigravity 需要启用混合调度
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue continue
...@@ -3257,8 +3326,7 @@ func (s *GatewayService) diagnoseSelectionFailure( ...@@ -3257,8 +3326,7 @@ func (s *GatewayService) diagnoseSelectionFailure(
return selectionFailureDiagnosis{Category: "excluded"} return selectionFailureDiagnosis{Category: "excluded"}
} }
if !s.isAccountSchedulableForSelection(acc) { if !s.isAccountSchedulableForSelection(acc) {
detail := "generic_unschedulable" return selectionFailureDiagnosis{Category: "unschedulable", Detail: "generic_unschedulable"}
return selectionFailureDiagnosis{Category: "unschedulable", Detail: detail}
} }
if isPlatformFilteredForSelection(acc, platform, allowMixedScheduling) { if isPlatformFilteredForSelection(acc, platform, allowMixedScheduling) {
return selectionFailureDiagnosis{ return selectionFailureDiagnosis{
...@@ -3282,7 +3350,6 @@ func (s *GatewayService) diagnoseSelectionFailure( ...@@ -3282,7 +3350,6 @@ func (s *GatewayService) diagnoseSelectionFailure(
return selectionFailureDiagnosis{Category: "eligible"} return selectionFailureDiagnosis{Category: "eligible"}
} }
// GetAccessToken 获取账号凭证
func isPlatformFilteredForSelection(acc *Account, platform string, allowMixedScheduling bool) bool { func isPlatformFilteredForSelection(acc *Account, platform string, allowMixedScheduling bool) bool {
if acc == nil { if acc == nil {
return true return true
...@@ -3653,6 +3720,86 @@ func injectClaudeCodePrompt(body []byte, system any) []byte { ...@@ -3653,6 +3720,86 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
return result return result
} }
// rewriteSystemForNonClaudeCode 将非 Claude Code 客户端的 system prompt 迁移至 messages,
// system 字段仅保留 Claude Code 标识提示词。
// Anthropic 基于 system 参数内容检测第三方应用,仅前置追加 Claude Code 提示词
// 无法通过检测,因为后续内容仍为非 Claude Code 格式。
// 策略:将原始 system prompt 提取并注入为 user/assistant 消息对,system 仅保留 Claude Code 标识。
func rewriteSystemForNonClaudeCode(body []byte, system any) []byte {
system = normalizeSystemParam(system)
// 1. 提取原始 system prompt 文本
var originalSystemText string
switch v := system.(type) {
case string:
originalSystemText = strings.TrimSpace(v)
case []any:
var parts []string
for _, item := range v {
if m, ok := item.(map[string]any); ok {
if text, ok := m["text"].(string); ok && strings.TrimSpace(text) != "" {
parts = append(parts, text)
}
}
}
originalSystemText = strings.Join(parts, "\n\n")
}
// 2. 将 system 替换为 Claude Code 标准提示词(array 格式,与真实 Claude Code 一致)
// 真实 Claude Code 始终以 [{type: "text", text: "...", cache_control: {type: "ephemeral"}}] 发送 system。
// 使用 string 格式会被 Anthropic 检测为第三方应用。
claudeCodeSystemBlock := []map[string]any{
{
"type": "text",
"text": claudeCodeSystemPrompt,
"cache_control": map[string]string{"type": "ephemeral"},
},
}
out, ok := setJSONValueBytes(body, "system", claudeCodeSystemBlock)
if !ok {
logger.LegacyPrintf("service.gateway", "Warning: failed to set Claude Code system prompt")
return body
}
// 3. 将原始 system prompt 作为 user/assistant 消息对注入到 messages 开头
// 模型仍通过 messages 接收完整指令,保留客户端功能
ccPromptTrimmed := strings.TrimSpace(claudeCodeSystemPrompt)
if originalSystemText != "" && originalSystemText != ccPromptTrimmed && !hasClaudeCodePrefix(originalSystemText) {
instrMsg, err1 := json.Marshal(map[string]any{
"role": "user",
"content": []map[string]any{
{"type": "text", "text": "[System Instructions]\n" + originalSystemText},
},
})
ackMsg, err2 := json.Marshal(map[string]any{
"role": "assistant",
"content": []map[string]any{
{"type": "text", "text": "Understood. I will follow these instructions."},
},
})
if err1 != nil || err2 != nil {
logger.LegacyPrintf("service.gateway", "Warning: failed to marshal system-to-messages injection")
return out
}
// 重建 messages 数组:[instruction, ack, ...originalMessages]
items := [][]byte{instrMsg, ackMsg}
messagesResult := gjson.GetBytes(out, "messages")
if messagesResult.IsArray() {
messagesResult.ForEach(func(_, msg gjson.Result) bool {
items = append(items, []byte(msg.Raw))
return true
})
}
if next, setOk := setJSONRawBytes(out, "messages", buildJSONArrayRaw(items)); setOk {
out = next
}
}
return out
}
type cacheControlPath struct { type cacheControlPath struct {
path string path string
log string log string
...@@ -3819,7 +3966,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -3819,7 +3966,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// Beta policy: evaluate once; block check + cache filter set for buildUpstreamRequest. // Beta policy: evaluate once; block check + cache filter set for buildUpstreamRequest.
// Always overwrite the cache to prevent stale values from a previous retry with a different account. // Always overwrite the cache to prevent stale values from a previous retry with a different account.
if account.Platform == PlatformAnthropic && c != nil { if account.Platform == PlatformAnthropic && c != nil {
policy := s.evaluateBetaPolicy(ctx, c.GetHeader("anthropic-beta"), account) policy := s.evaluateBetaPolicy(ctx, c.GetHeader("anthropic-beta"), account, parsed.Model)
if policy.blockErr != nil { if policy.blockErr != nil {
return nil, policy.blockErr return nil, policy.blockErr
} }
...@@ -3849,19 +3996,24 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -3849,19 +3996,24 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
if shouldMimicClaudeCode { if shouldMimicClaudeCode {
// 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要) // Claude Code 客户端:将 system 替换为 Claude Code 标识,原始 system 迁移至 messages
// 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词 // 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
systemRewritten := false
if !strings.Contains(strings.ToLower(reqModel), "haiku") && if !strings.Contains(strings.ToLower(reqModel), "haiku") &&
!systemIncludesClaudeCodePrompt(parsed.System) { !systemIncludesClaudeCodePrompt(parsed.System) {
body = injectClaudeCodePrompt(body, parsed.System) body = rewriteSystemForNonClaudeCode(body, parsed.System)
systemRewritten = true
} }
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true} // system 被重写时保留 CC prompt 的 cache_control: ephemeral(匹配真实 Claude Code 行为);
// 未重写时(haiku / 已含 CC 前缀)剥离客户端 cache_control,与原有行为一致。
// 两种情况下 enforceCacheControlLimit 都会兜底处理上限。
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: !systemRewritten}
if s.identityService != nil { if s.identityService != nil {
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header) fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
if err == nil && fp != nil { if err == nil && fp != nil {
// metadata 透传开启时跳过 metadata 注入 // metadata 透传开启时跳过 metadata 注入
_, mimicMPT := s.settingService.GetGatewayForwardingSettings(ctx) _, mimicMPT, _ := s.settingService.GetGatewayForwardingSettings(ctx)
if !mimicMPT { if !mimicMPT {
if metadataUserID := s.buildOAuthMetadataUserID(parsed, account, fp); metadataUserID != "" { if metadataUserID := s.buildOAuthMetadataUserID(parsed, account, fp); metadataUserID != "" {
normalizeOpts.injectMetadata = true normalizeOpts.injectMetadata = true
...@@ -5407,9 +5559,9 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex ...@@ -5407,9 +5559,9 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// OAuth账号:应用统一指纹和metadata重写(受设置开关控制) // OAuth账号:应用统一指纹和metadata重写(受设置开关控制)
var fingerprint *Fingerprint var fingerprint *Fingerprint
enableFP, enableMPT := true, false enableFP, enableMPT, enableCCH := true, false, false
if s.settingService != nil { if s.settingService != nil {
enableFP, enableMPT = s.settingService.GetGatewayForwardingSettings(ctx) enableFP, enableMPT, enableCCH = s.settingService.GetGatewayForwardingSettings(ctx)
} }
if account.IsOAuth() && s.identityService != nil { if account.IsOAuth() && s.identityService != nil {
// 1. 获取或创建指纹(包含随机生成的ClientID) // 1. 获取或创建指纹(包含随机生成的ClientID)
...@@ -5436,6 +5588,15 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex ...@@ -5436,6 +5588,15 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
} }
} }
// 同步 billing header cc_version 与实际发送的 User-Agent 版本
if fingerprint != nil {
body = syncBillingHeaderVersion(body, fingerprint.UserAgent)
}
// CCH 签名:将 cch=00000 占位符替换为 xxHash64 签名(需在所有 body 修改之后)
if enableCCH {
body = signBillingHeaderCCH(body)
}
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -5476,9 +5637,8 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex ...@@ -5476,9 +5637,8 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
} }
// Build effective drop set: merge static defaults with dynamic beta policy filter rules // Build effective drop set: merge static defaults with dynamic beta policy filter rules
policyFilterSet := s.getBetaPolicyFilterSet(ctx, c, account) policyFilterSet := s.getBetaPolicyFilterSet(ctx, c, account, modelID)
effectiveDropSet := mergeDropSets(policyFilterSet) effectiveDropSet := mergeDropSets(policyFilterSet)
effectiveDropWithClaudeCodeSet := mergeDropSets(policyFilterSet, claude.BetaClaudeCode)
// 处理 anthropic-beta header(OAuth 账号需要包含 oauth beta) // 处理 anthropic-beta header(OAuth 账号需要包含 oauth beta)
if tokenType == "oauth" { if tokenType == "oauth" {
...@@ -5489,11 +5649,16 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex ...@@ -5489,11 +5649,16 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
applyClaudeCodeMimicHeaders(req, reqStream) applyClaudeCodeMimicHeaders(req, reqStream)
incomingBeta := getHeaderRaw(req.Header, "anthropic-beta") incomingBeta := getHeaderRaw(req.Header, "anthropic-beta")
// Match real Claude CLI traffic (per mitmproxy reports): // Claude Code OAuth credentials are scoped to Claude Code.
// messages requests typically use only oauth + interleaved-thinking. // Non-haiku models MUST include claude-code beta for Anthropic to recognize
// Also drop claude-code beta if a downstream client added it. // this as a legitimate Claude Code request; without it, the request is
// rejected as third-party ("out of extra usage").
// Haiku models are exempt from third-party detection and don't need it.
requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking} requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking}
setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropWithClaudeCodeSet)) if !strings.Contains(strings.ToLower(modelID), "haiku") {
requiredBetas = []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking}
}
setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropSet))
} else { } else {
// Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta // Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta
clientBetaHeader := getHeaderRaw(req.Header, "anthropic-beta") clientBetaHeader := getHeaderRaw(req.Header, "anthropic-beta")
...@@ -5716,7 +5881,7 @@ type betaPolicyResult struct { ...@@ -5716,7 +5881,7 @@ type betaPolicyResult struct {
} }
// evaluateBetaPolicy loads settings once and evaluates all rules against the given request. // evaluateBetaPolicy loads settings once and evaluates all rules against the given request.
func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader string, account *Account) betaPolicyResult { func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader string, account *Account, model string) betaPolicyResult {
if s.settingService == nil { if s.settingService == nil {
return betaPolicyResult{} return betaPolicyResult{}
} }
...@@ -5731,10 +5896,11 @@ func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader stri ...@@ -5731,10 +5896,11 @@ func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader stri
if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) { if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) {
continue continue
} }
switch rule.Action { effectiveAction, effectiveErrMsg := resolveRuleAction(rule, model)
switch effectiveAction {
case BetaPolicyActionBlock: case BetaPolicyActionBlock:
if result.blockErr == nil && betaHeader != "" && containsBetaToken(betaHeader, rule.BetaToken) { if result.blockErr == nil && betaHeader != "" && containsBetaToken(betaHeader, rule.BetaToken) {
msg := rule.ErrorMessage msg := effectiveErrMsg
if msg == "" { if msg == "" {
msg = "beta feature " + rule.BetaToken + " is not allowed" msg = "beta feature " + rule.BetaToken + " is not allowed"
} }
...@@ -5776,7 +5942,7 @@ const betaPolicyFilterSetKey = "betaPolicyFilterSet" ...@@ -5776,7 +5942,7 @@ const betaPolicyFilterSetKey = "betaPolicyFilterSet"
// In the /v1/messages path, Forward() evaluates the policy first and caches the result; // In the /v1/messages path, Forward() evaluates the policy first and caches the result;
// buildUpstreamRequest reuses it (zero extra DB calls). In the count_tokens path, this // buildUpstreamRequest reuses it (zero extra DB calls). In the count_tokens path, this
// evaluates on demand (one DB call). // evaluates on demand (one DB call).
func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Context, account *Account) map[string]struct{} { func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Context, account *Account, model string) map[string]struct{} {
if c != nil { if c != nil {
if v, ok := c.Get(betaPolicyFilterSetKey); ok { if v, ok := c.Get(betaPolicyFilterSetKey); ok {
if fs, ok := v.(map[string]struct{}); ok { if fs, ok := v.(map[string]struct{}); ok {
...@@ -5784,7 +5950,7 @@ func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Cont ...@@ -5784,7 +5950,7 @@ func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Cont
} }
} }
} }
return s.evaluateBetaPolicy(ctx, "", account).filterSet return s.evaluateBetaPolicy(ctx, "", account, model).filterSet
} }
// betaPolicyScopeMatches checks whether a rule's scope matches the current account type. // betaPolicyScopeMatches checks whether a rule's scope matches the current account type.
...@@ -5803,6 +5969,33 @@ func betaPolicyScopeMatches(scope string, isOAuth bool, isBedrock bool) bool { ...@@ -5803,6 +5969,33 @@ func betaPolicyScopeMatches(scope string, isOAuth bool, isBedrock bool) bool {
} }
} }
// matchModelWhitelist checks if a model matches any pattern in the whitelist.
// Reuses matchModelPattern from group.go which supports exact and wildcard prefix matching.
func matchModelWhitelist(model string, whitelist []string) bool {
for _, pattern := range whitelist {
if matchModelPattern(pattern, model) {
return true
}
}
return false
}
// resolveRuleAction determines the effective action and error message for a rule given the request model.
// When ModelWhitelist is empty, the rule's primary Action/ErrorMessage applies unconditionally.
// When non-empty, Action applies to matching models; FallbackAction/FallbackErrorMessage applies to others.
func resolveRuleAction(rule BetaPolicyRule, model string) (action, errorMessage string) {
if len(rule.ModelWhitelist) == 0 {
return rule.Action, rule.ErrorMessage
}
if matchModelWhitelist(model, rule.ModelWhitelist) {
return rule.Action, rule.ErrorMessage
}
if rule.FallbackAction != "" {
return rule.FallbackAction, rule.FallbackErrorMessage
}
return BetaPolicyActionPass, "" // default fallback: pass (fail-open)
}
// droppedBetaSet returns claude.DroppedBetas as a set, with optional extra tokens. // droppedBetaSet returns claude.DroppedBetas as a set, with optional extra tokens.
func droppedBetaSet(extra ...string) map[string]struct{} { func droppedBetaSet(extra ...string) map[string]struct{} {
m := make(map[string]struct{}, len(defaultDroppedBetasSet)+len(extra)) m := make(map[string]struct{}, len(defaultDroppedBetasSet)+len(extra))
...@@ -5849,7 +6042,7 @@ func (s *GatewayService) resolveBedrockBetaTokensForRequest( ...@@ -5849,7 +6042,7 @@ func (s *GatewayService) resolveBedrockBetaTokensForRequest(
modelID string, modelID string,
) ([]string, error) { ) ([]string, error) {
// 1. 对原始 header 中的 beta token 做 block 检查(快速失败) // 1. 对原始 header 中的 beta token 做 block 检查(快速失败)
policy := s.evaluateBetaPolicy(ctx, betaHeader, account) policy := s.evaluateBetaPolicy(ctx, betaHeader, account, modelID)
if policy.blockErr != nil { if policy.blockErr != nil {
return nil, policy.blockErr return nil, policy.blockErr
} }
...@@ -5861,7 +6054,7 @@ func (s *GatewayService) resolveBedrockBetaTokensForRequest( ...@@ -5861,7 +6054,7 @@ func (s *GatewayService) resolveBedrockBetaTokensForRequest(
// 例如:管理员 block 了 interleaved-thinking,客户端不在 header 中带该 token, // 例如:管理员 block 了 interleaved-thinking,客户端不在 header 中带该 token,
// 但请求体中包含 thinking 字段 → autoInjectBedrockBetaTokens 会自动补齐 → // 但请求体中包含 thinking 字段 → autoInjectBedrockBetaTokens 会自动补齐 →
// 如果不做此检查,block 规则会被绕过。 // 如果不做此检查,block 规则会被绕过。
if blockErr := s.checkBetaPolicyBlockForTokens(ctx, betaTokens, account); blockErr != nil { if blockErr := s.checkBetaPolicyBlockForTokens(ctx, betaTokens, account, modelID); blockErr != nil {
return nil, blockErr return nil, blockErr
} }
...@@ -5870,7 +6063,7 @@ func (s *GatewayService) resolveBedrockBetaTokensForRequest( ...@@ -5870,7 +6063,7 @@ func (s *GatewayService) resolveBedrockBetaTokensForRequest(
// checkBetaPolicyBlockForTokens 检查 token 列表中是否有被管理员 block 规则命中的 token。 // checkBetaPolicyBlockForTokens 检查 token 列表中是否有被管理员 block 规则命中的 token。
// 用于补充 evaluateBetaPolicy 对 header 的检查,覆盖 body 自动注入的 token。 // 用于补充 evaluateBetaPolicy 对 header 的检查,覆盖 body 自动注入的 token。
func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, tokens []string, account *Account) *BetaBlockedError { func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, tokens []string, account *Account, model string) *BetaBlockedError {
if s.settingService == nil || len(tokens) == 0 { if s.settingService == nil || len(tokens) == 0 {
return nil return nil
} }
...@@ -5882,14 +6075,15 @@ func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, toke ...@@ -5882,14 +6075,15 @@ func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, toke
isBedrock := account.IsBedrock() isBedrock := account.IsBedrock()
tokenSet := buildBetaTokenSet(tokens) tokenSet := buildBetaTokenSet(tokens)
for _, rule := range settings.Rules { for _, rule := range settings.Rules {
if rule.Action != BetaPolicyActionBlock { effectiveAction, effectiveErrMsg := resolveRuleAction(rule, model)
if effectiveAction != BetaPolicyActionBlock {
continue continue
} }
if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) { if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) {
continue continue
} }
if _, present := tokenSet[rule.BetaToken]; present { if _, present := tokenSet[rule.BetaToken]; present {
msg := rule.ErrorMessage msg := effectiveErrMsg
if msg == "" { if msg == "" {
msg = "beta feature " + rule.BetaToken + " is not allowed" msg = "beta feature " + rule.BetaToken + " is not allowed"
} }
...@@ -7146,49 +7340,41 @@ func (p *postUsageBillingParams) shouldUpdateAccountQuota() bool { ...@@ -7146,49 +7340,41 @@ func (p *postUsageBillingParams) shouldUpdateAccountQuota() bool {
return p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() return p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit()
} }
// postUsageBilling 统一处理使用量记录后的扣费逻辑: // postUsageBilling is the legacy fallback billing path used when the unified
// - 订阅/余额扣费 // billing repo is unavailable (nil). Production uses applyUsageBilling → repo.Apply
// - API Key 配额更新 // for atomic billing. This path only runs in tests or degraded mode.
// - API Key 限速用量更新
// - 账号配额用量更新(账号口径:TotalCost × 账号计费倍率)
func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps) { func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps) {
billingCtx, cancel := detachedBillingContext(ctx) billingCtx, cancel := detachedBillingContext(ctx)
defer cancel() defer cancel()
cost := p.Cost cost := p.Cost
// 1. 订阅 / 余额扣费
if p.IsSubscriptionBill { if p.IsSubscriptionBill {
if cost.TotalCost > 0 { if cost.TotalCost > 0 {
if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.TotalCost); err != nil { if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.TotalCost); err != nil {
slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err) slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err)
} }
deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, cost.TotalCost)
} }
} else { } else {
if cost.ActualCost > 0 { if cost.ActualCost > 0 {
if err := deps.userRepo.DeductBalance(billingCtx, p.User.ID, cost.ActualCost); err != nil { if err := deps.userRepo.DeductBalance(billingCtx, p.User.ID, cost.ActualCost); err != nil {
slog.Error("deduct balance failed", "user_id", p.User.ID, "error", err) slog.Error("deduct balance failed", "user_id", p.User.ID, "error", err)
} }
deps.billingCacheService.QueueDeductBalance(p.User.ID, cost.ActualCost)
} }
} }
// 2. API Key 配额
if p.shouldDeductAPIKeyQuota() { if p.shouldDeductAPIKeyQuota() {
if err := p.APIKeyService.UpdateQuotaUsed(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil { if err := p.APIKeyService.UpdateQuotaUsed(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err) slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err)
} }
} }
// 3. API Key 限速用量
if p.shouldUpdateRateLimits() { if p.shouldUpdateRateLimits() {
if err := p.APIKeyService.UpdateRateLimitUsage(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil { if err := p.APIKeyService.UpdateRateLimitUsage(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err) slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err)
} }
} }
// 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率)
if p.shouldUpdateAccountQuota() { if p.shouldUpdateAccountQuota() {
accountCost := cost.TotalCost * p.AccountRateMultiplier accountCost := cost.TotalCost * p.AccountRateMultiplier
if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil { if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil {
...@@ -7196,7 +7382,10 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill ...@@ -7196,7 +7382,10 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
} }
} }
finalizePostUsageBilling(p, deps) // NOTE: finalizePostUsageBilling is NOT called here to avoid double-queuing
// cache updates. The legacy path does DB writes directly; the finalize path
// does cache queue + notifications. Notifications are dispatched separately
// by the caller after recording the usage log.
} }
func resolveUsageBillingRequestID(ctx context.Context, upstreamRequestID string) string { func resolveUsageBillingRequestID(ctx context.Context, upstreamRequestID string) string {
...@@ -7250,9 +7439,6 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage ...@@ -7250,9 +7439,6 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage
cmd.CacheCreationTokens = usageLog.CacheCreationTokens cmd.CacheCreationTokens = usageLog.CacheCreationTokens
cmd.CacheReadTokens = usageLog.CacheReadTokens cmd.CacheReadTokens = usageLog.CacheReadTokens
cmd.ImageCount = usageLog.ImageCount cmd.ImageCount = usageLog.ImageCount
if usageLog.MediaType != nil {
cmd.MediaType = *usageLog.MediaType
}
if usageLog.ServiceTier != nil { if usageLog.ServiceTier != nil {
cmd.ServiceTier = *usageLog.ServiceTier cmd.ServiceTier = *usageLog.ServiceTier
} }
...@@ -7315,11 +7501,11 @@ func applyUsageBilling(ctx context.Context, requestID string, usageLog *UsageLog ...@@ -7315,11 +7501,11 @@ func applyUsageBilling(ctx context.Context, requestID string, usageLog *UsageLog
} }
} }
finalizePostUsageBilling(p, deps) finalizePostUsageBilling(p, deps, result)
return true, nil return true, nil
} }
func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps) { func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) {
if p == nil || p.Cost == nil || deps == nil { if p == nil || p.Cost == nil || deps == nil {
return return
} }
...@@ -7338,22 +7524,82 @@ func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps) { ...@@ -7338,22 +7524,82 @@ func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps) {
deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID) deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID)
// Balance low notification — use real-time balance from billing cache (not stale snapshot) // Notification checks run async — all parameters are already captured,
if !p.IsSubscriptionBill && p.Cost.ActualCost > 0 && p.User != nil && deps.balanceNotifyService != nil { // no dependency on the request context or upstream connection.
oldBalance := p.User.Balance // fallback to snapshot go notifyBalanceLow(p, deps, result)
if deps.billingCacheService != nil { go notifyAccountQuota(p, deps, result)
if realBalance, err := deps.billingCacheService.GetUserBalance(context.Background(), p.User.ID); err == nil { }
oldBalance = realBalance + p.Cost.ActualCost // DB already deducted, reconstruct pre-deduction balance
} // notifyBalanceLow sends balance low notification after deduction.
// When result.NewBalance is available (from DB transaction RETURNING), it is used directly
// to reconstruct oldBalance, avoiding stale Redis reads and concurrent-deduction races.
func notifyBalanceLow(p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) {
defer func() {
if r := recover(); r != nil {
slog.Error("panic in notifyBalanceLow", "recover", r)
} }
deps.balanceNotifyService.CheckBalanceAfterDeduction(context.Background(), p.User, oldBalance, p.Cost.ActualCost) }()
if p.IsSubscriptionBill || p.Cost.ActualCost <= 0 || p.User == nil || deps.balanceNotifyService == nil {
slog.Debug("notifyBalanceLow: skipped",
"is_subscription", p.IsSubscriptionBill,
"actual_cost", p.Cost.ActualCost,
"user_nil", p.User == nil,
"service_nil", deps.balanceNotifyService == nil,
)
return
} }
// Account quota notification (use same cost formula as postUsageBilling) oldBalance := resolveOldBalance(p, result)
if p.Cost.TotalCost > 0 && p.Account != nil && p.Account.IsAPIKeyOrBedrock() && deps.balanceNotifyService != nil { slog.Debug("notifyBalanceLow: calling CheckBalanceAfterDeduction",
accountCost := p.Cost.TotalCost * p.AccountRateMultiplier "user_id", p.User.ID,
deps.balanceNotifyService.CheckAccountQuotaAfterIncrement(context.Background(), p.Account, accountCost) "old_balance", oldBalance,
"cost", p.Cost.ActualCost,
"notify_enabled", p.User.BalanceNotifyEnabled,
"threshold", p.User.BalanceNotifyThreshold,
"result_has_new_balance", result != nil && result.NewBalance != nil,
)
deps.balanceNotifyService.CheckBalanceAfterDeduction(context.Background(), p.User, oldBalance, p.Cost.ActualCost)
}
// resolveOldBalance returns the pre-deduction balance.
// Prefers the DB transaction result (newBalance + cost) over snapshot.
func resolveOldBalance(p *postUsageBillingParams, result *UsageBillingApplyResult) float64 {
if result != nil && result.NewBalance != nil {
return *result.NewBalance + p.Cost.ActualCost
} }
// Legacy fallback: snapshot balance from request context
return p.User.Balance
}
// notifyAccountQuota sends account quota threshold notification after increment.
// When result.QuotaState is available (from DB transaction RETURNING), it is passed directly
// to avoid a separate DB read that may see stale or concurrently-modified data.
func notifyAccountQuota(p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) {
defer func() {
if r := recover(); r != nil {
slog.Error("panic in notifyAccountQuota", "recover", r)
}
}()
if p.Cost.TotalCost <= 0 || p.Account == nil || !p.Account.IsAPIKeyOrBedrock() || deps.balanceNotifyService == nil {
slog.Debug("notifyAccountQuota: skipped",
"total_cost", p.Cost.TotalCost,
"account_nil", p.Account == nil,
"is_apikey_or_bedrock", p.Account != nil && p.Account.IsAPIKeyOrBedrock(),
"service_nil", deps.balanceNotifyService == nil,
)
return
}
accountCost := p.Cost.TotalCost * p.AccountRateMultiplier
var quotaState *AccountQuotaState
if result != nil {
quotaState = result.QuotaState
}
slog.Debug("notifyAccountQuota: calling CheckAccountQuotaAfterIncrement",
"account_id", p.Account.ID,
"account_cost", accountCost,
"has_quota_state", quotaState != nil,
)
deps.balanceNotifyService.CheckAccountQuotaAfterIncrement(context.Background(), p.Account, accountCost, quotaState)
} }
func detachedBillingContext(ctx context.Context) (context.Context, context.CancelFunc) { func detachedBillingContext(ctx context.Context) (context.Context, context.CancelFunc) {
...@@ -7422,11 +7668,11 @@ func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usage ...@@ -7422,11 +7668,11 @@ func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usage
// recordUsageOpts 内部选项,参数化 RecordUsage 与 RecordUsageWithLongContext 的差异点。 // recordUsageOpts 内部选项,参数化 RecordUsage 与 RecordUsageWithLongContext 的差异点。
type recordUsageOpts struct { type recordUsageOpts struct {
// ParsedRequest(可选,仅 Claude 路径传入) // Claude Max 策略所需的 ParsedRequest(可选,仅 Claude 路径传入)
ParsedRequest *ParsedRequest ParsedRequest *ParsedRequest
// EnableClaudePath 启用 Claude 路径特有逻辑: // EnableClaudePath 启用 Claude 路径特有逻辑:
// - MediaType 字段写入使用日志 // - Claude Max 缓存计费策略
EnableClaudePath bool EnableClaudePath bool
// 长上下文计费(仅 Gemini 路径需要) // 长上下文计费(仅 Gemini 路径需要)
...@@ -7451,7 +7697,6 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -7451,7 +7697,6 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
APIKeyService: input.APIKeyService, APIKeyService: input.APIKeyService,
ChannelUsageFields: input.ChannelUsageFields, ChannelUsageFields: input.ChannelUsageFields,
}, &recordUsageOpts{ }, &recordUsageOpts{
ParsedRequest: input.ParsedRequest,
EnableClaudePath: true, EnableClaudePath: true,
}) })
} }
...@@ -7517,6 +7762,7 @@ type recordUsageCoreInput struct { ...@@ -7517,6 +7762,7 @@ type recordUsageCoreInput struct {
// recordUsageCore 是 RecordUsage 和 RecordUsageWithLongContext 的统一实现。 // recordUsageCore 是 RecordUsage 和 RecordUsageWithLongContext 的统一实现。
// opts 中的字段控制两者之间的差异行为: // opts 中的字段控制两者之间的差异行为:
// - ParsedRequest != nil → 启用 Claude Max 缓存计费策略
// - LongContextThreshold > 0 → Token 计费回退走 CalculateCostWithLongContext // - LongContextThreshold > 0 → Token 计费回退走 CalculateCostWithLongContext
func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsageCoreInput, opts *recordUsageOpts) error { func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsageCoreInput, opts *recordUsageOpts) error {
result := input.Result result := input.Result
...@@ -7583,13 +7829,10 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage ...@@ -7583,13 +7829,10 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
// 计算账号统计定价费用(使用最终上游模型匹配自定义规则) // 计算账号统计定价费用(使用最终上游模型匹配自定义规则)
if apiKey.GroupID != nil { if apiKey.GroupID != nil {
upstreamModel := result.UpstreamModel applyAccountStatsCost(ctx, usageLog, s.channelService, s.billingService,
if upstreamModel == "" { account.ID, *apiKey.GroupID, result.UpstreamModel, result.Model,
upstreamModel = result.Model // Anthropic's input_tokens excludes cache_read and cache_creation (billed separately);
} // OpenAI gateway uses actualInputTokens which also excludes cache_read for the same reason.
usageLog.AccountStatsCost = resolveAccountStatsCost(
ctx, s.channelService, s.billingService,
account.ID, *apiKey.GroupID, upstreamModel,
UsageTokens{ UsageTokens{
InputTokens: result.Usage.InputTokens, InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens, OutputTokens: result.Usage.OutputTokens,
...@@ -7597,7 +7840,6 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage ...@@ -7597,7 +7840,6 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
CacheReadTokens: result.Usage.CacheReadInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens,
ImageOutputTokens: result.Usage.ImageOutputTokens, ImageOutputTokens: result.Usage.ImageOutputTokens,
}, },
1, // requestCount
cost.TotalCost, cost.TotalCost,
) )
} }
...@@ -7796,13 +8038,12 @@ func (s *GatewayService) buildRecordUsageLog( ...@@ -7796,13 +8038,12 @@ func (s *GatewayService) buildRecordUsageLog(
RateMultiplier: multiplier, RateMultiplier: multiplier,
AccountRateMultiplier: &accountRateMultiplier, AccountRateMultiplier: &accountRateMultiplier,
BillingType: billingType, BillingType: billingType,
BillingMode: resolveBillingMode(opts, result, cost), BillingMode: resolveBillingMode(result, cost),
Stream: result.Stream, Stream: result.Stream,
DurationMs: &durationMs, DurationMs: &durationMs,
FirstTokenMs: result.FirstTokenMs, FirstTokenMs: result.FirstTokenMs,
ImageCount: result.ImageCount, ImageCount: result.ImageCount,
ImageSize: optionalTrimmedStringPtr(result.ImageSize), ImageSize: optionalTrimmedStringPtr(result.ImageSize),
MediaType: resolveMediaType(opts, result),
CacheTTLOverridden: cacheTTLOverridden, CacheTTLOverridden: cacheTTLOverridden,
ChannelID: optionalInt64Ptr(input.ChannelID), ChannelID: optionalInt64Ptr(input.ChannelID),
ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain), ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain),
...@@ -7826,7 +8067,7 @@ func (s *GatewayService) buildRecordUsageLog( ...@@ -7826,7 +8067,7 @@ func (s *GatewayService) buildRecordUsageLog(
} }
// resolveBillingMode 根据计费结果和请求类型确定计费模式。 // resolveBillingMode 根据计费结果和请求类型确定计费模式。
func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *CostBreakdown) *string { func resolveBillingMode(result *ForwardResult, cost *CostBreakdown) *string {
var mode string var mode string
switch { switch {
case cost != nil && cost.BillingMode != "": case cost != nil && cost.BillingMode != "":
...@@ -7839,10 +8080,6 @@ func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *Cost ...@@ -7839,10 +8080,6 @@ func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *Cost
return &mode return &mode
} }
func resolveMediaType(opts *recordUsageOpts, result *ForwardResult) *string {
return nil
}
func optionalSubscriptionID(subscription *UserSubscription) *int64 { func optionalSubscriptionID(subscription *UserSubscription) *int64 {
if subscription != nil { if subscription != nil {
return &subscription.ID return &subscription.ID
...@@ -8349,9 +8586,9 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con ...@@ -8349,9 +8586,9 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
// OAuth 账号:应用统一指纹和重写 userID(受设置开关控制) // OAuth 账号:应用统一指纹和重写 userID(受设置开关控制)
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值 // 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
ctEnableFP, ctEnableMPT := true, false ctEnableFP, ctEnableMPT, ctEnableCCH := true, false, false
if s.settingService != nil { if s.settingService != nil {
ctEnableFP, ctEnableMPT = s.settingService.GetGatewayForwardingSettings(ctx) ctEnableFP, ctEnableMPT, ctEnableCCH = s.settingService.GetGatewayForwardingSettings(ctx)
} }
var ctFingerprint *Fingerprint var ctFingerprint *Fingerprint
if account.IsOAuth() && s.identityService != nil { if account.IsOAuth() && s.identityService != nil {
...@@ -8369,6 +8606,14 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con ...@@ -8369,6 +8606,14 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
} }
} }
// 同步 billing header cc_version 与实际发送的 User-Agent 版本
if ctFingerprint != nil && ctEnableFP {
body = syncBillingHeaderVersion(body, ctFingerprint.UserAgent)
}
if ctEnableCCH {
body = signBillingHeaderCCH(body)
}
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -8409,7 +8654,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con ...@@ -8409,7 +8654,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
} }
// Build effective drop set for count_tokens: merge static defaults with dynamic beta policy filter rules // Build effective drop set for count_tokens: merge static defaults with dynamic beta policy filter rules
ctEffectiveDropSet := mergeDropSets(s.getBetaPolicyFilterSet(ctx, c, account)) ctEffectiveDropSet := mergeDropSets(s.getBetaPolicyFilterSet(ctx, c, account, modelID))
// OAuth 账号:处理 anthropic-beta header // OAuth 账号:处理 anthropic-beta header
if tokenType == "oauth" { if tokenType == "oauth" {
......
...@@ -128,3 +128,66 @@ func TestValidatePlanPatch_NilOriginalPrice(t *testing.T) { ...@@ -128,3 +128,66 @@ func TestValidatePlanPatch_NilOriginalPrice(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{OriginalPrice: nil}) err := validatePlanPatch(UpdatePlanRequest{OriginalPrice: nil})
require.NoError(t, err) require.NoError(t, err)
} }
// --- validatePlanPatch: other fields ---
func ptrStr(s string) *string { return &s }
func ptrInt(i int) *int { return &i }
func ptrInt64(i int64) *int64 { return &i }
func ptrFloat(f float64) *float64 { return &f }
func TestValidatePlanPatch_EmptyName(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{Name: ptrStr("")})
require.Error(t, err)
require.Contains(t, err.Error(), "plan name")
}
func TestValidatePlanPatch_ValidName(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{Name: ptrStr("Basic")})
require.NoError(t, err)
}
func TestValidatePlanPatch_ZeroGroupID(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{GroupID: ptrInt64(0)})
require.Error(t, err)
require.Contains(t, err.Error(), "group")
}
func TestValidatePlanPatch_NegativePrice(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{Price: ptrFloat(-1)})
require.Error(t, err)
require.Contains(t, err.Error(), "price")
}
func TestValidatePlanPatch_ZeroPrice(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{Price: ptrFloat(0)})
require.Error(t, err)
require.Contains(t, err.Error(), "price")
}
func TestValidatePlanPatch_ValidPrice(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{Price: ptrFloat(9.99)})
require.NoError(t, err)
}
func TestValidatePlanPatch_ZeroValidityDays(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{ValidityDays: ptrInt(0)})
require.Error(t, err)
require.Contains(t, err.Error(), "validity days")
}
func TestValidatePlanPatch_EmptyValidityUnit(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{ValidityUnit: ptrStr("")})
require.Error(t, err)
require.Contains(t, err.Error(), "validity unit")
}
func TestValidatePlanPatch_ValidValidityUnit(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{ValidityUnit: ptrStr("days")})
require.NoError(t, err)
}
func TestValidatePlanPatch_AllNil(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{})
require.NoError(t, err)
}
...@@ -330,6 +330,7 @@ func saveNotifyVerifyCode(ctx context.Context, cache EmailCache, email, code str ...@@ -330,6 +330,7 @@ func saveNotifyVerifyCode(ctx context.Context, cache EmailCache, email, code str
Code: code, Code: code,
Attempts: 0, Attempts: 0,
CreatedAt: time.Now(), CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(verifyCodeTTL),
} }
if err := cache.SetNotifyVerifyCode(ctx, email, data, verifyCodeTTL); err != nil { if err := cache.SetNotifyVerifyCode(ctx, email, data, verifyCodeTTL); err != nil {
return fmt.Errorf("save verify code: %w", err) return fmt.Errorf("save verify code: %w", err)
...@@ -370,7 +371,11 @@ func verifyNotifyCode(ctx context.Context, cache EmailCache, email, code string) ...@@ -370,7 +371,11 @@ func verifyNotifyCode(ctx context.Context, cache EmailCache, email, code string)
} }
if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 { if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 {
data.Attempts++ data.Attempts++
if err := cache.SetNotifyVerifyCode(ctx, email, data, verifyCodeTTL); err != nil { remaining := time.Until(data.ExpiresAt)
if remaining <= 0 {
return ErrInvalidVerifyCode
}
if err := cache.SetNotifyVerifyCode(ctx, email, data, remaining); err != nil {
slog.Error("failed to update notify verify code attempts", "email", email, "error", err) slog.Error("failed to update notify verify code attempts", "email", email, "error", err)
} }
if data.Attempts >= maxVerifyCodeAttempts { if data.Attempts >= maxVerifyCodeAttempts {
...@@ -418,11 +423,17 @@ func (s *UserService) RemoveNotifyEmail(ctx context.Context, userID int64, email ...@@ -418,11 +423,17 @@ func (s *UserService) RemoveNotifyEmail(ctx context.Context, userID int64, email
} }
filtered := make([]NotifyEmailEntry, 0, len(user.BalanceNotifyExtraEmails)) filtered := make([]NotifyEmailEntry, 0, len(user.BalanceNotifyExtraEmails))
found := false
for _, e := range user.BalanceNotifyExtraEmails { for _, e := range user.BalanceNotifyExtraEmails {
if !strings.EqualFold(e.Email, email) { if strings.EqualFold(e.Email, email) {
found = true
} else {
filtered = append(filtered, e) filtered = append(filtered, e)
} }
} }
if !found {
return infraerrors.BadRequest("EMAIL_NOT_FOUND", "notification email not found")
}
user.BalanceNotifyExtraEmails = filtered user.BalanceNotifyExtraEmails = filtered
return s.userRepo.Update(ctx, user) return s.userRepo.Update(ctx, user)
} }
......
<script setup lang="ts"> <script setup lang="ts">
import { useI18n } from 'vue-i18n' import { useI18n } from 'vue-i18n'
import QuotaNotifyToggle from './QuotaNotifyToggle.vue' import QuotaNotifyToggle from './QuotaNotifyToggle.vue'
import type { QuotaThresholdType, QuotaResetMode } from '@/constants/account'
const { t } = useI18n() const { t } = useI18n()
...@@ -11,9 +12,9 @@ const props = defineProps<{ ...@@ -11,9 +12,9 @@ const props = defineProps<{
quotaNotifyGlobalEnabled: boolean quotaNotifyGlobalEnabled: boolean
notifyEnabled: boolean | null notifyEnabled: boolean | null
notifyThreshold: number | null notifyThreshold: number | null
notifyThresholdType: string | null notifyThresholdType: QuotaThresholdType | null
// Reset mode (only for daily/weekly, null for total) // Reset mode (only for daily/weekly, null for total)
resetMode: 'rolling' | 'fixed' | null resetMode: QuotaResetMode | null
resetHour: number | null resetHour: number | null
resetDay: number | null // weekly only resetDay: number | null // weekly only
resetTimezone: string | null resetTimezone: string | null
...@@ -22,14 +23,15 @@ const props = defineProps<{ ...@@ -22,14 +23,15 @@ const props = defineProps<{
// Shared options passed from parent // Shared options passed from parent
hourOptions: number[] hourOptions: number[]
dayOptions: { value: number; key: string }[] dayOptions: { value: number; key: string }[]
timezoneOptions?: string[]
}>() }>()
const emit = defineEmits<{ const emit = defineEmits<{
'update:limit': [value: number | null] 'update:limit': [value: number | null]
'update:notifyEnabled': [value: boolean | null] 'update:notifyEnabled': [value: boolean | null]
'update:notifyThreshold': [value: number | null] 'update:notifyThreshold': [value: number | null]
'update:notifyThresholdType': [value: string | null] 'update:notifyThresholdType': [value: QuotaThresholdType | null]
'update:resetMode': [value: 'rolling' | 'fixed' | null] 'update:resetMode': [value: QuotaResetMode | null]
'update:resetHour': [value: number | null] 'update:resetHour': [value: number | null]
'update:resetDay': [value: number | null] 'update:resetDay': [value: number | null]
'update:resetTimezone': [value: string | null] 'update:resetTimezone': [value: string | null]
...@@ -43,7 +45,7 @@ const onLimitInput = (e: Event) => { ...@@ -43,7 +45,7 @@ const onLimitInput = (e: Event) => {
} }
const onModeChange = (e: Event) => { const onModeChange = (e: Event) => {
const val = (e.target as HTMLSelectElement).value as 'rolling' | 'fixed' const val = (e.target as HTMLSelectElement).value as QuotaResetMode
emit('update:resetMode', val) emit('update:resetMode', val)
if (val === 'fixed') { if (val === 'fixed') {
if (props.resetHour == null) emit('update:resetHour', 0) if (props.resetHour == null) emit('update:resetHour', 0)
...@@ -51,6 +53,17 @@ const onModeChange = (e: Event) => { ...@@ -51,6 +53,17 @@ const onModeChange = (e: Event) => {
if (!props.resetTimezone) emit('update:resetTimezone', 'UTC') if (!props.resetTimezone) emit('update:resetTimezone', 'UTC')
} }
} }
function getTimezoneOffsetLabel(tz: string): string {
try {
const dtf = new Intl.DateTimeFormat('en-US', { timeZone: tz, timeZoneName: 'shortOffset' })
const parts = dtf.formatToParts(new Date())
const tzPart = parts.find(p => p.type === 'timeZoneName')
return tzPart ? (tzPart.value === 'GMT' ? 'GMT+0' : tzPart.value) : ''
} catch {
return ''
}
}
</script> </script>
<template> <template>
...@@ -95,6 +108,11 @@ const onModeChange = (e: Event) => { ...@@ -95,6 +108,11 @@ const onModeChange = (e: Event) => {
<select :value="resetHour ?? 0" @change="emit('update:resetHour', Number(($event.target as HTMLSelectElement).value))" class="input py-1 text-xs w-24"> <select :value="resetHour ?? 0" @change="emit('update:resetHour', Number(($event.target as HTMLSelectElement).value))" class="input py-1 text-xs w-24">
<option v-for="h in hourOptions" :key="h" :value="h">{{ String(h).padStart(2, '0') }}:00</option> <option v-for="h in hourOptions" :key="h" :value="h">{{ String(h).padStart(2, '0') }}:00</option>
</select> </select>
<template v-if="timezoneOptions && timezoneOptions.length > 0">
<select :value="resetTimezone || 'UTC'" @change="emit('update:resetTimezone', ($event.target as HTMLSelectElement).value)" class="input py-1 text-xs w-auto">
<option v-for="tz in timezoneOptions" :key="tz" :value="tz">{{ tz }} ({{ getTimezoneOffsetLabel(tz) }})</option>
</select>
</template>
</template> </template>
<span class="text-[11px] text-gray-500 dark:text-gray-400"> <span class="text-[11px] text-gray-500 dark:text-gray-400">
<template v-if="resetMode === 'fixed'">{{ hintFixed }}</template> <template v-if="resetMode === 'fixed'">{{ hintFixed }}</template>
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import { ref, watch, computed } from 'vue' import { ref, watch, computed } from 'vue'
import { useI18n } from 'vue-i18n' import { useI18n } from 'vue-i18n'
import QuotaDimensionRow from './QuotaDimensionRow.vue' import QuotaDimensionRow from './QuotaDimensionRow.vue'
import type { QuotaThresholdType, QuotaResetMode } from '@/constants/account'
const { t } = useI18n() const { t } = useI18n()
...@@ -9,22 +10,22 @@ const props = withDefaults(defineProps<{ ...@@ -9,22 +10,22 @@ const props = withDefaults(defineProps<{
totalLimit: number | null totalLimit: number | null
dailyLimit: number | null dailyLimit: number | null
weeklyLimit: number | null weeklyLimit: number | null
dailyResetMode: 'rolling' | 'fixed' | null dailyResetMode: QuotaResetMode | null
dailyResetHour: number | null dailyResetHour: number | null
weeklyResetMode: 'rolling' | 'fixed' | null weeklyResetMode: QuotaResetMode | null
weeklyResetDay: number | null weeklyResetDay: number | null
weeklyResetHour: number | null weeklyResetHour: number | null
resetTimezone: string | null resetTimezone: string | null
quotaNotifyGlobalEnabled?: boolean quotaNotifyGlobalEnabled?: boolean
quotaNotifyDailyEnabled?: boolean | null quotaNotifyDailyEnabled?: boolean | null
quotaNotifyDailyThreshold?: number | null quotaNotifyDailyThreshold?: number | null
quotaNotifyDailyThresholdType?: string | null quotaNotifyDailyThresholdType?: QuotaThresholdType | null
quotaNotifyWeeklyEnabled?: boolean | null quotaNotifyWeeklyEnabled?: boolean | null
quotaNotifyWeeklyThreshold?: number | null quotaNotifyWeeklyThreshold?: number | null
quotaNotifyWeeklyThresholdType?: string | null quotaNotifyWeeklyThresholdType?: QuotaThresholdType | null
quotaNotifyTotalEnabled?: boolean | null quotaNotifyTotalEnabled?: boolean | null
quotaNotifyTotalThreshold?: number | null quotaNotifyTotalThreshold?: number | null
quotaNotifyTotalThresholdType?: string | null quotaNotifyTotalThresholdType?: QuotaThresholdType | null
}>(), { }>(), {
quotaNotifyGlobalEnabled: false, quotaNotifyGlobalEnabled: false,
quotaNotifyDailyEnabled: null, quotaNotifyDailyEnabled: null,
...@@ -42,21 +43,21 @@ const emit = defineEmits<{ ...@@ -42,21 +43,21 @@ const emit = defineEmits<{
'update:totalLimit': [value: number | null] 'update:totalLimit': [value: number | null]
'update:dailyLimit': [value: number | null] 'update:dailyLimit': [value: number | null]
'update:weeklyLimit': [value: number | null] 'update:weeklyLimit': [value: number | null]
'update:dailyResetMode': [value: 'rolling' | 'fixed' | null] 'update:dailyResetMode': [value: QuotaResetMode | null]
'update:dailyResetHour': [value: number | null] 'update:dailyResetHour': [value: number | null]
'update:weeklyResetMode': [value: 'rolling' | 'fixed' | null] 'update:weeklyResetMode': [value: QuotaResetMode | null]
'update:weeklyResetDay': [value: number | null] 'update:weeklyResetDay': [value: number | null]
'update:weeklyResetHour': [value: number | null] 'update:weeklyResetHour': [value: number | null]
'update:resetTimezone': [value: string | null] 'update:resetTimezone': [value: string | null]
'update:quotaNotifyDailyEnabled': [value: boolean | null] 'update:quotaNotifyDailyEnabled': [value: boolean | null]
'update:quotaNotifyDailyThreshold': [value: number | null] 'update:quotaNotifyDailyThreshold': [value: number | null]
'update:quotaNotifyDailyThresholdType': [value: string | null] 'update:quotaNotifyDailyThresholdType': [value: QuotaThresholdType | null]
'update:quotaNotifyWeeklyEnabled': [value: boolean | null] 'update:quotaNotifyWeeklyEnabled': [value: boolean | null]
'update:quotaNotifyWeeklyThreshold': [value: number | null] 'update:quotaNotifyWeeklyThreshold': [value: number | null]
'update:quotaNotifyWeeklyThresholdType': [value: string | null] 'update:quotaNotifyWeeklyThresholdType': [value: QuotaThresholdType | null]
'update:quotaNotifyTotalEnabled': [value: boolean | null] 'update:quotaNotifyTotalEnabled': [value: boolean | null]
'update:quotaNotifyTotalThreshold': [value: number | null] 'update:quotaNotifyTotalThreshold': [value: number | null]
'update:quotaNotifyTotalThresholdType': [value: string | null] 'update:quotaNotifyTotalThresholdType': [value: QuotaThresholdType | null]
}>() }>()
const enabled = computed(() => const enabled = computed(() =>
...@@ -89,11 +90,6 @@ watch(localEnabled, (val) => { ...@@ -89,11 +90,6 @@ watch(localEnabled, (val) => {
} }
}) })
// Whether any fixed mode is active (to show timezone selector)
const hasFixedMode = computed(() =>
props.dailyResetMode === 'fixed' || props.weeklyResetMode === 'fixed'
)
// Common timezone options // Common timezone options
const timezoneOptions = [ const timezoneOptions = [
'UTC', 'Asia/Shanghai', 'Asia/Tokyo', 'Asia/Seoul', 'Asia/Singapore', 'Asia/Kolkata', 'UTC', 'Asia/Shanghai', 'Asia/Tokyo', 'Asia/Seoul', 'Asia/Singapore', 'Asia/Kolkata',
...@@ -102,18 +98,6 @@ const timezoneOptions = [ ...@@ -102,18 +98,6 @@ const timezoneOptions = [
'America/Sao_Paulo', 'Australia/Sydney', 'Pacific/Auckland', 'America/Sao_Paulo', 'Australia/Sydney', 'Pacific/Auckland',
] ]
// Compute GMT offset label (e.g. "GMT+8", "GMT-5") for a given IANA timezone.
function getTimezoneOffsetLabel(tz: string): string {
try {
const dtf = new Intl.DateTimeFormat('en-US', { timeZone: tz, timeZoneName: 'shortOffset' })
const parts = dtf.formatToParts(new Date())
const tzPart = parts.find(p => p.type === 'timeZoneName')
return tzPart ? (tzPart.value === 'GMT' ? 'GMT+0' : tzPart.value) : ''
} catch {
return ''
}
}
// Hours for dropdown (0-23) // Hours for dropdown (0-23)
const hourOptions = Array.from({ length: 24 }, (_, i) => i) const hourOptions = Array.from({ length: 24 }, (_, i) => i)
...@@ -197,6 +181,7 @@ const dailyFixedHint = computed(() => ...@@ -197,6 +181,7 @@ const dailyFixedHint = computed(() =>
:hint-fixed="dailyFixedHint" :hint-fixed="dailyFixedHint"
:hour-options="hourOptions" :hour-options="hourOptions"
:day-options="dayOptions" :day-options="dayOptions"
:timezone-options="timezoneOptions"
@update:limit="emit('update:dailyLimit', $event)" @update:limit="emit('update:dailyLimit', $event)"
@update:notify-enabled="emit('update:quotaNotifyDailyEnabled', $event)" @update:notify-enabled="emit('update:quotaNotifyDailyEnabled', $event)"
@update:notify-threshold="emit('update:quotaNotifyDailyThreshold', $event)" @update:notify-threshold="emit('update:quotaNotifyDailyThreshold', $event)"
...@@ -223,6 +208,7 @@ const dailyFixedHint = computed(() => ...@@ -223,6 +208,7 @@ const dailyFixedHint = computed(() =>
:hint-fixed="weeklyFixedHint" :hint-fixed="weeklyFixedHint"
:hour-options="hourOptions" :hour-options="hourOptions"
:day-options="dayOptions" :day-options="dayOptions"
:timezone-options="timezoneOptions"
@update:limit="emit('update:weeklyLimit', $event)" @update:limit="emit('update:weeklyLimit', $event)"
@update:notify-enabled="emit('update:quotaNotifyWeeklyEnabled', $event)" @update:notify-enabled="emit('update:quotaNotifyWeeklyEnabled', $event)"
@update:notify-threshold="emit('update:quotaNotifyWeeklyThreshold', $event)" @update:notify-threshold="emit('update:quotaNotifyWeeklyThreshold', $event)"
...@@ -233,14 +219,6 @@ const dailyFixedHint = computed(() => ...@@ -233,14 +219,6 @@ const dailyFixedHint = computed(() =>
@update:reset-timezone="emit('update:resetTimezone', $event)" @update:reset-timezone="emit('update:resetTimezone', $event)"
/> />
<!-- Timezone selector (shared by daily/weekly when fixed mode is active) -->
<div v-if="hasFixedMode">
<label class="input-label">{{ t('admin.accounts.quotaResetTimezone') }}</label>
<select :value="resetTimezone || 'UTC'" @change="emit('update:resetTimezone', ($event.target as HTMLSelectElement).value)" class="input text-sm">
<option v-for="tz in timezoneOptions" :key="tz" :value="tz">{{ tz }} ({{ getTimezoneOffsetLabel(tz) }})</option>
</select>
</div>
<!-- Total quota --> <!-- Total quota -->
<QuotaDimensionRow <QuotaDimensionRow
dim="total" dim="total"
......
<script setup lang="ts"> <script setup lang="ts">
import { QUOTA_THRESHOLD_TYPE_FIXED, QUOTA_THRESHOLD_TYPE_PERCENTAGE } from '@/constants/account' import { QUOTA_THRESHOLD_TYPE_FIXED, QUOTA_THRESHOLD_TYPE_PERCENTAGE, type QuotaThresholdType } from '@/constants/account'
defineProps<{ defineProps<{
enabled: boolean | null enabled: boolean | null
threshold: number | null threshold: number | null
thresholdType: string | null // "fixed" (default) or "percentage" thresholdType: QuotaThresholdType | null
}>() }>()
const emit = defineEmits<{ const emit = defineEmits<{
'update:enabled': [value: boolean | null] 'update:enabled': [value: boolean | null]
'update:threshold': [value: number | null] 'update:threshold': [value: number | null]
'update:thresholdType': [value: string | null] 'update:thresholdType': [value: QuotaThresholdType | null]
}>() }>()
</script> </script>
...@@ -43,7 +43,7 @@ const emit = defineEmits<{ ...@@ -43,7 +43,7 @@ const emit = defineEmits<{
/> />
<select <select
:value="thresholdType || QUOTA_THRESHOLD_TYPE_FIXED" :value="thresholdType || QUOTA_THRESHOLD_TYPE_FIXED"
@change="emit('update:thresholdType', ($event.target as HTMLSelectElement).value)" @change="emit('update:thresholdType', ($event.target as HTMLSelectElement).value as QuotaThresholdType)"
class="input py-1 text-xs w-[4.5rem] flex-shrink-0 text-center" class="input py-1 text-xs w-[4.5rem] flex-shrink-0 text-center"
> >
<option :value="QUOTA_THRESHOLD_TYPE_FIXED">$</option> <option :value="QUOTA_THRESHOLD_TYPE_FIXED">$</option>
......
...@@ -313,10 +313,6 @@ ...@@ -313,10 +313,6 @@
<span class="text-gray-400">{{ t('usage.rate') }}</span> <span class="text-gray-400">{{ t('usage.rate') }}</span>
<span class="font-semibold text-blue-400">{{ formatMultiplier(tooltipData?.rate_multiplier || 1) }}x</span> <span class="font-semibold text-blue-400">{{ formatMultiplier(tooltipData?.rate_multiplier || 1) }}x</span>
</div> </div>
<div class="flex items-center justify-between gap-6">
<span class="text-gray-400">{{ t('usage.accountMultiplier') }}</span>
<span class="font-semibold text-blue-400">{{ formatMultiplier(tooltipData?.account_rate_multiplier ?? 1) }}x</span>
</div>
<div class="flex items-center justify-between gap-6"> <div class="flex items-center justify-between gap-6">
<span class="text-gray-400">{{ t('usage.original') }}</span> <span class="text-gray-400">{{ t('usage.original') }}</span>
<span class="font-medium text-white">${{ tooltipData?.total_cost?.toFixed(6) || '0.000000' }}</span> <span class="font-medium text-white">${{ tooltipData?.total_cost?.toFixed(6) || '0.000000' }}</span>
...@@ -325,7 +321,12 @@ ...@@ -325,7 +321,12 @@
<span class="text-gray-400">{{ t('usage.userBilled') }}</span> <span class="text-gray-400">{{ t('usage.userBilled') }}</span>
<span class="font-semibold text-green-400">${{ tooltipData?.actual_cost?.toFixed(6) || '0.000000' }}</span> <span class="font-semibold text-green-400">${{ tooltipData?.actual_cost?.toFixed(6) || '0.000000' }}</span>
</div> </div>
<!-- Account billing (separated from user billing) -->
<div class="flex items-center justify-between gap-6 border-t border-gray-700 pt-1.5"> <div class="flex items-center justify-between gap-6 border-t border-gray-700 pt-1.5">
<span class="text-gray-400">{{ t('usage.accountMultiplier') }}</span>
<span class="font-semibold text-blue-400">{{ formatMultiplier(tooltipData?.account_rate_multiplier ?? 1) }}x</span>
</div>
<div class="flex items-center justify-between gap-6">
<span class="text-gray-400">{{ t('usage.accountBilled') }}</span> <span class="text-gray-400">{{ t('usage.accountBilled') }}</span>
<span class="font-semibold text-green-400"> <span class="font-semibold text-green-400">
${{ accountBilled({ ${{ accountBilled({
...@@ -355,7 +356,8 @@ import { getBillingModeLabel, getBillingModeBadgeClass, BILLING_MODE_TOKEN, BILL ...@@ -355,7 +356,8 @@ import { getBillingModeLabel, getBillingModeBadgeClass, BILLING_MODE_TOKEN, BILL
/** Compute the account-billed cost for display: (account_stats_cost ?? total_cost) * rate_multiplier */ /** Compute the account-billed cost for display: (account_stats_cost ?? total_cost) * rate_multiplier */
function accountBilled(row: { total_cost?: number | null; account_stats_cost?: number | null; account_rate_multiplier?: number | null }): number { function accountBilled(row: { total_cost?: number | null; account_stats_cost?: number | null; account_rate_multiplier?: number | null }): number {
const base = row.account_stats_cost != null ? row.account_stats_cost : (row.total_cost ?? 0) const base = row.account_stats_cost != null ? row.account_stats_cost : (row.total_cost ?? 0)
return base * (row.account_rate_multiplier ?? 1) const result = base * (row.account_rate_multiplier ?? 1)
return Number.isNaN(result) ? 0 : result
} }
import DataTable from '@/components/common/DataTable.vue' import DataTable from '@/components/common/DataTable.vue'
......
import { reactive, ref } from 'vue' import { reactive, ref } from 'vue'
import { adminAPI } from '@/api/admin' import { adminAPI } from '@/api/admin'
import { QUOTA_THRESHOLD_TYPE_FIXED } from '@/constants/account' import { QUOTA_THRESHOLD_TYPE_FIXED, type QuotaThresholdType } from '@/constants/account'
export const QUOTA_NOTIFY_DIMS = ['daily', 'weekly', 'total'] as const export const QUOTA_NOTIFY_DIMS = ['daily', 'weekly', 'total'] as const
export type QuotaNotifyDim = (typeof QUOTA_NOTIFY_DIMS)[number] export type QuotaNotifyDim = (typeof QUOTA_NOTIFY_DIMS)[number]
...@@ -8,7 +8,7 @@ export type QuotaNotifyDim = (typeof QUOTA_NOTIFY_DIMS)[number] ...@@ -8,7 +8,7 @@ export type QuotaNotifyDim = (typeof QUOTA_NOTIFY_DIMS)[number]
interface DimState { interface DimState {
enabled: boolean | null enabled: boolean | null
threshold: number | null threshold: number | null
thresholdType: string | null thresholdType: QuotaThresholdType | null
} }
export function useQuotaNotifyState() { export function useQuotaNotifyState() {
...@@ -34,7 +34,7 @@ export function useQuotaNotifyState() { ...@@ -34,7 +34,7 @@ export function useQuotaNotifyState() {
for (const d of QUOTA_NOTIFY_DIMS) { for (const d of QUOTA_NOTIFY_DIMS) {
state[d].enabled = (extra?.[`quota_notify_${d}_enabled`] as boolean) ?? null state[d].enabled = (extra?.[`quota_notify_${d}_enabled`] as boolean) ?? null
state[d].threshold = (extra?.[`quota_notify_${d}_threshold`] as number) ?? null state[d].threshold = (extra?.[`quota_notify_${d}_threshold`] as number) ?? null
state[d].thresholdType = (extra?.[`quota_notify_${d}_threshold_type`] as string) ?? null state[d].thresholdType = (extra?.[`quota_notify_${d}_threshold_type`] as QuotaThresholdType) ?? null
} }
} }
......
...@@ -8,3 +8,8 @@ export type WebSearchMode = typeof WEB_SEARCH_MODE_DEFAULT | typeof WEB_SEARCH_M ...@@ -8,3 +8,8 @@ export type WebSearchMode = typeof WEB_SEARCH_MODE_DEFAULT | typeof WEB_SEARCH_M
export const QUOTA_THRESHOLD_TYPE_FIXED = 'fixed' as const export const QUOTA_THRESHOLD_TYPE_FIXED = 'fixed' as const
export const QUOTA_THRESHOLD_TYPE_PERCENTAGE = 'percentage' as const export const QUOTA_THRESHOLD_TYPE_PERCENTAGE = 'percentage' as const
export type QuotaThresholdType = typeof QUOTA_THRESHOLD_TYPE_FIXED | typeof QUOTA_THRESHOLD_TYPE_PERCENTAGE export type QuotaThresholdType = typeof QUOTA_THRESHOLD_TYPE_FIXED | typeof QUOTA_THRESHOLD_TYPE_PERCENTAGE
/** Quota reset mode values */
export const QUOTA_RESET_MODE_ROLLING = 'rolling' as const
export const QUOTA_RESET_MODE_FIXED = 'fixed' as const
export type QuotaResetMode = typeof QUOTA_RESET_MODE_ROLLING | typeof QUOTA_RESET_MODE_FIXED
...@@ -166,8 +166,8 @@ ...@@ -166,8 +166,8 @@
class="channel-tab group" class="channel-tab group"
:class="activeTab === section.platform ? 'channel-tab-active' : 'channel-tab-inactive'" :class="activeTab === section.platform ? 'channel-tab-active' : 'channel-tab-inactive'"
> >
<PlatformIcon :platform="section.platform" size="xs" :class="getPlatformTextColor(section.platform)" /> <PlatformIcon :platform="section.platform" size="xs" :class="platformTextClass(section.platform)" />
<span :class="getPlatformTextColor(section.platform)">{{ t('admin.groups.platforms.' + section.platform, section.platform) }}</span> <span :class="platformTextClass(section.platform)">{{ t('admin.groups.platforms.' + section.platform, section.platform) }}</span>
</button> </button>
</div> </div>
...@@ -246,8 +246,8 @@ ...@@ -246,8 +246,8 @@
class="h-3.5 w-3.5 rounded border-gray-300 text-primary-600 focus:ring-primary-500" class="h-3.5 w-3.5 rounded border-gray-300 text-primary-600 focus:ring-primary-500"
@change="togglePlatform(p)" @change="togglePlatform(p)"
/> />
<PlatformIcon :platform="p" size="xs" :class="getPlatformTextColor(p)" /> <PlatformIcon :platform="p" size="xs" :class="platformTextClass(p)" />
<span :class="getPlatformTextColor(p)">{{ t('admin.groups.platforms.' + p, p) }}</span> <span :class="platformTextClass(p)">{{ t('admin.groups.platforms.' + p, p) }}</span>
</label> </label>
</div> </div>
</div> </div>
...@@ -310,9 +310,9 @@ ...@@ -310,9 +310,9 @@
class="h-3 w-3 rounded border-gray-300 text-primary-600 focus:ring-primary-500" class="h-3 w-3 rounded border-gray-300 text-primary-600 focus:ring-primary-500"
@change="toggleGroupInSection(sIdx, group.id)" @change="toggleGroupInSection(sIdx, group.id)"
/> />
<span :class="['font-medium', getPlatformTextColor(group.platform)]">{{ group.name }}</span> <span :class="['font-medium', platformTextClass(group.platform)]">{{ group.name }}</span>
<span <span
:class="['rounded-full px-1 py-0 text-[10px]', getRateBadgeClass(group.platform)]" :class="['rounded-full px-1 py-0 text-[10px]', platformBadgeLightClass(group.platform)]"
>{{ group.rate_multiplier }}x</span> >{{ group.rate_multiplier }}x</span>
<span class="text-[10px] text-gray-400">{{ group.account_count || 0 }}</span> <span class="text-[10px] text-gray-400">{{ group.account_count || 0 }}</span>
<span <span
...@@ -363,7 +363,7 @@ ...@@ -363,7 +363,7 @@
:value="srcModel" :value="srcModel"
type="text" type="text"
class="input flex-1 text-xs" class="input flex-1 text-xs"
:class="getPlatformTextColor(section.platform)" :class="platformTextClass(section.platform)"
:placeholder="t('admin.channels.form.mappingSource', 'Source model')" :placeholder="t('admin.channels.form.mappingSource', 'Source model')"
@change="renameMappingKey(sIdx, srcModel, ($event.target as HTMLInputElement).value)" @change="renameMappingKey(sIdx, srcModel, ($event.target as HTMLInputElement).value)"
/> />
...@@ -372,7 +372,7 @@ ...@@ -372,7 +372,7 @@
:value="section.model_mapping[srcModel]" :value="section.model_mapping[srcModel]"
type="text" type="text"
class="input flex-1 text-xs" class="input flex-1 text-xs"
:class="getPlatformTextColor(section.platform)" :class="platformTextClass(section.platform)"
:placeholder="t('admin.channels.form.mappingTarget', 'Target model')" :placeholder="t('admin.channels.form.mappingTarget', 'Target model')"
@input="section.model_mapping[srcModel] = ($event.target as HTMLInputElement).value" @input="section.model_mapping[srcModel] = ($event.target as HTMLInputElement).value"
/> />
...@@ -464,7 +464,7 @@ ...@@ -464,7 +464,7 @@
: 'border-gray-200 hover:bg-gray-50 dark:border-dark-600 dark:hover:bg-dark-700'" : 'border-gray-200 hover:bg-gray-50 dark:border-dark-600 dark:hover:bg-dark-700'"
> >
<input type="checkbox" :checked="rule.group_ids.includes(gid)" class="h-3 w-3 rounded border-gray-300 text-primary-600 focus:ring-primary-500" @change="rule.group_ids.includes(gid) ? rule.group_ids.splice(rule.group_ids.indexOf(gid), 1) : rule.group_ids.push(gid)" /> <input type="checkbox" :checked="rule.group_ids.includes(gid)" class="h-3 w-3 rounded border-gray-300 text-primary-600 focus:ring-primary-500" @change="rule.group_ids.includes(gid) ? rule.group_ids.splice(rule.group_ids.indexOf(gid), 1) : rule.group_ids.push(gid)" />
<span>{{ getGroupNameById(gid) }}</span> <span :class="['font-medium', platformTextClass(section.platform)]">{{ getGroupNameById(gid) }}</span>
</label> </label>
</div> </div>
<p v-if="section.group_ids.length === 0" class="mt-1 text-xs text-gray-400"> <p v-if="section.group_ids.length === 0" class="mt-1 text-xs text-gray-400">
...@@ -481,7 +481,7 @@ ...@@ -481,7 +481,7 @@
:key="accountId" :key="accountId"
class="inline-flex items-center gap-1 rounded-md border border-primary-300 bg-primary-50 px-2 py-0.5 text-xs dark:border-primary-700 dark:bg-primary-900/20" class="inline-flex items-center gap-1 rounded-md border border-primary-300 bg-primary-50 px-2 py-0.5 text-xs dark:border-primary-700 dark:bg-primary-900/20"
> >
<span>{{ getRuleAccountLabel(accountId) }}</span> <span :class="['font-medium', platformTextClass(section.platform)]">{{ getRuleAccountLabel(accountId) }}</span>
<button type="button" @click="removeRuleAccount(rule, accountId)" class="text-gray-400 hover:text-red-500"> <button type="button" @click="removeRuleAccount(rule, accountId)" class="text-gray-400 hover:text-red-500">
<Icon name="x" size="xs" /> <Icon name="x" size="xs" />
</button> </button>
...@@ -595,7 +595,7 @@ import type { PricingFormEntry } from '@/components/admin/channel/types' ...@@ -595,7 +595,7 @@ import type { PricingFormEntry } from '@/components/admin/channel/types'
import { mTokToPerToken, perTokenToMTok, apiIntervalsToForm, formIntervalsToAPI, findModelConflict, validateIntervals } from '@/components/admin/channel/types' import { mTokToPerToken, perTokenToMTok, apiIntervalsToForm, formIntervalsToAPI, findModelConflict, validateIntervals } from '@/components/admin/channel/types'
import type { AdminGroup, GroupPlatform } from '@/types' import type { AdminGroup, GroupPlatform } from '@/types'
import type { Column } from '@/components/common/types' import type { Column } from '@/components/common/types'
import { platformTextClass } from '@/utils/platformColors' import { platformTextClass, platformBadgeLightClass } from '@/utils/platformColors'
import AppLayout from '@/components/layout/AppLayout.vue' import AppLayout from '@/components/layout/AppLayout.vue'
import TablePageLayout from '@/components/layout/TablePageLayout.vue' import TablePageLayout from '@/components/layout/TablePageLayout.vue'
import DataTable from '@/components/common/DataTable.vue' import DataTable from '@/components/common/DataTable.vue'
...@@ -720,26 +720,6 @@ let abortController: AbortController | null = null ...@@ -720,26 +720,6 @@ let abortController: AbortController | null = null
// ── Platform config ── // ── Platform config ──
const platformOrder: GroupPlatform[] = ['anthropic', 'openai', 'gemini', 'antigravity'] const platformOrder: GroupPlatform[] = ['anthropic', 'openai', 'gemini', 'antigravity']
function getPlatformTextColor(platform: string): string {
switch (platform) {
case 'anthropic': return 'text-orange-600 dark:text-orange-400'
case 'openai': return 'text-emerald-600 dark:text-emerald-400'
case 'gemini': return 'text-blue-600 dark:text-blue-400'
case 'antigravity': return 'text-purple-600 dark:text-purple-400'
default: return 'text-gray-600 dark:text-gray-400'
}
}
function getRateBadgeClass(platform: string): string {
switch (platform) {
case 'anthropic': return 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-400'
case 'openai': return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-900/30 dark:text-emerald-400'
case 'gemini': return 'bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-400'
case 'antigravity': return 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-400'
default: return 'bg-gray-100 text-gray-700 dark:bg-gray-900/30 dark:text-gray-400'
}
}
// ── Helpers ── // ── Helpers ──
function formatDate(value: string): string { function formatDate(value: string): string {
if (!value) return '-' if (!value) return '-'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment