Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
7dddd065
Commit
7dddd065
authored
Jan 04, 2026
by
yangjianbo
Browse files
Merge branch 'main' of
https://github.com/mt21625457/aicodex2api
parents
25a0d49a
e78c8646
Changes
183
Show whitespace changes
Inline
Side-by-side
backend/internal/service/gemini_oauth_service_test.go
View file @
7dddd065
package
service
import
"testing"
import
(
"context"
"net/url"
"strings"
"testing"
func
TestInferGoogleOneTier
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
)
func
TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy
(
t
*
testing
.
T
)
{
t
.
Parallel
()
type
testCase
struct
{
name
string
storageBytes
int64
expectedTier
string
}{
{
"Negative storage"
,
-
1
,
TierGoogleOneUnknown
},
{
"Zero storage"
,
0
,
TierGoogleOneUnknown
},
// Free tier boundary (15GB)
{
"Below free tier"
,
10
*
GB
,
TierGoogleOneUnknown
},
{
"Just below free tier"
,
StorageTierFree
-
1
,
TierGoogleOneUnknown
},
{
"Free tier (15GB)"
,
StorageTierFree
,
TierFree
},
// Basic tier boundary (100GB)
{
"Between free and basic"
,
50
*
GB
,
TierFree
},
{
"Just below basic tier"
,
StorageTierBasic
-
1
,
TierFree
},
{
"Basic tier (100GB)"
,
StorageTierBasic
,
TierGoogleOneBasic
},
// Standard tier boundary (200GB)
{
"Between basic and standard"
,
150
*
GB
,
TierGoogleOneBasic
},
{
"Just below standard tier"
,
StorageTierStandard
-
1
,
TierGoogleOneBasic
},
{
"Standard tier (200GB)"
,
StorageTierStandard
,
TierGoogleOneStandard
},
// AI Premium tier boundary (2TB)
{
"Between standard and premium"
,
1
*
TB
,
TierGoogleOneStandard
},
{
"Just below AI Premium tier"
,
StorageTierAIPremium
-
1
,
TierGoogleOneStandard
},
{
"AI Premium tier (2TB)"
,
StorageTierAIPremium
,
TierAIPremium
},
// Unlimited tier boundary (> 100TB)
{
"Between premium and unlimited"
,
50
*
TB
,
TierAIPremium
},
{
"At unlimited threshold (100TB)"
,
StorageTierUnlimited
,
TierAIPremium
},
{
"Unlimited tier (100TB+)"
,
StorageTierUnlimited
+
1
,
TierGoogleOneUnlimited
},
{
"Unlimited tier (101TB+)"
,
101
*
TB
,
TierGoogleOneUnlimited
},
{
"Very large storage"
,
1000
*
TB
,
TierGoogleOneUnlimited
},
cfg
*
config
.
Config
oauthType
string
projectID
string
wantClientID
string
wantRedirect
string
wantScope
string
wantProjectID
string
wantErrSubstr
string
}
tests
:=
[]
testCase
{
{
name
:
"google_one uses built-in client when not configured and redirects to upstream"
,
cfg
:
&
config
.
Config
{
Gemini
:
config
.
GeminiConfig
{
OAuth
:
config
.
GeminiOAuthConfig
{},
},
},
oauthType
:
"google_one"
,
wantClientID
:
geminicli
.
GeminiCLIOAuthClientID
,
wantRedirect
:
geminicli
.
GeminiCLIRedirectURI
,
wantScope
:
geminicli
.
DefaultCodeAssistScopes
,
wantProjectID
:
""
,
},
{
name
:
"google_one uses custom client when configured and redirects to localhost"
,
cfg
:
&
config
.
Config
{
Gemini
:
config
.
GeminiConfig
{
OAuth
:
config
.
GeminiOAuthConfig
{
ClientID
:
"custom-client-id"
,
ClientSecret
:
"custom-client-secret"
,
},
},
},
oauthType
:
"google_one"
,
wantClientID
:
"custom-client-id"
,
wantRedirect
:
geminicli
.
AIStudioOAuthRedirectURI
,
wantScope
:
geminicli
.
DefaultGoogleOneScopes
,
wantProjectID
:
""
,
},
{
name
:
"code_assist always forces built-in client even when custom client configured"
,
cfg
:
&
config
.
Config
{
Gemini
:
config
.
GeminiConfig
{
OAuth
:
config
.
GeminiOAuthConfig
{
ClientID
:
"custom-client-id"
,
ClientSecret
:
"custom-client-secret"
,
},
},
},
oauthType
:
"code_assist"
,
projectID
:
"my-gcp-project"
,
wantClientID
:
geminicli
.
GeminiCLIOAuthClientID
,
wantRedirect
:
geminicli
.
GeminiCLIRedirectURI
,
wantScope
:
geminicli
.
DefaultCodeAssistScopes
,
wantProjectID
:
"my-gcp-project"
,
},
{
name
:
"ai_studio requires custom client"
,
cfg
:
&
config
.
Config
{
Gemini
:
config
.
GeminiConfig
{
OAuth
:
config
.
GeminiOAuthConfig
{},
},
},
oauthType
:
"ai_studio"
,
wantErrSubstr
:
"AI Studio OAuth requires a custom OAuth Client"
,
},
}
for
_
,
tt
:=
range
tests
{
tt
:=
tt
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
result
:=
inferGoogleOneTier
(
tt
.
storageBytes
)
if
result
!=
tt
.
expectedTier
{
t
.
Errorf
(
"inferGoogleOneTier(%d) = %s, want %s"
,
tt
.
storageBytes
,
result
,
tt
.
expectedTier
)
t
.
Parallel
()
svc
:=
NewGeminiOAuthService
(
nil
,
nil
,
nil
,
tt
.
cfg
)
got
,
err
:=
svc
.
GenerateAuthURL
(
context
.
Background
(),
nil
,
"https://example.com/auth/callback"
,
tt
.
projectID
,
tt
.
oauthType
,
""
)
if
tt
.
wantErrSubstr
!=
""
{
if
err
==
nil
{
t
.
Fatalf
(
"expected error containing %q, got nil"
,
tt
.
wantErrSubstr
)
}
if
!
strings
.
Contains
(
err
.
Error
(),
tt
.
wantErrSubstr
)
{
t
.
Fatalf
(
"expected error containing %q, got: %v"
,
tt
.
wantErrSubstr
,
err
)
}
return
}
if
err
!=
nil
{
t
.
Fatalf
(
"GenerateAuthURL returned error: %v"
,
err
)
}
parsed
,
err
:=
url
.
Parse
(
got
.
AuthURL
)
if
err
!=
nil
{
t
.
Fatalf
(
"failed to parse auth_url: %v"
,
err
)
}
q
:=
parsed
.
Query
()
if
gotState
:=
q
.
Get
(
"state"
);
gotState
!=
got
.
State
{
t
.
Fatalf
(
"state mismatch: query=%q result=%q"
,
gotState
,
got
.
State
)
}
if
gotClientID
:=
q
.
Get
(
"client_id"
);
gotClientID
!=
tt
.
wantClientID
{
t
.
Fatalf
(
"client_id mismatch: got=%q want=%q"
,
gotClientID
,
tt
.
wantClientID
)
}
if
gotRedirect
:=
q
.
Get
(
"redirect_uri"
);
gotRedirect
!=
tt
.
wantRedirect
{
t
.
Fatalf
(
"redirect_uri mismatch: got=%q want=%q"
,
gotRedirect
,
tt
.
wantRedirect
)
}
if
gotScope
:=
q
.
Get
(
"scope"
);
gotScope
!=
tt
.
wantScope
{
t
.
Fatalf
(
"scope mismatch: got=%q want=%q"
,
gotScope
,
tt
.
wantScope
)
}
if
gotProjectID
:=
q
.
Get
(
"project_id"
);
gotProjectID
!=
tt
.
wantProjectID
{
t
.
Fatalf
(
"project_id mismatch: got=%q want=%q"
,
gotProjectID
,
tt
.
wantProjectID
)
}
})
}
...
...
backend/internal/service/gemini_quota.go
View file @
7dddd065
...
...
@@ -20,13 +20,24 @@ const (
geminiModelFlash
geminiModelClass
=
"flash"
)
type
GeminiDailyQuota
struct
{
ProRPD
int64
FlashRPD
int64
type
GeminiQuota
struct
{
// SharedRPD is a shared requests-per-day pool across models.
// When SharedRPD > 0, callers should treat ProRPD/FlashRPD as not applicable for daily quota checks.
SharedRPD
int64
`json:"shared_rpd,omitempty"`
// SharedRPM is a shared requests-per-minute pool across models.
// When SharedRPM > 0, callers should treat ProRPM/FlashRPM as not applicable for minute quota checks.
SharedRPM
int64
`json:"shared_rpm,omitempty"`
// Per-model quotas (AI Studio / API key).
// A value of -1 means "unlimited" (pay-as-you-go).
ProRPD
int64
`json:"pro_rpd,omitempty"`
ProRPM
int64
`json:"pro_rpm,omitempty"`
FlashRPD
int64
`json:"flash_rpd,omitempty"`
FlashRPM
int64
`json:"flash_rpm,omitempty"`
}
type
GeminiTierPolicy
struct
{
Quota
Gemini
Daily
Quota
Quota
GeminiQuota
Cooldown
time
.
Duration
}
...
...
@@ -45,10 +56,27 @@ type GeminiUsageTotals struct {
const
geminiQuotaCacheTTL
=
time
.
Minute
type
geminiQuotaOverrides
struct
{
type
geminiQuotaOverrides
V1
struct
{
Tiers
map
[
string
]
config
.
GeminiTierQuotaConfig
`json:"tiers"`
}
type
geminiQuotaOverridesV2
struct
{
QuotaRules
map
[
string
]
geminiQuotaRuleOverride
`json:"quota_rules"`
}
type
geminiQuotaRuleOverride
struct
{
SharedRPD
*
int64
`json:"shared_rpd,omitempty"`
SharedRPM
*
int64
`json:"rpm,omitempty"`
GeminiPro
*
geminiModelQuotaOverride
`json:"gemini_pro,omitempty"`
GeminiFlash
*
geminiModelQuotaOverride
`json:"gemini_flash,omitempty"`
Desc
*
string
`json:"desc,omitempty"`
}
type
geminiModelQuotaOverride
struct
{
RPD
*
int64
`json:"rpd,omitempty"`
RPM
*
int64
`json:"rpm,omitempty"`
}
type
GeminiQuotaService
struct
{
cfg
*
config
.
Config
settingRepo
SettingRepository
...
...
@@ -82,11 +110,17 @@ func (s *GeminiQuotaService) Policy(ctx context.Context) *GeminiQuotaPolicy {
if
s
.
cfg
!=
nil
{
policy
.
ApplyOverrides
(
s
.
cfg
.
Gemini
.
Quota
.
Tiers
)
if
strings
.
TrimSpace
(
s
.
cfg
.
Gemini
.
Quota
.
Policy
)
!=
""
{
var
overrides
geminiQuotaOverrides
if
err
:=
json
.
Unmarshal
([]
byte
(
s
.
cfg
.
Gemini
.
Quota
.
Policy
),
&
overrides
);
err
!=
nil
{
raw
:=
[]
byte
(
s
.
cfg
.
Gemini
.
Quota
.
Policy
)
var
overridesV2
geminiQuotaOverridesV2
if
err
:=
json
.
Unmarshal
(
raw
,
&
overridesV2
);
err
==
nil
&&
len
(
overridesV2
.
QuotaRules
)
>
0
{
policy
.
ApplyQuotaRulesOverrides
(
overridesV2
.
QuotaRules
)
}
else
{
var
overridesV1
geminiQuotaOverridesV1
if
err
:=
json
.
Unmarshal
(
raw
,
&
overridesV1
);
err
!=
nil
{
log
.
Printf
(
"gemini quota: parse config policy failed: %v"
,
err
)
}
else
{
policy
.
ApplyOverrides
(
overrides
.
Tiers
)
policy
.
ApplyOverrides
(
overridesV1
.
Tiers
)
}
}
}
}
...
...
@@ -96,11 +130,17 @@ func (s *GeminiQuotaService) Policy(ctx context.Context) *GeminiQuotaPolicy {
if
err
!=
nil
&&
!
errors
.
Is
(
err
,
ErrSettingNotFound
)
{
log
.
Printf
(
"gemini quota: load setting failed: %v"
,
err
)
}
else
if
strings
.
TrimSpace
(
value
)
!=
""
{
var
overrides
geminiQuotaOverrides
if
err
:=
json
.
Unmarshal
([]
byte
(
value
),
&
overrides
);
err
!=
nil
{
raw
:=
[]
byte
(
value
)
var
overridesV2
geminiQuotaOverridesV2
if
err
:=
json
.
Unmarshal
(
raw
,
&
overridesV2
);
err
==
nil
&&
len
(
overridesV2
.
QuotaRules
)
>
0
{
policy
.
ApplyQuotaRulesOverrides
(
overridesV2
.
QuotaRules
)
}
else
{
var
overridesV1
geminiQuotaOverridesV1
if
err
:=
json
.
Unmarshal
(
raw
,
&
overridesV1
);
err
!=
nil
{
log
.
Printf
(
"gemini quota: parse setting failed: %v"
,
err
)
}
else
{
policy
.
ApplyOverrides
(
overrides
.
Tiers
)
policy
.
ApplyOverrides
(
overridesV1
.
Tiers
)
}
}
}
}
...
...
@@ -113,12 +153,20 @@ func (s *GeminiQuotaService) Policy(ctx context.Context) *GeminiQuotaPolicy {
return
policy
}
func
(
s
*
GeminiQuotaService
)
QuotaForAccount
(
ctx
context
.
Context
,
account
*
Account
)
(
GeminiDailyQuota
,
bool
)
{
if
account
==
nil
||
!
account
.
IsGeminiCodeAssist
()
{
return
GeminiDailyQuota
{},
false
func
(
s
*
GeminiQuotaService
)
QuotaForAccount
(
ctx
context
.
Context
,
account
*
Account
)
(
GeminiQuota
,
bool
)
{
if
account
==
nil
||
account
.
Platform
!=
PlatformGemini
{
return
GeminiQuota
{},
false
}
// Map (oauth_type + tier_id) to a canonical policy tier key.
// This keeps the policy table stable even if upstream tier_id strings vary.
tierKey
:=
geminiQuotaTierKeyForAccount
(
account
)
if
tierKey
==
""
{
return
GeminiQuota
{},
false
}
policy
:=
s
.
Policy
(
ctx
)
return
policy
.
QuotaForTier
(
account
.
GeminiTierID
()
)
return
policy
.
QuotaForTier
(
tierKey
)
}
func
(
s
*
GeminiQuotaService
)
CooldownForTier
(
ctx
context
.
Context
,
tierID
string
)
time
.
Duration
{
...
...
@@ -126,12 +174,36 @@ func (s *GeminiQuotaService) CooldownForTier(ctx context.Context, tierID string)
return
policy
.
CooldownForTier
(
tierID
)
}
func
(
s
*
GeminiQuotaService
)
CooldownForAccount
(
ctx
context
.
Context
,
account
*
Account
)
time
.
Duration
{
if
s
==
nil
||
account
==
nil
||
account
.
Platform
!=
PlatformGemini
{
return
5
*
time
.
Minute
}
tierKey
:=
geminiQuotaTierKeyForAccount
(
account
)
if
strings
.
TrimSpace
(
tierKey
)
==
""
{
return
5
*
time
.
Minute
}
return
s
.
CooldownForTier
(
ctx
,
tierKey
)
}
func
newGeminiQuotaPolicy
()
*
GeminiQuotaPolicy
{
return
&
GeminiQuotaPolicy
{
tiers
:
map
[
string
]
GeminiTierPolicy
{
"LEGACY"
:
{
Quota
:
GeminiDailyQuota
{
ProRPD
:
50
,
FlashRPD
:
1500
},
Cooldown
:
30
*
time
.
Minute
},
"PRO"
:
{
Quota
:
GeminiDailyQuota
{
ProRPD
:
1500
,
FlashRPD
:
4000
},
Cooldown
:
5
*
time
.
Minute
},
"ULTRA"
:
{
Quota
:
GeminiDailyQuota
{
ProRPD
:
2000
,
FlashRPD
:
0
},
Cooldown
:
5
*
time
.
Minute
},
// --- AI Studio / API Key (per-model) ---
// aistudio_free:
// - gemini_pro: 50 RPD / 2 RPM
// - gemini_flash: 1500 RPD / 15 RPM
GeminiTierAIStudioFree
:
{
Quota
:
GeminiQuota
{
ProRPD
:
50
,
ProRPM
:
2
,
FlashRPD
:
1500
,
FlashRPM
:
15
},
Cooldown
:
30
*
time
.
Minute
},
// aistudio_paid: -1 means "unlimited/pay-as-you-go" for RPD.
GeminiTierAIStudioPaid
:
{
Quota
:
GeminiQuota
{
ProRPD
:
-
1
,
ProRPM
:
1000
,
FlashRPD
:
-
1
,
FlashRPM
:
2000
},
Cooldown
:
5
*
time
.
Minute
},
// --- Google One (shared pool) ---
GeminiTierGoogleOneFree
:
{
Quota
:
GeminiQuota
{
SharedRPD
:
1000
,
SharedRPM
:
60
},
Cooldown
:
30
*
time
.
Minute
},
GeminiTierGoogleAIPro
:
{
Quota
:
GeminiQuota
{
SharedRPD
:
1500
,
SharedRPM
:
120
},
Cooldown
:
5
*
time
.
Minute
},
GeminiTierGoogleAIUltra
:
{
Quota
:
GeminiQuota
{
SharedRPD
:
2000
,
SharedRPM
:
120
},
Cooldown
:
5
*
time
.
Minute
},
// --- GCP Code Assist (shared pool) ---
GeminiTierGCPStandard
:
{
Quota
:
GeminiQuota
{
SharedRPD
:
1500
,
SharedRPM
:
120
},
Cooldown
:
5
*
time
.
Minute
},
GeminiTierGCPEnterprise
:
{
Quota
:
GeminiQuota
{
SharedRPD
:
2000
,
SharedRPM
:
120
},
Cooldown
:
5
*
time
.
Minute
},
},
}
}
...
...
@@ -149,11 +221,22 @@ func (p *GeminiQuotaPolicy) ApplyOverrides(tiers map[string]config.GeminiTierQuo
if
!
ok
{
policy
=
GeminiTierPolicy
{
Cooldown
:
5
*
time
.
Minute
}
}
// Backward-compatible overrides:
// - If the tier uses shared quota, interpret pro_rpd as shared_rpd.
// - Otherwise apply per-model overrides.
if
override
.
ProRPD
!=
nil
{
policy
.
Quota
.
ProRPD
=
clampGeminiQuotaInt64
(
*
override
.
ProRPD
)
if
policy
.
Quota
.
SharedRPD
>
0
{
policy
.
Quota
.
SharedRPD
=
clampGeminiQuotaInt64WithUnlimited
(
*
override
.
ProRPD
)
}
else
{
policy
.
Quota
.
ProRPD
=
clampGeminiQuotaInt64WithUnlimited
(
*
override
.
ProRPD
)
}
}
if
override
.
FlashRPD
!=
nil
{
policy
.
Quota
.
FlashRPD
=
clampGeminiQuotaInt64
(
*
override
.
FlashRPD
)
if
policy
.
Quota
.
SharedRPD
>
0
{
// No separate flash RPD for shared tiers.
}
else
{
policy
.
Quota
.
FlashRPD
=
clampGeminiQuotaInt64WithUnlimited
(
*
override
.
FlashRPD
)
}
}
if
override
.
CooldownMinutes
!=
nil
{
minutes
:=
clampGeminiQuotaInt
(
*
override
.
CooldownMinutes
)
...
...
@@ -163,10 +246,51 @@ func (p *GeminiQuotaPolicy) ApplyOverrides(tiers map[string]config.GeminiTierQuo
}
}
func
(
p
*
GeminiQuotaPolicy
)
QuotaForTier
(
tierID
string
)
(
GeminiDailyQuota
,
bool
)
{
func
(
p
*
GeminiQuotaPolicy
)
ApplyQuotaRulesOverrides
(
rules
map
[
string
]
geminiQuotaRuleOverride
)
{
if
p
==
nil
||
len
(
rules
)
==
0
{
return
}
for
rawID
,
override
:=
range
rules
{
tierID
:=
normalizeGeminiTierID
(
rawID
)
if
tierID
==
""
{
continue
}
policy
,
ok
:=
p
.
tiers
[
tierID
]
if
!
ok
{
policy
=
GeminiTierPolicy
{
Cooldown
:
5
*
time
.
Minute
}
}
if
override
.
SharedRPD
!=
nil
{
policy
.
Quota
.
SharedRPD
=
clampGeminiQuotaInt64WithUnlimited
(
*
override
.
SharedRPD
)
}
if
override
.
SharedRPM
!=
nil
{
policy
.
Quota
.
SharedRPM
=
clampGeminiQuotaRPM
(
*
override
.
SharedRPM
)
}
if
override
.
GeminiPro
!=
nil
{
if
override
.
GeminiPro
.
RPD
!=
nil
{
policy
.
Quota
.
ProRPD
=
clampGeminiQuotaInt64WithUnlimited
(
*
override
.
GeminiPro
.
RPD
)
}
if
override
.
GeminiPro
.
RPM
!=
nil
{
policy
.
Quota
.
ProRPM
=
clampGeminiQuotaRPM
(
*
override
.
GeminiPro
.
RPM
)
}
}
if
override
.
GeminiFlash
!=
nil
{
if
override
.
GeminiFlash
.
RPD
!=
nil
{
policy
.
Quota
.
FlashRPD
=
clampGeminiQuotaInt64WithUnlimited
(
*
override
.
GeminiFlash
.
RPD
)
}
if
override
.
GeminiFlash
.
RPM
!=
nil
{
policy
.
Quota
.
FlashRPM
=
clampGeminiQuotaRPM
(
*
override
.
GeminiFlash
.
RPM
)
}
}
p
.
tiers
[
tierID
]
=
policy
}
}
func
(
p
*
GeminiQuotaPolicy
)
QuotaForTier
(
tierID
string
)
(
GeminiQuota
,
bool
)
{
policy
,
ok
:=
p
.
policyForTier
(
tierID
)
if
!
ok
{
return
Gemini
Daily
Quota
{},
false
return
GeminiQuota
{},
false
}
return
policy
.
Quota
,
true
}
...
...
@@ -184,22 +308,43 @@ func (p *GeminiQuotaPolicy) policyForTier(tierID string) (GeminiTierPolicy, bool
return
GeminiTierPolicy
{},
false
}
normalized
:=
normalizeGeminiTierID
(
tierID
)
if
normalized
==
""
{
normalized
=
"LEGACY"
}
if
policy
,
ok
:=
p
.
tiers
[
normalized
];
ok
{
return
policy
,
true
}
policy
,
ok
:=
p
.
tiers
[
"LEGACY"
]
return
policy
,
ok
return
GeminiTierPolicy
{},
false
}
func
normalizeGeminiTierID
(
tierID
string
)
string
{
return
strings
.
ToUpper
(
strings
.
TrimSpace
(
tierID
))
tierID
=
strings
.
TrimSpace
(
tierID
)
if
tierID
==
""
{
return
""
}
// Prefer canonical mapping (handles legacy tier strings).
if
canonical
:=
canonicalGeminiTierID
(
tierID
);
canonical
!=
""
{
return
canonical
}
// Accept older policy keys that used uppercase names.
switch
strings
.
ToUpper
(
tierID
)
{
case
"AISTUDIO_FREE"
:
return
GeminiTierAIStudioFree
case
"AISTUDIO_PAID"
:
return
GeminiTierAIStudioPaid
case
"GOOGLE_ONE_FREE"
:
return
GeminiTierGoogleOneFree
case
"GOOGLE_AI_PRO"
:
return
GeminiTierGoogleAIPro
case
"GOOGLE_AI_ULTRA"
:
return
GeminiTierGoogleAIUltra
case
"GCP_STANDARD"
:
return
GeminiTierGCPStandard
case
"GCP_ENTERPRISE"
:
return
GeminiTierGCPEnterprise
}
return
strings
.
ToLower
(
tierID
)
}
func
clampGeminiQuotaInt64
(
value
int64
)
int64
{
if
value
<
0
{
func
clampGeminiQuotaInt64
WithUnlimited
(
value
int64
)
int64
{
if
value
<
-
1
{
return
0
}
return
value
...
...
@@ -212,11 +357,46 @@ func clampGeminiQuotaInt(value int) int {
return
value
}
func
clampGeminiQuotaRPM
(
value
int64
)
int64
{
if
value
<
0
{
return
0
}
return
value
}
func
geminiCooldownForTier
(
tierID
string
)
time
.
Duration
{
policy
:=
newGeminiQuotaPolicy
()
return
policy
.
CooldownForTier
(
tierID
)
}
func
geminiQuotaTierKeyForAccount
(
account
*
Account
)
string
{
if
account
==
nil
||
account
.
Platform
!=
PlatformGemini
{
return
""
}
// Note: GeminiOAuthType() already defaults legacy (project_id present) to code_assist.
oauthType
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
account
.
GeminiOAuthType
()))
rawTier
:=
strings
.
TrimSpace
(
account
.
GeminiTierID
())
// Prefer the canonical tier stored in credentials.
if
tierID
:=
canonicalGeminiTierIDForOAuthType
(
oauthType
,
rawTier
);
tierID
!=
""
&&
tierID
!=
GeminiTierGoogleOneUnknown
{
return
tierID
}
// Fallback defaults when tier_id is missing or unknown.
switch
oauthType
{
case
"google_one"
:
return
GeminiTierGoogleOneFree
case
"code_assist"
:
return
GeminiTierGCPStandard
case
"ai_studio"
:
return
GeminiTierAIStudioFree
default
:
// API Key accounts (type=apikey) have empty oauth_type and are treated as AI Studio.
return
GeminiTierAIStudioFree
}
}
func
geminiModelClassFromName
(
model
string
)
geminiModelClass
{
name
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
model
))
if
strings
.
Contains
(
name
,
"flash"
)
||
strings
.
Contains
(
name
,
"lite"
)
{
...
...
backend/internal/service/openai_gateway_service.go
View file @
7dddd065
...
...
@@ -490,7 +490,7 @@ func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Acco
return
""
,
""
,
errors
.
New
(
"access_token not found in credentials"
)
}
return
accessToken
,
"oauth"
,
nil
case
AccountTypeA
pi
Key
:
case
AccountTypeA
PI
Key
:
apiKey
:=
account
.
GetOpenAIApiKey
()
if
apiKey
==
""
{
return
""
,
""
,
errors
.
New
(
"api_key not found in credentials"
)
...
...
@@ -630,7 +630,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
case
AccountTypeOAuth
:
// OAuth accounts use ChatGPT internal API
targetURL
=
chatgptCodexURL
case
AccountTypeA
pi
Key
:
case
AccountTypeA
PI
Key
:
// API Key accounts use Platform API or custom base URL
baseURL
:=
account
.
GetOpenAIBaseURL
()
if
baseURL
==
""
{
...
...
@@ -710,7 +710,13 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
}
// Handle upstream error (mark account status)
s
.
rateLimitService
.
HandleUpstreamError
(
ctx
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
body
)
shouldDisable
:=
false
if
s
.
rateLimitService
!=
nil
{
shouldDisable
=
s
.
rateLimitService
.
HandleUpstreamError
(
ctx
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
body
)
}
if
shouldDisable
{
return
nil
,
&
UpstreamFailoverError
{
StatusCode
:
resp
.
StatusCode
}
}
// Return appropriate error response
var
errType
,
errMsg
string
...
...
@@ -1065,7 +1071,7 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
// OpenAIRecordUsageInput input for recording usage
type
OpenAIRecordUsageInput
struct
{
Result
*
OpenAIForwardResult
A
pi
Key
*
A
pi
Key
A
PI
Key
*
A
PI
Key
User
*
User
Account
*
Account
Subscription
*
UserSubscription
...
...
@@ -1074,7 +1080,7 @@ type OpenAIRecordUsageInput struct {
// RecordUsage records usage and deducts balance
func
(
s
*
OpenAIGatewayService
)
RecordUsage
(
ctx
context
.
Context
,
input
*
OpenAIRecordUsageInput
)
error
{
result
:=
input
.
Result
apiKey
:=
input
.
A
pi
Key
apiKey
:=
input
.
A
PI
Key
user
:=
input
.
User
account
:=
input
.
Account
subscription
:=
input
.
Subscription
...
...
@@ -1116,7 +1122,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
durationMs
:=
int
(
result
.
Duration
.
Milliseconds
())
usageLog
:=
&
UsageLog
{
UserID
:
user
.
ID
,
A
pi
KeyID
:
apiKey
.
ID
,
A
PI
KeyID
:
apiKey
.
ID
,
AccountID
:
account
.
ID
,
RequestID
:
result
.
RequestID
,
Model
:
result
.
Model
,
...
...
@@ -1145,22 +1151,23 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
usageLog
.
SubscriptionID
=
&
subscription
.
ID
}
_
=
s
.
usageLogRepo
.
Create
(
ctx
,
usageLog
)
inserted
,
err
:=
s
.
usageLogRepo
.
Create
(
ctx
,
usageLog
)
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
RunMode
==
config
.
RunModeSimple
{
log
.
Printf
(
"[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d"
,
usageLog
.
UserID
,
usageLog
.
TotalTokens
())
s
.
deferredService
.
ScheduleLastUsedUpdate
(
account
.
ID
)
return
nil
}
shouldBill
:=
inserted
||
err
!=
nil
// Deduct based on billing type
if
isSubscriptionBilling
{
if
cost
.
TotalCost
>
0
{
if
shouldBill
&&
cost
.
TotalCost
>
0
{
_
=
s
.
userSubRepo
.
IncrementUsage
(
ctx
,
subscription
.
ID
,
cost
.
TotalCost
)
s
.
billingCacheService
.
QueueUpdateSubscriptionUsage
(
user
.
ID
,
*
apiKey
.
GroupID
,
cost
.
TotalCost
)
}
}
else
{
if
cost
.
ActualCost
>
0
{
if
shouldBill
&&
cost
.
ActualCost
>
0
{
_
=
s
.
userRepo
.
DeductBalance
(
ctx
,
user
.
ID
,
cost
.
ActualCost
)
s
.
billingCacheService
.
QueueDeductBalance
(
user
.
ID
,
cost
.
ActualCost
)
}
...
...
backend/internal/service/ratelimit_service.go
View file @
7dddd065
...
...
@@ -2,6 +2,7 @@ package service
import
(
"context"
"encoding/json"
"log"
"net/http"
"strconv"
...
...
@@ -18,6 +19,7 @@ type RateLimitService struct {
usageRepo
UsageLogRepository
cfg
*
config
.
Config
geminiQuotaService
*
GeminiQuotaService
tempUnschedCache
TempUnschedCache
usageCacheMu
sync
.
RWMutex
usageCache
map
[
int64
]
*
geminiUsageCacheEntry
}
...
...
@@ -31,12 +33,13 @@ type geminiUsageCacheEntry struct {
const
geminiPrecheckCacheTTL
=
time
.
Minute
// NewRateLimitService 创建RateLimitService实例
func
NewRateLimitService
(
accountRepo
AccountRepository
,
usageRepo
UsageLogRepository
,
cfg
*
config
.
Config
,
geminiQuotaService
*
GeminiQuotaService
)
*
RateLimitService
{
func
NewRateLimitService
(
accountRepo
AccountRepository
,
usageRepo
UsageLogRepository
,
cfg
*
config
.
Config
,
geminiQuotaService
*
GeminiQuotaService
,
tempUnschedCache
TempUnschedCache
)
*
RateLimitService
{
return
&
RateLimitService
{
accountRepo
:
accountRepo
,
usageRepo
:
usageRepo
,
cfg
:
cfg
,
geminiQuotaService
:
geminiQuotaService
,
tempUnschedCache
:
tempUnschedCache
,
usageCache
:
make
(
map
[
int64
]
*
geminiUsageCacheEntry
),
}
}
...
...
@@ -51,38 +54,45 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
return
false
}
tempMatched
:=
s
.
tryTempUnschedulable
(
ctx
,
account
,
statusCode
,
responseBody
)
switch
statusCode
{
case
401
:
// 认证失败:停止调度,记录错误
s
.
handleAuthError
(
ctx
,
account
,
"Authentication failed (401): invalid or expired credentials"
)
return
true
shouldDisable
=
true
case
402
:
// 支付要求:余额不足或计费问题,停止调度
s
.
handleAuthError
(
ctx
,
account
,
"Payment required (402): insufficient balance or billing issue"
)
return
true
shouldDisable
=
true
case
403
:
// 禁止访问:停止调度,记录错误
s
.
handleAuthError
(
ctx
,
account
,
"Access forbidden (403): account may be suspended or lack permissions"
)
return
true
shouldDisable
=
true
case
429
:
s
.
handle429
(
ctx
,
account
,
headers
)
return
false
shouldDisable
=
false
case
529
:
s
.
handle529
(
ctx
,
account
)
return
false
shouldDisable
=
false
default
:
// 其他5xx错误:记录但不停止调度
if
statusCode
>=
500
{
log
.
Printf
(
"Account %d received upstream error %d"
,
account
.
ID
,
statusCode
)
}
return
false
shouldDisable
=
false
}
if
tempMatched
{
return
true
}
return
shouldDisable
}
// PreCheckUsage proactively checks local quota before dispatching a request.
// Returns false when the account should be skipped.
func
(
s
*
RateLimitService
)
PreCheckUsage
(
ctx
context
.
Context
,
account
*
Account
,
requestedModel
string
)
(
bool
,
error
)
{
if
account
==
nil
||
!
account
.
IsGeminiCodeAssist
()
||
strings
.
TrimSpace
(
requestedModel
)
==
""
{
if
account
==
nil
||
account
.
Platform
!=
PlatformGemini
{
return
true
,
nil
}
if
s
.
usageRepo
==
nil
||
s
.
geminiQuotaService
==
nil
{
...
...
@@ -94,18 +104,24 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
return
true
,
nil
}
now
:=
time
.
Now
()
modelClass
:=
geminiModelClassFromName
(
requestedModel
)
// 1) Daily quota precheck (RPD; resets at PST midnight)
{
var
limit
int64
switch
geminiModelClassFromName
(
requestedModel
)
{
if
quota
.
SharedRPD
>
0
{
limit
=
quota
.
SharedRPD
}
else
{
switch
modelClass
{
case
geminiModelFlash
:
limit
=
quota
.
FlashRPD
default
:
limit
=
quota
.
ProRPD
}
if
limit
<=
0
{
return
true
,
nil
}
now
:=
time
.
Now
()
if
limit
>
0
{
start
:=
geminiDailyWindowStart
(
now
)
totals
,
ok
:=
s
.
getGeminiUsageTotals
(
account
.
ID
,
start
,
now
)
if
!
ok
{
...
...
@@ -118,21 +134,70 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
}
var
used
int64
switch
geminiModelClassFromName
(
requestedModel
)
{
if
quota
.
SharedRPD
>
0
{
used
=
totals
.
ProRequests
+
totals
.
FlashRequests
}
else
{
switch
modelClass
{
case
geminiModelFlash
:
used
=
totals
.
FlashRequests
default
:
used
=
totals
.
ProRequests
}
}
if
used
>=
limit
{
resetAt
:=
geminiDailyResetTime
(
now
)
if
err
:=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
resetAt
);
err
!=
nil
{
log
.
Printf
(
"SetRateLimited failed for account %d: %v"
,
account
.
ID
,
err
)
// NOTE:
// - This is a local precheck to reduce upstream 429s.
// - Do NOT mark the account as rate-limited here; rate_limit_reset_at should reflect real upstream 429s.
log
.
Printf
(
"[Gemini PreCheck] Account %d reached daily quota (%d/%d), skip until %v"
,
account
.
ID
,
used
,
limit
,
resetAt
)
return
false
,
nil
}
}
}
// 2) Minute quota precheck (RPM; fixed window current minute)
{
var
limit
int64
if
quota
.
SharedRPM
>
0
{
limit
=
quota
.
SharedRPM
}
else
{
switch
modelClass
{
case
geminiModelFlash
:
limit
=
quota
.
FlashRPM
default
:
limit
=
quota
.
ProRPM
}
}
log
.
Printf
(
"[Gemini PreCheck] Account %d reached daily quota (%d/%d), rate limited until %v"
,
account
.
ID
,
used
,
limit
,
resetAt
)
if
limit
>
0
{
start
:=
now
.
Truncate
(
time
.
Minute
)
stats
,
err
:=
s
.
usageRepo
.
GetModelStatsWithFilters
(
ctx
,
start
,
now
,
0
,
0
,
account
.
ID
)
if
err
!=
nil
{
return
true
,
err
}
totals
:=
geminiAggregateUsage
(
stats
)
var
used
int64
if
quota
.
SharedRPM
>
0
{
used
=
totals
.
ProRequests
+
totals
.
FlashRequests
}
else
{
switch
modelClass
{
case
geminiModelFlash
:
used
=
totals
.
FlashRequests
default
:
used
=
totals
.
ProRequests
}
}
if
used
>=
limit
{
resetAt
:=
start
.
Add
(
time
.
Minute
)
// Do not persist "rate limited" status from local precheck. See note above.
log
.
Printf
(
"[Gemini PreCheck] Account %d reached minute quota (%d/%d), skip until %v"
,
account
.
ID
,
used
,
limit
,
resetAt
)
return
false
,
nil
}
}
}
return
true
,
nil
}
...
...
@@ -176,7 +241,10 @@ func (s *RateLimitService) GeminiCooldown(ctx context.Context, account *Account)
if
account
==
nil
{
return
5
*
time
.
Minute
}
return
s
.
geminiQuotaService
.
CooldownForTier
(
ctx
,
account
.
GeminiTierID
())
if
s
.
geminiQuotaService
==
nil
{
return
5
*
time
.
Minute
}
return
s
.
geminiQuotaService
.
CooldownForAccount
(
ctx
,
account
)
}
// handleAuthError 处理认证类错误(401/403),停止账号调度
...
...
@@ -287,3 +355,183 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc
func
(
s
*
RateLimitService
)
ClearRateLimit
(
ctx
context
.
Context
,
accountID
int64
)
error
{
return
s
.
accountRepo
.
ClearRateLimit
(
ctx
,
accountID
)
}
func
(
s
*
RateLimitService
)
ClearTempUnschedulable
(
ctx
context
.
Context
,
accountID
int64
)
error
{
if
err
:=
s
.
accountRepo
.
ClearTempUnschedulable
(
ctx
,
accountID
);
err
!=
nil
{
return
err
}
if
s
.
tempUnschedCache
!=
nil
{
if
err
:=
s
.
tempUnschedCache
.
DeleteTempUnsched
(
ctx
,
accountID
);
err
!=
nil
{
log
.
Printf
(
"DeleteTempUnsched failed for account %d: %v"
,
accountID
,
err
)
}
}
return
nil
}
func
(
s
*
RateLimitService
)
GetTempUnschedStatus
(
ctx
context
.
Context
,
accountID
int64
)
(
*
TempUnschedState
,
error
)
{
now
:=
time
.
Now
()
.
Unix
()
if
s
.
tempUnschedCache
!=
nil
{
state
,
err
:=
s
.
tempUnschedCache
.
GetTempUnsched
(
ctx
,
accountID
)
if
err
!=
nil
{
return
nil
,
err
}
if
state
!=
nil
&&
state
.
UntilUnix
>
now
{
return
state
,
nil
}
}
account
,
err
:=
s
.
accountRepo
.
GetByID
(
ctx
,
accountID
)
if
err
!=
nil
{
return
nil
,
err
}
if
account
.
TempUnschedulableUntil
==
nil
{
return
nil
,
nil
}
if
account
.
TempUnschedulableUntil
.
Unix
()
<=
now
{
return
nil
,
nil
}
state
:=
&
TempUnschedState
{
UntilUnix
:
account
.
TempUnschedulableUntil
.
Unix
(),
}
if
account
.
TempUnschedulableReason
!=
""
{
var
parsed
TempUnschedState
if
err
:=
json
.
Unmarshal
([]
byte
(
account
.
TempUnschedulableReason
),
&
parsed
);
err
==
nil
{
if
parsed
.
UntilUnix
==
0
{
parsed
.
UntilUnix
=
state
.
UntilUnix
}
state
=
&
parsed
}
else
{
state
.
ErrorMessage
=
account
.
TempUnschedulableReason
}
}
if
s
.
tempUnschedCache
!=
nil
{
if
err
:=
s
.
tempUnschedCache
.
SetTempUnsched
(
ctx
,
accountID
,
state
);
err
!=
nil
{
log
.
Printf
(
"SetTempUnsched failed for account %d: %v"
,
accountID
,
err
)
}
}
return
state
,
nil
}
func
(
s
*
RateLimitService
)
HandleTempUnschedulable
(
ctx
context
.
Context
,
account
*
Account
,
statusCode
int
,
responseBody
[]
byte
)
bool
{
if
account
==
nil
{
return
false
}
if
!
account
.
ShouldHandleErrorCode
(
statusCode
)
{
return
false
}
return
s
.
tryTempUnschedulable
(
ctx
,
account
,
statusCode
,
responseBody
)
}
const
tempUnschedBodyMaxBytes
=
64
<<
10
const
tempUnschedMessageMaxBytes
=
2048
func
(
s
*
RateLimitService
)
tryTempUnschedulable
(
ctx
context
.
Context
,
account
*
Account
,
statusCode
int
,
responseBody
[]
byte
)
bool
{
if
account
==
nil
{
return
false
}
if
!
account
.
IsTempUnschedulableEnabled
()
{
return
false
}
rules
:=
account
.
GetTempUnschedulableRules
()
if
len
(
rules
)
==
0
{
return
false
}
if
statusCode
<=
0
||
len
(
responseBody
)
==
0
{
return
false
}
body
:=
responseBody
if
len
(
body
)
>
tempUnschedBodyMaxBytes
{
body
=
body
[
:
tempUnschedBodyMaxBytes
]
}
bodyLower
:=
strings
.
ToLower
(
string
(
body
))
for
idx
,
rule
:=
range
rules
{
if
rule
.
ErrorCode
!=
statusCode
||
len
(
rule
.
Keywords
)
==
0
{
continue
}
matchedKeyword
:=
matchTempUnschedKeyword
(
bodyLower
,
rule
.
Keywords
)
if
matchedKeyword
==
""
{
continue
}
if
s
.
triggerTempUnschedulable
(
ctx
,
account
,
rule
,
idx
,
statusCode
,
matchedKeyword
,
responseBody
)
{
return
true
}
}
return
false
}
func
matchTempUnschedKeyword
(
bodyLower
string
,
keywords
[]
string
)
string
{
if
bodyLower
==
""
{
return
""
}
for
_
,
keyword
:=
range
keywords
{
k
:=
strings
.
TrimSpace
(
keyword
)
if
k
==
""
{
continue
}
if
strings
.
Contains
(
bodyLower
,
strings
.
ToLower
(
k
))
{
return
k
}
}
return
""
}
func
(
s
*
RateLimitService
)
triggerTempUnschedulable
(
ctx
context
.
Context
,
account
*
Account
,
rule
TempUnschedulableRule
,
ruleIndex
int
,
statusCode
int
,
matchedKeyword
string
,
responseBody
[]
byte
)
bool
{
if
account
==
nil
{
return
false
}
if
rule
.
DurationMinutes
<=
0
{
return
false
}
now
:=
time
.
Now
()
until
:=
now
.
Add
(
time
.
Duration
(
rule
.
DurationMinutes
)
*
time
.
Minute
)
state
:=
&
TempUnschedState
{
UntilUnix
:
until
.
Unix
(),
TriggeredAtUnix
:
now
.
Unix
(),
StatusCode
:
statusCode
,
MatchedKeyword
:
matchedKeyword
,
RuleIndex
:
ruleIndex
,
ErrorMessage
:
truncateTempUnschedMessage
(
responseBody
,
tempUnschedMessageMaxBytes
),
}
reason
:=
""
if
raw
,
err
:=
json
.
Marshal
(
state
);
err
==
nil
{
reason
=
string
(
raw
)
}
if
reason
==
""
{
reason
=
strings
.
TrimSpace
(
state
.
ErrorMessage
)
}
if
err
:=
s
.
accountRepo
.
SetTempUnschedulable
(
ctx
,
account
.
ID
,
until
,
reason
);
err
!=
nil
{
log
.
Printf
(
"SetTempUnschedulable failed for account %d: %v"
,
account
.
ID
,
err
)
return
false
}
if
s
.
tempUnschedCache
!=
nil
{
if
err
:=
s
.
tempUnschedCache
.
SetTempUnsched
(
ctx
,
account
.
ID
,
state
);
err
!=
nil
{
log
.
Printf
(
"SetTempUnsched cache failed for account %d: %v"
,
account
.
ID
,
err
)
}
}
log
.
Printf
(
"Account %d temp unschedulable until %v (rule %d, code %d)"
,
account
.
ID
,
until
,
ruleIndex
,
statusCode
)
return
true
}
func
truncateTempUnschedMessage
(
body
[]
byte
,
maxBytes
int
)
string
{
if
maxBytes
<=
0
||
len
(
body
)
==
0
{
return
""
}
if
len
(
body
)
>
maxBytes
{
body
=
body
[
:
maxBytes
]
}
return
strings
.
TrimSpace
(
string
(
body
))
}
backend/internal/service/setting_service.go
View file @
7dddd065
...
...
@@ -61,9 +61,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeySiteName
,
SettingKeySiteLogo
,
SettingKeySiteSubtitle
,
SettingKeyA
pi
BaseU
rl
,
SettingKeyA
PI
BaseU
RL
,
SettingKeyContactInfo
,
SettingKeyDocU
rl
,
SettingKeyDocU
RL
,
}
settings
,
err
:=
s
.
settingRepo
.
GetMultiple
(
ctx
,
keys
)
...
...
@@ -79,9 +79,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SiteName
:
s
.
getStringOrDefault
(
settings
,
SettingKeySiteName
,
"Sub2API"
),
SiteLogo
:
settings
[
SettingKeySiteLogo
],
SiteSubtitle
:
s
.
getStringOrDefault
(
settings
,
SettingKeySiteSubtitle
,
"Subscription to API Conversion Platform"
),
A
pi
BaseU
rl
:
settings
[
SettingKeyA
pi
BaseU
rl
],
A
PI
BaseU
RL
:
settings
[
SettingKeyA
PI
BaseU
RL
],
ContactInfo
:
settings
[
SettingKeyContactInfo
],
DocU
rl
:
settings
[
SettingKeyDocU
rl
],
DocU
RL
:
settings
[
SettingKeyDocU
RL
],
},
nil
}
...
...
@@ -94,15 +94,15 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates
[
SettingKeyEmailVerifyEnabled
]
=
strconv
.
FormatBool
(
settings
.
EmailVerifyEnabled
)
// 邮件服务设置(只有非空才更新密码)
updates
[
SettingKeyS
mtp
Host
]
=
settings
.
S
mtp
Host
updates
[
SettingKeyS
mtp
Port
]
=
strconv
.
Itoa
(
settings
.
S
mtp
Port
)
updates
[
SettingKeyS
mtp
Username
]
=
settings
.
S
mtp
Username
if
settings
.
S
mtp
Password
!=
""
{
updates
[
SettingKeyS
mtp
Password
]
=
settings
.
S
mtp
Password
updates
[
SettingKeyS
MTP
Host
]
=
settings
.
S
MTP
Host
updates
[
SettingKeyS
MTP
Port
]
=
strconv
.
Itoa
(
settings
.
S
MTP
Port
)
updates
[
SettingKeyS
MTP
Username
]
=
settings
.
S
MTP
Username
if
settings
.
S
MTP
Password
!=
""
{
updates
[
SettingKeyS
MTP
Password
]
=
settings
.
S
MTP
Password
}
updates
[
SettingKeyS
mtp
From
]
=
settings
.
S
mtp
From
updates
[
SettingKeyS
mtp
FromName
]
=
settings
.
S
mtp
FromName
updates
[
SettingKeyS
mtp
UseTLS
]
=
strconv
.
FormatBool
(
settings
.
S
mtp
UseTLS
)
updates
[
SettingKeyS
MTP
From
]
=
settings
.
S
MTP
From
updates
[
SettingKeyS
MTP
FromName
]
=
settings
.
S
MTP
FromName
updates
[
SettingKeyS
MTP
UseTLS
]
=
strconv
.
FormatBool
(
settings
.
S
MTP
UseTLS
)
// Cloudflare Turnstile 设置(只有非空才更新密钥)
updates
[
SettingKeyTurnstileEnabled
]
=
strconv
.
FormatBool
(
settings
.
TurnstileEnabled
)
...
...
@@ -115,14 +115,21 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates
[
SettingKeySiteName
]
=
settings
.
SiteName
updates
[
SettingKeySiteLogo
]
=
settings
.
SiteLogo
updates
[
SettingKeySiteSubtitle
]
=
settings
.
SiteSubtitle
updates
[
SettingKeyA
pi
BaseU
rl
]
=
settings
.
A
pi
BaseU
rl
updates
[
SettingKeyA
PI
BaseU
RL
]
=
settings
.
A
PI
BaseU
RL
updates
[
SettingKeyContactInfo
]
=
settings
.
ContactInfo
updates
[
SettingKeyDocU
rl
]
=
settings
.
DocU
rl
updates
[
SettingKeyDocU
RL
]
=
settings
.
DocU
RL
// 默认配置
updates
[
SettingKeyDefaultConcurrency
]
=
strconv
.
Itoa
(
settings
.
DefaultConcurrency
)
updates
[
SettingKeyDefaultBalance
]
=
strconv
.
FormatFloat
(
settings
.
DefaultBalance
,
'f'
,
8
,
64
)
// Model fallback configuration
updates
[
SettingKeyEnableModelFallback
]
=
strconv
.
FormatBool
(
settings
.
EnableModelFallback
)
updates
[
SettingKeyFallbackModelAnthropic
]
=
settings
.
FallbackModelAnthropic
updates
[
SettingKeyFallbackModelOpenAI
]
=
settings
.
FallbackModelOpenAI
updates
[
SettingKeyFallbackModelGemini
]
=
settings
.
FallbackModelGemini
updates
[
SettingKeyFallbackModelAntigravity
]
=
settings
.
FallbackModelAntigravity
return
s
.
settingRepo
.
SetMultiple
(
ctx
,
updates
)
}
...
...
@@ -198,8 +205,14 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeySiteLogo
:
""
,
SettingKeyDefaultConcurrency
:
strconv
.
Itoa
(
s
.
cfg
.
Default
.
UserConcurrency
),
SettingKeyDefaultBalance
:
strconv
.
FormatFloat
(
s
.
cfg
.
Default
.
UserBalance
,
'f'
,
8
,
64
),
SettingKeySmtpPort
:
"587"
,
SettingKeySmtpUseTLS
:
"false"
,
SettingKeySMTPPort
:
"587"
,
SettingKeySMTPUseTLS
:
"false"
,
// Model fallback defaults
SettingKeyEnableModelFallback
:
"false"
,
SettingKeyFallbackModelAnthropic
:
"claude-3-5-sonnet-20241022"
,
SettingKeyFallbackModelOpenAI
:
"gpt-4o"
,
SettingKeyFallbackModelGemini
:
"gemini-2.5-pro"
,
SettingKeyFallbackModelAntigravity
:
"gemini-2.5-pro"
,
}
return
s
.
settingRepo
.
SetMultiple
(
ctx
,
defaults
)
...
...
@@ -210,28 +223,28 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result
:=
&
SystemSettings
{
RegistrationEnabled
:
settings
[
SettingKeyRegistrationEnabled
]
==
"true"
,
EmailVerifyEnabled
:
settings
[
SettingKeyEmailVerifyEnabled
]
==
"true"
,
S
mtp
Host
:
settings
[
SettingKeyS
mtp
Host
],
S
mtp
Username
:
settings
[
SettingKeyS
mtp
Username
],
S
mtp
From
:
settings
[
SettingKeyS
mtp
From
],
S
mtp
FromName
:
settings
[
SettingKeyS
mtp
FromName
],
S
mtp
UseTLS
:
settings
[
SettingKeyS
mtp
UseTLS
]
==
"true"
,
S
mtp
PasswordConfigured
:
settings
[
SettingKeyS
mtp
Password
]
!=
""
,
S
MTP
Host
:
settings
[
SettingKeyS
MTP
Host
],
S
MTP
Username
:
settings
[
SettingKeyS
MTP
Username
],
S
MTP
From
:
settings
[
SettingKeyS
MTP
From
],
S
MTP
FromName
:
settings
[
SettingKeyS
MTP
FromName
],
S
MTP
UseTLS
:
settings
[
SettingKeyS
MTP
UseTLS
]
==
"true"
,
S
MTP
PasswordConfigured
:
settings
[
SettingKeyS
MTP
Password
]
!=
""
,
TurnstileEnabled
:
settings
[
SettingKeyTurnstileEnabled
]
==
"true"
,
TurnstileSiteKey
:
settings
[
SettingKeyTurnstileSiteKey
],
TurnstileSecretKeyConfigured
:
settings
[
SettingKeyTurnstileSecretKey
]
!=
""
,
SiteName
:
s
.
getStringOrDefault
(
settings
,
SettingKeySiteName
,
"Sub2API"
),
SiteLogo
:
settings
[
SettingKeySiteLogo
],
SiteSubtitle
:
s
.
getStringOrDefault
(
settings
,
SettingKeySiteSubtitle
,
"Subscription to API Conversion Platform"
),
A
pi
BaseU
rl
:
settings
[
SettingKeyA
pi
BaseU
rl
],
A
PI
BaseU
RL
:
settings
[
SettingKeyA
PI
BaseU
RL
],
ContactInfo
:
settings
[
SettingKeyContactInfo
],
DocU
rl
:
settings
[
SettingKeyDocU
rl
],
DocU
RL
:
settings
[
SettingKeyDocU
RL
],
}
// 解析整数类型
if
port
,
err
:=
strconv
.
Atoi
(
settings
[
SettingKeyS
mtp
Port
]);
err
==
nil
{
result
.
S
mtp
Port
=
port
if
port
,
err
:=
strconv
.
Atoi
(
settings
[
SettingKeyS
MTP
Port
]);
err
==
nil
{
result
.
S
MTP
Port
=
port
}
else
{
result
.
S
mtp
Port
=
587
result
.
S
MTP
Port
=
587
}
if
concurrency
,
err
:=
strconv
.
Atoi
(
settings
[
SettingKeyDefaultConcurrency
]);
err
==
nil
{
...
...
@@ -247,6 +260,17 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result
.
DefaultBalance
=
s
.
cfg
.
Default
.
UserBalance
}
// 敏感信息直接返回,方便测试连接时使用
result
.
SMTPPassword
=
settings
[
SettingKeySMTPPassword
]
result
.
TurnstileSecretKey
=
settings
[
SettingKeyTurnstileSecretKey
]
// Model fallback settings
result
.
EnableModelFallback
=
settings
[
SettingKeyEnableModelFallback
]
==
"true"
result
.
FallbackModelAnthropic
=
s
.
getStringOrDefault
(
settings
,
SettingKeyFallbackModelAnthropic
,
"claude-3-5-sonnet-20241022"
)
result
.
FallbackModelOpenAI
=
s
.
getStringOrDefault
(
settings
,
SettingKeyFallbackModelOpenAI
,
"gpt-4o"
)
result
.
FallbackModelGemini
=
s
.
getStringOrDefault
(
settings
,
SettingKeyFallbackModelGemini
,
"gemini-2.5-pro"
)
result
.
FallbackModelAntigravity
=
s
.
getStringOrDefault
(
settings
,
SettingKeyFallbackModelAntigravity
,
"gemini-2.5-pro"
)
return
result
}
...
...
@@ -276,28 +300,28 @@ func (s *SettingService) GetTurnstileSecretKey(ctx context.Context) string {
return
value
}
// GenerateAdminA
pi
Key 生成新的管理员 API Key
func
(
s
*
SettingService
)
GenerateAdminA
pi
Key
(
ctx
context
.
Context
)
(
string
,
error
)
{
// GenerateAdminA
PI
Key 生成新的管理员 API Key
func
(
s
*
SettingService
)
GenerateAdminA
PI
Key
(
ctx
context
.
Context
)
(
string
,
error
)
{
// 生成 32 字节随机数 = 64 位十六进制字符
bytes
:=
make
([]
byte
,
32
)
if
_
,
err
:=
rand
.
Read
(
bytes
);
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"generate random bytes: %w"
,
err
)
}
key
:=
AdminA
pi
KeyPrefix
+
hex
.
EncodeToString
(
bytes
)
key
:=
AdminA
PI
KeyPrefix
+
hex
.
EncodeToString
(
bytes
)
// 存储到 settings 表
if
err
:=
s
.
settingRepo
.
Set
(
ctx
,
SettingKeyAdminA
pi
Key
,
key
);
err
!=
nil
{
if
err
:=
s
.
settingRepo
.
Set
(
ctx
,
SettingKeyAdminA
PI
Key
,
key
);
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"save admin api key: %w"
,
err
)
}
return
key
,
nil
}
// GetAdminA
pi
KeyStatus 获取管理员 API Key 状态
// GetAdminA
PI
KeyStatus 获取管理员 API Key 状态
// 返回脱敏的 key、是否存在、错误
func
(
s
*
SettingService
)
GetAdminA
pi
KeyStatus
(
ctx
context
.
Context
)
(
maskedKey
string
,
exists
bool
,
err
error
)
{
key
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyAdminA
pi
Key
)
func
(
s
*
SettingService
)
GetAdminA
PI
KeyStatus
(
ctx
context
.
Context
)
(
maskedKey
string
,
exists
bool
,
err
error
)
{
key
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyAdminA
PI
Key
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
ErrSettingNotFound
)
{
return
""
,
false
,
nil
...
...
@@ -318,10 +342,10 @@ func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey st
return
maskedKey
,
true
,
nil
}
// GetAdminA
pi
Key 获取完整的管理员 API Key(仅供内部验证使用)
// GetAdminA
PI
Key 获取完整的管理员 API Key(仅供内部验证使用)
// 如果未配置返回空字符串和 nil 错误,只有数据库错误时才返回 error
func
(
s
*
SettingService
)
GetAdminA
pi
Key
(
ctx
context
.
Context
)
(
string
,
error
)
{
key
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyAdminA
pi
Key
)
func
(
s
*
SettingService
)
GetAdminA
PI
Key
(
ctx
context
.
Context
)
(
string
,
error
)
{
key
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyAdminA
PI
Key
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
ErrSettingNotFound
)
{
return
""
,
nil
// 未配置,返回空字符串
...
...
@@ -331,7 +355,45 @@ func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) {
return
key
,
nil
}
// DeleteAdminApiKey 删除管理员 API Key
func
(
s
*
SettingService
)
DeleteAdminApiKey
(
ctx
context
.
Context
)
error
{
return
s
.
settingRepo
.
Delete
(
ctx
,
SettingKeyAdminApiKey
)
// DeleteAdminAPIKey 删除管理员 API Key
func
(
s
*
SettingService
)
DeleteAdminAPIKey
(
ctx
context
.
Context
)
error
{
return
s
.
settingRepo
.
Delete
(
ctx
,
SettingKeyAdminAPIKey
)
}
// IsModelFallbackEnabled 检查是否启用模型兜底机制
func
(
s
*
SettingService
)
IsModelFallbackEnabled
(
ctx
context
.
Context
)
bool
{
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyEnableModelFallback
)
if
err
!=
nil
{
return
false
// Default: disabled
}
return
value
==
"true"
}
// GetFallbackModel 获取指定平台的兜底模型
func
(
s
*
SettingService
)
GetFallbackModel
(
ctx
context
.
Context
,
platform
string
)
string
{
var
key
string
var
defaultModel
string
switch
platform
{
case
PlatformAnthropic
:
key
=
SettingKeyFallbackModelAnthropic
defaultModel
=
"claude-3-5-sonnet-20241022"
case
PlatformOpenAI
:
key
=
SettingKeyFallbackModelOpenAI
defaultModel
=
"gpt-4o"
case
PlatformGemini
:
key
=
SettingKeyFallbackModelGemini
defaultModel
=
"gemini-2.5-pro"
case
PlatformAntigravity
:
key
=
SettingKeyFallbackModelAntigravity
defaultModel
=
"gemini-2.5-pro"
default
:
return
""
}
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
key
)
if
err
!=
nil
||
value
==
""
{
return
defaultModel
}
return
value
}
backend/internal/service/settings_view.go
View file @
7dddd065
...
...
@@ -4,14 +4,14 @@ type SystemSettings struct {
RegistrationEnabled
bool
EmailVerifyEnabled
bool
S
mtp
Host
string
S
mtp
Port
int
S
mtp
Username
string
S
mtp
Password
string
S
mtp
PasswordConfigured
bool
S
mtp
From
string
S
mtp
FromName
string
S
mtp
UseTLS
bool
S
MTP
Host
string
S
MTP
Port
int
S
MTP
Username
string
S
MTP
Password
string
S
MTP
PasswordConfigured
bool
S
MTP
From
string
S
MTP
FromName
string
S
MTP
UseTLS
bool
TurnstileEnabled
bool
TurnstileSiteKey
string
...
...
@@ -21,12 +21,19 @@ type SystemSettings struct {
SiteName
string
SiteLogo
string
SiteSubtitle
string
A
pi
BaseU
rl
string
A
PI
BaseU
RL
string
ContactInfo
string
DocU
rl
string
DocU
RL
string
DefaultConcurrency
int
DefaultBalance
float64
// Model fallback configuration
EnableModelFallback
bool
`json:"enable_model_fallback"`
FallbackModelAnthropic
string
`json:"fallback_model_anthropic"`
FallbackModelOpenAI
string
`json:"fallback_model_openai"`
FallbackModelGemini
string
`json:"fallback_model_gemini"`
FallbackModelAntigravity
string
`json:"fallback_model_antigravity"`
}
type
PublicSettings
struct
{
...
...
@@ -37,8 +44,8 @@ type PublicSettings struct {
SiteName
string
SiteLogo
string
SiteSubtitle
string
A
pi
BaseU
rl
string
A
PI
BaseU
RL
string
ContactInfo
string
DocU
rl
string
DocU
RL
string
Version
string
}
backend/internal/service/temp_unsched.go
0 → 100644
View file @
7dddd065
package
service
import
(
"context"
)
// TempUnschedState 临时不可调度状态
type
TempUnschedState
struct
{
UntilUnix
int64
`json:"until_unix"`
// 解除时间(Unix 时间戳)
TriggeredAtUnix
int64
`json:"triggered_at_unix"`
// 触发时间(Unix 时间戳)
StatusCode
int
`json:"status_code"`
// 触发的错误码
MatchedKeyword
string
`json:"matched_keyword"`
// 匹配的关键词
RuleIndex
int
`json:"rule_index"`
// 触发的规则索引
ErrorMessage
string
`json:"error_message"`
// 错误消息
}
// TempUnschedCache 临时不可调度缓存接口
type
TempUnschedCache
interface
{
SetTempUnsched
(
ctx
context
.
Context
,
accountID
int64
,
state
*
TempUnschedState
)
error
GetTempUnsched
(
ctx
context
.
Context
,
accountID
int64
)
(
*
TempUnschedState
,
error
)
DeleteTempUnsched
(
ctx
context
.
Context
,
accountID
int64
)
error
}
backend/internal/service/token_refresher_test.go
View file @
7dddd065
...
...
@@ -197,7 +197,7 @@ func TestClaudeTokenRefresher_CanRefresh(t *testing.T) {
{
name
:
"anthropic api-key - cannot refresh"
,
platform
:
PlatformAnthropic
,
accType
:
AccountTypeA
pi
Key
,
accType
:
AccountTypeA
PI
Key
,
want
:
false
,
},
{
...
...
backend/internal/service/update_service.go
View file @
7dddd065
...
...
@@ -79,7 +79,7 @@ type ReleaseInfo struct {
Name
string
`json:"name"`
Body
string
`json:"body"`
PublishedAt
string
`json:"published_at"`
H
tml
URL
string
`json:"html_url"`
H
TML
URL
string
`json:"html_url"`
Assets
[]
Asset
`json:"assets,omitempty"`
}
...
...
@@ -96,13 +96,13 @@ type GitHubRelease struct {
Name
string
`json:"name"`
Body
string
`json:"body"`
PublishedAt
string
`json:"published_at"`
H
tmlUrl
string
`json:"html_url"`
H
TMLURL
string
`json:"html_url"`
Assets
[]
GitHubAsset
`json:"assets"`
}
type
GitHubAsset
struct
{
Name
string
`json:"name"`
BrowserDownloadU
rl
string
`json:"browser_download_url"`
BrowserDownloadU
RL
string
`json:"browser_download_url"`
Size
int64
`json:"size"`
}
...
...
@@ -285,7 +285,7 @@ func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, er
for
i
,
a
:=
range
release
.
Assets
{
assets
[
i
]
=
Asset
{
Name
:
a
.
Name
,
DownloadURL
:
a
.
BrowserDownloadU
rl
,
DownloadURL
:
a
.
BrowserDownloadU
RL
,
Size
:
a
.
Size
,
}
}
...
...
@@ -298,7 +298,7 @@ func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, er
Name
:
release
.
Name
,
Body
:
release
.
Body
,
PublishedAt
:
release
.
PublishedAt
,
H
tml
URL
:
release
.
H
tmlUrl
,
H
TML
URL
:
release
.
H
TMLURL
,
Assets
:
assets
,
},
Cached
:
false
,
...
...
backend/internal/service/usage_log.go
View file @
7dddd065
...
...
@@ -10,7 +10,7 @@ const (
type
UsageLog
struct
{
ID
int64
UserID
int64
A
pi
KeyID
int64
A
PI
KeyID
int64
AccountID
int64
RequestID
string
Model
string
...
...
@@ -42,7 +42,7 @@ type UsageLog struct {
CreatedAt
time
.
Time
User
*
User
A
pi
Key
*
A
pi
Key
A
PI
Key
*
A
PI
Key
Account
*
Account
Group
*
Group
Subscription
*
UserSubscription
...
...
backend/internal/service/usage_service.go
View file @
7dddd065
...
...
@@ -2,9 +2,11 @@ package service
import
(
"context"
"errors"
"fmt"
"time"
dbent
"github.com/Wei-Shaw/sub2api/ent"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
...
...
@@ -17,7 +19,7 @@ var (
// CreateUsageLogRequest 创建使用日志请求
type
CreateUsageLogRequest
struct
{
UserID
int64
`json:"user_id"`
A
pi
KeyID
int64
`json:"api_key_id"`
A
PI
KeyID
int64
`json:"api_key_id"`
AccountID
int64
`json:"account_id"`
RequestID
string
`json:"request_id"`
Model
string
`json:"model"`
...
...
@@ -54,20 +56,34 @@ type UsageStats struct {
type
UsageService
struct
{
usageRepo
UsageLogRepository
userRepo
UserRepository
entClient
*
dbent
.
Client
}
// NewUsageService 创建使用统计服务实例
func
NewUsageService
(
usageRepo
UsageLogRepository
,
userRepo
UserRepository
)
*
UsageService
{
func
NewUsageService
(
usageRepo
UsageLogRepository
,
userRepo
UserRepository
,
entClient
*
dbent
.
Client
)
*
UsageService
{
return
&
UsageService
{
usageRepo
:
usageRepo
,
userRepo
:
userRepo
,
entClient
:
entClient
,
}
}
// Create 创建使用日志
func
(
s
*
UsageService
)
Create
(
ctx
context
.
Context
,
req
CreateUsageLogRequest
)
(
*
UsageLog
,
error
)
{
// 使用数据库事务保证「使用日志插入」与「扣费」的原子性,避免重复扣费或漏扣风险。
tx
,
err
:=
s
.
entClient
.
Tx
(
ctx
)
if
err
!=
nil
&&
!
errors
.
Is
(
err
,
dbent
.
ErrTxStarted
)
{
return
nil
,
fmt
.
Errorf
(
"begin transaction: %w"
,
err
)
}
txCtx
:=
ctx
if
err
==
nil
{
defer
func
()
{
_
=
tx
.
Rollback
()
}()
txCtx
=
dbent
.
NewTxContext
(
ctx
,
tx
)
}
// 验证用户存在
_
,
err
:
=
s
.
userRepo
.
GetByID
(
c
tx
,
req
.
UserID
)
_
,
err
=
s
.
userRepo
.
GetByID
(
txC
tx
,
req
.
UserID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get user: %w"
,
err
)
}
...
...
@@ -75,7 +91,7 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
// 创建使用日志
usageLog
:=
&
UsageLog
{
UserID
:
req
.
UserID
,
A
pi
KeyID
:
req
.
A
pi
KeyID
,
A
PI
KeyID
:
req
.
A
PI
KeyID
,
AccountID
:
req
.
AccountID
,
RequestID
:
req
.
RequestID
,
Model
:
req
.
Model
,
...
...
@@ -96,17 +112,24 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
DurationMs
:
req
.
DurationMs
,
}
if
err
:=
s
.
usageRepo
.
Create
(
ctx
,
usageLog
);
err
!=
nil
{
inserted
,
err
:=
s
.
usageRepo
.
Create
(
txCtx
,
usageLog
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"create usage log: %w"
,
err
)
}
// 扣除用户余额
if
req
.
ActualCost
>
0
{
if
err
:=
s
.
userRepo
.
UpdateBalance
(
c
tx
,
req
.
UserID
,
-
req
.
ActualCost
);
err
!=
nil
{
if
inserted
&&
req
.
ActualCost
>
0
{
if
err
:=
s
.
userRepo
.
UpdateBalance
(
txC
tx
,
req
.
UserID
,
-
req
.
ActualCost
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"update user balance: %w"
,
err
)
}
}
if
tx
!=
nil
{
if
err
:=
tx
.
Commit
();
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"commit transaction: %w"
,
err
)
}
}
return
usageLog
,
nil
}
...
...
@@ -128,9 +151,9 @@ func (s *UsageService) ListByUser(ctx context.Context, userID int64, params pagi
return
logs
,
pagination
,
nil
}
// ListByA
pi
Key 获取API Key的使用日志列表
func
(
s
*
UsageService
)
ListByA
pi
Key
(
ctx
context
.
Context
,
apiKeyID
int64
,
params
pagination
.
PaginationParams
)
([]
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
logs
,
pagination
,
err
:=
s
.
usageRepo
.
ListByA
pi
Key
(
ctx
,
apiKeyID
,
params
)
// ListByA
PI
Key 获取API Key的使用日志列表
func
(
s
*
UsageService
)
ListByA
PI
Key
(
ctx
context
.
Context
,
apiKeyID
int64
,
params
pagination
.
PaginationParams
)
([]
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
logs
,
pagination
,
err
:=
s
.
usageRepo
.
ListByA
PI
Key
(
ctx
,
apiKeyID
,
params
)
if
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"list usage logs: %w"
,
err
)
}
...
...
@@ -165,9 +188,9 @@ func (s *UsageService) GetStatsByUser(ctx context.Context, userID int64, startTi
},
nil
}
// GetStatsByA
pi
Key 获取API Key的使用统计
func
(
s
*
UsageService
)
GetStatsByA
pi
Key
(
ctx
context
.
Context
,
apiKeyID
int64
,
startTime
,
endTime
time
.
Time
)
(
*
UsageStats
,
error
)
{
stats
,
err
:=
s
.
usageRepo
.
GetA
pi
KeyStatsAggregated
(
ctx
,
apiKeyID
,
startTime
,
endTime
)
// GetStatsByA
PI
Key 获取API Key的使用统计
func
(
s
*
UsageService
)
GetStatsByA
PI
Key
(
ctx
context
.
Context
,
apiKeyID
int64
,
startTime
,
endTime
time
.
Time
)
(
*
UsageStats
,
error
)
{
stats
,
err
:=
s
.
usageRepo
.
GetA
PI
KeyStatsAggregated
(
ctx
,
apiKeyID
,
startTime
,
endTime
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get api key stats: %w"
,
err
)
}
...
...
@@ -270,9 +293,9 @@ func (s *UsageService) GetUserModelStats(ctx context.Context, userID int64, star
return
stats
,
nil
}
// GetBatchA
pi
KeyUsageStats returns today/total actual_cost for given api keys.
func
(
s
*
UsageService
)
GetBatchA
pi
KeyUsageStats
(
ctx
context
.
Context
,
apiKeyIDs
[]
int64
)
(
map
[
int64
]
*
usagestats
.
BatchA
pi
KeyUsageStats
,
error
)
{
stats
,
err
:=
s
.
usageRepo
.
GetBatchA
pi
KeyUsageStats
(
ctx
,
apiKeyIDs
)
// GetBatchA
PI
KeyUsageStats returns today/total actual_cost for given api keys.
func
(
s
*
UsageService
)
GetBatchA
PI
KeyUsageStats
(
ctx
context
.
Context
,
apiKeyIDs
[]
int64
)
(
map
[
int64
]
*
usagestats
.
BatchA
PI
KeyUsageStats
,
error
)
{
stats
,
err
:=
s
.
usageRepo
.
GetBatchA
PI
KeyUsageStats
(
ctx
,
apiKeyIDs
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get batch api key usage stats: %w"
,
err
)
}
...
...
backend/internal/service/user.go
View file @
7dddd065
...
...
@@ -21,7 +21,7 @@ type User struct {
CreatedAt
time
.
Time
UpdatedAt
time
.
Time
A
pi
Keys
[]
A
pi
Key
A
PI
Keys
[]
A
PI
Key
Subscriptions
[]
UserSubscription
}
...
...
backend/internal/service/user_attribute_service.go
View file @
7dddd065
...
...
@@ -56,6 +56,10 @@ func (s *UserAttributeService) CreateDefinition(ctx context.Context, input Creat
Enabled
:
input
.
Enabled
,
}
if
err
:=
validateDefinitionPattern
(
def
);
err
!=
nil
{
return
nil
,
err
}
if
err
:=
s
.
defRepo
.
Create
(
ctx
,
def
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"create definition: %w"
,
err
)
}
...
...
@@ -108,6 +112,10 @@ func (s *UserAttributeService) UpdateDefinition(ctx context.Context, id int64, i
def
.
Enabled
=
*
input
.
Enabled
}
if
err
:=
validateDefinitionPattern
(
def
);
err
!=
nil
{
return
nil
,
err
}
if
err
:=
s
.
defRepo
.
Update
(
ctx
,
def
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"update definition: %w"
,
err
)
}
...
...
@@ -231,7 +239,10 @@ func (s *UserAttributeService) validateValue(def *UserAttributeDefinition, value
// Pattern validation
if
v
.
Pattern
!=
nil
&&
*
v
.
Pattern
!=
""
&&
value
!=
""
{
re
,
err
:=
regexp
.
Compile
(
*
v
.
Pattern
)
if
err
==
nil
&&
!
re
.
MatchString
(
value
)
{
if
err
!=
nil
{
return
validationError
(
def
.
Name
+
" has an invalid pattern"
)
}
if
!
re
.
MatchString
(
value
)
{
msg
:=
def
.
Name
+
" format is invalid"
if
v
.
Message
!=
nil
&&
*
v
.
Message
!=
""
{
msg
=
*
v
.
Message
...
...
@@ -293,3 +304,20 @@ func isValidAttributeType(t UserAttributeType) bool {
}
return
false
}
func
validateDefinitionPattern
(
def
*
UserAttributeDefinition
)
error
{
if
def
==
nil
{
return
nil
}
if
def
.
Validation
.
Pattern
==
nil
{
return
nil
}
pattern
:=
strings
.
TrimSpace
(
*
def
.
Validation
.
Pattern
)
if
pattern
==
""
{
return
nil
}
if
_
,
err
:=
regexp
.
Compile
(
pattern
);
err
!=
nil
{
return
infraerrors
.
BadRequest
(
"INVALID_ATTRIBUTE_PATTERN"
,
fmt
.
Sprintf
(
"invalid pattern for %s: %v"
,
def
.
Name
,
err
))
}
return
nil
}
backend/internal/service/wire.go
View file @
7dddd065
...
...
@@ -75,7 +75,7 @@ var ProviderSet = wire.NewSet(
// Core services
NewAuthService
,
NewUserService
,
NewA
pi
KeyService
,
NewA
PI
KeyService
,
NewGroupService
,
NewAccountService
,
NewProxyService
,
...
...
backend/internal/setup/cli.go
View file @
7dddd065
// Package setup provides CLI commands and application initialization helpers.
package
setup
import
(
...
...
backend/internal/setup/setup.go
View file @
7dddd065
...
...
@@ -352,7 +352,7 @@ func writeConfigFile(cfg *SetupConfig) error {
Default
struct
{
UserConcurrency
int
`yaml:"user_concurrency"`
UserBalance
float64
`yaml:"user_balance"`
A
pi
KeyPrefix
string
`yaml:"api_key_prefix"`
A
PI
KeyPrefix
string
`yaml:"api_key_prefix"`
RateMultiplier
float64
`yaml:"rate_multiplier"`
}
`yaml:"default"`
RateLimit
struct
{
...
...
@@ -374,12 +374,12 @@ func writeConfigFile(cfg *SetupConfig) error {
Default
:
struct
{
UserConcurrency
int
`yaml:"user_concurrency"`
UserBalance
float64
`yaml:"user_balance"`
A
pi
KeyPrefix
string
`yaml:"api_key_prefix"`
A
PI
KeyPrefix
string
`yaml:"api_key_prefix"`
RateMultiplier
float64
`yaml:"rate_multiplier"`
}{
UserConcurrency
:
5
,
UserBalance
:
0
,
A
pi
KeyPrefix
:
"sk-"
,
A
PI
KeyPrefix
:
"sk-"
,
RateMultiplier
:
1.0
,
},
RateLimit
:
struct
{
...
...
backend/internal/web/embed_off.go
View file @
7dddd065
//go:build !embed
// Package web provides embedded web assets for the application.
package
web
import
(
...
...
backend/migrations/020_add_temp_unschedulable.sql
0 → 100644
View file @
7dddd065
-- 020_add_temp_unschedulable.sql
-- 添加临时不可调度功能相关字段
-- 添加临时不可调度状态解除时间字段
ALTER
TABLE
accounts
ADD
COLUMN
IF
NOT
EXISTS
temp_unschedulable_until
timestamptz
;
-- 添加临时不可调度原因字段(用于排障和审计)
ALTER
TABLE
accounts
ADD
COLUMN
IF
NOT
EXISTS
temp_unschedulable_reason
text
;
-- 添加索引以优化调度查询性能
CREATE
INDEX
IF
NOT
EXISTS
idx_accounts_temp_unschedulable_until
ON
accounts
(
temp_unschedulable_until
)
WHERE
deleted_at
IS
NULL
;
-- 添加注释说明字段用途
COMMENT
ON
COLUMN
accounts
.
temp_unschedulable_until
IS
'临时不可调度状态解除时间,当触发临时不可调度规则时设置(基于错误码或错误描述关键词)'
;
COMMENT
ON
COLUMN
accounts
.
temp_unschedulable_reason
IS
'临时不可调度原因,记录触发临时不可调度的具体原因(用于排障和审计)'
;
backend/migrations/026_ops_metrics_aggregation_tables.sql
0 → 100644
View file @
7dddd065
-- Ops monitoring: pre-aggregation tables for dashboard queries
--
-- Problem:
-- The ops dashboard currently runs percentile_cont + GROUP BY queries over large raw tables
-- (usage_logs, ops_error_logs). These will get slower as data grows.
--
-- This migration adds schema-only aggregation tables that can be populated by a future background job.
-- No triggers/functions/jobs are created here (schema only).
-- ============================================
-- Hourly aggregates (per provider/platform)
-- ============================================
CREATE
TABLE
IF
NOT
EXISTS
ops_metrics_hourly
(
-- Start of the hour bucket (recommended: UTC).
bucket_start
TIMESTAMPTZ
NOT
NULL
,
-- Provider/platform label (e.g. anthropic/openai/gemini). Mirrors ops_* queries that GROUP BY platform.
platform
VARCHAR
(
50
)
NOT
NULL
,
-- Traffic counts (use these to compute rates reliably across ranges).
request_count
BIGINT
NOT
NULL
DEFAULT
0
,
success_count
BIGINT
NOT
NULL
DEFAULT
0
,
error_count
BIGINT
NOT
NULL
DEFAULT
0
,
-- Error breakdown used by provider health UI.
error_4xx_count
BIGINT
NOT
NULL
DEFAULT
0
,
error_5xx_count
BIGINT
NOT
NULL
DEFAULT
0
,
timeout_count
BIGINT
NOT
NULL
DEFAULT
0
,
-- Latency aggregates (ms).
avg_latency_ms
DOUBLE
PRECISION
,
p99_latency_ms
DOUBLE
PRECISION
,
-- Convenience rate (percentage, 0-100). Still keep counts as source of truth.
error_rate
DOUBLE
PRECISION
NOT
NULL
DEFAULT
0
,
-- When this row was last (re)computed by the background job.
computed_at
TIMESTAMPTZ
NOT
NULL
DEFAULT
NOW
(),
PRIMARY
KEY
(
bucket_start
,
platform
)
);
CREATE
INDEX
IF
NOT
EXISTS
idx_ops_metrics_hourly_platform_bucket_start
ON
ops_metrics_hourly
(
platform
,
bucket_start
DESC
);
COMMENT
ON
TABLE
ops_metrics_hourly
IS
'Pre-aggregated hourly ops metrics by provider/platform to speed up dashboard queries.'
;
COMMENT
ON
COLUMN
ops_metrics_hourly
.
bucket_start
IS
'Start timestamp of the hour bucket (recommended UTC).'
;
COMMENT
ON
COLUMN
ops_metrics_hourly
.
platform
IS
'Provider/platform label (anthropic/openai/gemini, etc).'
;
COMMENT
ON
COLUMN
ops_metrics_hourly
.
error_rate
IS
'Error rate percentage for the bucket (0-100). Counts remain the source of truth.'
;
COMMENT
ON
COLUMN
ops_metrics_hourly
.
computed_at
IS
'When the row was last computed/refreshed.'
;
-- ============================================
-- Daily aggregates (per provider/platform)
-- ============================================
CREATE
TABLE
IF
NOT
EXISTS
ops_metrics_daily
(
-- Day bucket (recommended: UTC date).
bucket_date
DATE
NOT
NULL
,
platform
VARCHAR
(
50
)
NOT
NULL
,
request_count
BIGINT
NOT
NULL
DEFAULT
0
,
success_count
BIGINT
NOT
NULL
DEFAULT
0
,
error_count
BIGINT
NOT
NULL
DEFAULT
0
,
error_4xx_count
BIGINT
NOT
NULL
DEFAULT
0
,
error_5xx_count
BIGINT
NOT
NULL
DEFAULT
0
,
timeout_count
BIGINT
NOT
NULL
DEFAULT
0
,
avg_latency_ms
DOUBLE
PRECISION
,
p99_latency_ms
DOUBLE
PRECISION
,
error_rate
DOUBLE
PRECISION
NOT
NULL
DEFAULT
0
,
computed_at
TIMESTAMPTZ
NOT
NULL
DEFAULT
NOW
(),
PRIMARY
KEY
(
bucket_date
,
platform
)
);
CREATE
INDEX
IF
NOT
EXISTS
idx_ops_metrics_daily_platform_bucket_date
ON
ops_metrics_daily
(
platform
,
bucket_date
DESC
);
COMMENT
ON
TABLE
ops_metrics_daily
IS
'Pre-aggregated daily ops metrics by provider/platform for longer-term trends.'
;
COMMENT
ON
COLUMN
ops_metrics_daily
.
bucket_date
IS
'UTC date of the day bucket (recommended).'
;
-- ============================================
-- Population strategy (future background job)
-- ============================================
--
-- Suggested approach:
-- 1) Compute hourly buckets from raw logs using UTC time-bucketing, then UPSERT into ops_metrics_hourly.
-- 2) Compute daily buckets either directly from raw logs or by rolling up ops_metrics_hourly.
--
-- Notes:
-- - Ensure the job uses a consistent timezone (recommended: SET TIME ZONE ''UTC'') to avoid bucket drift.
-- - Derive the provider/platform similarly to existing dashboard queries:
-- usage_logs: COALESCE(NULLIF(groups.platform, ''), accounts.platform, '')
-- ops_error_logs: COALESCE(NULLIF(ops_error_logs.platform, ''), groups.platform, accounts.platform, '')
-- - Keep request_count/success_count/error_count as the authoritative values; compute error_rate from counts.
--
-- Example (hourly) shape (pseudo-SQL):
-- INSERT INTO ops_metrics_hourly (...)
-- SELECT date_trunc('hour', created_at) AS bucket_start, platform, ...
-- FROM (/* aggregate usage_logs + ops_error_logs */) s
-- ON CONFLICT (bucket_start, platform) DO UPDATE SET ...;
backend/migrations/027_usage_billing_consistency.sql
0 → 100644
View file @
7dddd065
-- 027_usage_billing_consistency.sql
-- Ensure usage_logs idempotency (request_id, api_key_id) and add reconciliation infrastructure.
-- -----------------------------------------------------------------------------
-- 1) Normalize legacy request_id values
-- -----------------------------------------------------------------------------
-- Historically request_id may be inserted as empty string. Convert it to NULL so
-- the upcoming unique index does not break on repeated "" values.
UPDATE
usage_logs
SET
request_id
=
NULL
WHERE
request_id
=
''
;
-- If duplicates already exist for the same (request_id, api_key_id), keep the
-- first row and NULL-out request_id for the rest so the unique index can be
-- created without deleting historical logs.
WITH
ranked
AS
(
SELECT
id
,
ROW_NUMBER
()
OVER
(
PARTITION
BY
api_key_id
,
request_id
ORDER
BY
id
)
AS
rn
FROM
usage_logs
WHERE
request_id
IS
NOT
NULL
)
UPDATE
usage_logs
ul
SET
request_id
=
NULL
FROM
ranked
r
WHERE
ul
.
id
=
r
.
id
AND
r
.
rn
>
1
;
-- -----------------------------------------------------------------------------
-- 2) Idempotency constraint for usage_logs
-- -----------------------------------------------------------------------------
CREATE
UNIQUE
INDEX
IF
NOT
EXISTS
idx_usage_logs_request_id_api_key_unique
ON
usage_logs
(
request_id
,
api_key_id
);
-- -----------------------------------------------------------------------------
-- 3) Reconciliation infrastructure: billing ledger for usage charges
-- -----------------------------------------------------------------------------
CREATE
TABLE
IF
NOT
EXISTS
billing_usage_entries
(
id
BIGSERIAL
PRIMARY
KEY
,
usage_log_id
BIGINT
NOT
NULL
REFERENCES
usage_logs
(
id
)
ON
DELETE
CASCADE
,
user_id
BIGINT
NOT
NULL
REFERENCES
users
(
id
)
ON
DELETE
CASCADE
,
api_key_id
BIGINT
NOT
NULL
REFERENCES
api_keys
(
id
)
ON
DELETE
CASCADE
,
subscription_id
BIGINT
REFERENCES
user_subscriptions
(
id
)
ON
DELETE
SET
NULL
,
billing_type
SMALLINT
NOT
NULL
,
applied
BOOLEAN
NOT
NULL
DEFAULT
TRUE
,
delta_usd
DECIMAL
(
20
,
10
)
NOT
NULL
DEFAULT
0
,
created_at
TIMESTAMPTZ
NOT
NULL
DEFAULT
NOW
()
);
CREATE
UNIQUE
INDEX
IF
NOT
EXISTS
billing_usage_entries_usage_log_id_unique
ON
billing_usage_entries
(
usage_log_id
);
CREATE
INDEX
IF
NOT
EXISTS
idx_billing_usage_entries_user_time
ON
billing_usage_entries
(
user_id
,
created_at
);
CREATE
INDEX
IF
NOT
EXISTS
idx_billing_usage_entries_created_at
ON
billing_usage_entries
(
created_at
);
Prev
1
…
4
5
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment