Commit c86d445c authored by IanShaw027's avatar IanShaw027
Browse files

fix(frontend): sync with main and finalize i18n & component optimizations

parents 6c036d7b e78c8646
package service package service
import "testing" import (
"context"
func TestInferGoogleOneTier(t *testing.T) { "net/url"
tests := []struct { "strings"
name string "testing"
storageBytes int64
expectedTier string "github.com/Wei-Shaw/sub2api/internal/config"
}{ "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
{"Negative storage", -1, TierGoogleOneUnknown}, )
{"Zero storage", 0, TierGoogleOneUnknown},
func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) {
// Free tier boundary (15GB) t.Parallel()
{"Below free tier", 10 * GB, TierGoogleOneUnknown},
{"Just below free tier", StorageTierFree - 1, TierGoogleOneUnknown}, type testCase struct {
{"Free tier (15GB)", StorageTierFree, TierFree}, name string
cfg *config.Config
// Basic tier boundary (100GB) oauthType string
{"Between free and basic", 50 * GB, TierFree}, projectID string
{"Just below basic tier", StorageTierBasic - 1, TierFree}, wantClientID string
{"Basic tier (100GB)", StorageTierBasic, TierGoogleOneBasic}, wantRedirect string
wantScope string
// Standard tier boundary (200GB) wantProjectID string
{"Between basic and standard", 150 * GB, TierGoogleOneBasic}, wantErrSubstr string
{"Just below standard tier", StorageTierStandard - 1, TierGoogleOneBasic}, }
{"Standard tier (200GB)", StorageTierStandard, TierGoogleOneStandard},
tests := []testCase{
// AI Premium tier boundary (2TB) {
{"Between standard and premium", 1 * TB, TierGoogleOneStandard}, name: "google_one uses built-in client when not configured and redirects to upstream",
{"Just below AI Premium tier", StorageTierAIPremium - 1, TierGoogleOneStandard}, cfg: &config.Config{
{"AI Premium tier (2TB)", StorageTierAIPremium, TierAIPremium}, Gemini: config.GeminiConfig{
OAuth: config.GeminiOAuthConfig{},
// Unlimited tier boundary (> 100TB) },
{"Between premium and unlimited", 50 * TB, TierAIPremium}, },
{"At unlimited threshold (100TB)", StorageTierUnlimited, TierAIPremium}, oauthType: "google_one",
{"Unlimited tier (100TB+)", StorageTierUnlimited + 1, TierGoogleOneUnlimited}, wantClientID: geminicli.GeminiCLIOAuthClientID,
{"Unlimited tier (101TB+)", 101 * TB, TierGoogleOneUnlimited}, wantRedirect: geminicli.GeminiCLIRedirectURI,
{"Very large storage", 1000 * TB, TierGoogleOneUnlimited}, wantScope: geminicli.DefaultCodeAssistScopes,
wantProjectID: "",
},
{
name: "google_one uses custom client when configured and redirects to localhost",
cfg: &config.Config{
Gemini: config.GeminiConfig{
OAuth: config.GeminiOAuthConfig{
ClientID: "custom-client-id",
ClientSecret: "custom-client-secret",
},
},
},
oauthType: "google_one",
wantClientID: "custom-client-id",
wantRedirect: geminicli.AIStudioOAuthRedirectURI,
wantScope: geminicli.DefaultGoogleOneScopes,
wantProjectID: "",
},
{
name: "code_assist always forces built-in client even when custom client configured",
cfg: &config.Config{
Gemini: config.GeminiConfig{
OAuth: config.GeminiOAuthConfig{
ClientID: "custom-client-id",
ClientSecret: "custom-client-secret",
},
},
},
oauthType: "code_assist",
projectID: "my-gcp-project",
wantClientID: geminicli.GeminiCLIOAuthClientID,
wantRedirect: geminicli.GeminiCLIRedirectURI,
wantScope: geminicli.DefaultCodeAssistScopes,
wantProjectID: "my-gcp-project",
},
{
name: "ai_studio requires custom client",
cfg: &config.Config{
Gemini: config.GeminiConfig{
OAuth: config.GeminiOAuthConfig{},
},
},
oauthType: "ai_studio",
wantErrSubstr: "AI Studio OAuth requires a custom OAuth Client",
},
} }
for _, tt := range tests { for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result := inferGoogleOneTier(tt.storageBytes) t.Parallel()
if result != tt.expectedTier {
t.Errorf("inferGoogleOneTier(%d) = %s, want %s", svc := NewGeminiOAuthService(nil, nil, nil, tt.cfg)
tt.storageBytes, result, tt.expectedTier) got, err := svc.GenerateAuthURL(context.Background(), nil, "https://example.com/auth/callback", tt.projectID, tt.oauthType, "")
if tt.wantErrSubstr != "" {
if err == nil {
t.Fatalf("expected error containing %q, got nil", tt.wantErrSubstr)
}
if !strings.Contains(err.Error(), tt.wantErrSubstr) {
t.Fatalf("expected error containing %q, got: %v", tt.wantErrSubstr, err)
}
return
}
if err != nil {
t.Fatalf("GenerateAuthURL returned error: %v", err)
}
parsed, err := url.Parse(got.AuthURL)
if err != nil {
t.Fatalf("failed to parse auth_url: %v", err)
}
q := parsed.Query()
if gotState := q.Get("state"); gotState != got.State {
t.Fatalf("state mismatch: query=%q result=%q", gotState, got.State)
}
if gotClientID := q.Get("client_id"); gotClientID != tt.wantClientID {
t.Fatalf("client_id mismatch: got=%q want=%q", gotClientID, tt.wantClientID)
}
if gotRedirect := q.Get("redirect_uri"); gotRedirect != tt.wantRedirect {
t.Fatalf("redirect_uri mismatch: got=%q want=%q", gotRedirect, tt.wantRedirect)
}
if gotScope := q.Get("scope"); gotScope != tt.wantScope {
t.Fatalf("scope mismatch: got=%q want=%q", gotScope, tt.wantScope)
}
if gotProjectID := q.Get("project_id"); gotProjectID != tt.wantProjectID {
t.Fatalf("project_id mismatch: got=%q want=%q", gotProjectID, tt.wantProjectID)
} }
}) })
} }
......
...@@ -20,13 +20,24 @@ const ( ...@@ -20,13 +20,24 @@ const (
geminiModelFlash geminiModelClass = "flash" geminiModelFlash geminiModelClass = "flash"
) )
type GeminiDailyQuota struct { type GeminiQuota struct {
ProRPD int64 // SharedRPD is a shared requests-per-day pool across models.
FlashRPD int64 // When SharedRPD > 0, callers should treat ProRPD/FlashRPD as not applicable for daily quota checks.
SharedRPD int64 `json:"shared_rpd,omitempty"`
// SharedRPM is a shared requests-per-minute pool across models.
// When SharedRPM > 0, callers should treat ProRPM/FlashRPM as not applicable for minute quota checks.
SharedRPM int64 `json:"shared_rpm,omitempty"`
// Per-model quotas (AI Studio / API key).
// A value of -1 means "unlimited" (pay-as-you-go).
ProRPD int64 `json:"pro_rpd,omitempty"`
ProRPM int64 `json:"pro_rpm,omitempty"`
FlashRPD int64 `json:"flash_rpd,omitempty"`
FlashRPM int64 `json:"flash_rpm,omitempty"`
} }
type GeminiTierPolicy struct { type GeminiTierPolicy struct {
Quota GeminiDailyQuota Quota GeminiQuota
Cooldown time.Duration Cooldown time.Duration
} }
...@@ -45,10 +56,27 @@ type GeminiUsageTotals struct { ...@@ -45,10 +56,27 @@ type GeminiUsageTotals struct {
const geminiQuotaCacheTTL = time.Minute const geminiQuotaCacheTTL = time.Minute
type geminiQuotaOverrides struct { type geminiQuotaOverridesV1 struct {
Tiers map[string]config.GeminiTierQuotaConfig `json:"tiers"` Tiers map[string]config.GeminiTierQuotaConfig `json:"tiers"`
} }
type geminiQuotaOverridesV2 struct {
QuotaRules map[string]geminiQuotaRuleOverride `json:"quota_rules"`
}
type geminiQuotaRuleOverride struct {
SharedRPD *int64 `json:"shared_rpd,omitempty"`
SharedRPM *int64 `json:"rpm,omitempty"`
GeminiPro *geminiModelQuotaOverride `json:"gemini_pro,omitempty"`
GeminiFlash *geminiModelQuotaOverride `json:"gemini_flash,omitempty"`
Desc *string `json:"desc,omitempty"`
}
type geminiModelQuotaOverride struct {
RPD *int64 `json:"rpd,omitempty"`
RPM *int64 `json:"rpm,omitempty"`
}
type GeminiQuotaService struct { type GeminiQuotaService struct {
cfg *config.Config cfg *config.Config
settingRepo SettingRepository settingRepo SettingRepository
...@@ -82,11 +110,17 @@ func (s *GeminiQuotaService) Policy(ctx context.Context) *GeminiQuotaPolicy { ...@@ -82,11 +110,17 @@ func (s *GeminiQuotaService) Policy(ctx context.Context) *GeminiQuotaPolicy {
if s.cfg != nil { if s.cfg != nil {
policy.ApplyOverrides(s.cfg.Gemini.Quota.Tiers) policy.ApplyOverrides(s.cfg.Gemini.Quota.Tiers)
if strings.TrimSpace(s.cfg.Gemini.Quota.Policy) != "" { if strings.TrimSpace(s.cfg.Gemini.Quota.Policy) != "" {
var overrides geminiQuotaOverrides raw := []byte(s.cfg.Gemini.Quota.Policy)
if err := json.Unmarshal([]byte(s.cfg.Gemini.Quota.Policy), &overrides); err != nil { var overridesV2 geminiQuotaOverridesV2
log.Printf("gemini quota: parse config policy failed: %v", err) if err := json.Unmarshal(raw, &overridesV2); err == nil && len(overridesV2.QuotaRules) > 0 {
policy.ApplyQuotaRulesOverrides(overridesV2.QuotaRules)
} else { } else {
policy.ApplyOverrides(overrides.Tiers) var overridesV1 geminiQuotaOverridesV1
if err := json.Unmarshal(raw, &overridesV1); err != nil {
log.Printf("gemini quota: parse config policy failed: %v", err)
} else {
policy.ApplyOverrides(overridesV1.Tiers)
}
} }
} }
} }
...@@ -96,11 +130,17 @@ func (s *GeminiQuotaService) Policy(ctx context.Context) *GeminiQuotaPolicy { ...@@ -96,11 +130,17 @@ func (s *GeminiQuotaService) Policy(ctx context.Context) *GeminiQuotaPolicy {
if err != nil && !errors.Is(err, ErrSettingNotFound) { if err != nil && !errors.Is(err, ErrSettingNotFound) {
log.Printf("gemini quota: load setting failed: %v", err) log.Printf("gemini quota: load setting failed: %v", err)
} else if strings.TrimSpace(value) != "" { } else if strings.TrimSpace(value) != "" {
var overrides geminiQuotaOverrides raw := []byte(value)
if err := json.Unmarshal([]byte(value), &overrides); err != nil { var overridesV2 geminiQuotaOverridesV2
log.Printf("gemini quota: parse setting failed: %v", err) if err := json.Unmarshal(raw, &overridesV2); err == nil && len(overridesV2.QuotaRules) > 0 {
policy.ApplyQuotaRulesOverrides(overridesV2.QuotaRules)
} else { } else {
policy.ApplyOverrides(overrides.Tiers) var overridesV1 geminiQuotaOverridesV1
if err := json.Unmarshal(raw, &overridesV1); err != nil {
log.Printf("gemini quota: parse setting failed: %v", err)
} else {
policy.ApplyOverrides(overridesV1.Tiers)
}
} }
} }
} }
...@@ -113,12 +153,20 @@ func (s *GeminiQuotaService) Policy(ctx context.Context) *GeminiQuotaPolicy { ...@@ -113,12 +153,20 @@ func (s *GeminiQuotaService) Policy(ctx context.Context) *GeminiQuotaPolicy {
return policy return policy
} }
func (s *GeminiQuotaService) QuotaForAccount(ctx context.Context, account *Account) (GeminiDailyQuota, bool) { func (s *GeminiQuotaService) QuotaForAccount(ctx context.Context, account *Account) (GeminiQuota, bool) {
if account == nil || !account.IsGeminiCodeAssist() { if account == nil || account.Platform != PlatformGemini {
return GeminiDailyQuota{}, false return GeminiQuota{}, false
}
// Map (oauth_type + tier_id) to a canonical policy tier key.
// This keeps the policy table stable even if upstream tier_id strings vary.
tierKey := geminiQuotaTierKeyForAccount(account)
if tierKey == "" {
return GeminiQuota{}, false
} }
policy := s.Policy(ctx) policy := s.Policy(ctx)
return policy.QuotaForTier(account.GeminiTierID()) return policy.QuotaForTier(tierKey)
} }
func (s *GeminiQuotaService) CooldownForTier(ctx context.Context, tierID string) time.Duration { func (s *GeminiQuotaService) CooldownForTier(ctx context.Context, tierID string) time.Duration {
...@@ -126,12 +174,36 @@ func (s *GeminiQuotaService) CooldownForTier(ctx context.Context, tierID string) ...@@ -126,12 +174,36 @@ func (s *GeminiQuotaService) CooldownForTier(ctx context.Context, tierID string)
return policy.CooldownForTier(tierID) return policy.CooldownForTier(tierID)
} }
func (s *GeminiQuotaService) CooldownForAccount(ctx context.Context, account *Account) time.Duration {
if s == nil || account == nil || account.Platform != PlatformGemini {
return 5 * time.Minute
}
tierKey := geminiQuotaTierKeyForAccount(account)
if strings.TrimSpace(tierKey) == "" {
return 5 * time.Minute
}
return s.CooldownForTier(ctx, tierKey)
}
func newGeminiQuotaPolicy() *GeminiQuotaPolicy { func newGeminiQuotaPolicy() *GeminiQuotaPolicy {
return &GeminiQuotaPolicy{ return &GeminiQuotaPolicy{
tiers: map[string]GeminiTierPolicy{ tiers: map[string]GeminiTierPolicy{
"LEGACY": {Quota: GeminiDailyQuota{ProRPD: 50, FlashRPD: 1500}, Cooldown: 30 * time.Minute}, // --- AI Studio / API Key (per-model) ---
"PRO": {Quota: GeminiDailyQuota{ProRPD: 1500, FlashRPD: 4000}, Cooldown: 5 * time.Minute}, // aistudio_free:
"ULTRA": {Quota: GeminiDailyQuota{ProRPD: 2000, FlashRPD: 0}, Cooldown: 5 * time.Minute}, // - gemini_pro: 50 RPD / 2 RPM
// - gemini_flash: 1500 RPD / 15 RPM
GeminiTierAIStudioFree: {Quota: GeminiQuota{ProRPD: 50, ProRPM: 2, FlashRPD: 1500, FlashRPM: 15}, Cooldown: 30 * time.Minute},
// aistudio_paid: -1 means "unlimited/pay-as-you-go" for RPD.
GeminiTierAIStudioPaid: {Quota: GeminiQuota{ProRPD: -1, ProRPM: 1000, FlashRPD: -1, FlashRPM: 2000}, Cooldown: 5 * time.Minute},
// --- Google One (shared pool) ---
GeminiTierGoogleOneFree: {Quota: GeminiQuota{SharedRPD: 1000, SharedRPM: 60}, Cooldown: 30 * time.Minute},
GeminiTierGoogleAIPro: {Quota: GeminiQuota{SharedRPD: 1500, SharedRPM: 120}, Cooldown: 5 * time.Minute},
GeminiTierGoogleAIUltra: {Quota: GeminiQuota{SharedRPD: 2000, SharedRPM: 120}, Cooldown: 5 * time.Minute},
// --- GCP Code Assist (shared pool) ---
GeminiTierGCPStandard: {Quota: GeminiQuota{SharedRPD: 1500, SharedRPM: 120}, Cooldown: 5 * time.Minute},
GeminiTierGCPEnterprise: {Quota: GeminiQuota{SharedRPD: 2000, SharedRPM: 120}, Cooldown: 5 * time.Minute},
}, },
} }
} }
...@@ -149,11 +221,22 @@ func (p *GeminiQuotaPolicy) ApplyOverrides(tiers map[string]config.GeminiTierQuo ...@@ -149,11 +221,22 @@ func (p *GeminiQuotaPolicy) ApplyOverrides(tiers map[string]config.GeminiTierQuo
if !ok { if !ok {
policy = GeminiTierPolicy{Cooldown: 5 * time.Minute} policy = GeminiTierPolicy{Cooldown: 5 * time.Minute}
} }
// Backward-compatible overrides:
// - If the tier uses shared quota, interpret pro_rpd as shared_rpd.
// - Otherwise apply per-model overrides.
if override.ProRPD != nil { if override.ProRPD != nil {
policy.Quota.ProRPD = clampGeminiQuotaInt64(*override.ProRPD) if policy.Quota.SharedRPD > 0 {
policy.Quota.SharedRPD = clampGeminiQuotaInt64WithUnlimited(*override.ProRPD)
} else {
policy.Quota.ProRPD = clampGeminiQuotaInt64WithUnlimited(*override.ProRPD)
}
} }
if override.FlashRPD != nil { if override.FlashRPD != nil {
policy.Quota.FlashRPD = clampGeminiQuotaInt64(*override.FlashRPD) if policy.Quota.SharedRPD > 0 {
// No separate flash RPD for shared tiers.
} else {
policy.Quota.FlashRPD = clampGeminiQuotaInt64WithUnlimited(*override.FlashRPD)
}
} }
if override.CooldownMinutes != nil { if override.CooldownMinutes != nil {
minutes := clampGeminiQuotaInt(*override.CooldownMinutes) minutes := clampGeminiQuotaInt(*override.CooldownMinutes)
...@@ -163,10 +246,51 @@ func (p *GeminiQuotaPolicy) ApplyOverrides(tiers map[string]config.GeminiTierQuo ...@@ -163,10 +246,51 @@ func (p *GeminiQuotaPolicy) ApplyOverrides(tiers map[string]config.GeminiTierQuo
} }
} }
func (p *GeminiQuotaPolicy) QuotaForTier(tierID string) (GeminiDailyQuota, bool) { func (p *GeminiQuotaPolicy) ApplyQuotaRulesOverrides(rules map[string]geminiQuotaRuleOverride) {
if p == nil || len(rules) == 0 {
return
}
for rawID, override := range rules {
tierID := normalizeGeminiTierID(rawID)
if tierID == "" {
continue
}
policy, ok := p.tiers[tierID]
if !ok {
policy = GeminiTierPolicy{Cooldown: 5 * time.Minute}
}
if override.SharedRPD != nil {
policy.Quota.SharedRPD = clampGeminiQuotaInt64WithUnlimited(*override.SharedRPD)
}
if override.SharedRPM != nil {
policy.Quota.SharedRPM = clampGeminiQuotaRPM(*override.SharedRPM)
}
if override.GeminiPro != nil {
if override.GeminiPro.RPD != nil {
policy.Quota.ProRPD = clampGeminiQuotaInt64WithUnlimited(*override.GeminiPro.RPD)
}
if override.GeminiPro.RPM != nil {
policy.Quota.ProRPM = clampGeminiQuotaRPM(*override.GeminiPro.RPM)
}
}
if override.GeminiFlash != nil {
if override.GeminiFlash.RPD != nil {
policy.Quota.FlashRPD = clampGeminiQuotaInt64WithUnlimited(*override.GeminiFlash.RPD)
}
if override.GeminiFlash.RPM != nil {
policy.Quota.FlashRPM = clampGeminiQuotaRPM(*override.GeminiFlash.RPM)
}
}
p.tiers[tierID] = policy
}
}
func (p *GeminiQuotaPolicy) QuotaForTier(tierID string) (GeminiQuota, bool) {
policy, ok := p.policyForTier(tierID) policy, ok := p.policyForTier(tierID)
if !ok { if !ok {
return GeminiDailyQuota{}, false return GeminiQuota{}, false
} }
return policy.Quota, true return policy.Quota, true
} }
...@@ -184,22 +308,43 @@ func (p *GeminiQuotaPolicy) policyForTier(tierID string) (GeminiTierPolicy, bool ...@@ -184,22 +308,43 @@ func (p *GeminiQuotaPolicy) policyForTier(tierID string) (GeminiTierPolicy, bool
return GeminiTierPolicy{}, false return GeminiTierPolicy{}, false
} }
normalized := normalizeGeminiTierID(tierID) normalized := normalizeGeminiTierID(tierID)
if normalized == "" {
normalized = "LEGACY"
}
if policy, ok := p.tiers[normalized]; ok { if policy, ok := p.tiers[normalized]; ok {
return policy, true return policy, true
} }
policy, ok := p.tiers["LEGACY"] return GeminiTierPolicy{}, false
return policy, ok
} }
func normalizeGeminiTierID(tierID string) string { func normalizeGeminiTierID(tierID string) string {
return strings.ToUpper(strings.TrimSpace(tierID)) tierID = strings.TrimSpace(tierID)
if tierID == "" {
return ""
}
// Prefer canonical mapping (handles legacy tier strings).
if canonical := canonicalGeminiTierID(tierID); canonical != "" {
return canonical
}
// Accept older policy keys that used uppercase names.
switch strings.ToUpper(tierID) {
case "AISTUDIO_FREE":
return GeminiTierAIStudioFree
case "AISTUDIO_PAID":
return GeminiTierAIStudioPaid
case "GOOGLE_ONE_FREE":
return GeminiTierGoogleOneFree
case "GOOGLE_AI_PRO":
return GeminiTierGoogleAIPro
case "GOOGLE_AI_ULTRA":
return GeminiTierGoogleAIUltra
case "GCP_STANDARD":
return GeminiTierGCPStandard
case "GCP_ENTERPRISE":
return GeminiTierGCPEnterprise
}
return strings.ToLower(tierID)
} }
func clampGeminiQuotaInt64(value int64) int64 { func clampGeminiQuotaInt64WithUnlimited(value int64) int64 {
if value < 0 { if value < -1 {
return 0 return 0
} }
return value return value
...@@ -212,11 +357,46 @@ func clampGeminiQuotaInt(value int) int { ...@@ -212,11 +357,46 @@ func clampGeminiQuotaInt(value int) int {
return value return value
} }
func clampGeminiQuotaRPM(value int64) int64 {
if value < 0 {
return 0
}
return value
}
func geminiCooldownForTier(tierID string) time.Duration { func geminiCooldownForTier(tierID string) time.Duration {
policy := newGeminiQuotaPolicy() policy := newGeminiQuotaPolicy()
return policy.CooldownForTier(tierID) return policy.CooldownForTier(tierID)
} }
func geminiQuotaTierKeyForAccount(account *Account) string {
if account == nil || account.Platform != PlatformGemini {
return ""
}
// Note: GeminiOAuthType() already defaults legacy (project_id present) to code_assist.
oauthType := strings.ToLower(strings.TrimSpace(account.GeminiOAuthType()))
rawTier := strings.TrimSpace(account.GeminiTierID())
// Prefer the canonical tier stored in credentials.
if tierID := canonicalGeminiTierIDForOAuthType(oauthType, rawTier); tierID != "" && tierID != GeminiTierGoogleOneUnknown {
return tierID
}
// Fallback defaults when tier_id is missing or unknown.
switch oauthType {
case "google_one":
return GeminiTierGoogleOneFree
case "code_assist":
return GeminiTierGCPStandard
case "ai_studio":
return GeminiTierAIStudioFree
default:
// API Key accounts (type=apikey) have empty oauth_type and are treated as AI Studio.
return GeminiTierAIStudioFree
}
}
func geminiModelClassFromName(model string) geminiModelClass { func geminiModelClassFromName(model string) geminiModelClass {
name := strings.ToLower(strings.TrimSpace(model)) name := strings.ToLower(strings.TrimSpace(model))
if strings.Contains(name, "flash") || strings.Contains(name, "lite") { if strings.Contains(name, "flash") || strings.Contains(name, "lite") {
......
...@@ -487,7 +487,7 @@ func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Acco ...@@ -487,7 +487,7 @@ func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Acco
return "", "", errors.New("access_token not found in credentials") return "", "", errors.New("access_token not found in credentials")
} }
return accessToken, "oauth", nil return accessToken, "oauth", nil
case AccountTypeApiKey: case AccountTypeAPIKey:
apiKey := account.GetOpenAIApiKey() apiKey := account.GetOpenAIApiKey()
if apiKey == "" { if apiKey == "" {
return "", "", errors.New("api_key not found in credentials") return "", "", errors.New("api_key not found in credentials")
...@@ -627,7 +627,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. ...@@ -627,7 +627,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
case AccountTypeOAuth: case AccountTypeOAuth:
// OAuth accounts use ChatGPT internal API // OAuth accounts use ChatGPT internal API
targetURL = chatgptCodexURL targetURL = chatgptCodexURL
case AccountTypeApiKey: case AccountTypeAPIKey:
// API Key accounts use Platform API or custom base URL // API Key accounts use Platform API or custom base URL
baseURL := account.GetOpenAIBaseURL() baseURL := account.GetOpenAIBaseURL()
if baseURL != "" { if baseURL != "" {
...@@ -703,7 +703,13 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht ...@@ -703,7 +703,13 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
} }
// Handle upstream error (mark account status) // Handle upstream error (mark account status)
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) shouldDisable := false
if s.rateLimitService != nil {
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
if shouldDisable {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
// Return appropriate error response // Return appropriate error response
var errType, errMsg string var errType, errMsg string
...@@ -940,7 +946,7 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel ...@@ -940,7 +946,7 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
// OpenAIRecordUsageInput input for recording usage // OpenAIRecordUsageInput input for recording usage
type OpenAIRecordUsageInput struct { type OpenAIRecordUsageInput struct {
Result *OpenAIForwardResult Result *OpenAIForwardResult
ApiKey *ApiKey APIKey *APIKey
User *User User *User
Account *Account Account *Account
Subscription *UserSubscription Subscription *UserSubscription
...@@ -949,7 +955,7 @@ type OpenAIRecordUsageInput struct { ...@@ -949,7 +955,7 @@ type OpenAIRecordUsageInput struct {
// RecordUsage records usage and deducts balance // RecordUsage records usage and deducts balance
func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error { func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error {
result := input.Result result := input.Result
apiKey := input.ApiKey apiKey := input.APIKey
user := input.User user := input.User
account := input.Account account := input.Account
subscription := input.Subscription subscription := input.Subscription
...@@ -991,7 +997,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec ...@@ -991,7 +997,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
durationMs := int(result.Duration.Milliseconds()) durationMs := int(result.Duration.Milliseconds())
usageLog := &UsageLog{ usageLog := &UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, APIKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
RequestID: result.RequestID, RequestID: result.RequestID,
Model: result.Model, Model: result.Model,
...@@ -1020,22 +1026,23 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec ...@@ -1020,22 +1026,23 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
usageLog.SubscriptionID = &subscription.ID usageLog.SubscriptionID = &subscription.ID
} }
_ = s.usageLogRepo.Create(ctx, usageLog) inserted, err := s.usageLogRepo.Create(ctx, usageLog)
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
s.deferredService.ScheduleLastUsedUpdate(account.ID) s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil return nil
} }
shouldBill := inserted || err != nil
// Deduct based on billing type // Deduct based on billing type
if isSubscriptionBilling { if isSubscriptionBilling {
if cost.TotalCost > 0 { if shouldBill && cost.TotalCost > 0 {
_ = s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost) _ = s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost)
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost) s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
} }
} else { } else {
if cost.ActualCost > 0 { if shouldBill && cost.ActualCost > 0 {
_ = s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost) _ = s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost)
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost) s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
} }
......
...@@ -2,6 +2,7 @@ package service ...@@ -2,6 +2,7 @@ package service
import ( import (
"context" "context"
"encoding/json"
"log" "log"
"net/http" "net/http"
"strconv" "strconv"
...@@ -18,6 +19,7 @@ type RateLimitService struct { ...@@ -18,6 +19,7 @@ type RateLimitService struct {
usageRepo UsageLogRepository usageRepo UsageLogRepository
cfg *config.Config cfg *config.Config
geminiQuotaService *GeminiQuotaService geminiQuotaService *GeminiQuotaService
tempUnschedCache TempUnschedCache
usageCacheMu sync.RWMutex usageCacheMu sync.RWMutex
usageCache map[int64]*geminiUsageCacheEntry usageCache map[int64]*geminiUsageCacheEntry
} }
...@@ -31,12 +33,13 @@ type geminiUsageCacheEntry struct { ...@@ -31,12 +33,13 @@ type geminiUsageCacheEntry struct {
const geminiPrecheckCacheTTL = time.Minute const geminiPrecheckCacheTTL = time.Minute
// NewRateLimitService 创建RateLimitService实例 // NewRateLimitService 创建RateLimitService实例
func NewRateLimitService(accountRepo AccountRepository, usageRepo UsageLogRepository, cfg *config.Config, geminiQuotaService *GeminiQuotaService) *RateLimitService { func NewRateLimitService(accountRepo AccountRepository, usageRepo UsageLogRepository, cfg *config.Config, geminiQuotaService *GeminiQuotaService, tempUnschedCache TempUnschedCache) *RateLimitService {
return &RateLimitService{ return &RateLimitService{
accountRepo: accountRepo, accountRepo: accountRepo,
usageRepo: usageRepo, usageRepo: usageRepo,
cfg: cfg, cfg: cfg,
geminiQuotaService: geminiQuotaService, geminiQuotaService: geminiQuotaService,
tempUnschedCache: tempUnschedCache,
usageCache: make(map[int64]*geminiUsageCacheEntry), usageCache: make(map[int64]*geminiUsageCacheEntry),
} }
} }
...@@ -51,38 +54,45 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc ...@@ -51,38 +54,45 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
return false return false
} }
tempMatched := s.tryTempUnschedulable(ctx, account, statusCode, responseBody)
switch statusCode { switch statusCode {
case 401: case 401:
// 认证失败:停止调度,记录错误 // 认证失败:停止调度,记录错误
s.handleAuthError(ctx, account, "Authentication failed (401): invalid or expired credentials") s.handleAuthError(ctx, account, "Authentication failed (401): invalid or expired credentials")
return true shouldDisable = true
case 402: case 402:
// 支付要求:余额不足或计费问题,停止调度 // 支付要求:余额不足或计费问题,停止调度
s.handleAuthError(ctx, account, "Payment required (402): insufficient balance or billing issue") s.handleAuthError(ctx, account, "Payment required (402): insufficient balance or billing issue")
return true shouldDisable = true
case 403: case 403:
// 禁止访问:停止调度,记录错误 // 禁止访问:停止调度,记录错误
s.handleAuthError(ctx, account, "Access forbidden (403): account may be suspended or lack permissions") s.handleAuthError(ctx, account, "Access forbidden (403): account may be suspended or lack permissions")
return true shouldDisable = true
case 429: case 429:
s.handle429(ctx, account, headers) s.handle429(ctx, account, headers)
return false shouldDisable = false
case 529: case 529:
s.handle529(ctx, account) s.handle529(ctx, account)
return false shouldDisable = false
default: default:
// 其他5xx错误:记录但不停止调度 // 其他5xx错误:记录但不停止调度
if statusCode >= 500 { if statusCode >= 500 {
log.Printf("Account %d received upstream error %d", account.ID, statusCode) log.Printf("Account %d received upstream error %d", account.ID, statusCode)
} }
return false shouldDisable = false
} }
if tempMatched {
return true
}
return shouldDisable
} }
// PreCheckUsage proactively checks local quota before dispatching a request. // PreCheckUsage proactively checks local quota before dispatching a request.
// Returns false when the account should be skipped. // Returns false when the account should be skipped.
func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, requestedModel string) (bool, error) { func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, requestedModel string) (bool, error) {
if account == nil || !account.IsGeminiCodeAssist() || strings.TrimSpace(requestedModel) == "" { if account == nil || account.Platform != PlatformGemini {
return true, nil return true, nil
} }
if s.usageRepo == nil || s.geminiQuotaService == nil { if s.usageRepo == nil || s.geminiQuotaService == nil {
...@@ -94,44 +104,99 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, ...@@ -94,44 +104,99 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
return true, nil return true, nil
} }
var limit int64
switch geminiModelClassFromName(requestedModel) {
case geminiModelFlash:
limit = quota.FlashRPD
default:
limit = quota.ProRPD
}
if limit <= 0 {
return true, nil
}
now := time.Now() now := time.Now()
start := geminiDailyWindowStart(now) modelClass := geminiModelClassFromName(requestedModel)
totals, ok := s.getGeminiUsageTotals(account.ID, start, now)
if !ok { // 1) Daily quota precheck (RPD; resets at PST midnight)
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID) {
if err != nil { var limit int64
return true, err if quota.SharedRPD > 0 {
limit = quota.SharedRPD
} else {
switch modelClass {
case geminiModelFlash:
limit = quota.FlashRPD
default:
limit = quota.ProRPD
}
} }
totals = geminiAggregateUsage(stats)
s.setGeminiUsageTotals(account.ID, start, now, totals)
}
var used int64 if limit > 0 {
switch geminiModelClassFromName(requestedModel) { start := geminiDailyWindowStart(now)
case geminiModelFlash: totals, ok := s.getGeminiUsageTotals(account.ID, start, now)
used = totals.FlashRequests if !ok {
default: stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID)
used = totals.ProRequests if err != nil {
return true, err
}
totals = geminiAggregateUsage(stats)
s.setGeminiUsageTotals(account.ID, start, now, totals)
}
var used int64
if quota.SharedRPD > 0 {
used = totals.ProRequests + totals.FlashRequests
} else {
switch modelClass {
case geminiModelFlash:
used = totals.FlashRequests
default:
used = totals.ProRequests
}
}
if used >= limit {
resetAt := geminiDailyResetTime(now)
// NOTE:
// - This is a local precheck to reduce upstream 429s.
// - Do NOT mark the account as rate-limited here; rate_limit_reset_at should reflect real upstream 429s.
log.Printf("[Gemini PreCheck] Account %d reached daily quota (%d/%d), skip until %v", account.ID, used, limit, resetAt)
return false, nil
}
}
} }
if used >= limit { // 2) Minute quota precheck (RPM; fixed window current minute)
resetAt := geminiDailyResetTime(now) {
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil { var limit int64
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err) if quota.SharedRPM > 0 {
limit = quota.SharedRPM
} else {
switch modelClass {
case geminiModelFlash:
limit = quota.FlashRPM
default:
limit = quota.ProRPM
}
}
if limit > 0 {
start := now.Truncate(time.Minute)
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID)
if err != nil {
return true, err
}
totals := geminiAggregateUsage(stats)
var used int64
if quota.SharedRPM > 0 {
used = totals.ProRequests + totals.FlashRequests
} else {
switch modelClass {
case geminiModelFlash:
used = totals.FlashRequests
default:
used = totals.ProRequests
}
}
if used >= limit {
resetAt := start.Add(time.Minute)
// Do not persist "rate limited" status from local precheck. See note above.
log.Printf("[Gemini PreCheck] Account %d reached minute quota (%d/%d), skip until %v", account.ID, used, limit, resetAt)
return false, nil
}
} }
log.Printf("[Gemini PreCheck] Account %d reached daily quota (%d/%d), rate limited until %v", account.ID, used, limit, resetAt)
return false, nil
} }
return true, nil return true, nil
...@@ -176,7 +241,10 @@ func (s *RateLimitService) GeminiCooldown(ctx context.Context, account *Account) ...@@ -176,7 +241,10 @@ func (s *RateLimitService) GeminiCooldown(ctx context.Context, account *Account)
if account == nil { if account == nil {
return 5 * time.Minute return 5 * time.Minute
} }
return s.geminiQuotaService.CooldownForTier(ctx, account.GeminiTierID()) if s.geminiQuotaService == nil {
return 5 * time.Minute
}
return s.geminiQuotaService.CooldownForAccount(ctx, account)
} }
// handleAuthError 处理认证类错误(401/403),停止账号调度 // handleAuthError 处理认证类错误(401/403),停止账号调度
...@@ -287,3 +355,183 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc ...@@ -287,3 +355,183 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc
func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64) error { func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64) error {
return s.accountRepo.ClearRateLimit(ctx, accountID) return s.accountRepo.ClearRateLimit(ctx, accountID)
} }
func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID int64) error {
if err := s.accountRepo.ClearTempUnschedulable(ctx, accountID); err != nil {
return err
}
if s.tempUnschedCache != nil {
if err := s.tempUnschedCache.DeleteTempUnsched(ctx, accountID); err != nil {
log.Printf("DeleteTempUnsched failed for account %d: %v", accountID, err)
}
}
return nil
}
func (s *RateLimitService) GetTempUnschedStatus(ctx context.Context, accountID int64) (*TempUnschedState, error) {
now := time.Now().Unix()
if s.tempUnschedCache != nil {
state, err := s.tempUnschedCache.GetTempUnsched(ctx, accountID)
if err != nil {
return nil, err
}
if state != nil && state.UntilUnix > now {
return state, nil
}
}
account, err := s.accountRepo.GetByID(ctx, accountID)
if err != nil {
return nil, err
}
if account.TempUnschedulableUntil == nil {
return nil, nil
}
if account.TempUnschedulableUntil.Unix() <= now {
return nil, nil
}
state := &TempUnschedState{
UntilUnix: account.TempUnschedulableUntil.Unix(),
}
if account.TempUnschedulableReason != "" {
var parsed TempUnschedState
if err := json.Unmarshal([]byte(account.TempUnschedulableReason), &parsed); err == nil {
if parsed.UntilUnix == 0 {
parsed.UntilUnix = state.UntilUnix
}
state = &parsed
} else {
state.ErrorMessage = account.TempUnschedulableReason
}
}
if s.tempUnschedCache != nil {
if err := s.tempUnschedCache.SetTempUnsched(ctx, accountID, state); err != nil {
log.Printf("SetTempUnsched failed for account %d: %v", accountID, err)
}
}
return state, nil
}
func (s *RateLimitService) HandleTempUnschedulable(ctx context.Context, account *Account, statusCode int, responseBody []byte) bool {
if account == nil {
return false
}
if !account.ShouldHandleErrorCode(statusCode) {
return false
}
return s.tryTempUnschedulable(ctx, account, statusCode, responseBody)
}
const tempUnschedBodyMaxBytes = 64 << 10
const tempUnschedMessageMaxBytes = 2048
func (s *RateLimitService) tryTempUnschedulable(ctx context.Context, account *Account, statusCode int, responseBody []byte) bool {
if account == nil {
return false
}
if !account.IsTempUnschedulableEnabled() {
return false
}
rules := account.GetTempUnschedulableRules()
if len(rules) == 0 {
return false
}
if statusCode <= 0 || len(responseBody) == 0 {
return false
}
body := responseBody
if len(body) > tempUnschedBodyMaxBytes {
body = body[:tempUnschedBodyMaxBytes]
}
bodyLower := strings.ToLower(string(body))
for idx, rule := range rules {
if rule.ErrorCode != statusCode || len(rule.Keywords) == 0 {
continue
}
matchedKeyword := matchTempUnschedKeyword(bodyLower, rule.Keywords)
if matchedKeyword == "" {
continue
}
if s.triggerTempUnschedulable(ctx, account, rule, idx, statusCode, matchedKeyword, responseBody) {
return true
}
}
return false
}
func matchTempUnschedKeyword(bodyLower string, keywords []string) string {
if bodyLower == "" {
return ""
}
for _, keyword := range keywords {
k := strings.TrimSpace(keyword)
if k == "" {
continue
}
if strings.Contains(bodyLower, strings.ToLower(k)) {
return k
}
}
return ""
}
func (s *RateLimitService) triggerTempUnschedulable(ctx context.Context, account *Account, rule TempUnschedulableRule, ruleIndex int, statusCode int, matchedKeyword string, responseBody []byte) bool {
if account == nil {
return false
}
if rule.DurationMinutes <= 0 {
return false
}
now := time.Now()
until := now.Add(time.Duration(rule.DurationMinutes) * time.Minute)
state := &TempUnschedState{
UntilUnix: until.Unix(),
TriggeredAtUnix: now.Unix(),
StatusCode: statusCode,
MatchedKeyword: matchedKeyword,
RuleIndex: ruleIndex,
ErrorMessage: truncateTempUnschedMessage(responseBody, tempUnschedMessageMaxBytes),
}
reason := ""
if raw, err := json.Marshal(state); err == nil {
reason = string(raw)
}
if reason == "" {
reason = strings.TrimSpace(state.ErrorMessage)
}
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
log.Printf("SetTempUnschedulable failed for account %d: %v", account.ID, err)
return false
}
if s.tempUnschedCache != nil {
if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil {
log.Printf("SetTempUnsched cache failed for account %d: %v", account.ID, err)
}
}
log.Printf("Account %d temp unschedulable until %v (rule %d, code %d)", account.ID, until, ruleIndex, statusCode)
return true
}
func truncateTempUnschedMessage(body []byte, maxBytes int) string {
if maxBytes <= 0 || len(body) == 0 {
return ""
}
if len(body) > maxBytes {
body = body[:maxBytes]
}
return strings.TrimSpace(string(body))
}
...@@ -61,9 +61,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings ...@@ -61,9 +61,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeySiteName, SettingKeySiteName,
SettingKeySiteLogo, SettingKeySiteLogo,
SettingKeySiteSubtitle, SettingKeySiteSubtitle,
SettingKeyApiBaseUrl, SettingKeyAPIBaseURL,
SettingKeyContactInfo, SettingKeyContactInfo,
SettingKeyDocUrl, SettingKeyDocURL,
} }
settings, err := s.settingRepo.GetMultiple(ctx, keys) settings, err := s.settingRepo.GetMultiple(ctx, keys)
...@@ -79,9 +79,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings ...@@ -79,9 +79,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"), SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
SiteLogo: settings[SettingKeySiteLogo], SiteLogo: settings[SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"), SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
ApiBaseUrl: settings[SettingKeyApiBaseUrl], APIBaseURL: settings[SettingKeyAPIBaseURL],
ContactInfo: settings[SettingKeyContactInfo], ContactInfo: settings[SettingKeyContactInfo],
DocUrl: settings[SettingKeyDocUrl], DocURL: settings[SettingKeyDocURL],
}, nil }, nil
} }
...@@ -94,15 +94,15 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet ...@@ -94,15 +94,15 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled) updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
// 邮件服务设置(只有非空才更新密码) // 邮件服务设置(只有非空才更新密码)
updates[SettingKeySmtpHost] = settings.SmtpHost updates[SettingKeySMTPHost] = settings.SMTPHost
updates[SettingKeySmtpPort] = strconv.Itoa(settings.SmtpPort) updates[SettingKeySMTPPort] = strconv.Itoa(settings.SMTPPort)
updates[SettingKeySmtpUsername] = settings.SmtpUsername updates[SettingKeySMTPUsername] = settings.SMTPUsername
if settings.SmtpPassword != "" { if settings.SMTPPassword != "" {
updates[SettingKeySmtpPassword] = settings.SmtpPassword updates[SettingKeySMTPPassword] = settings.SMTPPassword
} }
updates[SettingKeySmtpFrom] = settings.SmtpFrom updates[SettingKeySMTPFrom] = settings.SMTPFrom
updates[SettingKeySmtpFromName] = settings.SmtpFromName updates[SettingKeySMTPFromName] = settings.SMTPFromName
updates[SettingKeySmtpUseTLS] = strconv.FormatBool(settings.SmtpUseTLS) updates[SettingKeySMTPUseTLS] = strconv.FormatBool(settings.SMTPUseTLS)
// Cloudflare Turnstile 设置(只有非空才更新密钥) // Cloudflare Turnstile 设置(只有非空才更新密钥)
updates[SettingKeyTurnstileEnabled] = strconv.FormatBool(settings.TurnstileEnabled) updates[SettingKeyTurnstileEnabled] = strconv.FormatBool(settings.TurnstileEnabled)
...@@ -115,14 +115,21 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet ...@@ -115,14 +115,21 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeySiteName] = settings.SiteName updates[SettingKeySiteName] = settings.SiteName
updates[SettingKeySiteLogo] = settings.SiteLogo updates[SettingKeySiteLogo] = settings.SiteLogo
updates[SettingKeySiteSubtitle] = settings.SiteSubtitle updates[SettingKeySiteSubtitle] = settings.SiteSubtitle
updates[SettingKeyApiBaseUrl] = settings.ApiBaseUrl updates[SettingKeyAPIBaseURL] = settings.APIBaseURL
updates[SettingKeyContactInfo] = settings.ContactInfo updates[SettingKeyContactInfo] = settings.ContactInfo
updates[SettingKeyDocUrl] = settings.DocUrl updates[SettingKeyDocURL] = settings.DocURL
// 默认配置 // 默认配置
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency) updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64) updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
// Model fallback configuration
updates[SettingKeyEnableModelFallback] = strconv.FormatBool(settings.EnableModelFallback)
updates[SettingKeyFallbackModelAnthropic] = settings.FallbackModelAnthropic
updates[SettingKeyFallbackModelOpenAI] = settings.FallbackModelOpenAI
updates[SettingKeyFallbackModelGemini] = settings.FallbackModelGemini
updates[SettingKeyFallbackModelAntigravity] = settings.FallbackModelAntigravity
return s.settingRepo.SetMultiple(ctx, updates) return s.settingRepo.SetMultiple(ctx, updates)
} }
...@@ -198,8 +205,14 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { ...@@ -198,8 +205,14 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeySiteLogo: "", SettingKeySiteLogo: "",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
SettingKeySmtpPort: "587", SettingKeySMTPPort: "587",
SettingKeySmtpUseTLS: "false", SettingKeySMTPUseTLS: "false",
// Model fallback defaults
SettingKeyEnableModelFallback: "false",
SettingKeyFallbackModelAnthropic: "claude-3-5-sonnet-20241022",
SettingKeyFallbackModelOpenAI: "gpt-4o",
SettingKeyFallbackModelGemini: "gemini-2.5-pro",
SettingKeyFallbackModelAntigravity: "gemini-2.5-pro",
} }
return s.settingRepo.SetMultiple(ctx, defaults) return s.settingRepo.SetMultiple(ctx, defaults)
...@@ -210,26 +223,26 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin ...@@ -210,26 +223,26 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result := &SystemSettings{ result := &SystemSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true", EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
SmtpHost: settings[SettingKeySmtpHost], SMTPHost: settings[SettingKeySMTPHost],
SmtpUsername: settings[SettingKeySmtpUsername], SMTPUsername: settings[SettingKeySMTPUsername],
SmtpFrom: settings[SettingKeySmtpFrom], SMTPFrom: settings[SettingKeySMTPFrom],
SmtpFromName: settings[SettingKeySmtpFromName], SMTPFromName: settings[SettingKeySMTPFromName],
SmtpUseTLS: settings[SettingKeySmtpUseTLS] == "true", SMTPUseTLS: settings[SettingKeySMTPUseTLS] == "true",
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true", TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey], TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"), SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
SiteLogo: settings[SettingKeySiteLogo], SiteLogo: settings[SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"), SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
ApiBaseUrl: settings[SettingKeyApiBaseUrl], APIBaseURL: settings[SettingKeyAPIBaseURL],
ContactInfo: settings[SettingKeyContactInfo], ContactInfo: settings[SettingKeyContactInfo],
DocUrl: settings[SettingKeyDocUrl], DocURL: settings[SettingKeyDocURL],
} }
// 解析整数类型 // 解析整数类型
if port, err := strconv.Atoi(settings[SettingKeySmtpPort]); err == nil { if port, err := strconv.Atoi(settings[SettingKeySMTPPort]); err == nil {
result.SmtpPort = port result.SMTPPort = port
} else { } else {
result.SmtpPort = 587 result.SMTPPort = 587
} }
if concurrency, err := strconv.Atoi(settings[SettingKeyDefaultConcurrency]); err == nil { if concurrency, err := strconv.Atoi(settings[SettingKeyDefaultConcurrency]); err == nil {
...@@ -246,9 +259,16 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin ...@@ -246,9 +259,16 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
} }
// 敏感信息直接返回,方便测试连接时使用 // 敏感信息直接返回,方便测试连接时使用
result.SmtpPassword = settings[SettingKeySmtpPassword] result.SMTPPassword = settings[SettingKeySMTPPassword]
result.TurnstileSecretKey = settings[SettingKeyTurnstileSecretKey] result.TurnstileSecretKey = settings[SettingKeyTurnstileSecretKey]
// Model fallback settings
result.EnableModelFallback = settings[SettingKeyEnableModelFallback] == "true"
result.FallbackModelAnthropic = s.getStringOrDefault(settings, SettingKeyFallbackModelAnthropic, "claude-3-5-sonnet-20241022")
result.FallbackModelOpenAI = s.getStringOrDefault(settings, SettingKeyFallbackModelOpenAI, "gpt-4o")
result.FallbackModelGemini = s.getStringOrDefault(settings, SettingKeyFallbackModelGemini, "gemini-2.5-pro")
result.FallbackModelAntigravity = s.getStringOrDefault(settings, SettingKeyFallbackModelAntigravity, "gemini-2.5-pro")
return result return result
} }
...@@ -278,28 +298,28 @@ func (s *SettingService) GetTurnstileSecretKey(ctx context.Context) string { ...@@ -278,28 +298,28 @@ func (s *SettingService) GetTurnstileSecretKey(ctx context.Context) string {
return value return value
} }
// GenerateAdminApiKey 生成新的管理员 API Key // GenerateAdminAPIKey 生成新的管理员 API Key
func (s *SettingService) GenerateAdminApiKey(ctx context.Context) (string, error) { func (s *SettingService) GenerateAdminAPIKey(ctx context.Context) (string, error) {
// 生成 32 字节随机数 = 64 位十六进制字符 // 生成 32 字节随机数 = 64 位十六进制字符
bytes := make([]byte, 32) bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil { if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("generate random bytes: %w", err) return "", fmt.Errorf("generate random bytes: %w", err)
} }
key := AdminApiKeyPrefix + hex.EncodeToString(bytes) key := AdminAPIKeyPrefix + hex.EncodeToString(bytes)
// 存储到 settings 表 // 存储到 settings 表
if err := s.settingRepo.Set(ctx, SettingKeyAdminApiKey, key); err != nil { if err := s.settingRepo.Set(ctx, SettingKeyAdminAPIKey, key); err != nil {
return "", fmt.Errorf("save admin api key: %w", err) return "", fmt.Errorf("save admin api key: %w", err)
} }
return key, nil return key, nil
} }
// GetAdminApiKeyStatus 获取管理员 API Key 状态 // GetAdminAPIKeyStatus 获取管理员 API Key 状态
// 返回脱敏的 key、是否存在、错误 // 返回脱敏的 key、是否存在、错误
func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey string, exists bool, err error) { func (s *SettingService) GetAdminAPIKeyStatus(ctx context.Context) (maskedKey string, exists bool, err error) {
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminApiKey) key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminAPIKey)
if err != nil { if err != nil {
if errors.Is(err, ErrSettingNotFound) { if errors.Is(err, ErrSettingNotFound) {
return "", false, nil return "", false, nil
...@@ -320,10 +340,10 @@ func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey st ...@@ -320,10 +340,10 @@ func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey st
return maskedKey, true, nil return maskedKey, true, nil
} }
// GetAdminApiKey 获取完整的管理员 API Key(仅供内部验证使用) // GetAdminAPIKey 获取完整的管理员 API Key(仅供内部验证使用)
// 如果未配置返回空字符串和 nil 错误,只有数据库错误时才返回 error // 如果未配置返回空字符串和 nil 错误,只有数据库错误时才返回 error
func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) { func (s *SettingService) GetAdminAPIKey(ctx context.Context) (string, error) {
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminApiKey) key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminAPIKey)
if err != nil { if err != nil {
if errors.Is(err, ErrSettingNotFound) { if errors.Is(err, ErrSettingNotFound) {
return "", nil // 未配置,返回空字符串 return "", nil // 未配置,返回空字符串
...@@ -333,7 +353,45 @@ func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) { ...@@ -333,7 +353,45 @@ func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) {
return key, nil return key, nil
} }
// DeleteAdminApiKey 删除管理员 API Key // DeleteAdminAPIKey 删除管理员 API Key
func (s *SettingService) DeleteAdminApiKey(ctx context.Context) error { func (s *SettingService) DeleteAdminAPIKey(ctx context.Context) error {
return s.settingRepo.Delete(ctx, SettingKeyAdminApiKey) return s.settingRepo.Delete(ctx, SettingKeyAdminAPIKey)
}
// IsModelFallbackEnabled 检查是否启用模型兜底机制
func (s *SettingService) IsModelFallbackEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, SettingKeyEnableModelFallback)
if err != nil {
return false // Default: disabled
}
return value == "true"
}
// GetFallbackModel 获取指定平台的兜底模型
func (s *SettingService) GetFallbackModel(ctx context.Context, platform string) string {
var key string
var defaultModel string
switch platform {
case PlatformAnthropic:
key = SettingKeyFallbackModelAnthropic
defaultModel = "claude-3-5-sonnet-20241022"
case PlatformOpenAI:
key = SettingKeyFallbackModelOpenAI
defaultModel = "gpt-4o"
case PlatformGemini:
key = SettingKeyFallbackModelGemini
defaultModel = "gemini-2.5-pro"
case PlatformAntigravity:
key = SettingKeyFallbackModelAntigravity
defaultModel = "gemini-2.5-pro"
default:
return ""
}
value, err := s.settingRepo.GetValue(ctx, key)
if err != nil || value == "" {
return defaultModel
}
return value
} }
...@@ -4,13 +4,13 @@ type SystemSettings struct { ...@@ -4,13 +4,13 @@ type SystemSettings struct {
RegistrationEnabled bool RegistrationEnabled bool
EmailVerifyEnabled bool EmailVerifyEnabled bool
SmtpHost string SMTPHost string
SmtpPort int SMTPPort int
SmtpUsername string SMTPUsername string
SmtpPassword string SMTPPassword string
SmtpFrom string SMTPFrom string
SmtpFromName string SMTPFromName string
SmtpUseTLS bool SMTPUseTLS bool
TurnstileEnabled bool TurnstileEnabled bool
TurnstileSiteKey string TurnstileSiteKey string
...@@ -19,12 +19,19 @@ type SystemSettings struct { ...@@ -19,12 +19,19 @@ type SystemSettings struct {
SiteName string SiteName string
SiteLogo string SiteLogo string
SiteSubtitle string SiteSubtitle string
ApiBaseUrl string APIBaseURL string
ContactInfo string ContactInfo string
DocUrl string DocURL string
DefaultConcurrency int DefaultConcurrency int
DefaultBalance float64 DefaultBalance float64
// Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"`
FallbackModelAnthropic string `json:"fallback_model_anthropic"`
FallbackModelOpenAI string `json:"fallback_model_openai"`
FallbackModelGemini string `json:"fallback_model_gemini"`
FallbackModelAntigravity string `json:"fallback_model_antigravity"`
} }
type PublicSettings struct { type PublicSettings struct {
...@@ -35,8 +42,8 @@ type PublicSettings struct { ...@@ -35,8 +42,8 @@ type PublicSettings struct {
SiteName string SiteName string
SiteLogo string SiteLogo string
SiteSubtitle string SiteSubtitle string
ApiBaseUrl string APIBaseURL string
ContactInfo string ContactInfo string
DocUrl string DocURL string
Version string Version string
} }
package service
import (
"context"
)
// TempUnschedState 临时不可调度状态
type TempUnschedState struct {
UntilUnix int64 `json:"until_unix"` // 解除时间(Unix 时间戳)
TriggeredAtUnix int64 `json:"triggered_at_unix"` // 触发时间(Unix 时间戳)
StatusCode int `json:"status_code"` // 触发的错误码
MatchedKeyword string `json:"matched_keyword"` // 匹配的关键词
RuleIndex int `json:"rule_index"` // 触发的规则索引
ErrorMessage string `json:"error_message"` // 错误消息
}
// TempUnschedCache 临时不可调度缓存接口
type TempUnschedCache interface {
SetTempUnsched(ctx context.Context, accountID int64, state *TempUnschedState) error
GetTempUnsched(ctx context.Context, accountID int64) (*TempUnschedState, error)
DeleteTempUnsched(ctx context.Context, accountID int64) error
}
...@@ -197,7 +197,7 @@ func TestClaudeTokenRefresher_CanRefresh(t *testing.T) { ...@@ -197,7 +197,7 @@ func TestClaudeTokenRefresher_CanRefresh(t *testing.T) {
{ {
name: "anthropic api-key - cannot refresh", name: "anthropic api-key - cannot refresh",
platform: PlatformAnthropic, platform: PlatformAnthropic,
accType: AccountTypeApiKey, accType: AccountTypeAPIKey,
want: false, want: false,
}, },
{ {
......
...@@ -79,7 +79,7 @@ type ReleaseInfo struct { ...@@ -79,7 +79,7 @@ type ReleaseInfo struct {
Name string `json:"name"` Name string `json:"name"`
Body string `json:"body"` Body string `json:"body"`
PublishedAt string `json:"published_at"` PublishedAt string `json:"published_at"`
HtmlURL string `json:"html_url"` HTMLURL string `json:"html_url"`
Assets []Asset `json:"assets,omitempty"` Assets []Asset `json:"assets,omitempty"`
} }
...@@ -96,13 +96,13 @@ type GitHubRelease struct { ...@@ -96,13 +96,13 @@ type GitHubRelease struct {
Name string `json:"name"` Name string `json:"name"`
Body string `json:"body"` Body string `json:"body"`
PublishedAt string `json:"published_at"` PublishedAt string `json:"published_at"`
HtmlUrl string `json:"html_url"` HTMLURL string `json:"html_url"`
Assets []GitHubAsset `json:"assets"` Assets []GitHubAsset `json:"assets"`
} }
type GitHubAsset struct { type GitHubAsset struct {
Name string `json:"name"` Name string `json:"name"`
BrowserDownloadUrl string `json:"browser_download_url"` BrowserDownloadURL string `json:"browser_download_url"`
Size int64 `json:"size"` Size int64 `json:"size"`
} }
...@@ -285,7 +285,7 @@ func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, er ...@@ -285,7 +285,7 @@ func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, er
for i, a := range release.Assets { for i, a := range release.Assets {
assets[i] = Asset{ assets[i] = Asset{
Name: a.Name, Name: a.Name,
DownloadURL: a.BrowserDownloadUrl, DownloadURL: a.BrowserDownloadURL,
Size: a.Size, Size: a.Size,
} }
} }
...@@ -298,7 +298,7 @@ func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, er ...@@ -298,7 +298,7 @@ func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, er
Name: release.Name, Name: release.Name,
Body: release.Body, Body: release.Body,
PublishedAt: release.PublishedAt, PublishedAt: release.PublishedAt,
HtmlURL: release.HtmlUrl, HTMLURL: release.HTMLURL,
Assets: assets, Assets: assets,
}, },
Cached: false, Cached: false,
......
...@@ -10,7 +10,7 @@ const ( ...@@ -10,7 +10,7 @@ const (
type UsageLog struct { type UsageLog struct {
ID int64 ID int64
UserID int64 UserID int64
ApiKeyID int64 APIKeyID int64
AccountID int64 AccountID int64
RequestID string RequestID string
Model string Model string
...@@ -42,7 +42,7 @@ type UsageLog struct { ...@@ -42,7 +42,7 @@ type UsageLog struct {
CreatedAt time.Time CreatedAt time.Time
User *User User *User
ApiKey *ApiKey APIKey *APIKey
Account *Account Account *Account
Group *Group Group *Group
Subscription *UserSubscription Subscription *UserSubscription
......
...@@ -2,9 +2,11 @@ package service ...@@ -2,9 +2,11 @@ package service
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"time" "time"
dbent "github.com/Wei-Shaw/sub2api/ent"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
...@@ -17,7 +19,7 @@ var ( ...@@ -17,7 +19,7 @@ var (
// CreateUsageLogRequest 创建使用日志请求 // CreateUsageLogRequest 创建使用日志请求
type CreateUsageLogRequest struct { type CreateUsageLogRequest struct {
UserID int64 `json:"user_id"` UserID int64 `json:"user_id"`
ApiKeyID int64 `json:"api_key_id"` APIKeyID int64 `json:"api_key_id"`
AccountID int64 `json:"account_id"` AccountID int64 `json:"account_id"`
RequestID string `json:"request_id"` RequestID string `json:"request_id"`
Model string `json:"model"` Model string `json:"model"`
...@@ -54,20 +56,34 @@ type UsageStats struct { ...@@ -54,20 +56,34 @@ type UsageStats struct {
type UsageService struct { type UsageService struct {
usageRepo UsageLogRepository usageRepo UsageLogRepository
userRepo UserRepository userRepo UserRepository
entClient *dbent.Client
} }
// NewUsageService 创建使用统计服务实例 // NewUsageService 创建使用统计服务实例
func NewUsageService(usageRepo UsageLogRepository, userRepo UserRepository) *UsageService { func NewUsageService(usageRepo UsageLogRepository, userRepo UserRepository, entClient *dbent.Client) *UsageService {
return &UsageService{ return &UsageService{
usageRepo: usageRepo, usageRepo: usageRepo,
userRepo: userRepo, userRepo: userRepo,
entClient: entClient,
} }
} }
// Create 创建使用日志 // Create 创建使用日志
func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*UsageLog, error) { func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*UsageLog, error) {
// 使用数据库事务保证「使用日志插入」与「扣费」的原子性,避免重复扣费或漏扣风险。
tx, err := s.entClient.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
return nil, fmt.Errorf("begin transaction: %w", err)
}
txCtx := ctx
if err == nil {
defer func() { _ = tx.Rollback() }()
txCtx = dbent.NewTxContext(ctx, tx)
}
// 验证用户存在 // 验证用户存在
_, err := s.userRepo.GetByID(ctx, req.UserID) _, err = s.userRepo.GetByID(txCtx, req.UserID)
if err != nil { if err != nil {
return nil, fmt.Errorf("get user: %w", err) return nil, fmt.Errorf("get user: %w", err)
} }
...@@ -75,7 +91,7 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (* ...@@ -75,7 +91,7 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
// 创建使用日志 // 创建使用日志
usageLog := &UsageLog{ usageLog := &UsageLog{
UserID: req.UserID, UserID: req.UserID,
ApiKeyID: req.ApiKeyID, APIKeyID: req.APIKeyID,
AccountID: req.AccountID, AccountID: req.AccountID,
RequestID: req.RequestID, RequestID: req.RequestID,
Model: req.Model, Model: req.Model,
...@@ -96,17 +112,24 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (* ...@@ -96,17 +112,24 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
DurationMs: req.DurationMs, DurationMs: req.DurationMs,
} }
if err := s.usageRepo.Create(ctx, usageLog); err != nil { inserted, err := s.usageRepo.Create(txCtx, usageLog)
if err != nil {
return nil, fmt.Errorf("create usage log: %w", err) return nil, fmt.Errorf("create usage log: %w", err)
} }
// 扣除用户余额 // 扣除用户余额
if req.ActualCost > 0 { if inserted && req.ActualCost > 0 {
if err := s.userRepo.UpdateBalance(ctx, req.UserID, -req.ActualCost); err != nil { if err := s.userRepo.UpdateBalance(txCtx, req.UserID, -req.ActualCost); err != nil {
return nil, fmt.Errorf("update user balance: %w", err) return nil, fmt.Errorf("update user balance: %w", err)
} }
} }
if tx != nil {
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("commit transaction: %w", err)
}
}
return usageLog, nil return usageLog, nil
} }
...@@ -128,9 +151,9 @@ func (s *UsageService) ListByUser(ctx context.Context, userID int64, params pagi ...@@ -128,9 +151,9 @@ func (s *UsageService) ListByUser(ctx context.Context, userID int64, params pagi
return logs, pagination, nil return logs, pagination, nil
} }
// ListByApiKey 获取API Key的使用日志列表 // ListByAPIKey 获取API Key的使用日志列表
func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) { func (s *UsageService) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) {
logs, pagination, err := s.usageRepo.ListByApiKey(ctx, apiKeyID, params) logs, pagination, err := s.usageRepo.ListByAPIKey(ctx, apiKeyID, params)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("list usage logs: %w", err) return nil, nil, fmt.Errorf("list usage logs: %w", err)
} }
...@@ -165,9 +188,9 @@ func (s *UsageService) GetStatsByUser(ctx context.Context, userID int64, startTi ...@@ -165,9 +188,9 @@ func (s *UsageService) GetStatsByUser(ctx context.Context, userID int64, startTi
}, nil }, nil
} }
// GetStatsByApiKey 获取API Key的使用统计 // GetStatsByAPIKey 获取API Key的使用统计
func (s *UsageService) GetStatsByApiKey(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*UsageStats, error) { func (s *UsageService) GetStatsByAPIKey(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*UsageStats, error) {
stats, err := s.usageRepo.GetApiKeyStatsAggregated(ctx, apiKeyID, startTime, endTime) stats, err := s.usageRepo.GetAPIKeyStatsAggregated(ctx, apiKeyID, startTime, endTime)
if err != nil { if err != nil {
return nil, fmt.Errorf("get api key stats: %w", err) return nil, fmt.Errorf("get api key stats: %w", err)
} }
...@@ -270,9 +293,9 @@ func (s *UsageService) GetUserModelStats(ctx context.Context, userID int64, star ...@@ -270,9 +293,9 @@ func (s *UsageService) GetUserModelStats(ctx context.Context, userID int64, star
return stats, nil return stats, nil
} }
// GetBatchApiKeyUsageStats returns today/total actual_cost for given api keys. // GetBatchAPIKeyUsageStats returns today/total actual_cost for given api keys.
func (s *UsageService) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) { func (s *UsageService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchApiKeyUsageStats(ctx, apiKeyIDs) stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs)
if err != nil { if err != nil {
return nil, fmt.Errorf("get batch api key usage stats: %w", err) return nil, fmt.Errorf("get batch api key usage stats: %w", err)
} }
......
...@@ -21,7 +21,7 @@ type User struct { ...@@ -21,7 +21,7 @@ type User struct {
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
ApiKeys []ApiKey APIKeys []APIKey
Subscriptions []UserSubscription Subscriptions []UserSubscription
} }
......
...@@ -56,6 +56,10 @@ func (s *UserAttributeService) CreateDefinition(ctx context.Context, input Creat ...@@ -56,6 +56,10 @@ func (s *UserAttributeService) CreateDefinition(ctx context.Context, input Creat
Enabled: input.Enabled, Enabled: input.Enabled,
} }
if err := validateDefinitionPattern(def); err != nil {
return nil, err
}
if err := s.defRepo.Create(ctx, def); err != nil { if err := s.defRepo.Create(ctx, def); err != nil {
return nil, fmt.Errorf("create definition: %w", err) return nil, fmt.Errorf("create definition: %w", err)
} }
...@@ -108,6 +112,10 @@ func (s *UserAttributeService) UpdateDefinition(ctx context.Context, id int64, i ...@@ -108,6 +112,10 @@ func (s *UserAttributeService) UpdateDefinition(ctx context.Context, id int64, i
def.Enabled = *input.Enabled def.Enabled = *input.Enabled
} }
if err := validateDefinitionPattern(def); err != nil {
return nil, err
}
if err := s.defRepo.Update(ctx, def); err != nil { if err := s.defRepo.Update(ctx, def); err != nil {
return nil, fmt.Errorf("update definition: %w", err) return nil, fmt.Errorf("update definition: %w", err)
} }
...@@ -231,7 +239,10 @@ func (s *UserAttributeService) validateValue(def *UserAttributeDefinition, value ...@@ -231,7 +239,10 @@ func (s *UserAttributeService) validateValue(def *UserAttributeDefinition, value
// Pattern validation // Pattern validation
if v.Pattern != nil && *v.Pattern != "" && value != "" { if v.Pattern != nil && *v.Pattern != "" && value != "" {
re, err := regexp.Compile(*v.Pattern) re, err := regexp.Compile(*v.Pattern)
if err == nil && !re.MatchString(value) { if err != nil {
return validationError(def.Name + " has an invalid pattern")
}
if !re.MatchString(value) {
msg := def.Name + " format is invalid" msg := def.Name + " format is invalid"
if v.Message != nil && *v.Message != "" { if v.Message != nil && *v.Message != "" {
msg = *v.Message msg = *v.Message
...@@ -293,3 +304,20 @@ func isValidAttributeType(t UserAttributeType) bool { ...@@ -293,3 +304,20 @@ func isValidAttributeType(t UserAttributeType) bool {
} }
return false return false
} }
func validateDefinitionPattern(def *UserAttributeDefinition) error {
if def == nil {
return nil
}
if def.Validation.Pattern == nil {
return nil
}
pattern := strings.TrimSpace(*def.Validation.Pattern)
if pattern == "" {
return nil
}
if _, err := regexp.Compile(pattern); err != nil {
return infraerrors.BadRequest("INVALID_ATTRIBUTE_PATTERN", fmt.Sprintf("invalid pattern for %s: %v", def.Name, err))
}
return nil
}
...@@ -75,7 +75,7 @@ var ProviderSet = wire.NewSet( ...@@ -75,7 +75,7 @@ var ProviderSet = wire.NewSet(
// Core services // Core services
NewAuthService, NewAuthService,
NewUserService, NewUserService,
NewApiKeyService, NewAPIKeyService,
NewGroupService, NewGroupService,
NewAccountService, NewAccountService,
NewProxyService, NewProxyService,
......
// Package setup provides CLI commands and application initialization helpers.
package setup package setup
import ( import (
......
...@@ -345,7 +345,7 @@ func writeConfigFile(cfg *SetupConfig) error { ...@@ -345,7 +345,7 @@ func writeConfigFile(cfg *SetupConfig) error {
Default struct { Default struct {
UserConcurrency int `yaml:"user_concurrency"` UserConcurrency int `yaml:"user_concurrency"`
UserBalance float64 `yaml:"user_balance"` UserBalance float64 `yaml:"user_balance"`
ApiKeyPrefix string `yaml:"api_key_prefix"` APIKeyPrefix string `yaml:"api_key_prefix"`
RateMultiplier float64 `yaml:"rate_multiplier"` RateMultiplier float64 `yaml:"rate_multiplier"`
} `yaml:"default"` } `yaml:"default"`
RateLimit struct { RateLimit struct {
...@@ -367,12 +367,12 @@ func writeConfigFile(cfg *SetupConfig) error { ...@@ -367,12 +367,12 @@ func writeConfigFile(cfg *SetupConfig) error {
Default: struct { Default: struct {
UserConcurrency int `yaml:"user_concurrency"` UserConcurrency int `yaml:"user_concurrency"`
UserBalance float64 `yaml:"user_balance"` UserBalance float64 `yaml:"user_balance"`
ApiKeyPrefix string `yaml:"api_key_prefix"` APIKeyPrefix string `yaml:"api_key_prefix"`
RateMultiplier float64 `yaml:"rate_multiplier"` RateMultiplier float64 `yaml:"rate_multiplier"`
}{ }{
UserConcurrency: 5, UserConcurrency: 5,
UserBalance: 0, UserBalance: 0,
ApiKeyPrefix: "sk-", APIKeyPrefix: "sk-",
RateMultiplier: 1.0, RateMultiplier: 1.0,
}, },
RateLimit: struct { RateLimit: struct {
......
//go:build !embed //go:build !embed
// Package web provides embedded web assets for the application.
package web package web
import ( import (
......
-- 020_add_temp_unschedulable.sql
-- 添加临时不可调度功能相关字段
-- 添加临时不可调度状态解除时间字段
ALTER TABLE accounts ADD COLUMN IF NOT EXISTS temp_unschedulable_until timestamptz;
-- 添加临时不可调度原因字段(用于排障和审计)
ALTER TABLE accounts ADD COLUMN IF NOT EXISTS temp_unschedulable_reason text;
-- 添加索引以优化调度查询性能
CREATE INDEX IF NOT EXISTS idx_accounts_temp_unschedulable_until ON accounts(temp_unschedulable_until) WHERE deleted_at IS NULL;
-- 添加注释说明字段用途
COMMENT ON COLUMN accounts.temp_unschedulable_until IS '临时不可调度状态解除时间,当触发临时不可调度规则时设置(基于错误码或错误描述关键词)';
COMMENT ON COLUMN accounts.temp_unschedulable_reason IS '临时不可调度原因,记录触发临时不可调度的具体原因(用于排障和审计)';
-- Ops monitoring: pre-aggregation tables for dashboard queries
--
-- Problem:
-- The ops dashboard currently runs percentile_cont + GROUP BY queries over large raw tables
-- (usage_logs, ops_error_logs). These will get slower as data grows.
--
-- This migration adds schema-only aggregation tables that can be populated by a future background job.
-- No triggers/functions/jobs are created here (schema only).
-- ============================================
-- Hourly aggregates (per provider/platform)
-- ============================================
CREATE TABLE IF NOT EXISTS ops_metrics_hourly (
-- Start of the hour bucket (recommended: UTC).
bucket_start TIMESTAMPTZ NOT NULL,
-- Provider/platform label (e.g. anthropic/openai/gemini). Mirrors ops_* queries that GROUP BY platform.
platform VARCHAR(50) NOT NULL,
-- Traffic counts (use these to compute rates reliably across ranges).
request_count BIGINT NOT NULL DEFAULT 0,
success_count BIGINT NOT NULL DEFAULT 0,
error_count BIGINT NOT NULL DEFAULT 0,
-- Error breakdown used by provider health UI.
error_4xx_count BIGINT NOT NULL DEFAULT 0,
error_5xx_count BIGINT NOT NULL DEFAULT 0,
timeout_count BIGINT NOT NULL DEFAULT 0,
-- Latency aggregates (ms).
avg_latency_ms DOUBLE PRECISION,
p99_latency_ms DOUBLE PRECISION,
-- Convenience rate (percentage, 0-100). Still keep counts as source of truth.
error_rate DOUBLE PRECISION NOT NULL DEFAULT 0,
-- When this row was last (re)computed by the background job.
computed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
PRIMARY KEY (bucket_start, platform)
);
CREATE INDEX IF NOT EXISTS idx_ops_metrics_hourly_platform_bucket_start
ON ops_metrics_hourly (platform, bucket_start DESC);
COMMENT ON TABLE ops_metrics_hourly IS 'Pre-aggregated hourly ops metrics by provider/platform to speed up dashboard queries.';
COMMENT ON COLUMN ops_metrics_hourly.bucket_start IS 'Start timestamp of the hour bucket (recommended UTC).';
COMMENT ON COLUMN ops_metrics_hourly.platform IS 'Provider/platform label (anthropic/openai/gemini, etc).';
COMMENT ON COLUMN ops_metrics_hourly.error_rate IS 'Error rate percentage for the bucket (0-100). Counts remain the source of truth.';
COMMENT ON COLUMN ops_metrics_hourly.computed_at IS 'When the row was last computed/refreshed.';
-- ============================================
-- Daily aggregates (per provider/platform)
-- ============================================
CREATE TABLE IF NOT EXISTS ops_metrics_daily (
-- Day bucket (recommended: UTC date).
bucket_date DATE NOT NULL,
platform VARCHAR(50) NOT NULL,
request_count BIGINT NOT NULL DEFAULT 0,
success_count BIGINT NOT NULL DEFAULT 0,
error_count BIGINT NOT NULL DEFAULT 0,
error_4xx_count BIGINT NOT NULL DEFAULT 0,
error_5xx_count BIGINT NOT NULL DEFAULT 0,
timeout_count BIGINT NOT NULL DEFAULT 0,
avg_latency_ms DOUBLE PRECISION,
p99_latency_ms DOUBLE PRECISION,
error_rate DOUBLE PRECISION NOT NULL DEFAULT 0,
computed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
PRIMARY KEY (bucket_date, platform)
);
CREATE INDEX IF NOT EXISTS idx_ops_metrics_daily_platform_bucket_date
ON ops_metrics_daily (platform, bucket_date DESC);
COMMENT ON TABLE ops_metrics_daily IS 'Pre-aggregated daily ops metrics by provider/platform for longer-term trends.';
COMMENT ON COLUMN ops_metrics_daily.bucket_date IS 'UTC date of the day bucket (recommended).';
-- ============================================
-- Population strategy (future background job)
-- ============================================
--
-- Suggested approach:
-- 1) Compute hourly buckets from raw logs using UTC time-bucketing, then UPSERT into ops_metrics_hourly.
-- 2) Compute daily buckets either directly from raw logs or by rolling up ops_metrics_hourly.
--
-- Notes:
-- - Ensure the job uses a consistent timezone (recommended: SET TIME ZONE ''UTC'') to avoid bucket drift.
-- - Derive the provider/platform similarly to existing dashboard queries:
-- usage_logs: COALESCE(NULLIF(groups.platform, ''), accounts.platform, '')
-- ops_error_logs: COALESCE(NULLIF(ops_error_logs.platform, ''), groups.platform, accounts.platform, '')
-- - Keep request_count/success_count/error_count as the authoritative values; compute error_rate from counts.
--
-- Example (hourly) shape (pseudo-SQL):
-- INSERT INTO ops_metrics_hourly (...)
-- SELECT date_trunc('hour', created_at) AS bucket_start, platform, ...
-- FROM (/* aggregate usage_logs + ops_error_logs */) s
-- ON CONFLICT (bucket_start, platform) DO UPDATE SET ...;
-- 027_usage_billing_consistency.sql
-- Ensure usage_logs idempotency (request_id, api_key_id) and add reconciliation infrastructure.
-- -----------------------------------------------------------------------------
-- 1) Normalize legacy request_id values
-- -----------------------------------------------------------------------------
-- Historically request_id may be inserted as empty string. Convert it to NULL so
-- the upcoming unique index does not break on repeated "" values.
UPDATE usage_logs
SET request_id = NULL
WHERE request_id = '';
-- If duplicates already exist for the same (request_id, api_key_id), keep the
-- first row and NULL-out request_id for the rest so the unique index can be
-- created without deleting historical logs.
WITH ranked AS (
SELECT
id,
ROW_NUMBER() OVER (PARTITION BY api_key_id, request_id ORDER BY id) AS rn
FROM usage_logs
WHERE request_id IS NOT NULL
)
UPDATE usage_logs ul
SET request_id = NULL
FROM ranked r
WHERE ul.id = r.id
AND r.rn > 1;
-- -----------------------------------------------------------------------------
-- 2) Idempotency constraint for usage_logs
-- -----------------------------------------------------------------------------
CREATE UNIQUE INDEX IF NOT EXISTS idx_usage_logs_request_id_api_key_unique
ON usage_logs (request_id, api_key_id);
-- -----------------------------------------------------------------------------
-- 3) Reconciliation infrastructure: billing ledger for usage charges
-- -----------------------------------------------------------------------------
CREATE TABLE IF NOT EXISTS billing_usage_entries (
id BIGSERIAL PRIMARY KEY,
usage_log_id BIGINT NOT NULL REFERENCES usage_logs(id) ON DELETE CASCADE,
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
api_key_id BIGINT NOT NULL REFERENCES api_keys(id) ON DELETE CASCADE,
subscription_id BIGINT REFERENCES user_subscriptions(id) ON DELETE SET NULL,
billing_type SMALLINT NOT NULL,
applied BOOLEAN NOT NULL DEFAULT TRUE,
delta_usd DECIMAL(20, 10) NOT NULL DEFAULT 0,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE UNIQUE INDEX IF NOT EXISTS billing_usage_entries_usage_log_id_unique
ON billing_usage_entries (usage_log_id);
CREATE INDEX IF NOT EXISTS idx_billing_usage_entries_user_time
ON billing_usage_entries (user_id, created_at);
CREATE INDEX IF NOT EXISTS idx_billing_usage_entries_created_at
ON billing_usage_entries (created_at);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment