Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
7dddd065
Commit
7dddd065
authored
Jan 04, 2026
by
yangjianbo
Browse files
Merge branch 'main' of
https://github.com/mt21625457/aicodex2api
parents
25a0d49a
e78c8646
Changes
183
Hide whitespace changes
Inline
Side-by-side
backend/internal/service/gemini_oauth_service_test.go
View file @
7dddd065
package
service
package
service
import
"testing"
import
(
"context"
func
TestInferGoogleOneTier
(
t
*
testing
.
T
)
{
"net/url"
tests
:=
[]
struct
{
"strings"
name
string
"testing"
storageBytes
int64
expectedTier
string
"github.com/Wei-Shaw/sub2api/internal/config"
}{
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
{
"Negative storage"
,
-
1
,
TierGoogleOneUnknown
},
)
{
"Zero storage"
,
0
,
TierGoogleOneUnknown
},
func
TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy
(
t
*
testing
.
T
)
{
// Free tier boundary (15GB)
t
.
Parallel
()
{
"Below free tier"
,
10
*
GB
,
TierGoogleOneUnknown
},
{
"Just below free tier"
,
StorageTierFree
-
1
,
TierGoogleOneUnknown
},
type
testCase
struct
{
{
"Free tier (15GB)"
,
StorageTierFree
,
TierFree
},
name
string
cfg
*
config
.
Config
// Basic tier boundary (100GB)
oauthType
string
{
"Between free and basic"
,
50
*
GB
,
TierFree
},
projectID
string
{
"Just below basic tier"
,
StorageTierBasic
-
1
,
TierFree
},
wantClientID
string
{
"Basic tier (100GB)"
,
StorageTierBasic
,
TierGoogleOneBasic
},
wantRedirect
string
wantScope
string
// Standard tier boundary (200GB)
wantProjectID
string
{
"Between basic and standard"
,
150
*
GB
,
TierGoogleOneBasic
},
wantErrSubstr
string
{
"Just below standard tier"
,
StorageTierStandard
-
1
,
TierGoogleOneBasic
},
}
{
"Standard tier (200GB)"
,
StorageTierStandard
,
TierGoogleOneStandard
},
tests
:=
[]
testCase
{
// AI Premium tier boundary (2TB)
{
{
"Between standard and premium"
,
1
*
TB
,
TierGoogleOneStandard
},
name
:
"google_one uses built-in client when not configured and redirects to upstream"
,
{
"Just below AI Premium tier"
,
StorageTierAIPremium
-
1
,
TierGoogleOneStandard
},
cfg
:
&
config
.
Config
{
{
"AI Premium tier (2TB)"
,
StorageTierAIPremium
,
TierAIPremium
},
Gemini
:
config
.
GeminiConfig
{
OAuth
:
config
.
GeminiOAuthConfig
{},
// Unlimited tier boundary (> 100TB)
},
{
"Between premium and unlimited"
,
50
*
TB
,
TierAIPremium
},
},
{
"At unlimited threshold (100TB)"
,
StorageTierUnlimited
,
TierAIPremium
},
oauthType
:
"google_one"
,
{
"Unlimited tier (100TB+)"
,
StorageTierUnlimited
+
1
,
TierGoogleOneUnlimited
},
wantClientID
:
geminicli
.
GeminiCLIOAuthClientID
,
{
"Unlimited tier (101TB+)"
,
101
*
TB
,
TierGoogleOneUnlimited
},
wantRedirect
:
geminicli
.
GeminiCLIRedirectURI
,
{
"Very large storage"
,
1000
*
TB
,
TierGoogleOneUnlimited
},
wantScope
:
geminicli
.
DefaultCodeAssistScopes
,
wantProjectID
:
""
,
},
{
name
:
"google_one uses custom client when configured and redirects to localhost"
,
cfg
:
&
config
.
Config
{
Gemini
:
config
.
GeminiConfig
{
OAuth
:
config
.
GeminiOAuthConfig
{
ClientID
:
"custom-client-id"
,
ClientSecret
:
"custom-client-secret"
,
},
},
},
oauthType
:
"google_one"
,
wantClientID
:
"custom-client-id"
,
wantRedirect
:
geminicli
.
AIStudioOAuthRedirectURI
,
wantScope
:
geminicli
.
DefaultGoogleOneScopes
,
wantProjectID
:
""
,
},
{
name
:
"code_assist always forces built-in client even when custom client configured"
,
cfg
:
&
config
.
Config
{
Gemini
:
config
.
GeminiConfig
{
OAuth
:
config
.
GeminiOAuthConfig
{
ClientID
:
"custom-client-id"
,
ClientSecret
:
"custom-client-secret"
,
},
},
},
oauthType
:
"code_assist"
,
projectID
:
"my-gcp-project"
,
wantClientID
:
geminicli
.
GeminiCLIOAuthClientID
,
wantRedirect
:
geminicli
.
GeminiCLIRedirectURI
,
wantScope
:
geminicli
.
DefaultCodeAssistScopes
,
wantProjectID
:
"my-gcp-project"
,
},
{
name
:
"ai_studio requires custom client"
,
cfg
:
&
config
.
Config
{
Gemini
:
config
.
GeminiConfig
{
OAuth
:
config
.
GeminiOAuthConfig
{},
},
},
oauthType
:
"ai_studio"
,
wantErrSubstr
:
"AI Studio OAuth requires a custom OAuth Client"
,
},
}
}
for
_
,
tt
:=
range
tests
{
for
_
,
tt
:=
range
tests
{
tt
:=
tt
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
result
:=
inferGoogleOneTier
(
tt
.
storageBytes
)
t
.
Parallel
()
if
result
!=
tt
.
expectedTier
{
t
.
Errorf
(
"inferGoogleOneTier(%d) = %s, want %s"
,
svc
:=
NewGeminiOAuthService
(
nil
,
nil
,
nil
,
tt
.
cfg
)
tt
.
storageBytes
,
result
,
tt
.
expectedTier
)
got
,
err
:=
svc
.
GenerateAuthURL
(
context
.
Background
(),
nil
,
"https://example.com/auth/callback"
,
tt
.
projectID
,
tt
.
oauthType
,
""
)
if
tt
.
wantErrSubstr
!=
""
{
if
err
==
nil
{
t
.
Fatalf
(
"expected error containing %q, got nil"
,
tt
.
wantErrSubstr
)
}
if
!
strings
.
Contains
(
err
.
Error
(),
tt
.
wantErrSubstr
)
{
t
.
Fatalf
(
"expected error containing %q, got: %v"
,
tt
.
wantErrSubstr
,
err
)
}
return
}
if
err
!=
nil
{
t
.
Fatalf
(
"GenerateAuthURL returned error: %v"
,
err
)
}
parsed
,
err
:=
url
.
Parse
(
got
.
AuthURL
)
if
err
!=
nil
{
t
.
Fatalf
(
"failed to parse auth_url: %v"
,
err
)
}
q
:=
parsed
.
Query
()
if
gotState
:=
q
.
Get
(
"state"
);
gotState
!=
got
.
State
{
t
.
Fatalf
(
"state mismatch: query=%q result=%q"
,
gotState
,
got
.
State
)
}
if
gotClientID
:=
q
.
Get
(
"client_id"
);
gotClientID
!=
tt
.
wantClientID
{
t
.
Fatalf
(
"client_id mismatch: got=%q want=%q"
,
gotClientID
,
tt
.
wantClientID
)
}
if
gotRedirect
:=
q
.
Get
(
"redirect_uri"
);
gotRedirect
!=
tt
.
wantRedirect
{
t
.
Fatalf
(
"redirect_uri mismatch: got=%q want=%q"
,
gotRedirect
,
tt
.
wantRedirect
)
}
if
gotScope
:=
q
.
Get
(
"scope"
);
gotScope
!=
tt
.
wantScope
{
t
.
Fatalf
(
"scope mismatch: got=%q want=%q"
,
gotScope
,
tt
.
wantScope
)
}
if
gotProjectID
:=
q
.
Get
(
"project_id"
);
gotProjectID
!=
tt
.
wantProjectID
{
t
.
Fatalf
(
"project_id mismatch: got=%q want=%q"
,
gotProjectID
,
tt
.
wantProjectID
)
}
}
})
})
}
}
...
...
backend/internal/service/gemini_quota.go
View file @
7dddd065
...
@@ -20,13 +20,24 @@ const (
...
@@ -20,13 +20,24 @@ const (
geminiModelFlash
geminiModelClass
=
"flash"
geminiModelFlash
geminiModelClass
=
"flash"
)
)
type
GeminiDailyQuota
struct
{
type
GeminiQuota
struct
{
ProRPD
int64
// SharedRPD is a shared requests-per-day pool across models.
FlashRPD
int64
// When SharedRPD > 0, callers should treat ProRPD/FlashRPD as not applicable for daily quota checks.
SharedRPD
int64
`json:"shared_rpd,omitempty"`
// SharedRPM is a shared requests-per-minute pool across models.
// When SharedRPM > 0, callers should treat ProRPM/FlashRPM as not applicable for minute quota checks.
SharedRPM
int64
`json:"shared_rpm,omitempty"`
// Per-model quotas (AI Studio / API key).
// A value of -1 means "unlimited" (pay-as-you-go).
ProRPD
int64
`json:"pro_rpd,omitempty"`
ProRPM
int64
`json:"pro_rpm,omitempty"`
FlashRPD
int64
`json:"flash_rpd,omitempty"`
FlashRPM
int64
`json:"flash_rpm,omitempty"`
}
}
type
GeminiTierPolicy
struct
{
type
GeminiTierPolicy
struct
{
Quota
Gemini
Daily
Quota
Quota
GeminiQuota
Cooldown
time
.
Duration
Cooldown
time
.
Duration
}
}
...
@@ -45,10 +56,27 @@ type GeminiUsageTotals struct {
...
@@ -45,10 +56,27 @@ type GeminiUsageTotals struct {
const
geminiQuotaCacheTTL
=
time
.
Minute
const
geminiQuotaCacheTTL
=
time
.
Minute
type
geminiQuotaOverrides
struct
{
type
geminiQuotaOverrides
V1
struct
{
Tiers
map
[
string
]
config
.
GeminiTierQuotaConfig
`json:"tiers"`
Tiers
map
[
string
]
config
.
GeminiTierQuotaConfig
`json:"tiers"`
}
}
type
geminiQuotaOverridesV2
struct
{
QuotaRules
map
[
string
]
geminiQuotaRuleOverride
`json:"quota_rules"`
}
type
geminiQuotaRuleOverride
struct
{
SharedRPD
*
int64
`json:"shared_rpd,omitempty"`
SharedRPM
*
int64
`json:"rpm,omitempty"`
GeminiPro
*
geminiModelQuotaOverride
`json:"gemini_pro,omitempty"`
GeminiFlash
*
geminiModelQuotaOverride
`json:"gemini_flash,omitempty"`
Desc
*
string
`json:"desc,omitempty"`
}
type
geminiModelQuotaOverride
struct
{
RPD
*
int64
`json:"rpd,omitempty"`
RPM
*
int64
`json:"rpm,omitempty"`
}
type
GeminiQuotaService
struct
{
type
GeminiQuotaService
struct
{
cfg
*
config
.
Config
cfg
*
config
.
Config
settingRepo
SettingRepository
settingRepo
SettingRepository
...
@@ -82,11 +110,17 @@ func (s *GeminiQuotaService) Policy(ctx context.Context) *GeminiQuotaPolicy {
...
@@ -82,11 +110,17 @@ func (s *GeminiQuotaService) Policy(ctx context.Context) *GeminiQuotaPolicy {
if
s
.
cfg
!=
nil
{
if
s
.
cfg
!=
nil
{
policy
.
ApplyOverrides
(
s
.
cfg
.
Gemini
.
Quota
.
Tiers
)
policy
.
ApplyOverrides
(
s
.
cfg
.
Gemini
.
Quota
.
Tiers
)
if
strings
.
TrimSpace
(
s
.
cfg
.
Gemini
.
Quota
.
Policy
)
!=
""
{
if
strings
.
TrimSpace
(
s
.
cfg
.
Gemini
.
Quota
.
Policy
)
!=
""
{
var
overrides
geminiQuotaOverrides
raw
:=
[]
byte
(
s
.
cfg
.
Gemini
.
Quota
.
Policy
)
if
err
:=
json
.
Unmarshal
([]
byte
(
s
.
cfg
.
Gemini
.
Quota
.
Policy
),
&
overrides
);
err
!=
nil
{
var
overridesV2
geminiQuotaOverridesV2
log
.
Printf
(
"gemini quota: parse config policy failed: %v"
,
err
)
if
err
:=
json
.
Unmarshal
(
raw
,
&
overridesV2
);
err
==
nil
&&
len
(
overridesV2
.
QuotaRules
)
>
0
{
policy
.
ApplyQuotaRulesOverrides
(
overridesV2
.
QuotaRules
)
}
else
{
}
else
{
policy
.
ApplyOverrides
(
overrides
.
Tiers
)
var
overridesV1
geminiQuotaOverridesV1
if
err
:=
json
.
Unmarshal
(
raw
,
&
overridesV1
);
err
!=
nil
{
log
.
Printf
(
"gemini quota: parse config policy failed: %v"
,
err
)
}
else
{
policy
.
ApplyOverrides
(
overridesV1
.
Tiers
)
}
}
}
}
}
}
}
...
@@ -96,11 +130,17 @@ func (s *GeminiQuotaService) Policy(ctx context.Context) *GeminiQuotaPolicy {
...
@@ -96,11 +130,17 @@ func (s *GeminiQuotaService) Policy(ctx context.Context) *GeminiQuotaPolicy {
if
err
!=
nil
&&
!
errors
.
Is
(
err
,
ErrSettingNotFound
)
{
if
err
!=
nil
&&
!
errors
.
Is
(
err
,
ErrSettingNotFound
)
{
log
.
Printf
(
"gemini quota: load setting failed: %v"
,
err
)
log
.
Printf
(
"gemini quota: load setting failed: %v"
,
err
)
}
else
if
strings
.
TrimSpace
(
value
)
!=
""
{
}
else
if
strings
.
TrimSpace
(
value
)
!=
""
{
var
overrides
geminiQuotaOverrides
raw
:=
[]
byte
(
value
)
if
err
:=
json
.
Unmarshal
([]
byte
(
value
),
&
overrides
);
err
!=
nil
{
var
overridesV2
geminiQuotaOverridesV2
log
.
Printf
(
"gemini quota: parse setting failed: %v"
,
err
)
if
err
:=
json
.
Unmarshal
(
raw
,
&
overridesV2
);
err
==
nil
&&
len
(
overridesV2
.
QuotaRules
)
>
0
{
policy
.
ApplyQuotaRulesOverrides
(
overridesV2
.
QuotaRules
)
}
else
{
}
else
{
policy
.
ApplyOverrides
(
overrides
.
Tiers
)
var
overridesV1
geminiQuotaOverridesV1
if
err
:=
json
.
Unmarshal
(
raw
,
&
overridesV1
);
err
!=
nil
{
log
.
Printf
(
"gemini quota: parse setting failed: %v"
,
err
)
}
else
{
policy
.
ApplyOverrides
(
overridesV1
.
Tiers
)
}
}
}
}
}
}
}
...
@@ -113,12 +153,20 @@ func (s *GeminiQuotaService) Policy(ctx context.Context) *GeminiQuotaPolicy {
...
@@ -113,12 +153,20 @@ func (s *GeminiQuotaService) Policy(ctx context.Context) *GeminiQuotaPolicy {
return
policy
return
policy
}
}
func
(
s
*
GeminiQuotaService
)
QuotaForAccount
(
ctx
context
.
Context
,
account
*
Account
)
(
GeminiDailyQuota
,
bool
)
{
func
(
s
*
GeminiQuotaService
)
QuotaForAccount
(
ctx
context
.
Context
,
account
*
Account
)
(
GeminiQuota
,
bool
)
{
if
account
==
nil
||
!
account
.
IsGeminiCodeAssist
()
{
if
account
==
nil
||
account
.
Platform
!=
PlatformGemini
{
return
GeminiDailyQuota
{},
false
return
GeminiQuota
{},
false
}
// Map (oauth_type + tier_id) to a canonical policy tier key.
// This keeps the policy table stable even if upstream tier_id strings vary.
tierKey
:=
geminiQuotaTierKeyForAccount
(
account
)
if
tierKey
==
""
{
return
GeminiQuota
{},
false
}
}
policy
:=
s
.
Policy
(
ctx
)
policy
:=
s
.
Policy
(
ctx
)
return
policy
.
QuotaForTier
(
account
.
GeminiTierID
()
)
return
policy
.
QuotaForTier
(
tierKey
)
}
}
func
(
s
*
GeminiQuotaService
)
CooldownForTier
(
ctx
context
.
Context
,
tierID
string
)
time
.
Duration
{
func
(
s
*
GeminiQuotaService
)
CooldownForTier
(
ctx
context
.
Context
,
tierID
string
)
time
.
Duration
{
...
@@ -126,12 +174,36 @@ func (s *GeminiQuotaService) CooldownForTier(ctx context.Context, tierID string)
...
@@ -126,12 +174,36 @@ func (s *GeminiQuotaService) CooldownForTier(ctx context.Context, tierID string)
return
policy
.
CooldownForTier
(
tierID
)
return
policy
.
CooldownForTier
(
tierID
)
}
}
func
(
s
*
GeminiQuotaService
)
CooldownForAccount
(
ctx
context
.
Context
,
account
*
Account
)
time
.
Duration
{
if
s
==
nil
||
account
==
nil
||
account
.
Platform
!=
PlatformGemini
{
return
5
*
time
.
Minute
}
tierKey
:=
geminiQuotaTierKeyForAccount
(
account
)
if
strings
.
TrimSpace
(
tierKey
)
==
""
{
return
5
*
time
.
Minute
}
return
s
.
CooldownForTier
(
ctx
,
tierKey
)
}
func
newGeminiQuotaPolicy
()
*
GeminiQuotaPolicy
{
func
newGeminiQuotaPolicy
()
*
GeminiQuotaPolicy
{
return
&
GeminiQuotaPolicy
{
return
&
GeminiQuotaPolicy
{
tiers
:
map
[
string
]
GeminiTierPolicy
{
tiers
:
map
[
string
]
GeminiTierPolicy
{
"LEGACY"
:
{
Quota
:
GeminiDailyQuota
{
ProRPD
:
50
,
FlashRPD
:
1500
},
Cooldown
:
30
*
time
.
Minute
},
// --- AI Studio / API Key (per-model) ---
"PRO"
:
{
Quota
:
GeminiDailyQuota
{
ProRPD
:
1500
,
FlashRPD
:
4000
},
Cooldown
:
5
*
time
.
Minute
},
// aistudio_free:
"ULTRA"
:
{
Quota
:
GeminiDailyQuota
{
ProRPD
:
2000
,
FlashRPD
:
0
},
Cooldown
:
5
*
time
.
Minute
},
// - gemini_pro: 50 RPD / 2 RPM
// - gemini_flash: 1500 RPD / 15 RPM
GeminiTierAIStudioFree
:
{
Quota
:
GeminiQuota
{
ProRPD
:
50
,
ProRPM
:
2
,
FlashRPD
:
1500
,
FlashRPM
:
15
},
Cooldown
:
30
*
time
.
Minute
},
// aistudio_paid: -1 means "unlimited/pay-as-you-go" for RPD.
GeminiTierAIStudioPaid
:
{
Quota
:
GeminiQuota
{
ProRPD
:
-
1
,
ProRPM
:
1000
,
FlashRPD
:
-
1
,
FlashRPM
:
2000
},
Cooldown
:
5
*
time
.
Minute
},
// --- Google One (shared pool) ---
GeminiTierGoogleOneFree
:
{
Quota
:
GeminiQuota
{
SharedRPD
:
1000
,
SharedRPM
:
60
},
Cooldown
:
30
*
time
.
Minute
},
GeminiTierGoogleAIPro
:
{
Quota
:
GeminiQuota
{
SharedRPD
:
1500
,
SharedRPM
:
120
},
Cooldown
:
5
*
time
.
Minute
},
GeminiTierGoogleAIUltra
:
{
Quota
:
GeminiQuota
{
SharedRPD
:
2000
,
SharedRPM
:
120
},
Cooldown
:
5
*
time
.
Minute
},
// --- GCP Code Assist (shared pool) ---
GeminiTierGCPStandard
:
{
Quota
:
GeminiQuota
{
SharedRPD
:
1500
,
SharedRPM
:
120
},
Cooldown
:
5
*
time
.
Minute
},
GeminiTierGCPEnterprise
:
{
Quota
:
GeminiQuota
{
SharedRPD
:
2000
,
SharedRPM
:
120
},
Cooldown
:
5
*
time
.
Minute
},
},
},
}
}
}
}
...
@@ -149,11 +221,22 @@ func (p *GeminiQuotaPolicy) ApplyOverrides(tiers map[string]config.GeminiTierQuo
...
@@ -149,11 +221,22 @@ func (p *GeminiQuotaPolicy) ApplyOverrides(tiers map[string]config.GeminiTierQuo
if
!
ok
{
if
!
ok
{
policy
=
GeminiTierPolicy
{
Cooldown
:
5
*
time
.
Minute
}
policy
=
GeminiTierPolicy
{
Cooldown
:
5
*
time
.
Minute
}
}
}
// Backward-compatible overrides:
// - If the tier uses shared quota, interpret pro_rpd as shared_rpd.
// - Otherwise apply per-model overrides.
if
override
.
ProRPD
!=
nil
{
if
override
.
ProRPD
!=
nil
{
policy
.
Quota
.
ProRPD
=
clampGeminiQuotaInt64
(
*
override
.
ProRPD
)
if
policy
.
Quota
.
SharedRPD
>
0
{
policy
.
Quota
.
SharedRPD
=
clampGeminiQuotaInt64WithUnlimited
(
*
override
.
ProRPD
)
}
else
{
policy
.
Quota
.
ProRPD
=
clampGeminiQuotaInt64WithUnlimited
(
*
override
.
ProRPD
)
}
}
}
if
override
.
FlashRPD
!=
nil
{
if
override
.
FlashRPD
!=
nil
{
policy
.
Quota
.
FlashRPD
=
clampGeminiQuotaInt64
(
*
override
.
FlashRPD
)
if
policy
.
Quota
.
SharedRPD
>
0
{
// No separate flash RPD for shared tiers.
}
else
{
policy
.
Quota
.
FlashRPD
=
clampGeminiQuotaInt64WithUnlimited
(
*
override
.
FlashRPD
)
}
}
}
if
override
.
CooldownMinutes
!=
nil
{
if
override
.
CooldownMinutes
!=
nil
{
minutes
:=
clampGeminiQuotaInt
(
*
override
.
CooldownMinutes
)
minutes
:=
clampGeminiQuotaInt
(
*
override
.
CooldownMinutes
)
...
@@ -163,10 +246,51 @@ func (p *GeminiQuotaPolicy) ApplyOverrides(tiers map[string]config.GeminiTierQuo
...
@@ -163,10 +246,51 @@ func (p *GeminiQuotaPolicy) ApplyOverrides(tiers map[string]config.GeminiTierQuo
}
}
}
}
func
(
p
*
GeminiQuotaPolicy
)
QuotaForTier
(
tierID
string
)
(
GeminiDailyQuota
,
bool
)
{
func
(
p
*
GeminiQuotaPolicy
)
ApplyQuotaRulesOverrides
(
rules
map
[
string
]
geminiQuotaRuleOverride
)
{
if
p
==
nil
||
len
(
rules
)
==
0
{
return
}
for
rawID
,
override
:=
range
rules
{
tierID
:=
normalizeGeminiTierID
(
rawID
)
if
tierID
==
""
{
continue
}
policy
,
ok
:=
p
.
tiers
[
tierID
]
if
!
ok
{
policy
=
GeminiTierPolicy
{
Cooldown
:
5
*
time
.
Minute
}
}
if
override
.
SharedRPD
!=
nil
{
policy
.
Quota
.
SharedRPD
=
clampGeminiQuotaInt64WithUnlimited
(
*
override
.
SharedRPD
)
}
if
override
.
SharedRPM
!=
nil
{
policy
.
Quota
.
SharedRPM
=
clampGeminiQuotaRPM
(
*
override
.
SharedRPM
)
}
if
override
.
GeminiPro
!=
nil
{
if
override
.
GeminiPro
.
RPD
!=
nil
{
policy
.
Quota
.
ProRPD
=
clampGeminiQuotaInt64WithUnlimited
(
*
override
.
GeminiPro
.
RPD
)
}
if
override
.
GeminiPro
.
RPM
!=
nil
{
policy
.
Quota
.
ProRPM
=
clampGeminiQuotaRPM
(
*
override
.
GeminiPro
.
RPM
)
}
}
if
override
.
GeminiFlash
!=
nil
{
if
override
.
GeminiFlash
.
RPD
!=
nil
{
policy
.
Quota
.
FlashRPD
=
clampGeminiQuotaInt64WithUnlimited
(
*
override
.
GeminiFlash
.
RPD
)
}
if
override
.
GeminiFlash
.
RPM
!=
nil
{
policy
.
Quota
.
FlashRPM
=
clampGeminiQuotaRPM
(
*
override
.
GeminiFlash
.
RPM
)
}
}
p
.
tiers
[
tierID
]
=
policy
}
}
func
(
p
*
GeminiQuotaPolicy
)
QuotaForTier
(
tierID
string
)
(
GeminiQuota
,
bool
)
{
policy
,
ok
:=
p
.
policyForTier
(
tierID
)
policy
,
ok
:=
p
.
policyForTier
(
tierID
)
if
!
ok
{
if
!
ok
{
return
Gemini
Daily
Quota
{},
false
return
GeminiQuota
{},
false
}
}
return
policy
.
Quota
,
true
return
policy
.
Quota
,
true
}
}
...
@@ -184,22 +308,43 @@ func (p *GeminiQuotaPolicy) policyForTier(tierID string) (GeminiTierPolicy, bool
...
@@ -184,22 +308,43 @@ func (p *GeminiQuotaPolicy) policyForTier(tierID string) (GeminiTierPolicy, bool
return
GeminiTierPolicy
{},
false
return
GeminiTierPolicy
{},
false
}
}
normalized
:=
normalizeGeminiTierID
(
tierID
)
normalized
:=
normalizeGeminiTierID
(
tierID
)
if
normalized
==
""
{
normalized
=
"LEGACY"
}
if
policy
,
ok
:=
p
.
tiers
[
normalized
];
ok
{
if
policy
,
ok
:=
p
.
tiers
[
normalized
];
ok
{
return
policy
,
true
return
policy
,
true
}
}
policy
,
ok
:=
p
.
tiers
[
"LEGACY"
]
return
GeminiTierPolicy
{},
false
return
policy
,
ok
}
}
func
normalizeGeminiTierID
(
tierID
string
)
string
{
func
normalizeGeminiTierID
(
tierID
string
)
string
{
return
strings
.
ToUpper
(
strings
.
TrimSpace
(
tierID
))
tierID
=
strings
.
TrimSpace
(
tierID
)
if
tierID
==
""
{
return
""
}
// Prefer canonical mapping (handles legacy tier strings).
if
canonical
:=
canonicalGeminiTierID
(
tierID
);
canonical
!=
""
{
return
canonical
}
// Accept older policy keys that used uppercase names.
switch
strings
.
ToUpper
(
tierID
)
{
case
"AISTUDIO_FREE"
:
return
GeminiTierAIStudioFree
case
"AISTUDIO_PAID"
:
return
GeminiTierAIStudioPaid
case
"GOOGLE_ONE_FREE"
:
return
GeminiTierGoogleOneFree
case
"GOOGLE_AI_PRO"
:
return
GeminiTierGoogleAIPro
case
"GOOGLE_AI_ULTRA"
:
return
GeminiTierGoogleAIUltra
case
"GCP_STANDARD"
:
return
GeminiTierGCPStandard
case
"GCP_ENTERPRISE"
:
return
GeminiTierGCPEnterprise
}
return
strings
.
ToLower
(
tierID
)
}
}
func
clampGeminiQuotaInt64
(
value
int64
)
int64
{
func
clampGeminiQuotaInt64
WithUnlimited
(
value
int64
)
int64
{
if
value
<
0
{
if
value
<
-
1
{
return
0
return
0
}
}
return
value
return
value
...
@@ -212,11 +357,46 @@ func clampGeminiQuotaInt(value int) int {
...
@@ -212,11 +357,46 @@ func clampGeminiQuotaInt(value int) int {
return
value
return
value
}
}
func
clampGeminiQuotaRPM
(
value
int64
)
int64
{
if
value
<
0
{
return
0
}
return
value
}
func
geminiCooldownForTier
(
tierID
string
)
time
.
Duration
{
func
geminiCooldownForTier
(
tierID
string
)
time
.
Duration
{
policy
:=
newGeminiQuotaPolicy
()
policy
:=
newGeminiQuotaPolicy
()
return
policy
.
CooldownForTier
(
tierID
)
return
policy
.
CooldownForTier
(
tierID
)
}
}
func
geminiQuotaTierKeyForAccount
(
account
*
Account
)
string
{
if
account
==
nil
||
account
.
Platform
!=
PlatformGemini
{
return
""
}
// Note: GeminiOAuthType() already defaults legacy (project_id present) to code_assist.
oauthType
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
account
.
GeminiOAuthType
()))
rawTier
:=
strings
.
TrimSpace
(
account
.
GeminiTierID
())
// Prefer the canonical tier stored in credentials.
if
tierID
:=
canonicalGeminiTierIDForOAuthType
(
oauthType
,
rawTier
);
tierID
!=
""
&&
tierID
!=
GeminiTierGoogleOneUnknown
{
return
tierID
}
// Fallback defaults when tier_id is missing or unknown.
switch
oauthType
{
case
"google_one"
:
return
GeminiTierGoogleOneFree
case
"code_assist"
:
return
GeminiTierGCPStandard
case
"ai_studio"
:
return
GeminiTierAIStudioFree
default
:
// API Key accounts (type=apikey) have empty oauth_type and are treated as AI Studio.
return
GeminiTierAIStudioFree
}
}
func
geminiModelClassFromName
(
model
string
)
geminiModelClass
{
func
geminiModelClassFromName
(
model
string
)
geminiModelClass
{
name
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
model
))
name
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
model
))
if
strings
.
Contains
(
name
,
"flash"
)
||
strings
.
Contains
(
name
,
"lite"
)
{
if
strings
.
Contains
(
name
,
"flash"
)
||
strings
.
Contains
(
name
,
"lite"
)
{
...
...
backend/internal/service/openai_gateway_service.go
View file @
7dddd065
...
@@ -490,7 +490,7 @@ func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Acco
...
@@ -490,7 +490,7 @@ func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Acco
return
""
,
""
,
errors
.
New
(
"access_token not found in credentials"
)
return
""
,
""
,
errors
.
New
(
"access_token not found in credentials"
)
}
}
return
accessToken
,
"oauth"
,
nil
return
accessToken
,
"oauth"
,
nil
case
AccountTypeA
pi
Key
:
case
AccountTypeA
PI
Key
:
apiKey
:=
account
.
GetOpenAIApiKey
()
apiKey
:=
account
.
GetOpenAIApiKey
()
if
apiKey
==
""
{
if
apiKey
==
""
{
return
""
,
""
,
errors
.
New
(
"api_key not found in credentials"
)
return
""
,
""
,
errors
.
New
(
"api_key not found in credentials"
)
...
@@ -630,7 +630,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
...
@@ -630,7 +630,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
case
AccountTypeOAuth
:
case
AccountTypeOAuth
:
// OAuth accounts use ChatGPT internal API
// OAuth accounts use ChatGPT internal API
targetURL
=
chatgptCodexURL
targetURL
=
chatgptCodexURL
case
AccountTypeA
pi
Key
:
case
AccountTypeA
PI
Key
:
// API Key accounts use Platform API or custom base URL
// API Key accounts use Platform API or custom base URL
baseURL
:=
account
.
GetOpenAIBaseURL
()
baseURL
:=
account
.
GetOpenAIBaseURL
()
if
baseURL
==
""
{
if
baseURL
==
""
{
...
@@ -710,7 +710,13 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
...
@@ -710,7 +710,13 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
}
}
// Handle upstream error (mark account status)
// Handle upstream error (mark account status)
s
.
rateLimitService
.
HandleUpstreamError
(
ctx
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
body
)
shouldDisable
:=
false
if
s
.
rateLimitService
!=
nil
{
shouldDisable
=
s
.
rateLimitService
.
HandleUpstreamError
(
ctx
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
body
)
}
if
shouldDisable
{
return
nil
,
&
UpstreamFailoverError
{
StatusCode
:
resp
.
StatusCode
}
}
// Return appropriate error response
// Return appropriate error response
var
errType
,
errMsg
string
var
errType
,
errMsg
string
...
@@ -1065,7 +1071,7 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
...
@@ -1065,7 +1071,7 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
// OpenAIRecordUsageInput input for recording usage
// OpenAIRecordUsageInput input for recording usage
type
OpenAIRecordUsageInput
struct
{
type
OpenAIRecordUsageInput
struct
{
Result
*
OpenAIForwardResult
Result
*
OpenAIForwardResult
A
pi
Key
*
A
pi
Key
A
PI
Key
*
A
PI
Key
User
*
User
User
*
User
Account
*
Account
Account
*
Account
Subscription
*
UserSubscription
Subscription
*
UserSubscription
...
@@ -1074,7 +1080,7 @@ type OpenAIRecordUsageInput struct {
...
@@ -1074,7 +1080,7 @@ type OpenAIRecordUsageInput struct {
// RecordUsage records usage and deducts balance
// RecordUsage records usage and deducts balance
func
(
s
*
OpenAIGatewayService
)
RecordUsage
(
ctx
context
.
Context
,
input
*
OpenAIRecordUsageInput
)
error
{
func
(
s
*
OpenAIGatewayService
)
RecordUsage
(
ctx
context
.
Context
,
input
*
OpenAIRecordUsageInput
)
error
{
result
:=
input
.
Result
result
:=
input
.
Result
apiKey
:=
input
.
A
pi
Key
apiKey
:=
input
.
A
PI
Key
user
:=
input
.
User
user
:=
input
.
User
account
:=
input
.
Account
account
:=
input
.
Account
subscription
:=
input
.
Subscription
subscription
:=
input
.
Subscription
...
@@ -1116,7 +1122,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
...
@@ -1116,7 +1122,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
durationMs
:=
int
(
result
.
Duration
.
Milliseconds
())
durationMs
:=
int
(
result
.
Duration
.
Milliseconds
())
usageLog
:=
&
UsageLog
{
usageLog
:=
&
UsageLog
{
UserID
:
user
.
ID
,
UserID
:
user
.
ID
,
A
pi
KeyID
:
apiKey
.
ID
,
A
PI
KeyID
:
apiKey
.
ID
,
AccountID
:
account
.
ID
,
AccountID
:
account
.
ID
,
RequestID
:
result
.
RequestID
,
RequestID
:
result
.
RequestID
,
Model
:
result
.
Model
,
Model
:
result
.
Model
,
...
@@ -1145,22 +1151,23 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
...
@@ -1145,22 +1151,23 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
usageLog
.
SubscriptionID
=
&
subscription
.
ID
usageLog
.
SubscriptionID
=
&
subscription
.
ID
}
}
_
=
s
.
usageLogRepo
.
Create
(
ctx
,
usageLog
)
inserted
,
err
:=
s
.
usageLogRepo
.
Create
(
ctx
,
usageLog
)
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
RunMode
==
config
.
RunModeSimple
{
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
RunMode
==
config
.
RunModeSimple
{
log
.
Printf
(
"[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d"
,
usageLog
.
UserID
,
usageLog
.
TotalTokens
())
log
.
Printf
(
"[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d"
,
usageLog
.
UserID
,
usageLog
.
TotalTokens
())
s
.
deferredService
.
ScheduleLastUsedUpdate
(
account
.
ID
)
s
.
deferredService
.
ScheduleLastUsedUpdate
(
account
.
ID
)
return
nil
return
nil
}
}
shouldBill
:=
inserted
||
err
!=
nil
// Deduct based on billing type
// Deduct based on billing type
if
isSubscriptionBilling
{
if
isSubscriptionBilling
{
if
cost
.
TotalCost
>
0
{
if
shouldBill
&&
cost
.
TotalCost
>
0
{
_
=
s
.
userSubRepo
.
IncrementUsage
(
ctx
,
subscription
.
ID
,
cost
.
TotalCost
)
_
=
s
.
userSubRepo
.
IncrementUsage
(
ctx
,
subscription
.
ID
,
cost
.
TotalCost
)
s
.
billingCacheService
.
QueueUpdateSubscriptionUsage
(
user
.
ID
,
*
apiKey
.
GroupID
,
cost
.
TotalCost
)
s
.
billingCacheService
.
QueueUpdateSubscriptionUsage
(
user
.
ID
,
*
apiKey
.
GroupID
,
cost
.
TotalCost
)
}
}
}
else
{
}
else
{
if
cost
.
ActualCost
>
0
{
if
shouldBill
&&
cost
.
ActualCost
>
0
{
_
=
s
.
userRepo
.
DeductBalance
(
ctx
,
user
.
ID
,
cost
.
ActualCost
)
_
=
s
.
userRepo
.
DeductBalance
(
ctx
,
user
.
ID
,
cost
.
ActualCost
)
s
.
billingCacheService
.
QueueDeductBalance
(
user
.
ID
,
cost
.
ActualCost
)
s
.
billingCacheService
.
QueueDeductBalance
(
user
.
ID
,
cost
.
ActualCost
)
}
}
...
...
backend/internal/service/ratelimit_service.go
View file @
7dddd065
...
@@ -2,6 +2,7 @@ package service
...
@@ -2,6 +2,7 @@ package service
import
(
import
(
"context"
"context"
"encoding/json"
"log"
"log"
"net/http"
"net/http"
"strconv"
"strconv"
...
@@ -18,6 +19,7 @@ type RateLimitService struct {
...
@@ -18,6 +19,7 @@ type RateLimitService struct {
usageRepo
UsageLogRepository
usageRepo
UsageLogRepository
cfg
*
config
.
Config
cfg
*
config
.
Config
geminiQuotaService
*
GeminiQuotaService
geminiQuotaService
*
GeminiQuotaService
tempUnschedCache
TempUnschedCache
usageCacheMu
sync
.
RWMutex
usageCacheMu
sync
.
RWMutex
usageCache
map
[
int64
]
*
geminiUsageCacheEntry
usageCache
map
[
int64
]
*
geminiUsageCacheEntry
}
}
...
@@ -31,12 +33,13 @@ type geminiUsageCacheEntry struct {
...
@@ -31,12 +33,13 @@ type geminiUsageCacheEntry struct {
const
geminiPrecheckCacheTTL
=
time
.
Minute
const
geminiPrecheckCacheTTL
=
time
.
Minute
// NewRateLimitService 创建RateLimitService实例
// NewRateLimitService 创建RateLimitService实例
func
NewRateLimitService
(
accountRepo
AccountRepository
,
usageRepo
UsageLogRepository
,
cfg
*
config
.
Config
,
geminiQuotaService
*
GeminiQuotaService
)
*
RateLimitService
{
func
NewRateLimitService
(
accountRepo
AccountRepository
,
usageRepo
UsageLogRepository
,
cfg
*
config
.
Config
,
geminiQuotaService
*
GeminiQuotaService
,
tempUnschedCache
TempUnschedCache
)
*
RateLimitService
{
return
&
RateLimitService
{
return
&
RateLimitService
{
accountRepo
:
accountRepo
,
accountRepo
:
accountRepo
,
usageRepo
:
usageRepo
,
usageRepo
:
usageRepo
,
cfg
:
cfg
,
cfg
:
cfg
,
geminiQuotaService
:
geminiQuotaService
,
geminiQuotaService
:
geminiQuotaService
,
tempUnschedCache
:
tempUnschedCache
,
usageCache
:
make
(
map
[
int64
]
*
geminiUsageCacheEntry
),
usageCache
:
make
(
map
[
int64
]
*
geminiUsageCacheEntry
),
}
}
}
}
...
@@ -51,38 +54,45 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
...
@@ -51,38 +54,45 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
return
false
return
false
}
}
tempMatched
:=
s
.
tryTempUnschedulable
(
ctx
,
account
,
statusCode
,
responseBody
)
switch
statusCode
{
switch
statusCode
{
case
401
:
case
401
:
// 认证失败:停止调度,记录错误
// 认证失败:停止调度,记录错误
s
.
handleAuthError
(
ctx
,
account
,
"Authentication failed (401): invalid or expired credentials"
)
s
.
handleAuthError
(
ctx
,
account
,
"Authentication failed (401): invalid or expired credentials"
)
return
true
shouldDisable
=
true
case
402
:
case
402
:
// 支付要求:余额不足或计费问题,停止调度
// 支付要求:余额不足或计费问题,停止调度
s
.
handleAuthError
(
ctx
,
account
,
"Payment required (402): insufficient balance or billing issue"
)
s
.
handleAuthError
(
ctx
,
account
,
"Payment required (402): insufficient balance or billing issue"
)
return
true
shouldDisable
=
true
case
403
:
case
403
:
// 禁止访问:停止调度,记录错误
// 禁止访问:停止调度,记录错误
s
.
handleAuthError
(
ctx
,
account
,
"Access forbidden (403): account may be suspended or lack permissions"
)
s
.
handleAuthError
(
ctx
,
account
,
"Access forbidden (403): account may be suspended or lack permissions"
)
return
true
shouldDisable
=
true
case
429
:
case
429
:
s
.
handle429
(
ctx
,
account
,
headers
)
s
.
handle429
(
ctx
,
account
,
headers
)
return
false
shouldDisable
=
false
case
529
:
case
529
:
s
.
handle529
(
ctx
,
account
)
s
.
handle529
(
ctx
,
account
)
return
false
shouldDisable
=
false
default
:
default
:
// 其他5xx错误:记录但不停止调度
// 其他5xx错误:记录但不停止调度
if
statusCode
>=
500
{
if
statusCode
>=
500
{
log
.
Printf
(
"Account %d received upstream error %d"
,
account
.
ID
,
statusCode
)
log
.
Printf
(
"Account %d received upstream error %d"
,
account
.
ID
,
statusCode
)
}
}
return
false
shouldDisable
=
false
}
}
if
tempMatched
{
return
true
}
return
shouldDisable
}
}
// PreCheckUsage proactively checks local quota before dispatching a request.
// PreCheckUsage proactively checks local quota before dispatching a request.
// Returns false when the account should be skipped.
// Returns false when the account should be skipped.
func
(
s
*
RateLimitService
)
PreCheckUsage
(
ctx
context
.
Context
,
account
*
Account
,
requestedModel
string
)
(
bool
,
error
)
{
func
(
s
*
RateLimitService
)
PreCheckUsage
(
ctx
context
.
Context
,
account
*
Account
,
requestedModel
string
)
(
bool
,
error
)
{
if
account
==
nil
||
!
account
.
IsGeminiCodeAssist
()
||
strings
.
TrimSpace
(
requestedModel
)
==
""
{
if
account
==
nil
||
account
.
Platform
!=
PlatformGemini
{
return
true
,
nil
return
true
,
nil
}
}
if
s
.
usageRepo
==
nil
||
s
.
geminiQuotaService
==
nil
{
if
s
.
usageRepo
==
nil
||
s
.
geminiQuotaService
==
nil
{
...
@@ -94,44 +104,99 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
...
@@ -94,44 +104,99 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
return
true
,
nil
return
true
,
nil
}
}
var
limit
int64
switch
geminiModelClassFromName
(
requestedModel
)
{
case
geminiModelFlash
:
limit
=
quota
.
FlashRPD
default
:
limit
=
quota
.
ProRPD
}
if
limit
<=
0
{
return
true
,
nil
}
now
:=
time
.
Now
()
now
:=
time
.
Now
()
start
:=
geminiDailyWindowStart
(
now
)
modelClass
:=
geminiModelClassFromName
(
requestedModel
)
totals
,
ok
:=
s
.
getGeminiUsageTotals
(
account
.
ID
,
start
,
now
)
if
!
ok
{
// 1) Daily quota precheck (RPD; resets at PST midnight)
stats
,
err
:=
s
.
usageRepo
.
GetModelStatsWithFilters
(
ctx
,
start
,
now
,
0
,
0
,
account
.
ID
)
{
if
err
!=
nil
{
var
limit
int64
return
true
,
err
if
quota
.
SharedRPD
>
0
{
limit
=
quota
.
SharedRPD
}
else
{
switch
modelClass
{
case
geminiModelFlash
:
limit
=
quota
.
FlashRPD
default
:
limit
=
quota
.
ProRPD
}
}
}
totals
=
geminiAggregateUsage
(
stats
)
s
.
setGeminiUsageTotals
(
account
.
ID
,
start
,
now
,
totals
)
}
var
used
int64
if
limit
>
0
{
switch
geminiModelClassFromName
(
requestedModel
)
{
start
:=
geminiDailyWindowStart
(
now
)
case
geminiModelFlash
:
totals
,
ok
:=
s
.
getGeminiUsageTotals
(
account
.
ID
,
start
,
now
)
used
=
totals
.
FlashRequests
if
!
ok
{
default
:
stats
,
err
:=
s
.
usageRepo
.
GetModelStatsWithFilters
(
ctx
,
start
,
now
,
0
,
0
,
account
.
ID
)
used
=
totals
.
ProRequests
if
err
!=
nil
{
return
true
,
err
}
totals
=
geminiAggregateUsage
(
stats
)
s
.
setGeminiUsageTotals
(
account
.
ID
,
start
,
now
,
totals
)
}
var
used
int64
if
quota
.
SharedRPD
>
0
{
used
=
totals
.
ProRequests
+
totals
.
FlashRequests
}
else
{
switch
modelClass
{
case
geminiModelFlash
:
used
=
totals
.
FlashRequests
default
:
used
=
totals
.
ProRequests
}
}
if
used
>=
limit
{
resetAt
:=
geminiDailyResetTime
(
now
)
// NOTE:
// - This is a local precheck to reduce upstream 429s.
// - Do NOT mark the account as rate-limited here; rate_limit_reset_at should reflect real upstream 429s.
log
.
Printf
(
"[Gemini PreCheck] Account %d reached daily quota (%d/%d), skip until %v"
,
account
.
ID
,
used
,
limit
,
resetAt
)
return
false
,
nil
}
}
}
}
if
used
>=
limit
{
// 2) Minute quota precheck (RPM; fixed window current minute)
resetAt
:=
geminiDailyResetTime
(
now
)
{
if
err
:=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
resetAt
);
err
!=
nil
{
var
limit
int64
log
.
Printf
(
"SetRateLimited failed for account %d: %v"
,
account
.
ID
,
err
)
if
quota
.
SharedRPM
>
0
{
limit
=
quota
.
SharedRPM
}
else
{
switch
modelClass
{
case
geminiModelFlash
:
limit
=
quota
.
FlashRPM
default
:
limit
=
quota
.
ProRPM
}
}
if
limit
>
0
{
start
:=
now
.
Truncate
(
time
.
Minute
)
stats
,
err
:=
s
.
usageRepo
.
GetModelStatsWithFilters
(
ctx
,
start
,
now
,
0
,
0
,
account
.
ID
)
if
err
!=
nil
{
return
true
,
err
}
totals
:=
geminiAggregateUsage
(
stats
)
var
used
int64
if
quota
.
SharedRPM
>
0
{
used
=
totals
.
ProRequests
+
totals
.
FlashRequests
}
else
{
switch
modelClass
{
case
geminiModelFlash
:
used
=
totals
.
FlashRequests
default
:
used
=
totals
.
ProRequests
}
}
if
used
>=
limit
{
resetAt
:=
start
.
Add
(
time
.
Minute
)
// Do not persist "rate limited" status from local precheck. See note above.
log
.
Printf
(
"[Gemini PreCheck] Account %d reached minute quota (%d/%d), skip until %v"
,
account
.
ID
,
used
,
limit
,
resetAt
)
return
false
,
nil
}
}
}
log
.
Printf
(
"[Gemini PreCheck] Account %d reached daily quota (%d/%d), rate limited until %v"
,
account
.
ID
,
used
,
limit
,
resetAt
)
return
false
,
nil
}
}
return
true
,
nil
return
true
,
nil
...
@@ -176,7 +241,10 @@ func (s *RateLimitService) GeminiCooldown(ctx context.Context, account *Account)
...
@@ -176,7 +241,10 @@ func (s *RateLimitService) GeminiCooldown(ctx context.Context, account *Account)
if
account
==
nil
{
if
account
==
nil
{
return
5
*
time
.
Minute
return
5
*
time
.
Minute
}
}
return
s
.
geminiQuotaService
.
CooldownForTier
(
ctx
,
account
.
GeminiTierID
())
if
s
.
geminiQuotaService
==
nil
{
return
5
*
time
.
Minute
}
return
s
.
geminiQuotaService
.
CooldownForAccount
(
ctx
,
account
)
}
}
// handleAuthError 处理认证类错误(401/403),停止账号调度
// handleAuthError 处理认证类错误(401/403),停止账号调度
...
@@ -287,3 +355,183 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc
...
@@ -287,3 +355,183 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc
func
(
s
*
RateLimitService
)
ClearRateLimit
(
ctx
context
.
Context
,
accountID
int64
)
error
{
func
(
s
*
RateLimitService
)
ClearRateLimit
(
ctx
context
.
Context
,
accountID
int64
)
error
{
return
s
.
accountRepo
.
ClearRateLimit
(
ctx
,
accountID
)
return
s
.
accountRepo
.
ClearRateLimit
(
ctx
,
accountID
)
}
}
func
(
s
*
RateLimitService
)
ClearTempUnschedulable
(
ctx
context
.
Context
,
accountID
int64
)
error
{
if
err
:=
s
.
accountRepo
.
ClearTempUnschedulable
(
ctx
,
accountID
);
err
!=
nil
{
return
err
}
if
s
.
tempUnschedCache
!=
nil
{
if
err
:=
s
.
tempUnschedCache
.
DeleteTempUnsched
(
ctx
,
accountID
);
err
!=
nil
{
log
.
Printf
(
"DeleteTempUnsched failed for account %d: %v"
,
accountID
,
err
)
}
}
return
nil
}
func
(
s
*
RateLimitService
)
GetTempUnschedStatus
(
ctx
context
.
Context
,
accountID
int64
)
(
*
TempUnschedState
,
error
)
{
now
:=
time
.
Now
()
.
Unix
()
if
s
.
tempUnschedCache
!=
nil
{
state
,
err
:=
s
.
tempUnschedCache
.
GetTempUnsched
(
ctx
,
accountID
)
if
err
!=
nil
{
return
nil
,
err
}
if
state
!=
nil
&&
state
.
UntilUnix
>
now
{
return
state
,
nil
}
}
account
,
err
:=
s
.
accountRepo
.
GetByID
(
ctx
,
accountID
)
if
err
!=
nil
{
return
nil
,
err
}
if
account
.
TempUnschedulableUntil
==
nil
{
return
nil
,
nil
}
if
account
.
TempUnschedulableUntil
.
Unix
()
<=
now
{
return
nil
,
nil
}
state
:=
&
TempUnschedState
{
UntilUnix
:
account
.
TempUnschedulableUntil
.
Unix
(),
}
if
account
.
TempUnschedulableReason
!=
""
{
var
parsed
TempUnschedState
if
err
:=
json
.
Unmarshal
([]
byte
(
account
.
TempUnschedulableReason
),
&
parsed
);
err
==
nil
{
if
parsed
.
UntilUnix
==
0
{
parsed
.
UntilUnix
=
state
.
UntilUnix
}
state
=
&
parsed
}
else
{
state
.
ErrorMessage
=
account
.
TempUnschedulableReason
}
}
if
s
.
tempUnschedCache
!=
nil
{
if
err
:=
s
.
tempUnschedCache
.
SetTempUnsched
(
ctx
,
accountID
,
state
);
err
!=
nil
{
log
.
Printf
(
"SetTempUnsched failed for account %d: %v"
,
accountID
,
err
)
}
}
return
state
,
nil
}
func
(
s
*
RateLimitService
)
HandleTempUnschedulable
(
ctx
context
.
Context
,
account
*
Account
,
statusCode
int
,
responseBody
[]
byte
)
bool
{
if
account
==
nil
{
return
false
}
if
!
account
.
ShouldHandleErrorCode
(
statusCode
)
{
return
false
}
return
s
.
tryTempUnschedulable
(
ctx
,
account
,
statusCode
,
responseBody
)
}
const
tempUnschedBodyMaxBytes
=
64
<<
10
const
tempUnschedMessageMaxBytes
=
2048
func
(
s
*
RateLimitService
)
tryTempUnschedulable
(
ctx
context
.
Context
,
account
*
Account
,
statusCode
int
,
responseBody
[]
byte
)
bool
{
if
account
==
nil
{
return
false
}
if
!
account
.
IsTempUnschedulableEnabled
()
{
return
false
}
rules
:=
account
.
GetTempUnschedulableRules
()
if
len
(
rules
)
==
0
{
return
false
}
if
statusCode
<=
0
||
len
(
responseBody
)
==
0
{
return
false
}
body
:=
responseBody
if
len
(
body
)
>
tempUnschedBodyMaxBytes
{
body
=
body
[
:
tempUnschedBodyMaxBytes
]
}
bodyLower
:=
strings
.
ToLower
(
string
(
body
))
for
idx
,
rule
:=
range
rules
{
if
rule
.
ErrorCode
!=
statusCode
||
len
(
rule
.
Keywords
)
==
0
{
continue
}
matchedKeyword
:=
matchTempUnschedKeyword
(
bodyLower
,
rule
.
Keywords
)
if
matchedKeyword
==
""
{
continue
}
if
s
.
triggerTempUnschedulable
(
ctx
,
account
,
rule
,
idx
,
statusCode
,
matchedKeyword
,
responseBody
)
{
return
true
}
}
return
false
}
func
matchTempUnschedKeyword
(
bodyLower
string
,
keywords
[]
string
)
string
{
if
bodyLower
==
""
{
return
""
}
for
_
,
keyword
:=
range
keywords
{
k
:=
strings
.
TrimSpace
(
keyword
)
if
k
==
""
{
continue
}
if
strings
.
Contains
(
bodyLower
,
strings
.
ToLower
(
k
))
{
return
k
}
}
return
""
}
func
(
s
*
RateLimitService
)
triggerTempUnschedulable
(
ctx
context
.
Context
,
account
*
Account
,
rule
TempUnschedulableRule
,
ruleIndex
int
,
statusCode
int
,
matchedKeyword
string
,
responseBody
[]
byte
)
bool
{
if
account
==
nil
{
return
false
}
if
rule
.
DurationMinutes
<=
0
{
return
false
}
now
:=
time
.
Now
()
until
:=
now
.
Add
(
time
.
Duration
(
rule
.
DurationMinutes
)
*
time
.
Minute
)
state
:=
&
TempUnschedState
{
UntilUnix
:
until
.
Unix
(),
TriggeredAtUnix
:
now
.
Unix
(),
StatusCode
:
statusCode
,
MatchedKeyword
:
matchedKeyword
,
RuleIndex
:
ruleIndex
,
ErrorMessage
:
truncateTempUnschedMessage
(
responseBody
,
tempUnschedMessageMaxBytes
),
}
reason
:=
""
if
raw
,
err
:=
json
.
Marshal
(
state
);
err
==
nil
{
reason
=
string
(
raw
)
}
if
reason
==
""
{
reason
=
strings
.
TrimSpace
(
state
.
ErrorMessage
)
}
if
err
:=
s
.
accountRepo
.
SetTempUnschedulable
(
ctx
,
account
.
ID
,
until
,
reason
);
err
!=
nil
{
log
.
Printf
(
"SetTempUnschedulable failed for account %d: %v"
,
account
.
ID
,
err
)
return
false
}
if
s
.
tempUnschedCache
!=
nil
{
if
err
:=
s
.
tempUnschedCache
.
SetTempUnsched
(
ctx
,
account
.
ID
,
state
);
err
!=
nil
{
log
.
Printf
(
"SetTempUnsched cache failed for account %d: %v"
,
account
.
ID
,
err
)
}
}
log
.
Printf
(
"Account %d temp unschedulable until %v (rule %d, code %d)"
,
account
.
ID
,
until
,
ruleIndex
,
statusCode
)
return
true
}
func
truncateTempUnschedMessage
(
body
[]
byte
,
maxBytes
int
)
string
{
if
maxBytes
<=
0
||
len
(
body
)
==
0
{
return
""
}
if
len
(
body
)
>
maxBytes
{
body
=
body
[
:
maxBytes
]
}
return
strings
.
TrimSpace
(
string
(
body
))
}
backend/internal/service/setting_service.go
View file @
7dddd065
...
@@ -61,9 +61,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
...
@@ -61,9 +61,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeySiteName
,
SettingKeySiteName
,
SettingKeySiteLogo
,
SettingKeySiteLogo
,
SettingKeySiteSubtitle
,
SettingKeySiteSubtitle
,
SettingKeyA
pi
BaseU
rl
,
SettingKeyA
PI
BaseU
RL
,
SettingKeyContactInfo
,
SettingKeyContactInfo
,
SettingKeyDocU
rl
,
SettingKeyDocU
RL
,
}
}
settings
,
err
:=
s
.
settingRepo
.
GetMultiple
(
ctx
,
keys
)
settings
,
err
:=
s
.
settingRepo
.
GetMultiple
(
ctx
,
keys
)
...
@@ -79,9 +79,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
...
@@ -79,9 +79,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SiteName
:
s
.
getStringOrDefault
(
settings
,
SettingKeySiteName
,
"Sub2API"
),
SiteName
:
s
.
getStringOrDefault
(
settings
,
SettingKeySiteName
,
"Sub2API"
),
SiteLogo
:
settings
[
SettingKeySiteLogo
],
SiteLogo
:
settings
[
SettingKeySiteLogo
],
SiteSubtitle
:
s
.
getStringOrDefault
(
settings
,
SettingKeySiteSubtitle
,
"Subscription to API Conversion Platform"
),
SiteSubtitle
:
s
.
getStringOrDefault
(
settings
,
SettingKeySiteSubtitle
,
"Subscription to API Conversion Platform"
),
A
pi
BaseU
rl
:
settings
[
SettingKeyA
pi
BaseU
rl
],
A
PI
BaseU
RL
:
settings
[
SettingKeyA
PI
BaseU
RL
],
ContactInfo
:
settings
[
SettingKeyContactInfo
],
ContactInfo
:
settings
[
SettingKeyContactInfo
],
DocU
rl
:
settings
[
SettingKeyDocU
rl
],
DocU
RL
:
settings
[
SettingKeyDocU
RL
],
},
nil
},
nil
}
}
...
@@ -94,15 +94,15 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
...
@@ -94,15 +94,15 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates
[
SettingKeyEmailVerifyEnabled
]
=
strconv
.
FormatBool
(
settings
.
EmailVerifyEnabled
)
updates
[
SettingKeyEmailVerifyEnabled
]
=
strconv
.
FormatBool
(
settings
.
EmailVerifyEnabled
)
// 邮件服务设置(只有非空才更新密码)
// 邮件服务设置(只有非空才更新密码)
updates
[
SettingKeyS
mtp
Host
]
=
settings
.
S
mtp
Host
updates
[
SettingKeyS
MTP
Host
]
=
settings
.
S
MTP
Host
updates
[
SettingKeyS
mtp
Port
]
=
strconv
.
Itoa
(
settings
.
S
mtp
Port
)
updates
[
SettingKeyS
MTP
Port
]
=
strconv
.
Itoa
(
settings
.
S
MTP
Port
)
updates
[
SettingKeyS
mtp
Username
]
=
settings
.
S
mtp
Username
updates
[
SettingKeyS
MTP
Username
]
=
settings
.
S
MTP
Username
if
settings
.
S
mtp
Password
!=
""
{
if
settings
.
S
MTP
Password
!=
""
{
updates
[
SettingKeyS
mtp
Password
]
=
settings
.
S
mtp
Password
updates
[
SettingKeyS
MTP
Password
]
=
settings
.
S
MTP
Password
}
}
updates
[
SettingKeyS
mtp
From
]
=
settings
.
S
mtp
From
updates
[
SettingKeyS
MTP
From
]
=
settings
.
S
MTP
From
updates
[
SettingKeyS
mtp
FromName
]
=
settings
.
S
mtp
FromName
updates
[
SettingKeyS
MTP
FromName
]
=
settings
.
S
MTP
FromName
updates
[
SettingKeyS
mtp
UseTLS
]
=
strconv
.
FormatBool
(
settings
.
S
mtp
UseTLS
)
updates
[
SettingKeyS
MTP
UseTLS
]
=
strconv
.
FormatBool
(
settings
.
S
MTP
UseTLS
)
// Cloudflare Turnstile 设置(只有非空才更新密钥)
// Cloudflare Turnstile 设置(只有非空才更新密钥)
updates
[
SettingKeyTurnstileEnabled
]
=
strconv
.
FormatBool
(
settings
.
TurnstileEnabled
)
updates
[
SettingKeyTurnstileEnabled
]
=
strconv
.
FormatBool
(
settings
.
TurnstileEnabled
)
...
@@ -115,14 +115,21 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
...
@@ -115,14 +115,21 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates
[
SettingKeySiteName
]
=
settings
.
SiteName
updates
[
SettingKeySiteName
]
=
settings
.
SiteName
updates
[
SettingKeySiteLogo
]
=
settings
.
SiteLogo
updates
[
SettingKeySiteLogo
]
=
settings
.
SiteLogo
updates
[
SettingKeySiteSubtitle
]
=
settings
.
SiteSubtitle
updates
[
SettingKeySiteSubtitle
]
=
settings
.
SiteSubtitle
updates
[
SettingKeyA
pi
BaseU
rl
]
=
settings
.
A
pi
BaseU
rl
updates
[
SettingKeyA
PI
BaseU
RL
]
=
settings
.
A
PI
BaseU
RL
updates
[
SettingKeyContactInfo
]
=
settings
.
ContactInfo
updates
[
SettingKeyContactInfo
]
=
settings
.
ContactInfo
updates
[
SettingKeyDocU
rl
]
=
settings
.
DocU
rl
updates
[
SettingKeyDocU
RL
]
=
settings
.
DocU
RL
// 默认配置
// 默认配置
updates
[
SettingKeyDefaultConcurrency
]
=
strconv
.
Itoa
(
settings
.
DefaultConcurrency
)
updates
[
SettingKeyDefaultConcurrency
]
=
strconv
.
Itoa
(
settings
.
DefaultConcurrency
)
updates
[
SettingKeyDefaultBalance
]
=
strconv
.
FormatFloat
(
settings
.
DefaultBalance
,
'f'
,
8
,
64
)
updates
[
SettingKeyDefaultBalance
]
=
strconv
.
FormatFloat
(
settings
.
DefaultBalance
,
'f'
,
8
,
64
)
// Model fallback configuration
updates
[
SettingKeyEnableModelFallback
]
=
strconv
.
FormatBool
(
settings
.
EnableModelFallback
)
updates
[
SettingKeyFallbackModelAnthropic
]
=
settings
.
FallbackModelAnthropic
updates
[
SettingKeyFallbackModelOpenAI
]
=
settings
.
FallbackModelOpenAI
updates
[
SettingKeyFallbackModelGemini
]
=
settings
.
FallbackModelGemini
updates
[
SettingKeyFallbackModelAntigravity
]
=
settings
.
FallbackModelAntigravity
return
s
.
settingRepo
.
SetMultiple
(
ctx
,
updates
)
return
s
.
settingRepo
.
SetMultiple
(
ctx
,
updates
)
}
}
...
@@ -198,8 +205,14 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
...
@@ -198,8 +205,14 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeySiteLogo
:
""
,
SettingKeySiteLogo
:
""
,
SettingKeyDefaultConcurrency
:
strconv
.
Itoa
(
s
.
cfg
.
Default
.
UserConcurrency
),
SettingKeyDefaultConcurrency
:
strconv
.
Itoa
(
s
.
cfg
.
Default
.
UserConcurrency
),
SettingKeyDefaultBalance
:
strconv
.
FormatFloat
(
s
.
cfg
.
Default
.
UserBalance
,
'f'
,
8
,
64
),
SettingKeyDefaultBalance
:
strconv
.
FormatFloat
(
s
.
cfg
.
Default
.
UserBalance
,
'f'
,
8
,
64
),
SettingKeySmtpPort
:
"587"
,
SettingKeySMTPPort
:
"587"
,
SettingKeySmtpUseTLS
:
"false"
,
SettingKeySMTPUseTLS
:
"false"
,
// Model fallback defaults
SettingKeyEnableModelFallback
:
"false"
,
SettingKeyFallbackModelAnthropic
:
"claude-3-5-sonnet-20241022"
,
SettingKeyFallbackModelOpenAI
:
"gpt-4o"
,
SettingKeyFallbackModelGemini
:
"gemini-2.5-pro"
,
SettingKeyFallbackModelAntigravity
:
"gemini-2.5-pro"
,
}
}
return
s
.
settingRepo
.
SetMultiple
(
ctx
,
defaults
)
return
s
.
settingRepo
.
SetMultiple
(
ctx
,
defaults
)
...
@@ -208,30 +221,30 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
...
@@ -208,30 +221,30 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// parseSettings 解析设置到结构体
// parseSettings 解析设置到结构体
func
(
s
*
SettingService
)
parseSettings
(
settings
map
[
string
]
string
)
*
SystemSettings
{
func
(
s
*
SettingService
)
parseSettings
(
settings
map
[
string
]
string
)
*
SystemSettings
{
result
:=
&
SystemSettings
{
result
:=
&
SystemSettings
{
RegistrationEnabled
:
settings
[
SettingKeyRegistrationEnabled
]
==
"true"
,
RegistrationEnabled
:
settings
[
SettingKeyRegistrationEnabled
]
==
"true"
,
EmailVerifyEnabled
:
settings
[
SettingKeyEmailVerifyEnabled
]
==
"true"
,
EmailVerifyEnabled
:
settings
[
SettingKeyEmailVerifyEnabled
]
==
"true"
,
S
mtp
Host
:
settings
[
SettingKeyS
mtp
Host
],
S
MTP
Host
:
settings
[
SettingKeyS
MTP
Host
],
S
mtp
Username
:
settings
[
SettingKeyS
mtp
Username
],
S
MTP
Username
:
settings
[
SettingKeyS
MTP
Username
],
S
mtp
From
:
settings
[
SettingKeyS
mtp
From
],
S
MTP
From
:
settings
[
SettingKeyS
MTP
From
],
S
mtp
FromName
:
settings
[
SettingKeyS
mtp
FromName
],
S
MTP
FromName
:
settings
[
SettingKeyS
MTP
FromName
],
S
mtp
UseTLS
:
settings
[
SettingKeyS
mtp
UseTLS
]
==
"true"
,
S
MTP
UseTLS
:
settings
[
SettingKeyS
MTP
UseTLS
]
==
"true"
,
S
mtp
PasswordConfigured
:
settings
[
SettingKeyS
mtp
Password
]
!=
""
,
S
MTP
PasswordConfigured
:
settings
[
SettingKeyS
MTP
Password
]
!=
""
,
TurnstileEnabled
:
settings
[
SettingKeyTurnstileEnabled
]
==
"true"
,
TurnstileEnabled
:
settings
[
SettingKeyTurnstileEnabled
]
==
"true"
,
TurnstileSiteKey
:
settings
[
SettingKeyTurnstileSiteKey
],
TurnstileSiteKey
:
settings
[
SettingKeyTurnstileSiteKey
],
TurnstileSecretKeyConfigured
:
settings
[
SettingKeyTurnstileSecretKey
]
!=
""
,
TurnstileSecretKeyConfigured
:
settings
[
SettingKeyTurnstileSecretKey
]
!=
""
,
SiteName
:
s
.
getStringOrDefault
(
settings
,
SettingKeySiteName
,
"Sub2API"
),
SiteName
:
s
.
getStringOrDefault
(
settings
,
SettingKeySiteName
,
"Sub2API"
),
SiteLogo
:
settings
[
SettingKeySiteLogo
],
SiteLogo
:
settings
[
SettingKeySiteLogo
],
SiteSubtitle
:
s
.
getStringOrDefault
(
settings
,
SettingKeySiteSubtitle
,
"Subscription to API Conversion Platform"
),
SiteSubtitle
:
s
.
getStringOrDefault
(
settings
,
SettingKeySiteSubtitle
,
"Subscription to API Conversion Platform"
),
A
pi
BaseU
rl
:
settings
[
SettingKeyA
pi
BaseU
rl
],
A
PI
BaseU
RL
:
settings
[
SettingKeyA
PI
BaseU
RL
],
ContactInfo
:
settings
[
SettingKeyContactInfo
],
ContactInfo
:
settings
[
SettingKeyContactInfo
],
DocU
rl
:
settings
[
SettingKeyDocU
rl
],
DocU
RL
:
settings
[
SettingKeyDocU
RL
],
}
}
// 解析整数类型
// 解析整数类型
if
port
,
err
:=
strconv
.
Atoi
(
settings
[
SettingKeyS
mtp
Port
]);
err
==
nil
{
if
port
,
err
:=
strconv
.
Atoi
(
settings
[
SettingKeyS
MTP
Port
]);
err
==
nil
{
result
.
S
mtp
Port
=
port
result
.
S
MTP
Port
=
port
}
else
{
}
else
{
result
.
S
mtp
Port
=
587
result
.
S
MTP
Port
=
587
}
}
if
concurrency
,
err
:=
strconv
.
Atoi
(
settings
[
SettingKeyDefaultConcurrency
]);
err
==
nil
{
if
concurrency
,
err
:=
strconv
.
Atoi
(
settings
[
SettingKeyDefaultConcurrency
]);
err
==
nil
{
...
@@ -247,6 +260,17 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
...
@@ -247,6 +260,17 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result
.
DefaultBalance
=
s
.
cfg
.
Default
.
UserBalance
result
.
DefaultBalance
=
s
.
cfg
.
Default
.
UserBalance
}
}
// 敏感信息直接返回,方便测试连接时使用
result
.
SMTPPassword
=
settings
[
SettingKeySMTPPassword
]
result
.
TurnstileSecretKey
=
settings
[
SettingKeyTurnstileSecretKey
]
// Model fallback settings
result
.
EnableModelFallback
=
settings
[
SettingKeyEnableModelFallback
]
==
"true"
result
.
FallbackModelAnthropic
=
s
.
getStringOrDefault
(
settings
,
SettingKeyFallbackModelAnthropic
,
"claude-3-5-sonnet-20241022"
)
result
.
FallbackModelOpenAI
=
s
.
getStringOrDefault
(
settings
,
SettingKeyFallbackModelOpenAI
,
"gpt-4o"
)
result
.
FallbackModelGemini
=
s
.
getStringOrDefault
(
settings
,
SettingKeyFallbackModelGemini
,
"gemini-2.5-pro"
)
result
.
FallbackModelAntigravity
=
s
.
getStringOrDefault
(
settings
,
SettingKeyFallbackModelAntigravity
,
"gemini-2.5-pro"
)
return
result
return
result
}
}
...
@@ -276,28 +300,28 @@ func (s *SettingService) GetTurnstileSecretKey(ctx context.Context) string {
...
@@ -276,28 +300,28 @@ func (s *SettingService) GetTurnstileSecretKey(ctx context.Context) string {
return
value
return
value
}
}
// GenerateAdminA
pi
Key 生成新的管理员 API Key
// GenerateAdminA
PI
Key 生成新的管理员 API Key
func
(
s
*
SettingService
)
GenerateAdminA
pi
Key
(
ctx
context
.
Context
)
(
string
,
error
)
{
func
(
s
*
SettingService
)
GenerateAdminA
PI
Key
(
ctx
context
.
Context
)
(
string
,
error
)
{
// 生成 32 字节随机数 = 64 位十六进制字符
// 生成 32 字节随机数 = 64 位十六进制字符
bytes
:=
make
([]
byte
,
32
)
bytes
:=
make
([]
byte
,
32
)
if
_
,
err
:=
rand
.
Read
(
bytes
);
err
!=
nil
{
if
_
,
err
:=
rand
.
Read
(
bytes
);
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"generate random bytes: %w"
,
err
)
return
""
,
fmt
.
Errorf
(
"generate random bytes: %w"
,
err
)
}
}
key
:=
AdminA
pi
KeyPrefix
+
hex
.
EncodeToString
(
bytes
)
key
:=
AdminA
PI
KeyPrefix
+
hex
.
EncodeToString
(
bytes
)
// 存储到 settings 表
// 存储到 settings 表
if
err
:=
s
.
settingRepo
.
Set
(
ctx
,
SettingKeyAdminA
pi
Key
,
key
);
err
!=
nil
{
if
err
:=
s
.
settingRepo
.
Set
(
ctx
,
SettingKeyAdminA
PI
Key
,
key
);
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"save admin api key: %w"
,
err
)
return
""
,
fmt
.
Errorf
(
"save admin api key: %w"
,
err
)
}
}
return
key
,
nil
return
key
,
nil
}
}
// GetAdminA
pi
KeyStatus 获取管理员 API Key 状态
// GetAdminA
PI
KeyStatus 获取管理员 API Key 状态
// 返回脱敏的 key、是否存在、错误
// 返回脱敏的 key、是否存在、错误
func
(
s
*
SettingService
)
GetAdminA
pi
KeyStatus
(
ctx
context
.
Context
)
(
maskedKey
string
,
exists
bool
,
err
error
)
{
func
(
s
*
SettingService
)
GetAdminA
PI
KeyStatus
(
ctx
context
.
Context
)
(
maskedKey
string
,
exists
bool
,
err
error
)
{
key
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyAdminA
pi
Key
)
key
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyAdminA
PI
Key
)
if
err
!=
nil
{
if
err
!=
nil
{
if
errors
.
Is
(
err
,
ErrSettingNotFound
)
{
if
errors
.
Is
(
err
,
ErrSettingNotFound
)
{
return
""
,
false
,
nil
return
""
,
false
,
nil
...
@@ -318,10 +342,10 @@ func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey st
...
@@ -318,10 +342,10 @@ func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey st
return
maskedKey
,
true
,
nil
return
maskedKey
,
true
,
nil
}
}
// GetAdminA
pi
Key 获取完整的管理员 API Key(仅供内部验证使用)
// GetAdminA
PI
Key 获取完整的管理员 API Key(仅供内部验证使用)
// 如果未配置返回空字符串和 nil 错误,只有数据库错误时才返回 error
// 如果未配置返回空字符串和 nil 错误,只有数据库错误时才返回 error
func
(
s
*
SettingService
)
GetAdminA
pi
Key
(
ctx
context
.
Context
)
(
string
,
error
)
{
func
(
s
*
SettingService
)
GetAdminA
PI
Key
(
ctx
context
.
Context
)
(
string
,
error
)
{
key
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyAdminA
pi
Key
)
key
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyAdminA
PI
Key
)
if
err
!=
nil
{
if
err
!=
nil
{
if
errors
.
Is
(
err
,
ErrSettingNotFound
)
{
if
errors
.
Is
(
err
,
ErrSettingNotFound
)
{
return
""
,
nil
// 未配置,返回空字符串
return
""
,
nil
// 未配置,返回空字符串
...
@@ -331,7 +355,45 @@ func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) {
...
@@ -331,7 +355,45 @@ func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) {
return
key
,
nil
return
key
,
nil
}
}
// DeleteAdminApiKey 删除管理员 API Key
// DeleteAdminAPIKey 删除管理员 API Key
func
(
s
*
SettingService
)
DeleteAdminApiKey
(
ctx
context
.
Context
)
error
{
func
(
s
*
SettingService
)
DeleteAdminAPIKey
(
ctx
context
.
Context
)
error
{
return
s
.
settingRepo
.
Delete
(
ctx
,
SettingKeyAdminApiKey
)
return
s
.
settingRepo
.
Delete
(
ctx
,
SettingKeyAdminAPIKey
)
}
// IsModelFallbackEnabled 检查是否启用模型兜底机制
func
(
s
*
SettingService
)
IsModelFallbackEnabled
(
ctx
context
.
Context
)
bool
{
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyEnableModelFallback
)
if
err
!=
nil
{
return
false
// Default: disabled
}
return
value
==
"true"
}
// GetFallbackModel 获取指定平台的兜底模型
func
(
s
*
SettingService
)
GetFallbackModel
(
ctx
context
.
Context
,
platform
string
)
string
{
var
key
string
var
defaultModel
string
switch
platform
{
case
PlatformAnthropic
:
key
=
SettingKeyFallbackModelAnthropic
defaultModel
=
"claude-3-5-sonnet-20241022"
case
PlatformOpenAI
:
key
=
SettingKeyFallbackModelOpenAI
defaultModel
=
"gpt-4o"
case
PlatformGemini
:
key
=
SettingKeyFallbackModelGemini
defaultModel
=
"gemini-2.5-pro"
case
PlatformAntigravity
:
key
=
SettingKeyFallbackModelAntigravity
defaultModel
=
"gemini-2.5-pro"
default
:
return
""
}
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
key
)
if
err
!=
nil
||
value
==
""
{
return
defaultModel
}
return
value
}
}
backend/internal/service/settings_view.go
View file @
7dddd065
...
@@ -4,29 +4,36 @@ type SystemSettings struct {
...
@@ -4,29 +4,36 @@ type SystemSettings struct {
RegistrationEnabled
bool
RegistrationEnabled
bool
EmailVerifyEnabled
bool
EmailVerifyEnabled
bool
S
mtp
Host
string
S
MTP
Host
string
S
mtp
Port
int
S
MTP
Port
int
S
mtp
Username
string
S
MTP
Username
string
S
mtp
Password
string
S
MTP
Password
string
S
mtp
PasswordConfigured
bool
S
MTP
PasswordConfigured
bool
S
mtp
From
string
S
MTP
From
string
S
mtp
FromName
string
S
MTP
FromName
string
S
mtp
UseTLS
bool
S
MTP
UseTLS
bool
TurnstileEnabled
bool
TurnstileEnabled
bool
TurnstileSiteKey
string
TurnstileSiteKey
string
TurnstileSecretKey
string
TurnstileSecretKey
string
TurnstileSecretKeyConfigured
bool
TurnstileSecretKeyConfigured
bool
SiteName
string
SiteName
string
SiteLogo
string
SiteLogo
string
SiteSubtitle
string
SiteSubtitle
string
A
pi
BaseU
rl
string
A
PI
BaseU
RL
string
ContactInfo
string
ContactInfo
string
DocU
rl
string
DocU
RL
string
DefaultConcurrency
int
DefaultConcurrency
int
DefaultBalance
float64
DefaultBalance
float64
// Model fallback configuration
EnableModelFallback
bool
`json:"enable_model_fallback"`
FallbackModelAnthropic
string
`json:"fallback_model_anthropic"`
FallbackModelOpenAI
string
`json:"fallback_model_openai"`
FallbackModelGemini
string
`json:"fallback_model_gemini"`
FallbackModelAntigravity
string
`json:"fallback_model_antigravity"`
}
}
type
PublicSettings
struct
{
type
PublicSettings
struct
{
...
@@ -37,8 +44,8 @@ type PublicSettings struct {
...
@@ -37,8 +44,8 @@ type PublicSettings struct {
SiteName
string
SiteName
string
SiteLogo
string
SiteLogo
string
SiteSubtitle
string
SiteSubtitle
string
A
pi
BaseU
rl
string
A
PI
BaseU
RL
string
ContactInfo
string
ContactInfo
string
DocU
rl
string
DocU
RL
string
Version
string
Version
string
}
}
backend/internal/service/temp_unsched.go
0 → 100644
View file @
7dddd065
package
service
import
(
"context"
)
// TempUnschedState 临时不可调度状态
type
TempUnschedState
struct
{
UntilUnix
int64
`json:"until_unix"`
// 解除时间(Unix 时间戳)
TriggeredAtUnix
int64
`json:"triggered_at_unix"`
// 触发时间(Unix 时间戳)
StatusCode
int
`json:"status_code"`
// 触发的错误码
MatchedKeyword
string
`json:"matched_keyword"`
// 匹配的关键词
RuleIndex
int
`json:"rule_index"`
// 触发的规则索引
ErrorMessage
string
`json:"error_message"`
// 错误消息
}
// TempUnschedCache 临时不可调度缓存接口
type
TempUnschedCache
interface
{
SetTempUnsched
(
ctx
context
.
Context
,
accountID
int64
,
state
*
TempUnschedState
)
error
GetTempUnsched
(
ctx
context
.
Context
,
accountID
int64
)
(
*
TempUnschedState
,
error
)
DeleteTempUnsched
(
ctx
context
.
Context
,
accountID
int64
)
error
}
backend/internal/service/token_refresher_test.go
View file @
7dddd065
...
@@ -197,7 +197,7 @@ func TestClaudeTokenRefresher_CanRefresh(t *testing.T) {
...
@@ -197,7 +197,7 @@ func TestClaudeTokenRefresher_CanRefresh(t *testing.T) {
{
{
name
:
"anthropic api-key - cannot refresh"
,
name
:
"anthropic api-key - cannot refresh"
,
platform
:
PlatformAnthropic
,
platform
:
PlatformAnthropic
,
accType
:
AccountTypeA
pi
Key
,
accType
:
AccountTypeA
PI
Key
,
want
:
false
,
want
:
false
,
},
},
{
{
...
...
backend/internal/service/update_service.go
View file @
7dddd065
...
@@ -79,7 +79,7 @@ type ReleaseInfo struct {
...
@@ -79,7 +79,7 @@ type ReleaseInfo struct {
Name
string
`json:"name"`
Name
string
`json:"name"`
Body
string
`json:"body"`
Body
string
`json:"body"`
PublishedAt
string
`json:"published_at"`
PublishedAt
string
`json:"published_at"`
H
tml
URL
string
`json:"html_url"`
H
TML
URL
string
`json:"html_url"`
Assets
[]
Asset
`json:"assets,omitempty"`
Assets
[]
Asset
`json:"assets,omitempty"`
}
}
...
@@ -96,13 +96,13 @@ type GitHubRelease struct {
...
@@ -96,13 +96,13 @@ type GitHubRelease struct {
Name
string
`json:"name"`
Name
string
`json:"name"`
Body
string
`json:"body"`
Body
string
`json:"body"`
PublishedAt
string
`json:"published_at"`
PublishedAt
string
`json:"published_at"`
H
tmlUrl
string
`json:"html_url"`
H
TMLURL
string
`json:"html_url"`
Assets
[]
GitHubAsset
`json:"assets"`
Assets
[]
GitHubAsset
`json:"assets"`
}
}
type
GitHubAsset
struct
{
type
GitHubAsset
struct
{
Name
string
`json:"name"`
Name
string
`json:"name"`
BrowserDownloadU
rl
string
`json:"browser_download_url"`
BrowserDownloadU
RL
string
`json:"browser_download_url"`
Size
int64
`json:"size"`
Size
int64
`json:"size"`
}
}
...
@@ -285,7 +285,7 @@ func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, er
...
@@ -285,7 +285,7 @@ func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, er
for
i
,
a
:=
range
release
.
Assets
{
for
i
,
a
:=
range
release
.
Assets
{
assets
[
i
]
=
Asset
{
assets
[
i
]
=
Asset
{
Name
:
a
.
Name
,
Name
:
a
.
Name
,
DownloadURL
:
a
.
BrowserDownloadU
rl
,
DownloadURL
:
a
.
BrowserDownloadU
RL
,
Size
:
a
.
Size
,
Size
:
a
.
Size
,
}
}
}
}
...
@@ -298,7 +298,7 @@ func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, er
...
@@ -298,7 +298,7 @@ func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, er
Name
:
release
.
Name
,
Name
:
release
.
Name
,
Body
:
release
.
Body
,
Body
:
release
.
Body
,
PublishedAt
:
release
.
PublishedAt
,
PublishedAt
:
release
.
PublishedAt
,
H
tml
URL
:
release
.
H
tmlUrl
,
H
TML
URL
:
release
.
H
TMLURL
,
Assets
:
assets
,
Assets
:
assets
,
},
},
Cached
:
false
,
Cached
:
false
,
...
...
backend/internal/service/usage_log.go
View file @
7dddd065
...
@@ -10,7 +10,7 @@ const (
...
@@ -10,7 +10,7 @@ const (
type
UsageLog
struct
{
type
UsageLog
struct
{
ID
int64
ID
int64
UserID
int64
UserID
int64
A
pi
KeyID
int64
A
PI
KeyID
int64
AccountID
int64
AccountID
int64
RequestID
string
RequestID
string
Model
string
Model
string
...
@@ -42,7 +42,7 @@ type UsageLog struct {
...
@@ -42,7 +42,7 @@ type UsageLog struct {
CreatedAt
time
.
Time
CreatedAt
time
.
Time
User
*
User
User
*
User
A
pi
Key
*
A
pi
Key
A
PI
Key
*
A
PI
Key
Account
*
Account
Account
*
Account
Group
*
Group
Group
*
Group
Subscription
*
UserSubscription
Subscription
*
UserSubscription
...
...
backend/internal/service/usage_service.go
View file @
7dddd065
...
@@ -2,9 +2,11 @@ package service
...
@@ -2,9 +2,11 @@ package service
import
(
import
(
"context"
"context"
"errors"
"fmt"
"fmt"
"time"
"time"
dbent
"github.com/Wei-Shaw/sub2api/ent"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
...
@@ -17,7 +19,7 @@ var (
...
@@ -17,7 +19,7 @@ var (
// CreateUsageLogRequest 创建使用日志请求
// CreateUsageLogRequest 创建使用日志请求
type
CreateUsageLogRequest
struct
{
type
CreateUsageLogRequest
struct
{
UserID
int64
`json:"user_id"`
UserID
int64
`json:"user_id"`
A
pi
KeyID
int64
`json:"api_key_id"`
A
PI
KeyID
int64
`json:"api_key_id"`
AccountID
int64
`json:"account_id"`
AccountID
int64
`json:"account_id"`
RequestID
string
`json:"request_id"`
RequestID
string
`json:"request_id"`
Model
string
`json:"model"`
Model
string
`json:"model"`
...
@@ -54,20 +56,34 @@ type UsageStats struct {
...
@@ -54,20 +56,34 @@ type UsageStats struct {
type
UsageService
struct
{
type
UsageService
struct
{
usageRepo
UsageLogRepository
usageRepo
UsageLogRepository
userRepo
UserRepository
userRepo
UserRepository
entClient
*
dbent
.
Client
}
}
// NewUsageService 创建使用统计服务实例
// NewUsageService 创建使用统计服务实例
func
NewUsageService
(
usageRepo
UsageLogRepository
,
userRepo
UserRepository
)
*
UsageService
{
func
NewUsageService
(
usageRepo
UsageLogRepository
,
userRepo
UserRepository
,
entClient
*
dbent
.
Client
)
*
UsageService
{
return
&
UsageService
{
return
&
UsageService
{
usageRepo
:
usageRepo
,
usageRepo
:
usageRepo
,
userRepo
:
userRepo
,
userRepo
:
userRepo
,
entClient
:
entClient
,
}
}
}
}
// Create 创建使用日志
// Create 创建使用日志
func
(
s
*
UsageService
)
Create
(
ctx
context
.
Context
,
req
CreateUsageLogRequest
)
(
*
UsageLog
,
error
)
{
func
(
s
*
UsageService
)
Create
(
ctx
context
.
Context
,
req
CreateUsageLogRequest
)
(
*
UsageLog
,
error
)
{
// 使用数据库事务保证「使用日志插入」与「扣费」的原子性,避免重复扣费或漏扣风险。
tx
,
err
:=
s
.
entClient
.
Tx
(
ctx
)
if
err
!=
nil
&&
!
errors
.
Is
(
err
,
dbent
.
ErrTxStarted
)
{
return
nil
,
fmt
.
Errorf
(
"begin transaction: %w"
,
err
)
}
txCtx
:=
ctx
if
err
==
nil
{
defer
func
()
{
_
=
tx
.
Rollback
()
}()
txCtx
=
dbent
.
NewTxContext
(
ctx
,
tx
)
}
// 验证用户存在
// 验证用户存在
_
,
err
:
=
s
.
userRepo
.
GetByID
(
c
tx
,
req
.
UserID
)
_
,
err
=
s
.
userRepo
.
GetByID
(
txC
tx
,
req
.
UserID
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get user: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"get user: %w"
,
err
)
}
}
...
@@ -75,7 +91,7 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
...
@@ -75,7 +91,7 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
// 创建使用日志
// 创建使用日志
usageLog
:=
&
UsageLog
{
usageLog
:=
&
UsageLog
{
UserID
:
req
.
UserID
,
UserID
:
req
.
UserID
,
A
pi
KeyID
:
req
.
A
pi
KeyID
,
A
PI
KeyID
:
req
.
A
PI
KeyID
,
AccountID
:
req
.
AccountID
,
AccountID
:
req
.
AccountID
,
RequestID
:
req
.
RequestID
,
RequestID
:
req
.
RequestID
,
Model
:
req
.
Model
,
Model
:
req
.
Model
,
...
@@ -96,17 +112,24 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
...
@@ -96,17 +112,24 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
DurationMs
:
req
.
DurationMs
,
DurationMs
:
req
.
DurationMs
,
}
}
if
err
:=
s
.
usageRepo
.
Create
(
ctx
,
usageLog
);
err
!=
nil
{
inserted
,
err
:=
s
.
usageRepo
.
Create
(
txCtx
,
usageLog
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"create usage log: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"create usage log: %w"
,
err
)
}
}
// 扣除用户余额
// 扣除用户余额
if
req
.
ActualCost
>
0
{
if
inserted
&&
req
.
ActualCost
>
0
{
if
err
:=
s
.
userRepo
.
UpdateBalance
(
c
tx
,
req
.
UserID
,
-
req
.
ActualCost
);
err
!=
nil
{
if
err
:=
s
.
userRepo
.
UpdateBalance
(
txC
tx
,
req
.
UserID
,
-
req
.
ActualCost
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"update user balance: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"update user balance: %w"
,
err
)
}
}
}
}
if
tx
!=
nil
{
if
err
:=
tx
.
Commit
();
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"commit transaction: %w"
,
err
)
}
}
return
usageLog
,
nil
return
usageLog
,
nil
}
}
...
@@ -128,9 +151,9 @@ func (s *UsageService) ListByUser(ctx context.Context, userID int64, params pagi
...
@@ -128,9 +151,9 @@ func (s *UsageService) ListByUser(ctx context.Context, userID int64, params pagi
return
logs
,
pagination
,
nil
return
logs
,
pagination
,
nil
}
}
// ListByA
pi
Key 获取API Key的使用日志列表
// ListByA
PI
Key 获取API Key的使用日志列表
func
(
s
*
UsageService
)
ListByA
pi
Key
(
ctx
context
.
Context
,
apiKeyID
int64
,
params
pagination
.
PaginationParams
)
([]
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
s
*
UsageService
)
ListByA
PI
Key
(
ctx
context
.
Context
,
apiKeyID
int64
,
params
pagination
.
PaginationParams
)
([]
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
logs
,
pagination
,
err
:=
s
.
usageRepo
.
ListByA
pi
Key
(
ctx
,
apiKeyID
,
params
)
logs
,
pagination
,
err
:=
s
.
usageRepo
.
ListByA
PI
Key
(
ctx
,
apiKeyID
,
params
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"list usage logs: %w"
,
err
)
return
nil
,
nil
,
fmt
.
Errorf
(
"list usage logs: %w"
,
err
)
}
}
...
@@ -165,9 +188,9 @@ func (s *UsageService) GetStatsByUser(ctx context.Context, userID int64, startTi
...
@@ -165,9 +188,9 @@ func (s *UsageService) GetStatsByUser(ctx context.Context, userID int64, startTi
},
nil
},
nil
}
}
// GetStatsByA
pi
Key 获取API Key的使用统计
// GetStatsByA
PI
Key 获取API Key的使用统计
func
(
s
*
UsageService
)
GetStatsByA
pi
Key
(
ctx
context
.
Context
,
apiKeyID
int64
,
startTime
,
endTime
time
.
Time
)
(
*
UsageStats
,
error
)
{
func
(
s
*
UsageService
)
GetStatsByA
PI
Key
(
ctx
context
.
Context
,
apiKeyID
int64
,
startTime
,
endTime
time
.
Time
)
(
*
UsageStats
,
error
)
{
stats
,
err
:=
s
.
usageRepo
.
GetA
pi
KeyStatsAggregated
(
ctx
,
apiKeyID
,
startTime
,
endTime
)
stats
,
err
:=
s
.
usageRepo
.
GetA
PI
KeyStatsAggregated
(
ctx
,
apiKeyID
,
startTime
,
endTime
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get api key stats: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"get api key stats: %w"
,
err
)
}
}
...
@@ -270,9 +293,9 @@ func (s *UsageService) GetUserModelStats(ctx context.Context, userID int64, star
...
@@ -270,9 +293,9 @@ func (s *UsageService) GetUserModelStats(ctx context.Context, userID int64, star
return
stats
,
nil
return
stats
,
nil
}
}
// GetBatchA
pi
KeyUsageStats returns today/total actual_cost for given api keys.
// GetBatchA
PI
KeyUsageStats returns today/total actual_cost for given api keys.
func
(
s
*
UsageService
)
GetBatchA
pi
KeyUsageStats
(
ctx
context
.
Context
,
apiKeyIDs
[]
int64
)
(
map
[
int64
]
*
usagestats
.
BatchA
pi
KeyUsageStats
,
error
)
{
func
(
s
*
UsageService
)
GetBatchA
PI
KeyUsageStats
(
ctx
context
.
Context
,
apiKeyIDs
[]
int64
)
(
map
[
int64
]
*
usagestats
.
BatchA
PI
KeyUsageStats
,
error
)
{
stats
,
err
:=
s
.
usageRepo
.
GetBatchA
pi
KeyUsageStats
(
ctx
,
apiKeyIDs
)
stats
,
err
:=
s
.
usageRepo
.
GetBatchA
PI
KeyUsageStats
(
ctx
,
apiKeyIDs
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get batch api key usage stats: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"get batch api key usage stats: %w"
,
err
)
}
}
...
...
backend/internal/service/user.go
View file @
7dddd065
...
@@ -21,7 +21,7 @@ type User struct {
...
@@ -21,7 +21,7 @@ type User struct {
CreatedAt
time
.
Time
CreatedAt
time
.
Time
UpdatedAt
time
.
Time
UpdatedAt
time
.
Time
A
pi
Keys
[]
A
pi
Key
A
PI
Keys
[]
A
PI
Key
Subscriptions
[]
UserSubscription
Subscriptions
[]
UserSubscription
}
}
...
...
backend/internal/service/user_attribute_service.go
View file @
7dddd065
...
@@ -56,6 +56,10 @@ func (s *UserAttributeService) CreateDefinition(ctx context.Context, input Creat
...
@@ -56,6 +56,10 @@ func (s *UserAttributeService) CreateDefinition(ctx context.Context, input Creat
Enabled
:
input
.
Enabled
,
Enabled
:
input
.
Enabled
,
}
}
if
err
:=
validateDefinitionPattern
(
def
);
err
!=
nil
{
return
nil
,
err
}
if
err
:=
s
.
defRepo
.
Create
(
ctx
,
def
);
err
!=
nil
{
if
err
:=
s
.
defRepo
.
Create
(
ctx
,
def
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"create definition: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"create definition: %w"
,
err
)
}
}
...
@@ -108,6 +112,10 @@ func (s *UserAttributeService) UpdateDefinition(ctx context.Context, id int64, i
...
@@ -108,6 +112,10 @@ func (s *UserAttributeService) UpdateDefinition(ctx context.Context, id int64, i
def
.
Enabled
=
*
input
.
Enabled
def
.
Enabled
=
*
input
.
Enabled
}
}
if
err
:=
validateDefinitionPattern
(
def
);
err
!=
nil
{
return
nil
,
err
}
if
err
:=
s
.
defRepo
.
Update
(
ctx
,
def
);
err
!=
nil
{
if
err
:=
s
.
defRepo
.
Update
(
ctx
,
def
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"update definition: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"update definition: %w"
,
err
)
}
}
...
@@ -231,7 +239,10 @@ func (s *UserAttributeService) validateValue(def *UserAttributeDefinition, value
...
@@ -231,7 +239,10 @@ func (s *UserAttributeService) validateValue(def *UserAttributeDefinition, value
// Pattern validation
// Pattern validation
if
v
.
Pattern
!=
nil
&&
*
v
.
Pattern
!=
""
&&
value
!=
""
{
if
v
.
Pattern
!=
nil
&&
*
v
.
Pattern
!=
""
&&
value
!=
""
{
re
,
err
:=
regexp
.
Compile
(
*
v
.
Pattern
)
re
,
err
:=
regexp
.
Compile
(
*
v
.
Pattern
)
if
err
==
nil
&&
!
re
.
MatchString
(
value
)
{
if
err
!=
nil
{
return
validationError
(
def
.
Name
+
" has an invalid pattern"
)
}
if
!
re
.
MatchString
(
value
)
{
msg
:=
def
.
Name
+
" format is invalid"
msg
:=
def
.
Name
+
" format is invalid"
if
v
.
Message
!=
nil
&&
*
v
.
Message
!=
""
{
if
v
.
Message
!=
nil
&&
*
v
.
Message
!=
""
{
msg
=
*
v
.
Message
msg
=
*
v
.
Message
...
@@ -293,3 +304,20 @@ func isValidAttributeType(t UserAttributeType) bool {
...
@@ -293,3 +304,20 @@ func isValidAttributeType(t UserAttributeType) bool {
}
}
return
false
return
false
}
}
func
validateDefinitionPattern
(
def
*
UserAttributeDefinition
)
error
{
if
def
==
nil
{
return
nil
}
if
def
.
Validation
.
Pattern
==
nil
{
return
nil
}
pattern
:=
strings
.
TrimSpace
(
*
def
.
Validation
.
Pattern
)
if
pattern
==
""
{
return
nil
}
if
_
,
err
:=
regexp
.
Compile
(
pattern
);
err
!=
nil
{
return
infraerrors
.
BadRequest
(
"INVALID_ATTRIBUTE_PATTERN"
,
fmt
.
Sprintf
(
"invalid pattern for %s: %v"
,
def
.
Name
,
err
))
}
return
nil
}
backend/internal/service/wire.go
View file @
7dddd065
...
@@ -75,7 +75,7 @@ var ProviderSet = wire.NewSet(
...
@@ -75,7 +75,7 @@ var ProviderSet = wire.NewSet(
// Core services
// Core services
NewAuthService
,
NewAuthService
,
NewUserService
,
NewUserService
,
NewA
pi
KeyService
,
NewA
PI
KeyService
,
NewGroupService
,
NewGroupService
,
NewAccountService
,
NewAccountService
,
NewProxyService
,
NewProxyService
,
...
...
backend/internal/setup/cli.go
View file @
7dddd065
// Package setup provides CLI commands and application initialization helpers.
package
setup
package
setup
import
(
import
(
...
...
backend/internal/setup/setup.go
View file @
7dddd065
...
@@ -352,7 +352,7 @@ func writeConfigFile(cfg *SetupConfig) error {
...
@@ -352,7 +352,7 @@ func writeConfigFile(cfg *SetupConfig) error {
Default
struct
{
Default
struct
{
UserConcurrency
int
`yaml:"user_concurrency"`
UserConcurrency
int
`yaml:"user_concurrency"`
UserBalance
float64
`yaml:"user_balance"`
UserBalance
float64
`yaml:"user_balance"`
A
pi
KeyPrefix
string
`yaml:"api_key_prefix"`
A
PI
KeyPrefix
string
`yaml:"api_key_prefix"`
RateMultiplier
float64
`yaml:"rate_multiplier"`
RateMultiplier
float64
`yaml:"rate_multiplier"`
}
`yaml:"default"`
}
`yaml:"default"`
RateLimit
struct
{
RateLimit
struct
{
...
@@ -374,12 +374,12 @@ func writeConfigFile(cfg *SetupConfig) error {
...
@@ -374,12 +374,12 @@ func writeConfigFile(cfg *SetupConfig) error {
Default
:
struct
{
Default
:
struct
{
UserConcurrency
int
`yaml:"user_concurrency"`
UserConcurrency
int
`yaml:"user_concurrency"`
UserBalance
float64
`yaml:"user_balance"`
UserBalance
float64
`yaml:"user_balance"`
A
pi
KeyPrefix
string
`yaml:"api_key_prefix"`
A
PI
KeyPrefix
string
`yaml:"api_key_prefix"`
RateMultiplier
float64
`yaml:"rate_multiplier"`
RateMultiplier
float64
`yaml:"rate_multiplier"`
}{
}{
UserConcurrency
:
5
,
UserConcurrency
:
5
,
UserBalance
:
0
,
UserBalance
:
0
,
A
pi
KeyPrefix
:
"sk-"
,
A
PI
KeyPrefix
:
"sk-"
,
RateMultiplier
:
1.0
,
RateMultiplier
:
1.0
,
},
},
RateLimit
:
struct
{
RateLimit
:
struct
{
...
...
backend/internal/web/embed_off.go
View file @
7dddd065
//go:build !embed
//go:build !embed
// Package web provides embedded web assets for the application.
package
web
package
web
import
(
import
(
...
...
backend/migrations/020_add_temp_unschedulable.sql
0 → 100644
View file @
7dddd065
-- 020_add_temp_unschedulable.sql
-- 添加临时不可调度功能相关字段
-- 添加临时不可调度状态解除时间字段
ALTER
TABLE
accounts
ADD
COLUMN
IF
NOT
EXISTS
temp_unschedulable_until
timestamptz
;
-- 添加临时不可调度原因字段(用于排障和审计)
ALTER
TABLE
accounts
ADD
COLUMN
IF
NOT
EXISTS
temp_unschedulable_reason
text
;
-- 添加索引以优化调度查询性能
CREATE
INDEX
IF
NOT
EXISTS
idx_accounts_temp_unschedulable_until
ON
accounts
(
temp_unschedulable_until
)
WHERE
deleted_at
IS
NULL
;
-- 添加注释说明字段用途
COMMENT
ON
COLUMN
accounts
.
temp_unschedulable_until
IS
'临时不可调度状态解除时间,当触发临时不可调度规则时设置(基于错误码或错误描述关键词)'
;
COMMENT
ON
COLUMN
accounts
.
temp_unschedulable_reason
IS
'临时不可调度原因,记录触发临时不可调度的具体原因(用于排障和审计)'
;
backend/migrations/026_ops_metrics_aggregation_tables.sql
0 → 100644
View file @
7dddd065
-- Ops monitoring: pre-aggregation tables for dashboard queries
--
-- Problem:
-- The ops dashboard currently runs percentile_cont + GROUP BY queries over large raw tables
-- (usage_logs, ops_error_logs). These will get slower as data grows.
--
-- This migration adds schema-only aggregation tables that can be populated by a future background job.
-- No triggers/functions/jobs are created here (schema only).
-- ============================================
-- Hourly aggregates (per provider/platform)
-- ============================================
CREATE
TABLE
IF
NOT
EXISTS
ops_metrics_hourly
(
-- Start of the hour bucket (recommended: UTC).
bucket_start
TIMESTAMPTZ
NOT
NULL
,
-- Provider/platform label (e.g. anthropic/openai/gemini). Mirrors ops_* queries that GROUP BY platform.
platform
VARCHAR
(
50
)
NOT
NULL
,
-- Traffic counts (use these to compute rates reliably across ranges).
request_count
BIGINT
NOT
NULL
DEFAULT
0
,
success_count
BIGINT
NOT
NULL
DEFAULT
0
,
error_count
BIGINT
NOT
NULL
DEFAULT
0
,
-- Error breakdown used by provider health UI.
error_4xx_count
BIGINT
NOT
NULL
DEFAULT
0
,
error_5xx_count
BIGINT
NOT
NULL
DEFAULT
0
,
timeout_count
BIGINT
NOT
NULL
DEFAULT
0
,
-- Latency aggregates (ms).
avg_latency_ms
DOUBLE
PRECISION
,
p99_latency_ms
DOUBLE
PRECISION
,
-- Convenience rate (percentage, 0-100). Still keep counts as source of truth.
error_rate
DOUBLE
PRECISION
NOT
NULL
DEFAULT
0
,
-- When this row was last (re)computed by the background job.
computed_at
TIMESTAMPTZ
NOT
NULL
DEFAULT
NOW
(),
PRIMARY
KEY
(
bucket_start
,
platform
)
);
CREATE
INDEX
IF
NOT
EXISTS
idx_ops_metrics_hourly_platform_bucket_start
ON
ops_metrics_hourly
(
platform
,
bucket_start
DESC
);
COMMENT
ON
TABLE
ops_metrics_hourly
IS
'Pre-aggregated hourly ops metrics by provider/platform to speed up dashboard queries.'
;
COMMENT
ON
COLUMN
ops_metrics_hourly
.
bucket_start
IS
'Start timestamp of the hour bucket (recommended UTC).'
;
COMMENT
ON
COLUMN
ops_metrics_hourly
.
platform
IS
'Provider/platform label (anthropic/openai/gemini, etc).'
;
COMMENT
ON
COLUMN
ops_metrics_hourly
.
error_rate
IS
'Error rate percentage for the bucket (0-100). Counts remain the source of truth.'
;
COMMENT
ON
COLUMN
ops_metrics_hourly
.
computed_at
IS
'When the row was last computed/refreshed.'
;
-- ============================================
-- Daily aggregates (per provider/platform)
-- ============================================
CREATE
TABLE
IF
NOT
EXISTS
ops_metrics_daily
(
-- Day bucket (recommended: UTC date).
bucket_date
DATE
NOT
NULL
,
platform
VARCHAR
(
50
)
NOT
NULL
,
request_count
BIGINT
NOT
NULL
DEFAULT
0
,
success_count
BIGINT
NOT
NULL
DEFAULT
0
,
error_count
BIGINT
NOT
NULL
DEFAULT
0
,
error_4xx_count
BIGINT
NOT
NULL
DEFAULT
0
,
error_5xx_count
BIGINT
NOT
NULL
DEFAULT
0
,
timeout_count
BIGINT
NOT
NULL
DEFAULT
0
,
avg_latency_ms
DOUBLE
PRECISION
,
p99_latency_ms
DOUBLE
PRECISION
,
error_rate
DOUBLE
PRECISION
NOT
NULL
DEFAULT
0
,
computed_at
TIMESTAMPTZ
NOT
NULL
DEFAULT
NOW
(),
PRIMARY
KEY
(
bucket_date
,
platform
)
);
CREATE
INDEX
IF
NOT
EXISTS
idx_ops_metrics_daily_platform_bucket_date
ON
ops_metrics_daily
(
platform
,
bucket_date
DESC
);
COMMENT
ON
TABLE
ops_metrics_daily
IS
'Pre-aggregated daily ops metrics by provider/platform for longer-term trends.'
;
COMMENT
ON
COLUMN
ops_metrics_daily
.
bucket_date
IS
'UTC date of the day bucket (recommended).'
;
-- ============================================
-- Population strategy (future background job)
-- ============================================
--
-- Suggested approach:
-- 1) Compute hourly buckets from raw logs using UTC time-bucketing, then UPSERT into ops_metrics_hourly.
-- 2) Compute daily buckets either directly from raw logs or by rolling up ops_metrics_hourly.
--
-- Notes:
-- - Ensure the job uses a consistent timezone (recommended: SET TIME ZONE ''UTC'') to avoid bucket drift.
-- - Derive the provider/platform similarly to existing dashboard queries:
-- usage_logs: COALESCE(NULLIF(groups.platform, ''), accounts.platform, '')
-- ops_error_logs: COALESCE(NULLIF(ops_error_logs.platform, ''), groups.platform, accounts.platform, '')
-- - Keep request_count/success_count/error_count as the authoritative values; compute error_rate from counts.
--
-- Example (hourly) shape (pseudo-SQL):
-- INSERT INTO ops_metrics_hourly (...)
-- SELECT date_trunc('hour', created_at) AS bucket_start, platform, ...
-- FROM (/* aggregate usage_logs + ops_error_logs */) s
-- ON CONFLICT (bucket_start, platform) DO UPDATE SET ...;
backend/migrations/027_usage_billing_consistency.sql
0 → 100644
View file @
7dddd065
-- 027_usage_billing_consistency.sql
-- Ensure usage_logs idempotency (request_id, api_key_id) and add reconciliation infrastructure.
-- -----------------------------------------------------------------------------
-- 1) Normalize legacy request_id values
-- -----------------------------------------------------------------------------
-- Historically request_id may be inserted as empty string. Convert it to NULL so
-- the upcoming unique index does not break on repeated "" values.
UPDATE
usage_logs
SET
request_id
=
NULL
WHERE
request_id
=
''
;
-- If duplicates already exist for the same (request_id, api_key_id), keep the
-- first row and NULL-out request_id for the rest so the unique index can be
-- created without deleting historical logs.
WITH
ranked
AS
(
SELECT
id
,
ROW_NUMBER
()
OVER
(
PARTITION
BY
api_key_id
,
request_id
ORDER
BY
id
)
AS
rn
FROM
usage_logs
WHERE
request_id
IS
NOT
NULL
)
UPDATE
usage_logs
ul
SET
request_id
=
NULL
FROM
ranked
r
WHERE
ul
.
id
=
r
.
id
AND
r
.
rn
>
1
;
-- -----------------------------------------------------------------------------
-- 2) Idempotency constraint for usage_logs
-- -----------------------------------------------------------------------------
CREATE
UNIQUE
INDEX
IF
NOT
EXISTS
idx_usage_logs_request_id_api_key_unique
ON
usage_logs
(
request_id
,
api_key_id
);
-- -----------------------------------------------------------------------------
-- 3) Reconciliation infrastructure: billing ledger for usage charges
-- -----------------------------------------------------------------------------
CREATE
TABLE
IF
NOT
EXISTS
billing_usage_entries
(
id
BIGSERIAL
PRIMARY
KEY
,
usage_log_id
BIGINT
NOT
NULL
REFERENCES
usage_logs
(
id
)
ON
DELETE
CASCADE
,
user_id
BIGINT
NOT
NULL
REFERENCES
users
(
id
)
ON
DELETE
CASCADE
,
api_key_id
BIGINT
NOT
NULL
REFERENCES
api_keys
(
id
)
ON
DELETE
CASCADE
,
subscription_id
BIGINT
REFERENCES
user_subscriptions
(
id
)
ON
DELETE
SET
NULL
,
billing_type
SMALLINT
NOT
NULL
,
applied
BOOLEAN
NOT
NULL
DEFAULT
TRUE
,
delta_usd
DECIMAL
(
20
,
10
)
NOT
NULL
DEFAULT
0
,
created_at
TIMESTAMPTZ
NOT
NULL
DEFAULT
NOW
()
);
CREATE
UNIQUE
INDEX
IF
NOT
EXISTS
billing_usage_entries_usage_log_id_unique
ON
billing_usage_entries
(
usage_log_id
);
CREATE
INDEX
IF
NOT
EXISTS
idx_billing_usage_entries_user_time
ON
billing_usage_entries
(
user_id
,
created_at
);
CREATE
INDEX
IF
NOT
EXISTS
idx_billing_usage_entries_created_at
ON
billing_usage_entries
(
created_at
);
Prev
1
…
4
5
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment