Commit 7dddd065 authored by yangjianbo's avatar yangjianbo
Browse files
parents 25a0d49a e78c8646
package service
import "testing"
import (
"context"
"net/url"
"strings"
"testing"
func TestInferGoogleOneTier(t *testing.T) {
tests := []struct {
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
)
func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) {
t.Parallel()
type testCase struct {
name string
storageBytes int64
expectedTier string
}{
{"Negative storage", -1, TierGoogleOneUnknown},
{"Zero storage", 0, TierGoogleOneUnknown},
// Free tier boundary (15GB)
{"Below free tier", 10 * GB, TierGoogleOneUnknown},
{"Just below free tier", StorageTierFree - 1, TierGoogleOneUnknown},
{"Free tier (15GB)", StorageTierFree, TierFree},
// Basic tier boundary (100GB)
{"Between free and basic", 50 * GB, TierFree},
{"Just below basic tier", StorageTierBasic - 1, TierFree},
{"Basic tier (100GB)", StorageTierBasic, TierGoogleOneBasic},
// Standard tier boundary (200GB)
{"Between basic and standard", 150 * GB, TierGoogleOneBasic},
{"Just below standard tier", StorageTierStandard - 1, TierGoogleOneBasic},
{"Standard tier (200GB)", StorageTierStandard, TierGoogleOneStandard},
// AI Premium tier boundary (2TB)
{"Between standard and premium", 1 * TB, TierGoogleOneStandard},
{"Just below AI Premium tier", StorageTierAIPremium - 1, TierGoogleOneStandard},
{"AI Premium tier (2TB)", StorageTierAIPremium, TierAIPremium},
// Unlimited tier boundary (> 100TB)
{"Between premium and unlimited", 50 * TB, TierAIPremium},
{"At unlimited threshold (100TB)", StorageTierUnlimited, TierAIPremium},
{"Unlimited tier (100TB+)", StorageTierUnlimited + 1, TierGoogleOneUnlimited},
{"Unlimited tier (101TB+)", 101 * TB, TierGoogleOneUnlimited},
{"Very large storage", 1000 * TB, TierGoogleOneUnlimited},
cfg *config.Config
oauthType string
projectID string
wantClientID string
wantRedirect string
wantScope string
wantProjectID string
wantErrSubstr string
}
tests := []testCase{
{
name: "google_one uses built-in client when not configured and redirects to upstream",
cfg: &config.Config{
Gemini: config.GeminiConfig{
OAuth: config.GeminiOAuthConfig{},
},
},
oauthType: "google_one",
wantClientID: geminicli.GeminiCLIOAuthClientID,
wantRedirect: geminicli.GeminiCLIRedirectURI,
wantScope: geminicli.DefaultCodeAssistScopes,
wantProjectID: "",
},
{
name: "google_one uses custom client when configured and redirects to localhost",
cfg: &config.Config{
Gemini: config.GeminiConfig{
OAuth: config.GeminiOAuthConfig{
ClientID: "custom-client-id",
ClientSecret: "custom-client-secret",
},
},
},
oauthType: "google_one",
wantClientID: "custom-client-id",
wantRedirect: geminicli.AIStudioOAuthRedirectURI,
wantScope: geminicli.DefaultGoogleOneScopes,
wantProjectID: "",
},
{
name: "code_assist always forces built-in client even when custom client configured",
cfg: &config.Config{
Gemini: config.GeminiConfig{
OAuth: config.GeminiOAuthConfig{
ClientID: "custom-client-id",
ClientSecret: "custom-client-secret",
},
},
},
oauthType: "code_assist",
projectID: "my-gcp-project",
wantClientID: geminicli.GeminiCLIOAuthClientID,
wantRedirect: geminicli.GeminiCLIRedirectURI,
wantScope: geminicli.DefaultCodeAssistScopes,
wantProjectID: "my-gcp-project",
},
{
name: "ai_studio requires custom client",
cfg: &config.Config{
Gemini: config.GeminiConfig{
OAuth: config.GeminiOAuthConfig{},
},
},
oauthType: "ai_studio",
wantErrSubstr: "AI Studio OAuth requires a custom OAuth Client",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
result := inferGoogleOneTier(tt.storageBytes)
if result != tt.expectedTier {
t.Errorf("inferGoogleOneTier(%d) = %s, want %s",
tt.storageBytes, result, tt.expectedTier)
t.Parallel()
svc := NewGeminiOAuthService(nil, nil, nil, tt.cfg)
got, err := svc.GenerateAuthURL(context.Background(), nil, "https://example.com/auth/callback", tt.projectID, tt.oauthType, "")
if tt.wantErrSubstr != "" {
if err == nil {
t.Fatalf("expected error containing %q, got nil", tt.wantErrSubstr)
}
if !strings.Contains(err.Error(), tt.wantErrSubstr) {
t.Fatalf("expected error containing %q, got: %v", tt.wantErrSubstr, err)
}
return
}
if err != nil {
t.Fatalf("GenerateAuthURL returned error: %v", err)
}
parsed, err := url.Parse(got.AuthURL)
if err != nil {
t.Fatalf("failed to parse auth_url: %v", err)
}
q := parsed.Query()
if gotState := q.Get("state"); gotState != got.State {
t.Fatalf("state mismatch: query=%q result=%q", gotState, got.State)
}
if gotClientID := q.Get("client_id"); gotClientID != tt.wantClientID {
t.Fatalf("client_id mismatch: got=%q want=%q", gotClientID, tt.wantClientID)
}
if gotRedirect := q.Get("redirect_uri"); gotRedirect != tt.wantRedirect {
t.Fatalf("redirect_uri mismatch: got=%q want=%q", gotRedirect, tt.wantRedirect)
}
if gotScope := q.Get("scope"); gotScope != tt.wantScope {
t.Fatalf("scope mismatch: got=%q want=%q", gotScope, tt.wantScope)
}
if gotProjectID := q.Get("project_id"); gotProjectID != tt.wantProjectID {
t.Fatalf("project_id mismatch: got=%q want=%q", gotProjectID, tt.wantProjectID)
}
})
}
......
......@@ -20,13 +20,24 @@ const (
geminiModelFlash geminiModelClass = "flash"
)
type GeminiDailyQuota struct {
ProRPD int64
FlashRPD int64
type GeminiQuota struct {
// SharedRPD is a shared requests-per-day pool across models.
// When SharedRPD > 0, callers should treat ProRPD/FlashRPD as not applicable for daily quota checks.
SharedRPD int64 `json:"shared_rpd,omitempty"`
// SharedRPM is a shared requests-per-minute pool across models.
// When SharedRPM > 0, callers should treat ProRPM/FlashRPM as not applicable for minute quota checks.
SharedRPM int64 `json:"shared_rpm,omitempty"`
// Per-model quotas (AI Studio / API key).
// A value of -1 means "unlimited" (pay-as-you-go).
ProRPD int64 `json:"pro_rpd,omitempty"`
ProRPM int64 `json:"pro_rpm,omitempty"`
FlashRPD int64 `json:"flash_rpd,omitempty"`
FlashRPM int64 `json:"flash_rpm,omitempty"`
}
type GeminiTierPolicy struct {
Quota GeminiDailyQuota
Quota GeminiQuota
Cooldown time.Duration
}
......@@ -45,10 +56,27 @@ type GeminiUsageTotals struct {
const geminiQuotaCacheTTL = time.Minute
type geminiQuotaOverrides struct {
type geminiQuotaOverridesV1 struct {
Tiers map[string]config.GeminiTierQuotaConfig `json:"tiers"`
}
type geminiQuotaOverridesV2 struct {
QuotaRules map[string]geminiQuotaRuleOverride `json:"quota_rules"`
}
type geminiQuotaRuleOverride struct {
SharedRPD *int64 `json:"shared_rpd,omitempty"`
SharedRPM *int64 `json:"rpm,omitempty"`
GeminiPro *geminiModelQuotaOverride `json:"gemini_pro,omitempty"`
GeminiFlash *geminiModelQuotaOverride `json:"gemini_flash,omitempty"`
Desc *string `json:"desc,omitempty"`
}
type geminiModelQuotaOverride struct {
RPD *int64 `json:"rpd,omitempty"`
RPM *int64 `json:"rpm,omitempty"`
}
type GeminiQuotaService struct {
cfg *config.Config
settingRepo SettingRepository
......@@ -82,11 +110,17 @@ func (s *GeminiQuotaService) Policy(ctx context.Context) *GeminiQuotaPolicy {
if s.cfg != nil {
policy.ApplyOverrides(s.cfg.Gemini.Quota.Tiers)
if strings.TrimSpace(s.cfg.Gemini.Quota.Policy) != "" {
var overrides geminiQuotaOverrides
if err := json.Unmarshal([]byte(s.cfg.Gemini.Quota.Policy), &overrides); err != nil {
raw := []byte(s.cfg.Gemini.Quota.Policy)
var overridesV2 geminiQuotaOverridesV2
if err := json.Unmarshal(raw, &overridesV2); err == nil && len(overridesV2.QuotaRules) > 0 {
policy.ApplyQuotaRulesOverrides(overridesV2.QuotaRules)
} else {
var overridesV1 geminiQuotaOverridesV1
if err := json.Unmarshal(raw, &overridesV1); err != nil {
log.Printf("gemini quota: parse config policy failed: %v", err)
} else {
policy.ApplyOverrides(overrides.Tiers)
policy.ApplyOverrides(overridesV1.Tiers)
}
}
}
}
......@@ -96,11 +130,17 @@ func (s *GeminiQuotaService) Policy(ctx context.Context) *GeminiQuotaPolicy {
if err != nil && !errors.Is(err, ErrSettingNotFound) {
log.Printf("gemini quota: load setting failed: %v", err)
} else if strings.TrimSpace(value) != "" {
var overrides geminiQuotaOverrides
if err := json.Unmarshal([]byte(value), &overrides); err != nil {
raw := []byte(value)
var overridesV2 geminiQuotaOverridesV2
if err := json.Unmarshal(raw, &overridesV2); err == nil && len(overridesV2.QuotaRules) > 0 {
policy.ApplyQuotaRulesOverrides(overridesV2.QuotaRules)
} else {
var overridesV1 geminiQuotaOverridesV1
if err := json.Unmarshal(raw, &overridesV1); err != nil {
log.Printf("gemini quota: parse setting failed: %v", err)
} else {
policy.ApplyOverrides(overrides.Tiers)
policy.ApplyOverrides(overridesV1.Tiers)
}
}
}
}
......@@ -113,12 +153,20 @@ func (s *GeminiQuotaService) Policy(ctx context.Context) *GeminiQuotaPolicy {
return policy
}
func (s *GeminiQuotaService) QuotaForAccount(ctx context.Context, account *Account) (GeminiDailyQuota, bool) {
if account == nil || !account.IsGeminiCodeAssist() {
return GeminiDailyQuota{}, false
func (s *GeminiQuotaService) QuotaForAccount(ctx context.Context, account *Account) (GeminiQuota, bool) {
if account == nil || account.Platform != PlatformGemini {
return GeminiQuota{}, false
}
// Map (oauth_type + tier_id) to a canonical policy tier key.
// This keeps the policy table stable even if upstream tier_id strings vary.
tierKey := geminiQuotaTierKeyForAccount(account)
if tierKey == "" {
return GeminiQuota{}, false
}
policy := s.Policy(ctx)
return policy.QuotaForTier(account.GeminiTierID())
return policy.QuotaForTier(tierKey)
}
func (s *GeminiQuotaService) CooldownForTier(ctx context.Context, tierID string) time.Duration {
......@@ -126,12 +174,36 @@ func (s *GeminiQuotaService) CooldownForTier(ctx context.Context, tierID string)
return policy.CooldownForTier(tierID)
}
func (s *GeminiQuotaService) CooldownForAccount(ctx context.Context, account *Account) time.Duration {
if s == nil || account == nil || account.Platform != PlatformGemini {
return 5 * time.Minute
}
tierKey := geminiQuotaTierKeyForAccount(account)
if strings.TrimSpace(tierKey) == "" {
return 5 * time.Minute
}
return s.CooldownForTier(ctx, tierKey)
}
func newGeminiQuotaPolicy() *GeminiQuotaPolicy {
return &GeminiQuotaPolicy{
tiers: map[string]GeminiTierPolicy{
"LEGACY": {Quota: GeminiDailyQuota{ProRPD: 50, FlashRPD: 1500}, Cooldown: 30 * time.Minute},
"PRO": {Quota: GeminiDailyQuota{ProRPD: 1500, FlashRPD: 4000}, Cooldown: 5 * time.Minute},
"ULTRA": {Quota: GeminiDailyQuota{ProRPD: 2000, FlashRPD: 0}, Cooldown: 5 * time.Minute},
// --- AI Studio / API Key (per-model) ---
// aistudio_free:
// - gemini_pro: 50 RPD / 2 RPM
// - gemini_flash: 1500 RPD / 15 RPM
GeminiTierAIStudioFree: {Quota: GeminiQuota{ProRPD: 50, ProRPM: 2, FlashRPD: 1500, FlashRPM: 15}, Cooldown: 30 * time.Minute},
// aistudio_paid: -1 means "unlimited/pay-as-you-go" for RPD.
GeminiTierAIStudioPaid: {Quota: GeminiQuota{ProRPD: -1, ProRPM: 1000, FlashRPD: -1, FlashRPM: 2000}, Cooldown: 5 * time.Minute},
// --- Google One (shared pool) ---
GeminiTierGoogleOneFree: {Quota: GeminiQuota{SharedRPD: 1000, SharedRPM: 60}, Cooldown: 30 * time.Minute},
GeminiTierGoogleAIPro: {Quota: GeminiQuota{SharedRPD: 1500, SharedRPM: 120}, Cooldown: 5 * time.Minute},
GeminiTierGoogleAIUltra: {Quota: GeminiQuota{SharedRPD: 2000, SharedRPM: 120}, Cooldown: 5 * time.Minute},
// --- GCP Code Assist (shared pool) ---
GeminiTierGCPStandard: {Quota: GeminiQuota{SharedRPD: 1500, SharedRPM: 120}, Cooldown: 5 * time.Minute},
GeminiTierGCPEnterprise: {Quota: GeminiQuota{SharedRPD: 2000, SharedRPM: 120}, Cooldown: 5 * time.Minute},
},
}
}
......@@ -149,11 +221,22 @@ func (p *GeminiQuotaPolicy) ApplyOverrides(tiers map[string]config.GeminiTierQuo
if !ok {
policy = GeminiTierPolicy{Cooldown: 5 * time.Minute}
}
// Backward-compatible overrides:
// - If the tier uses shared quota, interpret pro_rpd as shared_rpd.
// - Otherwise apply per-model overrides.
if override.ProRPD != nil {
policy.Quota.ProRPD = clampGeminiQuotaInt64(*override.ProRPD)
if policy.Quota.SharedRPD > 0 {
policy.Quota.SharedRPD = clampGeminiQuotaInt64WithUnlimited(*override.ProRPD)
} else {
policy.Quota.ProRPD = clampGeminiQuotaInt64WithUnlimited(*override.ProRPD)
}
}
if override.FlashRPD != nil {
policy.Quota.FlashRPD = clampGeminiQuotaInt64(*override.FlashRPD)
if policy.Quota.SharedRPD > 0 {
// No separate flash RPD for shared tiers.
} else {
policy.Quota.FlashRPD = clampGeminiQuotaInt64WithUnlimited(*override.FlashRPD)
}
}
if override.CooldownMinutes != nil {
minutes := clampGeminiQuotaInt(*override.CooldownMinutes)
......@@ -163,10 +246,51 @@ func (p *GeminiQuotaPolicy) ApplyOverrides(tiers map[string]config.GeminiTierQuo
}
}
func (p *GeminiQuotaPolicy) QuotaForTier(tierID string) (GeminiDailyQuota, bool) {
func (p *GeminiQuotaPolicy) ApplyQuotaRulesOverrides(rules map[string]geminiQuotaRuleOverride) {
if p == nil || len(rules) == 0 {
return
}
for rawID, override := range rules {
tierID := normalizeGeminiTierID(rawID)
if tierID == "" {
continue
}
policy, ok := p.tiers[tierID]
if !ok {
policy = GeminiTierPolicy{Cooldown: 5 * time.Minute}
}
if override.SharedRPD != nil {
policy.Quota.SharedRPD = clampGeminiQuotaInt64WithUnlimited(*override.SharedRPD)
}
if override.SharedRPM != nil {
policy.Quota.SharedRPM = clampGeminiQuotaRPM(*override.SharedRPM)
}
if override.GeminiPro != nil {
if override.GeminiPro.RPD != nil {
policy.Quota.ProRPD = clampGeminiQuotaInt64WithUnlimited(*override.GeminiPro.RPD)
}
if override.GeminiPro.RPM != nil {
policy.Quota.ProRPM = clampGeminiQuotaRPM(*override.GeminiPro.RPM)
}
}
if override.GeminiFlash != nil {
if override.GeminiFlash.RPD != nil {
policy.Quota.FlashRPD = clampGeminiQuotaInt64WithUnlimited(*override.GeminiFlash.RPD)
}
if override.GeminiFlash.RPM != nil {
policy.Quota.FlashRPM = clampGeminiQuotaRPM(*override.GeminiFlash.RPM)
}
}
p.tiers[tierID] = policy
}
}
func (p *GeminiQuotaPolicy) QuotaForTier(tierID string) (GeminiQuota, bool) {
policy, ok := p.policyForTier(tierID)
if !ok {
return GeminiDailyQuota{}, false
return GeminiQuota{}, false
}
return policy.Quota, true
}
......@@ -184,22 +308,43 @@ func (p *GeminiQuotaPolicy) policyForTier(tierID string) (GeminiTierPolicy, bool
return GeminiTierPolicy{}, false
}
normalized := normalizeGeminiTierID(tierID)
if normalized == "" {
normalized = "LEGACY"
}
if policy, ok := p.tiers[normalized]; ok {
return policy, true
}
policy, ok := p.tiers["LEGACY"]
return policy, ok
return GeminiTierPolicy{}, false
}
func normalizeGeminiTierID(tierID string) string {
return strings.ToUpper(strings.TrimSpace(tierID))
tierID = strings.TrimSpace(tierID)
if tierID == "" {
return ""
}
// Prefer canonical mapping (handles legacy tier strings).
if canonical := canonicalGeminiTierID(tierID); canonical != "" {
return canonical
}
// Accept older policy keys that used uppercase names.
switch strings.ToUpper(tierID) {
case "AISTUDIO_FREE":
return GeminiTierAIStudioFree
case "AISTUDIO_PAID":
return GeminiTierAIStudioPaid
case "GOOGLE_ONE_FREE":
return GeminiTierGoogleOneFree
case "GOOGLE_AI_PRO":
return GeminiTierGoogleAIPro
case "GOOGLE_AI_ULTRA":
return GeminiTierGoogleAIUltra
case "GCP_STANDARD":
return GeminiTierGCPStandard
case "GCP_ENTERPRISE":
return GeminiTierGCPEnterprise
}
return strings.ToLower(tierID)
}
func clampGeminiQuotaInt64(value int64) int64 {
if value < 0 {
func clampGeminiQuotaInt64WithUnlimited(value int64) int64 {
if value < -1 {
return 0
}
return value
......@@ -212,11 +357,46 @@ func clampGeminiQuotaInt(value int) int {
return value
}
func clampGeminiQuotaRPM(value int64) int64 {
if value < 0 {
return 0
}
return value
}
func geminiCooldownForTier(tierID string) time.Duration {
policy := newGeminiQuotaPolicy()
return policy.CooldownForTier(tierID)
}
func geminiQuotaTierKeyForAccount(account *Account) string {
if account == nil || account.Platform != PlatformGemini {
return ""
}
// Note: GeminiOAuthType() already defaults legacy (project_id present) to code_assist.
oauthType := strings.ToLower(strings.TrimSpace(account.GeminiOAuthType()))
rawTier := strings.TrimSpace(account.GeminiTierID())
// Prefer the canonical tier stored in credentials.
if tierID := canonicalGeminiTierIDForOAuthType(oauthType, rawTier); tierID != "" && tierID != GeminiTierGoogleOneUnknown {
return tierID
}
// Fallback defaults when tier_id is missing or unknown.
switch oauthType {
case "google_one":
return GeminiTierGoogleOneFree
case "code_assist":
return GeminiTierGCPStandard
case "ai_studio":
return GeminiTierAIStudioFree
default:
// API Key accounts (type=apikey) have empty oauth_type and are treated as AI Studio.
return GeminiTierAIStudioFree
}
}
func geminiModelClassFromName(model string) geminiModelClass {
name := strings.ToLower(strings.TrimSpace(model))
if strings.Contains(name, "flash") || strings.Contains(name, "lite") {
......
......@@ -490,7 +490,7 @@ func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Acco
return "", "", errors.New("access_token not found in credentials")
}
return accessToken, "oauth", nil
case AccountTypeApiKey:
case AccountTypeAPIKey:
apiKey := account.GetOpenAIApiKey()
if apiKey == "" {
return "", "", errors.New("api_key not found in credentials")
......@@ -630,7 +630,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
case AccountTypeOAuth:
// OAuth accounts use ChatGPT internal API
targetURL = chatgptCodexURL
case AccountTypeApiKey:
case AccountTypeAPIKey:
// API Key accounts use Platform API or custom base URL
baseURL := account.GetOpenAIBaseURL()
if baseURL == "" {
......@@ -710,7 +710,13 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
}
// Handle upstream error (mark account status)
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
shouldDisable := false
if s.rateLimitService != nil {
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
if shouldDisable {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
// Return appropriate error response
var errType, errMsg string
......@@ -1065,7 +1071,7 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
// OpenAIRecordUsageInput input for recording usage
type OpenAIRecordUsageInput struct {
Result *OpenAIForwardResult
ApiKey *ApiKey
APIKey *APIKey
User *User
Account *Account
Subscription *UserSubscription
......@@ -1074,7 +1080,7 @@ type OpenAIRecordUsageInput struct {
// RecordUsage records usage and deducts balance
func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error {
result := input.Result
apiKey := input.ApiKey
apiKey := input.APIKey
user := input.User
account := input.Account
subscription := input.Subscription
......@@ -1116,7 +1122,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
durationMs := int(result.Duration.Milliseconds())
usageLog := &UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: result.RequestID,
Model: result.Model,
......@@ -1145,22 +1151,23 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
usageLog.SubscriptionID = &subscription.ID
}
_ = s.usageLogRepo.Create(ctx, usageLog)
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil
}
shouldBill := inserted || err != nil
// Deduct based on billing type
if isSubscriptionBilling {
if cost.TotalCost > 0 {
if shouldBill && cost.TotalCost > 0 {
_ = s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost)
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
}
} else {
if cost.ActualCost > 0 {
if shouldBill && cost.ActualCost > 0 {
_ = s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost)
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
}
......
......@@ -2,6 +2,7 @@ package service
import (
"context"
"encoding/json"
"log"
"net/http"
"strconv"
......@@ -18,6 +19,7 @@ type RateLimitService struct {
usageRepo UsageLogRepository
cfg *config.Config
geminiQuotaService *GeminiQuotaService
tempUnschedCache TempUnschedCache
usageCacheMu sync.RWMutex
usageCache map[int64]*geminiUsageCacheEntry
}
......@@ -31,12 +33,13 @@ type geminiUsageCacheEntry struct {
const geminiPrecheckCacheTTL = time.Minute
// NewRateLimitService 创建RateLimitService实例
func NewRateLimitService(accountRepo AccountRepository, usageRepo UsageLogRepository, cfg *config.Config, geminiQuotaService *GeminiQuotaService) *RateLimitService {
func NewRateLimitService(accountRepo AccountRepository, usageRepo UsageLogRepository, cfg *config.Config, geminiQuotaService *GeminiQuotaService, tempUnschedCache TempUnschedCache) *RateLimitService {
return &RateLimitService{
accountRepo: accountRepo,
usageRepo: usageRepo,
cfg: cfg,
geminiQuotaService: geminiQuotaService,
tempUnschedCache: tempUnschedCache,
usageCache: make(map[int64]*geminiUsageCacheEntry),
}
}
......@@ -51,38 +54,45 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
return false
}
tempMatched := s.tryTempUnschedulable(ctx, account, statusCode, responseBody)
switch statusCode {
case 401:
// 认证失败:停止调度,记录错误
s.handleAuthError(ctx, account, "Authentication failed (401): invalid or expired credentials")
return true
shouldDisable = true
case 402:
// 支付要求:余额不足或计费问题,停止调度
s.handleAuthError(ctx, account, "Payment required (402): insufficient balance or billing issue")
return true
shouldDisable = true
case 403:
// 禁止访问:停止调度,记录错误
s.handleAuthError(ctx, account, "Access forbidden (403): account may be suspended or lack permissions")
return true
shouldDisable = true
case 429:
s.handle429(ctx, account, headers)
return false
shouldDisable = false
case 529:
s.handle529(ctx, account)
return false
shouldDisable = false
default:
// 其他5xx错误:记录但不停止调度
if statusCode >= 500 {
log.Printf("Account %d received upstream error %d", account.ID, statusCode)
}
return false
shouldDisable = false
}
if tempMatched {
return true
}
return shouldDisable
}
// PreCheckUsage proactively checks local quota before dispatching a request.
// Returns false when the account should be skipped.
func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, requestedModel string) (bool, error) {
if account == nil || !account.IsGeminiCodeAssist() || strings.TrimSpace(requestedModel) == "" {
if account == nil || account.Platform != PlatformGemini {
return true, nil
}
if s.usageRepo == nil || s.geminiQuotaService == nil {
......@@ -94,18 +104,24 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
return true, nil
}
now := time.Now()
modelClass := geminiModelClassFromName(requestedModel)
// 1) Daily quota precheck (RPD; resets at PST midnight)
{
var limit int64
switch geminiModelClassFromName(requestedModel) {
if quota.SharedRPD > 0 {
limit = quota.SharedRPD
} else {
switch modelClass {
case geminiModelFlash:
limit = quota.FlashRPD
default:
limit = quota.ProRPD
}
if limit <= 0 {
return true, nil
}
now := time.Now()
if limit > 0 {
start := geminiDailyWindowStart(now)
totals, ok := s.getGeminiUsageTotals(account.ID, start, now)
if !ok {
......@@ -118,21 +134,70 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
}
var used int64
switch geminiModelClassFromName(requestedModel) {
if quota.SharedRPD > 0 {
used = totals.ProRequests + totals.FlashRequests
} else {
switch modelClass {
case geminiModelFlash:
used = totals.FlashRequests
default:
used = totals.ProRequests
}
}
if used >= limit {
resetAt := geminiDailyResetTime(now)
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
// NOTE:
// - This is a local precheck to reduce upstream 429s.
// - Do NOT mark the account as rate-limited here; rate_limit_reset_at should reflect real upstream 429s.
log.Printf("[Gemini PreCheck] Account %d reached daily quota (%d/%d), skip until %v", account.ID, used, limit, resetAt)
return false, nil
}
}
}
// 2) Minute quota precheck (RPM; fixed window current minute)
{
var limit int64
if quota.SharedRPM > 0 {
limit = quota.SharedRPM
} else {
switch modelClass {
case geminiModelFlash:
limit = quota.FlashRPM
default:
limit = quota.ProRPM
}
}
log.Printf("[Gemini PreCheck] Account %d reached daily quota (%d/%d), rate limited until %v", account.ID, used, limit, resetAt)
if limit > 0 {
start := now.Truncate(time.Minute)
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID)
if err != nil {
return true, err
}
totals := geminiAggregateUsage(stats)
var used int64
if quota.SharedRPM > 0 {
used = totals.ProRequests + totals.FlashRequests
} else {
switch modelClass {
case geminiModelFlash:
used = totals.FlashRequests
default:
used = totals.ProRequests
}
}
if used >= limit {
resetAt := start.Add(time.Minute)
// Do not persist "rate limited" status from local precheck. See note above.
log.Printf("[Gemini PreCheck] Account %d reached minute quota (%d/%d), skip until %v", account.ID, used, limit, resetAt)
return false, nil
}
}
}
return true, nil
}
......@@ -176,7 +241,10 @@ func (s *RateLimitService) GeminiCooldown(ctx context.Context, account *Account)
if account == nil {
return 5 * time.Minute
}
return s.geminiQuotaService.CooldownForTier(ctx, account.GeminiTierID())
if s.geminiQuotaService == nil {
return 5 * time.Minute
}
return s.geminiQuotaService.CooldownForAccount(ctx, account)
}
// handleAuthError 处理认证类错误(401/403),停止账号调度
......@@ -287,3 +355,183 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc
func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64) error {
return s.accountRepo.ClearRateLimit(ctx, accountID)
}
func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID int64) error {
if err := s.accountRepo.ClearTempUnschedulable(ctx, accountID); err != nil {
return err
}
if s.tempUnschedCache != nil {
if err := s.tempUnschedCache.DeleteTempUnsched(ctx, accountID); err != nil {
log.Printf("DeleteTempUnsched failed for account %d: %v", accountID, err)
}
}
return nil
}
func (s *RateLimitService) GetTempUnschedStatus(ctx context.Context, accountID int64) (*TempUnschedState, error) {
now := time.Now().Unix()
if s.tempUnschedCache != nil {
state, err := s.tempUnschedCache.GetTempUnsched(ctx, accountID)
if err != nil {
return nil, err
}
if state != nil && state.UntilUnix > now {
return state, nil
}
}
account, err := s.accountRepo.GetByID(ctx, accountID)
if err != nil {
return nil, err
}
if account.TempUnschedulableUntil == nil {
return nil, nil
}
if account.TempUnschedulableUntil.Unix() <= now {
return nil, nil
}
state := &TempUnschedState{
UntilUnix: account.TempUnschedulableUntil.Unix(),
}
if account.TempUnschedulableReason != "" {
var parsed TempUnschedState
if err := json.Unmarshal([]byte(account.TempUnschedulableReason), &parsed); err == nil {
if parsed.UntilUnix == 0 {
parsed.UntilUnix = state.UntilUnix
}
state = &parsed
} else {
state.ErrorMessage = account.TempUnschedulableReason
}
}
if s.tempUnschedCache != nil {
if err := s.tempUnschedCache.SetTempUnsched(ctx, accountID, state); err != nil {
log.Printf("SetTempUnsched failed for account %d: %v", accountID, err)
}
}
return state, nil
}
func (s *RateLimitService) HandleTempUnschedulable(ctx context.Context, account *Account, statusCode int, responseBody []byte) bool {
if account == nil {
return false
}
if !account.ShouldHandleErrorCode(statusCode) {
return false
}
return s.tryTempUnschedulable(ctx, account, statusCode, responseBody)
}
const tempUnschedBodyMaxBytes = 64 << 10
const tempUnschedMessageMaxBytes = 2048
func (s *RateLimitService) tryTempUnschedulable(ctx context.Context, account *Account, statusCode int, responseBody []byte) bool {
if account == nil {
return false
}
if !account.IsTempUnschedulableEnabled() {
return false
}
rules := account.GetTempUnschedulableRules()
if len(rules) == 0 {
return false
}
if statusCode <= 0 || len(responseBody) == 0 {
return false
}
body := responseBody
if len(body) > tempUnschedBodyMaxBytes {
body = body[:tempUnschedBodyMaxBytes]
}
bodyLower := strings.ToLower(string(body))
for idx, rule := range rules {
if rule.ErrorCode != statusCode || len(rule.Keywords) == 0 {
continue
}
matchedKeyword := matchTempUnschedKeyword(bodyLower, rule.Keywords)
if matchedKeyword == "" {
continue
}
if s.triggerTempUnschedulable(ctx, account, rule, idx, statusCode, matchedKeyword, responseBody) {
return true
}
}
return false
}
func matchTempUnschedKeyword(bodyLower string, keywords []string) string {
if bodyLower == "" {
return ""
}
for _, keyword := range keywords {
k := strings.TrimSpace(keyword)
if k == "" {
continue
}
if strings.Contains(bodyLower, strings.ToLower(k)) {
return k
}
}
return ""
}
func (s *RateLimitService) triggerTempUnschedulable(ctx context.Context, account *Account, rule TempUnschedulableRule, ruleIndex int, statusCode int, matchedKeyword string, responseBody []byte) bool {
if account == nil {
return false
}
if rule.DurationMinutes <= 0 {
return false
}
now := time.Now()
until := now.Add(time.Duration(rule.DurationMinutes) * time.Minute)
state := &TempUnschedState{
UntilUnix: until.Unix(),
TriggeredAtUnix: now.Unix(),
StatusCode: statusCode,
MatchedKeyword: matchedKeyword,
RuleIndex: ruleIndex,
ErrorMessage: truncateTempUnschedMessage(responseBody, tempUnschedMessageMaxBytes),
}
reason := ""
if raw, err := json.Marshal(state); err == nil {
reason = string(raw)
}
if reason == "" {
reason = strings.TrimSpace(state.ErrorMessage)
}
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
log.Printf("SetTempUnschedulable failed for account %d: %v", account.ID, err)
return false
}
if s.tempUnschedCache != nil {
if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil {
log.Printf("SetTempUnsched cache failed for account %d: %v", account.ID, err)
}
}
log.Printf("Account %d temp unschedulable until %v (rule %d, code %d)", account.ID, until, ruleIndex, statusCode)
return true
}
func truncateTempUnschedMessage(body []byte, maxBytes int) string {
if maxBytes <= 0 || len(body) == 0 {
return ""
}
if len(body) > maxBytes {
body = body[:maxBytes]
}
return strings.TrimSpace(string(body))
}
......@@ -61,9 +61,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeySiteName,
SettingKeySiteLogo,
SettingKeySiteSubtitle,
SettingKeyApiBaseUrl,
SettingKeyAPIBaseURL,
SettingKeyContactInfo,
SettingKeyDocUrl,
SettingKeyDocURL,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
......@@ -79,9 +79,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
SiteLogo: settings[SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
ApiBaseUrl: settings[SettingKeyApiBaseUrl],
APIBaseURL: settings[SettingKeyAPIBaseURL],
ContactInfo: settings[SettingKeyContactInfo],
DocUrl: settings[SettingKeyDocUrl],
DocURL: settings[SettingKeyDocURL],
}, nil
}
......@@ -94,15 +94,15 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
// 邮件服务设置(只有非空才更新密码)
updates[SettingKeySmtpHost] = settings.SmtpHost
updates[SettingKeySmtpPort] = strconv.Itoa(settings.SmtpPort)
updates[SettingKeySmtpUsername] = settings.SmtpUsername
if settings.SmtpPassword != "" {
updates[SettingKeySmtpPassword] = settings.SmtpPassword
updates[SettingKeySMTPHost] = settings.SMTPHost
updates[SettingKeySMTPPort] = strconv.Itoa(settings.SMTPPort)
updates[SettingKeySMTPUsername] = settings.SMTPUsername
if settings.SMTPPassword != "" {
updates[SettingKeySMTPPassword] = settings.SMTPPassword
}
updates[SettingKeySmtpFrom] = settings.SmtpFrom
updates[SettingKeySmtpFromName] = settings.SmtpFromName
updates[SettingKeySmtpUseTLS] = strconv.FormatBool(settings.SmtpUseTLS)
updates[SettingKeySMTPFrom] = settings.SMTPFrom
updates[SettingKeySMTPFromName] = settings.SMTPFromName
updates[SettingKeySMTPUseTLS] = strconv.FormatBool(settings.SMTPUseTLS)
// Cloudflare Turnstile 设置(只有非空才更新密钥)
updates[SettingKeyTurnstileEnabled] = strconv.FormatBool(settings.TurnstileEnabled)
......@@ -115,14 +115,21 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeySiteName] = settings.SiteName
updates[SettingKeySiteLogo] = settings.SiteLogo
updates[SettingKeySiteSubtitle] = settings.SiteSubtitle
updates[SettingKeyApiBaseUrl] = settings.ApiBaseUrl
updates[SettingKeyAPIBaseURL] = settings.APIBaseURL
updates[SettingKeyContactInfo] = settings.ContactInfo
updates[SettingKeyDocUrl] = settings.DocUrl
updates[SettingKeyDocURL] = settings.DocURL
// 默认配置
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
// Model fallback configuration
updates[SettingKeyEnableModelFallback] = strconv.FormatBool(settings.EnableModelFallback)
updates[SettingKeyFallbackModelAnthropic] = settings.FallbackModelAnthropic
updates[SettingKeyFallbackModelOpenAI] = settings.FallbackModelOpenAI
updates[SettingKeyFallbackModelGemini] = settings.FallbackModelGemini
updates[SettingKeyFallbackModelAntigravity] = settings.FallbackModelAntigravity
return s.settingRepo.SetMultiple(ctx, updates)
}
......@@ -198,8 +205,14 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeySiteLogo: "",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
SettingKeySmtpPort: "587",
SettingKeySmtpUseTLS: "false",
SettingKeySMTPPort: "587",
SettingKeySMTPUseTLS: "false",
// Model fallback defaults
SettingKeyEnableModelFallback: "false",
SettingKeyFallbackModelAnthropic: "claude-3-5-sonnet-20241022",
SettingKeyFallbackModelOpenAI: "gpt-4o",
SettingKeyFallbackModelGemini: "gemini-2.5-pro",
SettingKeyFallbackModelAntigravity: "gemini-2.5-pro",
}
return s.settingRepo.SetMultiple(ctx, defaults)
......@@ -210,28 +223,28 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result := &SystemSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
SmtpHost: settings[SettingKeySmtpHost],
SmtpUsername: settings[SettingKeySmtpUsername],
SmtpFrom: settings[SettingKeySmtpFrom],
SmtpFromName: settings[SettingKeySmtpFromName],
SmtpUseTLS: settings[SettingKeySmtpUseTLS] == "true",
SmtpPasswordConfigured: settings[SettingKeySmtpPassword] != "",
SMTPHost: settings[SettingKeySMTPHost],
SMTPUsername: settings[SettingKeySMTPUsername],
SMTPFrom: settings[SettingKeySMTPFrom],
SMTPFromName: settings[SettingKeySMTPFromName],
SMTPUseTLS: settings[SettingKeySMTPUseTLS] == "true",
SMTPPasswordConfigured: settings[SettingKeySMTPPassword] != "",
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
TurnstileSecretKeyConfigured: settings[SettingKeyTurnstileSecretKey] != "",
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
SiteLogo: settings[SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
ApiBaseUrl: settings[SettingKeyApiBaseUrl],
APIBaseURL: settings[SettingKeyAPIBaseURL],
ContactInfo: settings[SettingKeyContactInfo],
DocUrl: settings[SettingKeyDocUrl],
DocURL: settings[SettingKeyDocURL],
}
// 解析整数类型
if port, err := strconv.Atoi(settings[SettingKeySmtpPort]); err == nil {
result.SmtpPort = port
if port, err := strconv.Atoi(settings[SettingKeySMTPPort]); err == nil {
result.SMTPPort = port
} else {
result.SmtpPort = 587
result.SMTPPort = 587
}
if concurrency, err := strconv.Atoi(settings[SettingKeyDefaultConcurrency]); err == nil {
......@@ -247,6 +260,17 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result.DefaultBalance = s.cfg.Default.UserBalance
}
// 敏感信息直接返回,方便测试连接时使用
result.SMTPPassword = settings[SettingKeySMTPPassword]
result.TurnstileSecretKey = settings[SettingKeyTurnstileSecretKey]
// Model fallback settings
result.EnableModelFallback = settings[SettingKeyEnableModelFallback] == "true"
result.FallbackModelAnthropic = s.getStringOrDefault(settings, SettingKeyFallbackModelAnthropic, "claude-3-5-sonnet-20241022")
result.FallbackModelOpenAI = s.getStringOrDefault(settings, SettingKeyFallbackModelOpenAI, "gpt-4o")
result.FallbackModelGemini = s.getStringOrDefault(settings, SettingKeyFallbackModelGemini, "gemini-2.5-pro")
result.FallbackModelAntigravity = s.getStringOrDefault(settings, SettingKeyFallbackModelAntigravity, "gemini-2.5-pro")
return result
}
......@@ -276,28 +300,28 @@ func (s *SettingService) GetTurnstileSecretKey(ctx context.Context) string {
return value
}
// GenerateAdminApiKey 生成新的管理员 API Key
func (s *SettingService) GenerateAdminApiKey(ctx context.Context) (string, error) {
// GenerateAdminAPIKey 生成新的管理员 API Key
func (s *SettingService) GenerateAdminAPIKey(ctx context.Context) (string, error) {
// 生成 32 字节随机数 = 64 位十六进制字符
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("generate random bytes: %w", err)
}
key := AdminApiKeyPrefix + hex.EncodeToString(bytes)
key := AdminAPIKeyPrefix + hex.EncodeToString(bytes)
// 存储到 settings 表
if err := s.settingRepo.Set(ctx, SettingKeyAdminApiKey, key); err != nil {
if err := s.settingRepo.Set(ctx, SettingKeyAdminAPIKey, key); err != nil {
return "", fmt.Errorf("save admin api key: %w", err)
}
return key, nil
}
// GetAdminApiKeyStatus 获取管理员 API Key 状态
// GetAdminAPIKeyStatus 获取管理员 API Key 状态
// 返回脱敏的 key、是否存在、错误
func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey string, exists bool, err error) {
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminApiKey)
func (s *SettingService) GetAdminAPIKeyStatus(ctx context.Context) (maskedKey string, exists bool, err error) {
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminAPIKey)
if err != nil {
if errors.Is(err, ErrSettingNotFound) {
return "", false, nil
......@@ -318,10 +342,10 @@ func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey st
return maskedKey, true, nil
}
// GetAdminApiKey 获取完整的管理员 API Key(仅供内部验证使用)
// GetAdminAPIKey 获取完整的管理员 API Key(仅供内部验证使用)
// 如果未配置返回空字符串和 nil 错误,只有数据库错误时才返回 error
func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) {
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminApiKey)
func (s *SettingService) GetAdminAPIKey(ctx context.Context) (string, error) {
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminAPIKey)
if err != nil {
if errors.Is(err, ErrSettingNotFound) {
return "", nil // 未配置,返回空字符串
......@@ -331,7 +355,45 @@ func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) {
return key, nil
}
// DeleteAdminApiKey 删除管理员 API Key
func (s *SettingService) DeleteAdminApiKey(ctx context.Context) error {
return s.settingRepo.Delete(ctx, SettingKeyAdminApiKey)
// DeleteAdminAPIKey 删除管理员 API Key
func (s *SettingService) DeleteAdminAPIKey(ctx context.Context) error {
return s.settingRepo.Delete(ctx, SettingKeyAdminAPIKey)
}
// IsModelFallbackEnabled 检查是否启用模型兜底机制
func (s *SettingService) IsModelFallbackEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, SettingKeyEnableModelFallback)
if err != nil {
return false // Default: disabled
}
return value == "true"
}
// GetFallbackModel 获取指定平台的兜底模型
func (s *SettingService) GetFallbackModel(ctx context.Context, platform string) string {
var key string
var defaultModel string
switch platform {
case PlatformAnthropic:
key = SettingKeyFallbackModelAnthropic
defaultModel = "claude-3-5-sonnet-20241022"
case PlatformOpenAI:
key = SettingKeyFallbackModelOpenAI
defaultModel = "gpt-4o"
case PlatformGemini:
key = SettingKeyFallbackModelGemini
defaultModel = "gemini-2.5-pro"
case PlatformAntigravity:
key = SettingKeyFallbackModelAntigravity
defaultModel = "gemini-2.5-pro"
default:
return ""
}
value, err := s.settingRepo.GetValue(ctx, key)
if err != nil || value == "" {
return defaultModel
}
return value
}
......@@ -4,14 +4,14 @@ type SystemSettings struct {
RegistrationEnabled bool
EmailVerifyEnabled bool
SmtpHost string
SmtpPort int
SmtpUsername string
SmtpPassword string
SmtpPasswordConfigured bool
SmtpFrom string
SmtpFromName string
SmtpUseTLS bool
SMTPHost string
SMTPPort int
SMTPUsername string
SMTPPassword string
SMTPPasswordConfigured bool
SMTPFrom string
SMTPFromName string
SMTPUseTLS bool
TurnstileEnabled bool
TurnstileSiteKey string
......@@ -21,12 +21,19 @@ type SystemSettings struct {
SiteName string
SiteLogo string
SiteSubtitle string
ApiBaseUrl string
APIBaseURL string
ContactInfo string
DocUrl string
DocURL string
DefaultConcurrency int
DefaultBalance float64
// Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"`
FallbackModelAnthropic string `json:"fallback_model_anthropic"`
FallbackModelOpenAI string `json:"fallback_model_openai"`
FallbackModelGemini string `json:"fallback_model_gemini"`
FallbackModelAntigravity string `json:"fallback_model_antigravity"`
}
type PublicSettings struct {
......@@ -37,8 +44,8 @@ type PublicSettings struct {
SiteName string
SiteLogo string
SiteSubtitle string
ApiBaseUrl string
APIBaseURL string
ContactInfo string
DocUrl string
DocURL string
Version string
}
package service
import (
"context"
)
// TempUnschedState 临时不可调度状态
type TempUnschedState struct {
UntilUnix int64 `json:"until_unix"` // 解除时间(Unix 时间戳)
TriggeredAtUnix int64 `json:"triggered_at_unix"` // 触发时间(Unix 时间戳)
StatusCode int `json:"status_code"` // 触发的错误码
MatchedKeyword string `json:"matched_keyword"` // 匹配的关键词
RuleIndex int `json:"rule_index"` // 触发的规则索引
ErrorMessage string `json:"error_message"` // 错误消息
}
// TempUnschedCache 临时不可调度缓存接口
type TempUnschedCache interface {
SetTempUnsched(ctx context.Context, accountID int64, state *TempUnschedState) error
GetTempUnsched(ctx context.Context, accountID int64) (*TempUnschedState, error)
DeleteTempUnsched(ctx context.Context, accountID int64) error
}
......@@ -197,7 +197,7 @@ func TestClaudeTokenRefresher_CanRefresh(t *testing.T) {
{
name: "anthropic api-key - cannot refresh",
platform: PlatformAnthropic,
accType: AccountTypeApiKey,
accType: AccountTypeAPIKey,
want: false,
},
{
......
......@@ -79,7 +79,7 @@ type ReleaseInfo struct {
Name string `json:"name"`
Body string `json:"body"`
PublishedAt string `json:"published_at"`
HtmlURL string `json:"html_url"`
HTMLURL string `json:"html_url"`
Assets []Asset `json:"assets,omitempty"`
}
......@@ -96,13 +96,13 @@ type GitHubRelease struct {
Name string `json:"name"`
Body string `json:"body"`
PublishedAt string `json:"published_at"`
HtmlUrl string `json:"html_url"`
HTMLURL string `json:"html_url"`
Assets []GitHubAsset `json:"assets"`
}
type GitHubAsset struct {
Name string `json:"name"`
BrowserDownloadUrl string `json:"browser_download_url"`
BrowserDownloadURL string `json:"browser_download_url"`
Size int64 `json:"size"`
}
......@@ -285,7 +285,7 @@ func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, er
for i, a := range release.Assets {
assets[i] = Asset{
Name: a.Name,
DownloadURL: a.BrowserDownloadUrl,
DownloadURL: a.BrowserDownloadURL,
Size: a.Size,
}
}
......@@ -298,7 +298,7 @@ func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, er
Name: release.Name,
Body: release.Body,
PublishedAt: release.PublishedAt,
HtmlURL: release.HtmlUrl,
HTMLURL: release.HTMLURL,
Assets: assets,
},
Cached: false,
......
......@@ -10,7 +10,7 @@ const (
type UsageLog struct {
ID int64
UserID int64
ApiKeyID int64
APIKeyID int64
AccountID int64
RequestID string
Model string
......@@ -42,7 +42,7 @@ type UsageLog struct {
CreatedAt time.Time
User *User
ApiKey *ApiKey
APIKey *APIKey
Account *Account
Group *Group
Subscription *UserSubscription
......
......@@ -2,9 +2,11 @@ package service
import (
"context"
"errors"
"fmt"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
......@@ -17,7 +19,7 @@ var (
// CreateUsageLogRequest 创建使用日志请求
type CreateUsageLogRequest struct {
UserID int64 `json:"user_id"`
ApiKeyID int64 `json:"api_key_id"`
APIKeyID int64 `json:"api_key_id"`
AccountID int64 `json:"account_id"`
RequestID string `json:"request_id"`
Model string `json:"model"`
......@@ -54,20 +56,34 @@ type UsageStats struct {
type UsageService struct {
usageRepo UsageLogRepository
userRepo UserRepository
entClient *dbent.Client
}
// NewUsageService 创建使用统计服务实例
func NewUsageService(usageRepo UsageLogRepository, userRepo UserRepository) *UsageService {
func NewUsageService(usageRepo UsageLogRepository, userRepo UserRepository, entClient *dbent.Client) *UsageService {
return &UsageService{
usageRepo: usageRepo,
userRepo: userRepo,
entClient: entClient,
}
}
// Create 创建使用日志
func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*UsageLog, error) {
// 使用数据库事务保证「使用日志插入」与「扣费」的原子性,避免重复扣费或漏扣风险。
tx, err := s.entClient.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
return nil, fmt.Errorf("begin transaction: %w", err)
}
txCtx := ctx
if err == nil {
defer func() { _ = tx.Rollback() }()
txCtx = dbent.NewTxContext(ctx, tx)
}
// 验证用户存在
_, err := s.userRepo.GetByID(ctx, req.UserID)
_, err = s.userRepo.GetByID(txCtx, req.UserID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
......@@ -75,7 +91,7 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
// 创建使用日志
usageLog := &UsageLog{
UserID: req.UserID,
ApiKeyID: req.ApiKeyID,
APIKeyID: req.APIKeyID,
AccountID: req.AccountID,
RequestID: req.RequestID,
Model: req.Model,
......@@ -96,17 +112,24 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
DurationMs: req.DurationMs,
}
if err := s.usageRepo.Create(ctx, usageLog); err != nil {
inserted, err := s.usageRepo.Create(txCtx, usageLog)
if err != nil {
return nil, fmt.Errorf("create usage log: %w", err)
}
// 扣除用户余额
if req.ActualCost > 0 {
if err := s.userRepo.UpdateBalance(ctx, req.UserID, -req.ActualCost); err != nil {
if inserted && req.ActualCost > 0 {
if err := s.userRepo.UpdateBalance(txCtx, req.UserID, -req.ActualCost); err != nil {
return nil, fmt.Errorf("update user balance: %w", err)
}
}
if tx != nil {
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("commit transaction: %w", err)
}
}
return usageLog, nil
}
......@@ -128,9 +151,9 @@ func (s *UsageService) ListByUser(ctx context.Context, userID int64, params pagi
return logs, pagination, nil
}
// ListByApiKey 获取API Key的使用日志列表
func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) {
logs, pagination, err := s.usageRepo.ListByApiKey(ctx, apiKeyID, params)
// ListByAPIKey 获取API Key的使用日志列表
func (s *UsageService) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) {
logs, pagination, err := s.usageRepo.ListByAPIKey(ctx, apiKeyID, params)
if err != nil {
return nil, nil, fmt.Errorf("list usage logs: %w", err)
}
......@@ -165,9 +188,9 @@ func (s *UsageService) GetStatsByUser(ctx context.Context, userID int64, startTi
}, nil
}
// GetStatsByApiKey 获取API Key的使用统计
func (s *UsageService) GetStatsByApiKey(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*UsageStats, error) {
stats, err := s.usageRepo.GetApiKeyStatsAggregated(ctx, apiKeyID, startTime, endTime)
// GetStatsByAPIKey 获取API Key的使用统计
func (s *UsageService) GetStatsByAPIKey(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*UsageStats, error) {
stats, err := s.usageRepo.GetAPIKeyStatsAggregated(ctx, apiKeyID, startTime, endTime)
if err != nil {
return nil, fmt.Errorf("get api key stats: %w", err)
}
......@@ -270,9 +293,9 @@ func (s *UsageService) GetUserModelStats(ctx context.Context, userID int64, star
return stats, nil
}
// GetBatchApiKeyUsageStats returns today/total actual_cost for given api keys.
func (s *UsageService) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchApiKeyUsageStats(ctx, apiKeyIDs)
// GetBatchAPIKeyUsageStats returns today/total actual_cost for given api keys.
func (s *UsageService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs)
if err != nil {
return nil, fmt.Errorf("get batch api key usage stats: %w", err)
}
......
......@@ -21,7 +21,7 @@ type User struct {
CreatedAt time.Time
UpdatedAt time.Time
ApiKeys []ApiKey
APIKeys []APIKey
Subscriptions []UserSubscription
}
......
......@@ -56,6 +56,10 @@ func (s *UserAttributeService) CreateDefinition(ctx context.Context, input Creat
Enabled: input.Enabled,
}
if err := validateDefinitionPattern(def); err != nil {
return nil, err
}
if err := s.defRepo.Create(ctx, def); err != nil {
return nil, fmt.Errorf("create definition: %w", err)
}
......@@ -108,6 +112,10 @@ func (s *UserAttributeService) UpdateDefinition(ctx context.Context, id int64, i
def.Enabled = *input.Enabled
}
if err := validateDefinitionPattern(def); err != nil {
return nil, err
}
if err := s.defRepo.Update(ctx, def); err != nil {
return nil, fmt.Errorf("update definition: %w", err)
}
......@@ -231,7 +239,10 @@ func (s *UserAttributeService) validateValue(def *UserAttributeDefinition, value
// Pattern validation
if v.Pattern != nil && *v.Pattern != "" && value != "" {
re, err := regexp.Compile(*v.Pattern)
if err == nil && !re.MatchString(value) {
if err != nil {
return validationError(def.Name + " has an invalid pattern")
}
if !re.MatchString(value) {
msg := def.Name + " format is invalid"
if v.Message != nil && *v.Message != "" {
msg = *v.Message
......@@ -293,3 +304,20 @@ func isValidAttributeType(t UserAttributeType) bool {
}
return false
}
func validateDefinitionPattern(def *UserAttributeDefinition) error {
if def == nil {
return nil
}
if def.Validation.Pattern == nil {
return nil
}
pattern := strings.TrimSpace(*def.Validation.Pattern)
if pattern == "" {
return nil
}
if _, err := regexp.Compile(pattern); err != nil {
return infraerrors.BadRequest("INVALID_ATTRIBUTE_PATTERN", fmt.Sprintf("invalid pattern for %s: %v", def.Name, err))
}
return nil
}
......@@ -75,7 +75,7 @@ var ProviderSet = wire.NewSet(
// Core services
NewAuthService,
NewUserService,
NewApiKeyService,
NewAPIKeyService,
NewGroupService,
NewAccountService,
NewProxyService,
......
// Package setup provides CLI commands and application initialization helpers.
package setup
import (
......
......@@ -352,7 +352,7 @@ func writeConfigFile(cfg *SetupConfig) error {
Default struct {
UserConcurrency int `yaml:"user_concurrency"`
UserBalance float64 `yaml:"user_balance"`
ApiKeyPrefix string `yaml:"api_key_prefix"`
APIKeyPrefix string `yaml:"api_key_prefix"`
RateMultiplier float64 `yaml:"rate_multiplier"`
} `yaml:"default"`
RateLimit struct {
......@@ -374,12 +374,12 @@ func writeConfigFile(cfg *SetupConfig) error {
Default: struct {
UserConcurrency int `yaml:"user_concurrency"`
UserBalance float64 `yaml:"user_balance"`
ApiKeyPrefix string `yaml:"api_key_prefix"`
APIKeyPrefix string `yaml:"api_key_prefix"`
RateMultiplier float64 `yaml:"rate_multiplier"`
}{
UserConcurrency: 5,
UserBalance: 0,
ApiKeyPrefix: "sk-",
APIKeyPrefix: "sk-",
RateMultiplier: 1.0,
},
RateLimit: struct {
......
//go:build !embed
// Package web provides embedded web assets for the application.
package web
import (
......
-- 020_add_temp_unschedulable.sql
-- 添加临时不可调度功能相关字段
-- 添加临时不可调度状态解除时间字段
ALTER TABLE accounts ADD COLUMN IF NOT EXISTS temp_unschedulable_until timestamptz;
-- 添加临时不可调度原因字段(用于排障和审计)
ALTER TABLE accounts ADD COLUMN IF NOT EXISTS temp_unschedulable_reason text;
-- 添加索引以优化调度查询性能
CREATE INDEX IF NOT EXISTS idx_accounts_temp_unschedulable_until ON accounts(temp_unschedulable_until) WHERE deleted_at IS NULL;
-- 添加注释说明字段用途
COMMENT ON COLUMN accounts.temp_unschedulable_until IS '临时不可调度状态解除时间,当触发临时不可调度规则时设置(基于错误码或错误描述关键词)';
COMMENT ON COLUMN accounts.temp_unschedulable_reason IS '临时不可调度原因,记录触发临时不可调度的具体原因(用于排障和审计)';
-- Ops monitoring: pre-aggregation tables for dashboard queries
--
-- Problem:
-- The ops dashboard currently runs percentile_cont + GROUP BY queries over large raw tables
-- (usage_logs, ops_error_logs). These will get slower as data grows.
--
-- This migration adds schema-only aggregation tables that can be populated by a future background job.
-- No triggers/functions/jobs are created here (schema only).
-- ============================================
-- Hourly aggregates (per provider/platform)
-- ============================================
CREATE TABLE IF NOT EXISTS ops_metrics_hourly (
-- Start of the hour bucket (recommended: UTC).
bucket_start TIMESTAMPTZ NOT NULL,
-- Provider/platform label (e.g. anthropic/openai/gemini). Mirrors ops_* queries that GROUP BY platform.
platform VARCHAR(50) NOT NULL,
-- Traffic counts (use these to compute rates reliably across ranges).
request_count BIGINT NOT NULL DEFAULT 0,
success_count BIGINT NOT NULL DEFAULT 0,
error_count BIGINT NOT NULL DEFAULT 0,
-- Error breakdown used by provider health UI.
error_4xx_count BIGINT NOT NULL DEFAULT 0,
error_5xx_count BIGINT NOT NULL DEFAULT 0,
timeout_count BIGINT NOT NULL DEFAULT 0,
-- Latency aggregates (ms).
avg_latency_ms DOUBLE PRECISION,
p99_latency_ms DOUBLE PRECISION,
-- Convenience rate (percentage, 0-100). Still keep counts as source of truth.
error_rate DOUBLE PRECISION NOT NULL DEFAULT 0,
-- When this row was last (re)computed by the background job.
computed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
PRIMARY KEY (bucket_start, platform)
);
CREATE INDEX IF NOT EXISTS idx_ops_metrics_hourly_platform_bucket_start
ON ops_metrics_hourly (platform, bucket_start DESC);
COMMENT ON TABLE ops_metrics_hourly IS 'Pre-aggregated hourly ops metrics by provider/platform to speed up dashboard queries.';
COMMENT ON COLUMN ops_metrics_hourly.bucket_start IS 'Start timestamp of the hour bucket (recommended UTC).';
COMMENT ON COLUMN ops_metrics_hourly.platform IS 'Provider/platform label (anthropic/openai/gemini, etc).';
COMMENT ON COLUMN ops_metrics_hourly.error_rate IS 'Error rate percentage for the bucket (0-100). Counts remain the source of truth.';
COMMENT ON COLUMN ops_metrics_hourly.computed_at IS 'When the row was last computed/refreshed.';
-- ============================================
-- Daily aggregates (per provider/platform)
-- ============================================
CREATE TABLE IF NOT EXISTS ops_metrics_daily (
-- Day bucket (recommended: UTC date).
bucket_date DATE NOT NULL,
platform VARCHAR(50) NOT NULL,
request_count BIGINT NOT NULL DEFAULT 0,
success_count BIGINT NOT NULL DEFAULT 0,
error_count BIGINT NOT NULL DEFAULT 0,
error_4xx_count BIGINT NOT NULL DEFAULT 0,
error_5xx_count BIGINT NOT NULL DEFAULT 0,
timeout_count BIGINT NOT NULL DEFAULT 0,
avg_latency_ms DOUBLE PRECISION,
p99_latency_ms DOUBLE PRECISION,
error_rate DOUBLE PRECISION NOT NULL DEFAULT 0,
computed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
PRIMARY KEY (bucket_date, platform)
);
CREATE INDEX IF NOT EXISTS idx_ops_metrics_daily_platform_bucket_date
ON ops_metrics_daily (platform, bucket_date DESC);
COMMENT ON TABLE ops_metrics_daily IS 'Pre-aggregated daily ops metrics by provider/platform for longer-term trends.';
COMMENT ON COLUMN ops_metrics_daily.bucket_date IS 'UTC date of the day bucket (recommended).';
-- ============================================
-- Population strategy (future background job)
-- ============================================
--
-- Suggested approach:
-- 1) Compute hourly buckets from raw logs using UTC time-bucketing, then UPSERT into ops_metrics_hourly.
-- 2) Compute daily buckets either directly from raw logs or by rolling up ops_metrics_hourly.
--
-- Notes:
-- - Ensure the job uses a consistent timezone (recommended: SET TIME ZONE ''UTC'') to avoid bucket drift.
-- - Derive the provider/platform similarly to existing dashboard queries:
-- usage_logs: COALESCE(NULLIF(groups.platform, ''), accounts.platform, '')
-- ops_error_logs: COALESCE(NULLIF(ops_error_logs.platform, ''), groups.platform, accounts.platform, '')
-- - Keep request_count/success_count/error_count as the authoritative values; compute error_rate from counts.
--
-- Example (hourly) shape (pseudo-SQL):
-- INSERT INTO ops_metrics_hourly (...)
-- SELECT date_trunc('hour', created_at) AS bucket_start, platform, ...
-- FROM (/* aggregate usage_logs + ops_error_logs */) s
-- ON CONFLICT (bucket_start, platform) DO UPDATE SET ...;
-- 027_usage_billing_consistency.sql
-- Ensure usage_logs idempotency (request_id, api_key_id) and add reconciliation infrastructure.
-- -----------------------------------------------------------------------------
-- 1) Normalize legacy request_id values
-- -----------------------------------------------------------------------------
-- Historically request_id may be inserted as empty string. Convert it to NULL so
-- the upcoming unique index does not break on repeated "" values.
UPDATE usage_logs
SET request_id = NULL
WHERE request_id = '';
-- If duplicates already exist for the same (request_id, api_key_id), keep the
-- first row and NULL-out request_id for the rest so the unique index can be
-- created without deleting historical logs.
WITH ranked AS (
SELECT
id,
ROW_NUMBER() OVER (PARTITION BY api_key_id, request_id ORDER BY id) AS rn
FROM usage_logs
WHERE request_id IS NOT NULL
)
UPDATE usage_logs ul
SET request_id = NULL
FROM ranked r
WHERE ul.id = r.id
AND r.rn > 1;
-- -----------------------------------------------------------------------------
-- 2) Idempotency constraint for usage_logs
-- -----------------------------------------------------------------------------
CREATE UNIQUE INDEX IF NOT EXISTS idx_usage_logs_request_id_api_key_unique
ON usage_logs (request_id, api_key_id);
-- -----------------------------------------------------------------------------
-- 3) Reconciliation infrastructure: billing ledger for usage charges
-- -----------------------------------------------------------------------------
CREATE TABLE IF NOT EXISTS billing_usage_entries (
id BIGSERIAL PRIMARY KEY,
usage_log_id BIGINT NOT NULL REFERENCES usage_logs(id) ON DELETE CASCADE,
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
api_key_id BIGINT NOT NULL REFERENCES api_keys(id) ON DELETE CASCADE,
subscription_id BIGINT REFERENCES user_subscriptions(id) ON DELETE SET NULL,
billing_type SMALLINT NOT NULL,
applied BOOLEAN NOT NULL DEFAULT TRUE,
delta_usd DECIMAL(20, 10) NOT NULL DEFAULT 0,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE UNIQUE INDEX IF NOT EXISTS billing_usage_entries_usage_log_id_unique
ON billing_usage_entries (usage_log_id);
CREATE INDEX IF NOT EXISTS idx_billing_usage_entries_user_time
ON billing_usage_entries (user_id, created_at);
CREATE INDEX IF NOT EXISTS idx_billing_usage_entries_created_at
ON billing_usage_entries (created_at);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment