Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
daf10907
Commit
daf10907
authored
Jan 14, 2026
by
yangjianbo
Browse files
fix(认证): 修复 OAuth token 缓存失效与 401 处理
新增 token 缓存失效接口并在刷新后清理 401 限流支持自定义规则与可配置冷却时间 补齐缓存失效与 401 处理测试 测试: make test
parent
9c567fad
Changes
19
Show whitespace changes
Inline
Side-by-side
backend/cmd/server/wire_gen.go
View file @
daf10907
...
@@ -98,12 +98,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
...
@@ -98,12 +98,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
geminiQuotaService
:=
service
.
NewGeminiQuotaService
(
configConfig
,
settingRepository
)
geminiQuotaService
:=
service
.
NewGeminiQuotaService
(
configConfig
,
settingRepository
)
tempUnschedCache
:=
repository
.
NewTempUnschedCache
(
redisClient
)
tempUnschedCache
:=
repository
.
NewTempUnschedCache
(
redisClient
)
timeoutCounterCache
:=
repository
.
NewTimeoutCounterCache
(
redisClient
)
timeoutCounterCache
:=
repository
.
NewTimeoutCounterCache
(
redisClient
)
rateLimitService
:=
service
.
ProvideRateLimitService
(
accountRepository
,
usageLogRepository
,
configConfig
,
geminiQuotaService
,
tempUnschedCache
,
timeoutCounterCache
,
settingService
)
geminiTokenCache
:=
repository
.
NewGeminiTokenCache
(
redisClient
)
tokenCacheInvalidator
:=
service
.
NewCompositeTokenCacheInvalidator
(
geminiTokenCache
)
rateLimitService
:=
service
.
ProvideRateLimitService
(
accountRepository
,
usageLogRepository
,
configConfig
,
geminiQuotaService
,
tempUnschedCache
,
timeoutCounterCache
,
settingService
,
tokenCacheInvalidator
)
claudeUsageFetcher
:=
repository
.
NewClaudeUsageFetcher
()
claudeUsageFetcher
:=
repository
.
NewClaudeUsageFetcher
()
antigravityQuotaFetcher
:=
service
.
NewAntigravityQuotaFetcher
(
proxyRepository
)
antigravityQuotaFetcher
:=
service
.
NewAntigravityQuotaFetcher
(
proxyRepository
)
usageCache
:=
service
.
NewUsageCache
()
usageCache
:=
service
.
NewUsageCache
()
accountUsageService
:=
service
.
NewAccountUsageService
(
accountRepository
,
usageLogRepository
,
claudeUsageFetcher
,
geminiQuotaService
,
antigravityQuotaFetcher
,
usageCache
)
accountUsageService
:=
service
.
NewAccountUsageService
(
accountRepository
,
usageLogRepository
,
claudeUsageFetcher
,
geminiQuotaService
,
antigravityQuotaFetcher
,
usageCache
)
geminiTokenCache
:=
repository
.
NewGeminiTokenCache
(
redisClient
)
geminiTokenProvider
:=
service
.
NewGeminiTokenProvider
(
accountRepository
,
geminiTokenCache
,
geminiOAuthService
)
geminiTokenProvider
:=
service
.
NewGeminiTokenProvider
(
accountRepository
,
geminiTokenCache
,
geminiOAuthService
)
gatewayCache
:=
repository
.
NewGatewayCache
(
redisClient
)
gatewayCache
:=
repository
.
NewGatewayCache
(
redisClient
)
antigravityTokenProvider
:=
service
.
NewAntigravityTokenProvider
(
accountRepository
,
geminiTokenCache
,
antigravityOAuthService
)
antigravityTokenProvider
:=
service
.
NewAntigravityTokenProvider
(
accountRepository
,
geminiTokenCache
,
antigravityOAuthService
)
...
@@ -166,7 +167,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
...
@@ -166,7 +167,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsAlertEvaluatorService
:=
service
.
ProvideOpsAlertEvaluatorService
(
opsService
,
opsRepository
,
emailService
,
redisClient
,
configConfig
)
opsAlertEvaluatorService
:=
service
.
ProvideOpsAlertEvaluatorService
(
opsService
,
opsRepository
,
emailService
,
redisClient
,
configConfig
)
opsCleanupService
:=
service
.
ProvideOpsCleanupService
(
opsRepository
,
db
,
redisClient
,
configConfig
)
opsCleanupService
:=
service
.
ProvideOpsCleanupService
(
opsRepository
,
db
,
redisClient
,
configConfig
)
opsScheduledReportService
:=
service
.
ProvideOpsScheduledReportService
(
opsService
,
userService
,
emailService
,
redisClient
,
configConfig
)
opsScheduledReportService
:=
service
.
ProvideOpsScheduledReportService
(
opsService
,
userService
,
emailService
,
redisClient
,
configConfig
)
tokenRefreshService
:=
service
.
ProvideTokenRefreshService
(
accountRepository
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
,
antigravityOAuthService
,
configConfig
)
tokenRefreshService
:=
service
.
ProvideTokenRefreshService
(
accountRepository
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
,
antigravityOAuthService
,
tokenCacheInvalidator
,
configConfig
)
accountExpiryService
:=
service
.
ProvideAccountExpiryService
(
accountRepository
)
accountExpiryService
:=
service
.
ProvideAccountExpiryService
(
accountRepository
)
v
:=
provideCleanup
(
client
,
redisClient
,
opsMetricsCollector
,
opsAggregationService
,
opsAlertEvaluatorService
,
opsCleanupService
,
opsScheduledReportService
,
schedulerSnapshotService
,
tokenRefreshService
,
accountExpiryService
,
pricingService
,
emailQueueService
,
billingCacheService
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
,
antigravityOAuthService
)
v
:=
provideCleanup
(
client
,
redisClient
,
opsMetricsCollector
,
opsAggregationService
,
opsAlertEvaluatorService
,
opsCleanupService
,
opsScheduledReportService
,
schedulerSnapshotService
,
tokenRefreshService
,
accountExpiryService
,
pricingService
,
emailQueueService
,
billingCacheService
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
,
antigravityOAuthService
)
application
:=
&
Application
{
application
:=
&
Application
{
...
...
backend/internal/config/config.go
View file @
daf10907
...
@@ -436,6 +436,7 @@ type DefaultConfig struct {
...
@@ -436,6 +436,7 @@ type DefaultConfig struct {
type
RateLimitConfig
struct
{
type
RateLimitConfig
struct
{
OverloadCooldownMinutes
int
`mapstructure:"overload_cooldown_minutes"`
// 529过载冷却时间(分钟)
OverloadCooldownMinutes
int
`mapstructure:"overload_cooldown_minutes"`
// 529过载冷却时间(分钟)
OAuth401CooldownMinutes
int
`mapstructure:"oauth_401_cooldown_minutes"`
// OAuth 401 临时不可调度冷却时间(分钟)
}
}
// APIKeyAuthCacheConfig API Key 认证缓存配置
// APIKeyAuthCacheConfig API Key 认证缓存配置
...
@@ -709,6 +710,7 @@ func setDefaults() {
...
@@ -709,6 +710,7 @@ func setDefaults() {
// RateLimit
// RateLimit
viper
.
SetDefault
(
"rate_limit.overload_cooldown_minutes"
,
10
)
viper
.
SetDefault
(
"rate_limit.overload_cooldown_minutes"
,
10
)
viper
.
SetDefault
(
"rate_limit.oauth_401_cooldown_minutes"
,
5
)
// Pricing - 从 price-mirror 分支同步,该分支维护了 sha256 哈希文件用于增量更新检查
// Pricing - 从 price-mirror 分支同步,该分支维护了 sha256 哈希文件用于增量更新检查
viper
.
SetDefault
(
"pricing.remote_url"
,
"https://raw.githubusercontent.com/Wei-Shaw/claude-relay-service/price-mirror/model_prices_and_context_window.json"
)
viper
.
SetDefault
(
"pricing.remote_url"
,
"https://raw.githubusercontent.com/Wei-Shaw/claude-relay-service/price-mirror/model_prices_and_context_window.json"
)
...
...
backend/internal/repository/gemini_token_cache.go
View file @
daf10907
...
@@ -33,6 +33,11 @@ func (c *geminiTokenCache) SetAccessToken(ctx context.Context, cacheKey string,
...
@@ -33,6 +33,11 @@ func (c *geminiTokenCache) SetAccessToken(ctx context.Context, cacheKey string,
return
c
.
rdb
.
Set
(
ctx
,
key
,
token
,
ttl
)
.
Err
()
return
c
.
rdb
.
Set
(
ctx
,
key
,
token
,
ttl
)
.
Err
()
}
}
func
(
c
*
geminiTokenCache
)
DeleteAccessToken
(
ctx
context
.
Context
,
cacheKey
string
)
error
{
key
:=
fmt
.
Sprintf
(
"%s%s"
,
geminiTokenKeyPrefix
,
cacheKey
)
return
c
.
rdb
.
Del
(
ctx
,
key
)
.
Err
()
}
func
(
c
*
geminiTokenCache
)
AcquireRefreshLock
(
ctx
context
.
Context
,
cacheKey
string
,
ttl
time
.
Duration
)
(
bool
,
error
)
{
func
(
c
*
geminiTokenCache
)
AcquireRefreshLock
(
ctx
context
.
Context
,
cacheKey
string
,
ttl
time
.
Duration
)
(
bool
,
error
)
{
key
:=
fmt
.
Sprintf
(
"%s%s"
,
geminiRefreshLockKeyPrefix
,
cacheKey
)
key
:=
fmt
.
Sprintf
(
"%s%s"
,
geminiRefreshLockKeyPrefix
,
cacheKey
)
return
c
.
rdb
.
SetNX
(
ctx
,
key
,
1
,
ttl
)
.
Result
()
return
c
.
rdb
.
SetNX
(
ctx
,
key
,
1
,
ttl
)
.
Result
()
...
...
backend/internal/repository/gemini_token_cache_integration_test.go
0 → 100644
View file @
daf10907
//go:build integration
package
repository
import
(
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type
GeminiTokenCacheSuite
struct
{
IntegrationRedisSuite
cache
service
.
GeminiTokenCache
}
func
(
s
*
GeminiTokenCacheSuite
)
SetupTest
()
{
s
.
IntegrationRedisSuite
.
SetupTest
()
s
.
cache
=
NewGeminiTokenCache
(
s
.
rdb
)
}
func
(
s
*
GeminiTokenCacheSuite
)
TestDeleteAccessToken
()
{
cacheKey
:=
"project-123"
token
:=
"token-value"
require
.
NoError
(
s
.
T
(),
s
.
cache
.
SetAccessToken
(
s
.
ctx
,
cacheKey
,
token
,
time
.
Minute
))
got
,
err
:=
s
.
cache
.
GetAccessToken
(
s
.
ctx
,
cacheKey
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
Equal
(
s
.
T
(),
token
,
got
)
require
.
NoError
(
s
.
T
(),
s
.
cache
.
DeleteAccessToken
(
s
.
ctx
,
cacheKey
))
_
,
err
=
s
.
cache
.
GetAccessToken
(
s
.
ctx
,
cacheKey
)
require
.
True
(
s
.
T
(),
errors
.
Is
(
err
,
redis
.
Nil
),
"expected redis.Nil after delete"
)
}
func
(
s
*
GeminiTokenCacheSuite
)
TestDeleteAccessToken_MissingKey
()
{
require
.
NoError
(
s
.
T
(),
s
.
cache
.
DeleteAccessToken
(
s
.
ctx
,
"missing-key"
))
}
func
TestGeminiTokenCacheSuite
(
t
*
testing
.
T
)
{
suite
.
Run
(
t
,
new
(
GeminiTokenCacheSuite
))
}
backend/internal/repository/gemini_token_cache_test.go
0 → 100644
View file @
daf10907
//go:build unit
package
repository
import
(
"context"
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
)
func
TestGeminiTokenCache_DeleteAccessToken_RedisError
(
t
*
testing
.
T
)
{
rdb
:=
redis
.
NewClient
(
&
redis
.
Options
{
Addr
:
"127.0.0.1:1"
,
DialTimeout
:
50
*
time
.
Millisecond
,
ReadTimeout
:
50
*
time
.
Millisecond
,
WriteTimeout
:
50
*
time
.
Millisecond
,
})
t
.
Cleanup
(
func
()
{
_
=
rdb
.
Close
()
})
cache
:=
NewGeminiTokenCache
(
rdb
)
err
:=
cache
.
DeleteAccessToken
(
context
.
Background
(),
"broken"
)
require
.
Error
(
t
,
err
)
}
backend/internal/service/antigravity_token_provider.go
View file @
daf10907
...
@@ -45,7 +45,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
...
@@ -45,7 +45,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
return
""
,
errors
.
New
(
"not an antigravity oauth account"
)
return
""
,
errors
.
New
(
"not an antigravity oauth account"
)
}
}
cacheKey
:=
a
ntigravityTokenCacheKey
(
account
)
cacheKey
:=
A
ntigravityTokenCacheKey
(
account
)
// 1. 先尝试缓存
// 1. 先尝试缓存
if
p
.
tokenCache
!=
nil
{
if
p
.
tokenCache
!=
nil
{
...
@@ -121,7 +121,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
...
@@ -121,7 +121,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
return
accessToken
,
nil
return
accessToken
,
nil
}
}
func
a
ntigravityTokenCacheKey
(
account
*
Account
)
string
{
func
A
ntigravityTokenCacheKey
(
account
*
Account
)
string
{
projectID
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"project_id"
))
projectID
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"project_id"
))
if
projectID
!=
""
{
if
projectID
!=
""
{
return
"ag:"
+
projectID
return
"ag:"
+
projectID
...
...
backend/internal/service/gemini_token_cache.go
View file @
daf10907
...
@@ -10,6 +10,7 @@ type GeminiTokenCache interface {
...
@@ -10,6 +10,7 @@ type GeminiTokenCache interface {
// cacheKey should be stable for the token scope; for GeminiCli OAuth we primarily use project_id.
// cacheKey should be stable for the token scope; for GeminiCli OAuth we primarily use project_id.
GetAccessToken
(
ctx
context
.
Context
,
cacheKey
string
)
(
string
,
error
)
GetAccessToken
(
ctx
context
.
Context
,
cacheKey
string
)
(
string
,
error
)
SetAccessToken
(
ctx
context
.
Context
,
cacheKey
string
,
token
string
,
ttl
time
.
Duration
)
error
SetAccessToken
(
ctx
context
.
Context
,
cacheKey
string
,
token
string
,
ttl
time
.
Duration
)
error
DeleteAccessToken
(
ctx
context
.
Context
,
cacheKey
string
)
error
AcquireRefreshLock
(
ctx
context
.
Context
,
cacheKey
string
,
ttl
time
.
Duration
)
(
bool
,
error
)
AcquireRefreshLock
(
ctx
context
.
Context
,
cacheKey
string
,
ttl
time
.
Duration
)
(
bool
,
error
)
ReleaseRefreshLock
(
ctx
context
.
Context
,
cacheKey
string
)
error
ReleaseRefreshLock
(
ctx
context
.
Context
,
cacheKey
string
)
error
...
...
backend/internal/service/gemini_token_provider.go
View file @
daf10907
...
@@ -40,7 +40,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
...
@@ -40,7 +40,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
return
""
,
errors
.
New
(
"not a gemini oauth account"
)
return
""
,
errors
.
New
(
"not a gemini oauth account"
)
}
}
cacheKey
:=
g
eminiTokenCacheKey
(
account
)
cacheKey
:=
G
eminiTokenCacheKey
(
account
)
// 1) Try cache first.
// 1) Try cache first.
if
p
.
tokenCache
!=
nil
{
if
p
.
tokenCache
!=
nil
{
...
@@ -151,7 +151,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
...
@@ -151,7 +151,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
return
accessToken
,
nil
return
accessToken
,
nil
}
}
func
g
eminiTokenCacheKey
(
account
*
Account
)
string
{
func
G
eminiTokenCacheKey
(
account
*
Account
)
string
{
projectID
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"project_id"
))
projectID
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"project_id"
))
if
projectID
!=
""
{
if
projectID
!=
""
{
return
projectID
return
projectID
...
...
backend/internal/service/ratelimit_service.go
View file @
daf10907
...
@@ -3,7 +3,7 @@ package service
...
@@ -3,7 +3,7 @@ package service
import
(
import
(
"context"
"context"
"encoding/json"
"encoding/json"
"log"
"log
/slog
"
"net/http"
"net/http"
"strconv"
"strconv"
"strings"
"strings"
...
@@ -22,6 +22,7 @@ type RateLimitService struct {
...
@@ -22,6 +22,7 @@ type RateLimitService struct {
tempUnschedCache
TempUnschedCache
tempUnschedCache
TempUnschedCache
timeoutCounterCache
TimeoutCounterCache
timeoutCounterCache
TimeoutCounterCache
settingService
*
SettingService
settingService
*
SettingService
tokenCacheInvalidator
TokenCacheInvalidator
usageCacheMu
sync
.
RWMutex
usageCacheMu
sync
.
RWMutex
usageCache
map
[
int64
]
*
geminiUsageCacheEntry
usageCache
map
[
int64
]
*
geminiUsageCacheEntry
}
}
...
@@ -56,6 +57,11 @@ func (s *RateLimitService) SetSettingService(settingService *SettingService) {
...
@@ -56,6 +57,11 @@ func (s *RateLimitService) SetSettingService(settingService *SettingService) {
s
.
settingService
=
settingService
s
.
settingService
=
settingService
}
}
// SetTokenCacheInvalidator 设置 token 缓存清理器(可选依赖)
func
(
s
*
RateLimitService
)
SetTokenCacheInvalidator
(
invalidator
TokenCacheInvalidator
)
{
s
.
tokenCacheInvalidator
=
invalidator
}
// HandleUpstreamError 处理上游错误响应,标记账号状态
// HandleUpstreamError 处理上游错误响应,标记账号状态
// 返回是否应该停止该账号的调度
// 返回是否应该停止该账号的调度
func
(
s
*
RateLimitService
)
HandleUpstreamError
(
ctx
context
.
Context
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
responseBody
[]
byte
)
(
shouldDisable
bool
)
{
func
(
s
*
RateLimitService
)
HandleUpstreamError
(
ctx
context
.
Context
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
responseBody
[]
byte
)
(
shouldDisable
bool
)
{
...
@@ -63,11 +69,16 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
...
@@ -63,11 +69,16 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
// 如果启用且错误码不在列表中,则不处理(不停止调度、不标记限流/过载)
// 如果启用且错误码不在列表中,则不处理(不停止调度、不标记限流/过载)
customErrorCodesEnabled
:=
account
.
IsCustomErrorCodesEnabled
()
customErrorCodesEnabled
:=
account
.
IsCustomErrorCodesEnabled
()
if
!
account
.
ShouldHandleErrorCode
(
statusCode
)
{
if
!
account
.
ShouldHandleErrorCode
(
statusCode
)
{
log
.
Printf
(
"
A
ccount
%d: error %d skipped (not in custom error codes)"
,
account
.
ID
,
statusCode
)
s
log
.
Info
(
"
a
ccount
_error_code_skipped"
,
"account_id"
,
account
.
ID
,
"status_code"
,
statusCode
)
return
false
return
false
}
}
tempMatched
:=
s
.
tryTempUnschedulable
(
ctx
,
account
,
statusCode
,
responseBody
)
isOAuth401
:=
statusCode
==
401
&&
account
.
Type
==
AccountTypeOAuth
&&
(
account
.
Platform
==
PlatformAntigravity
||
account
.
Platform
==
PlatformGemini
)
tempMatched
:=
false
if
!
isOAuth401
||
account
.
IsTempUnschedulableEnabled
()
{
tempMatched
=
s
.
tryTempUnschedulable
(
ctx
,
account
,
statusCode
,
responseBody
)
}
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
responseBody
))
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
responseBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
if
upstreamMsg
!=
""
{
if
upstreamMsg
!=
""
{
...
@@ -76,7 +87,19 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
...
@@ -76,7 +87,19 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
switch
statusCode
{
switch
statusCode
{
case
401
:
case
401
:
// 认证失败:停止调度,记录错误
if
isOAuth401
{
if
tempMatched
{
if
s
.
tokenCacheInvalidator
!=
nil
{
if
err
:=
s
.
tokenCacheInvalidator
.
InvalidateToken
(
ctx
,
account
);
err
!=
nil
{
slog
.
Warn
(
"oauth_401_invalidate_cache_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
}
shouldDisable
=
true
}
else
{
shouldDisable
=
s
.
handleOAuth401TempUnschedulable
(
ctx
,
account
,
upstreamMsg
)
}
break
}
msg
:=
"Authentication failed (401): invalid or expired credentials"
msg
:=
"Authentication failed (401): invalid or expired credentials"
if
upstreamMsg
!=
""
{
if
upstreamMsg
!=
""
{
msg
=
"Authentication failed (401): "
+
upstreamMsg
msg
=
"Authentication failed (401): "
+
upstreamMsg
...
@@ -116,7 +139,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
...
@@ -116,7 +139,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
shouldDisable
=
true
shouldDisable
=
true
}
else
if
statusCode
>=
500
{
}
else
if
statusCode
>=
500
{
// 未启用自定义错误码时:仅记录5xx错误
// 未启用自定义错误码时:仅记录5xx错误
log
.
Printf
(
"
A
ccount
%d received
upstream
error
%
d"
,
account
.
ID
,
statusCode
)
s
log
.
Warn
(
"
a
ccount
_
upstream
_
error
"
,
"account_i
d"
,
account
.
ID
,
"status_code"
,
statusCode
)
shouldDisable
=
false
shouldDisable
=
false
}
}
}
}
...
@@ -127,6 +150,63 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
...
@@ -127,6 +150,63 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
return
shouldDisable
return
shouldDisable
}
}
func
(
s
*
RateLimitService
)
handleOAuth401TempUnschedulable
(
ctx
context
.
Context
,
account
*
Account
,
upstreamMsg
string
)
bool
{
if
account
==
nil
{
return
false
}
if
s
.
tokenCacheInvalidator
!=
nil
{
if
err
:=
s
.
tokenCacheInvalidator
.
InvalidateToken
(
ctx
,
account
);
err
!=
nil
{
slog
.
Warn
(
"oauth_401_invalidate_cache_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
}
now
:=
time
.
Now
()
until
:=
now
.
Add
(
s
.
oauth401Cooldown
())
msg
:=
"Authentication failed (401): invalid or expired credentials"
if
upstreamMsg
!=
""
{
msg
=
"Authentication failed (401): "
+
upstreamMsg
}
state
:=
&
TempUnschedState
{
UntilUnix
:
until
.
Unix
(),
TriggeredAtUnix
:
now
.
Unix
(),
StatusCode
:
401
,
MatchedKeyword
:
"oauth_401"
,
RuleIndex
:
-
1
,
// -1 表示非规则触发,而是 OAuth 401 特殊处理
ErrorMessage
:
msg
,
}
reason
:=
""
if
raw
,
err
:=
json
.
Marshal
(
state
);
err
==
nil
{
reason
=
string
(
raw
)
}
if
reason
==
""
{
reason
=
msg
}
if
err
:=
s
.
accountRepo
.
SetTempUnschedulable
(
ctx
,
account
.
ID
,
until
,
reason
);
err
!=
nil
{
slog
.
Warn
(
"oauth_401_set_temp_unschedulable_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
return
false
}
if
s
.
tempUnschedCache
!=
nil
{
if
err
:=
s
.
tempUnschedCache
.
SetTempUnsched
(
ctx
,
account
.
ID
,
state
);
err
!=
nil
{
slog
.
Warn
(
"oauth_401_set_temp_unsched_cache_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
}
slog
.
Info
(
"oauth_401_temp_unschedulable"
,
"account_id"
,
account
.
ID
,
"until"
,
until
)
return
true
}
func
(
s
*
RateLimitService
)
oauth401Cooldown
()
time
.
Duration
{
if
s
!=
nil
&&
s
.
cfg
!=
nil
&&
s
.
cfg
.
RateLimit
.
OAuth401CooldownMinutes
>
0
{
return
time
.
Duration
(
s
.
cfg
.
RateLimit
.
OAuth401CooldownMinutes
)
*
time
.
Minute
}
return
5
*
time
.
Minute
}
// PreCheckUsage proactively checks local quota before dispatching a request.
// PreCheckUsage proactively checks local quota before dispatching a request.
// Returns false when the account should be skipped.
// Returns false when the account should be skipped.
func
(
s
*
RateLimitService
)
PreCheckUsage
(
ctx
context
.
Context
,
account
*
Account
,
requestedModel
string
)
(
bool
,
error
)
{
func
(
s
*
RateLimitService
)
PreCheckUsage
(
ctx
context
.
Context
,
account
*
Account
,
requestedModel
string
)
(
bool
,
error
)
{
...
@@ -188,7 +268,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
...
@@ -188,7 +268,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
// NOTE:
// NOTE:
// - This is a local precheck to reduce upstream 429s.
// - This is a local precheck to reduce upstream 429s.
// - Do NOT mark the account as rate-limited here; rate_limit_reset_at should reflect real upstream 429s.
// - Do NOT mark the account as rate-limited here; rate_limit_reset_at should reflect real upstream 429s.
log
.
Printf
(
"[G
emini
P
re
C
heck
] Account %d reached
daily
quota
(%d/%d), skip until %v
"
,
account
.
ID
,
used
,
limit
,
resetAt
)
s
log
.
Info
(
"g
emini
_p
re
c
heck
_
daily
_
quota
_reached"
,
"account_id
"
,
account
.
ID
,
"
used
"
,
used
,
"limit"
,
limit
,
"reset_at"
,
resetAt
)
return
false
,
nil
return
false
,
nil
}
}
}
}
...
@@ -231,7 +311,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
...
@@ -231,7 +311,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
if
used
>=
limit
{
if
used
>=
limit
{
resetAt
:=
start
.
Add
(
time
.
Minute
)
resetAt
:=
start
.
Add
(
time
.
Minute
)
// Do not persist "rate limited" status from local precheck. See note above.
// Do not persist "rate limited" status from local precheck. See note above.
log
.
Printf
(
"[G
emini
P
re
C
heck
] Account %d reached
minute
quota
(%d/%d), skip until %v
"
,
account
.
ID
,
used
,
limit
,
resetAt
)
s
log
.
Info
(
"g
emini
_p
re
c
heck
_
minute
_
quota
_reached"
,
"account_id
"
,
account
.
ID
,
"
used
"
,
used
,
"limit"
,
limit
,
"reset_at"
,
resetAt
)
return
false
,
nil
return
false
,
nil
}
}
}
}
...
@@ -288,20 +368,20 @@ func (s *RateLimitService) GeminiCooldown(ctx context.Context, account *Account)
...
@@ -288,20 +368,20 @@ func (s *RateLimitService) GeminiCooldown(ctx context.Context, account *Account)
// handleAuthError 处理认证类错误(401/403),停止账号调度
// handleAuthError 处理认证类错误(401/403),停止账号调度
func
(
s
*
RateLimitService
)
handleAuthError
(
ctx
context
.
Context
,
account
*
Account
,
errorMsg
string
)
{
func
(
s
*
RateLimitService
)
handleAuthError
(
ctx
context
.
Context
,
account
*
Account
,
errorMsg
string
)
{
if
err
:=
s
.
accountRepo
.
SetError
(
ctx
,
account
.
ID
,
errorMsg
);
err
!=
nil
{
if
err
:=
s
.
accountRepo
.
SetError
(
ctx
,
account
.
ID
,
errorMsg
);
err
!=
nil
{
log
.
Printf
(
"SetE
rror
failed
for
account
%d: %v
"
,
account
.
ID
,
err
)
s
log
.
Warn
(
"account_set_e
rror
_
failed
"
,
"
account
_id
"
,
account
.
ID
,
"error"
,
err
)
return
return
}
}
log
.
Printf
(
"
A
ccount
%d
disabled
due to
auth
error
: %s
"
,
account
.
ID
,
errorMsg
)
s
log
.
Warn
(
"
a
ccount
_
disabled
_
auth
_
error
"
,
"account_id
"
,
account
.
ID
,
"error"
,
errorMsg
)
}
}
// handleCustomErrorCode 处理自定义错误码,停止账号调度
// handleCustomErrorCode 处理自定义错误码,停止账号调度
func
(
s
*
RateLimitService
)
handleCustomErrorCode
(
ctx
context
.
Context
,
account
*
Account
,
statusCode
int
,
errorMsg
string
)
{
func
(
s
*
RateLimitService
)
handleCustomErrorCode
(
ctx
context
.
Context
,
account
*
Account
,
statusCode
int
,
errorMsg
string
)
{
msg
:=
"Custom error code "
+
strconv
.
Itoa
(
statusCode
)
+
": "
+
errorMsg
msg
:=
"Custom error code "
+
strconv
.
Itoa
(
statusCode
)
+
": "
+
errorMsg
if
err
:=
s
.
accountRepo
.
SetError
(
ctx
,
account
.
ID
,
msg
);
err
!=
nil
{
if
err
:=
s
.
accountRepo
.
SetError
(
ctx
,
account
.
ID
,
msg
);
err
!=
nil
{
log
.
Printf
(
"SetE
rror
failed
for
account
%d: %v
"
,
account
.
ID
,
err
)
s
log
.
Warn
(
"account_set_e
rror
_
failed
"
,
"
account
_id
"
,
account
.
ID
,
"status_code"
,
statusCode
,
"error"
,
err
)
return
return
}
}
log
.
Printf
(
"
A
ccount
%d
disabled
due to
custom
error
code %d: %s
"
,
account
.
ID
,
statusCode
,
errorMsg
)
s
log
.
Warn
(
"
a
ccount
_
disabled
_
custom
_
error
"
,
"account_id
"
,
account
.
ID
,
"status_code"
,
statusCode
,
"error"
,
errorMsg
)
}
}
// handle429 处理429限流错误
// handle429 处理429限流错误
...
@@ -313,7 +393,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
...
@@ -313,7 +393,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
// 没有重置时间,使用默认5分钟
// 没有重置时间,使用默认5分钟
resetAt
:=
time
.
Now
()
.
Add
(
5
*
time
.
Minute
)
resetAt
:=
time
.
Now
()
.
Add
(
5
*
time
.
Minute
)
if
err
:=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
resetAt
);
err
!=
nil
{
if
err
:=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
resetAt
);
err
!=
nil
{
log
.
Printf
(
"SetR
ate
L
imit
ed
failed
for
account
%d: %v
"
,
account
.
ID
,
err
)
s
log
.
Warn
(
"r
ate
_l
imit
_set_
failed
"
,
"
account
_id
"
,
account
.
ID
,
"error"
,
err
)
}
}
return
return
}
}
...
@@ -321,10 +401,10 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
...
@@ -321,10 +401,10 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
// 解析Unix时间戳
// 解析Unix时间戳
ts
,
err
:=
strconv
.
ParseInt
(
resetTimestamp
,
10
,
64
)
ts
,
err
:=
strconv
.
ParseInt
(
resetTimestamp
,
10
,
64
)
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Printf
(
"Parse
reset
timestamp
failed: %v
"
,
err
)
s
log
.
Warn
(
"rate_limit_reset_parse_failed"
,
"
reset
_
timestamp
"
,
resetTimestamp
,
"error
"
,
err
)
resetAt
:=
time
.
Now
()
.
Add
(
5
*
time
.
Minute
)
resetAt
:=
time
.
Now
()
.
Add
(
5
*
time
.
Minute
)
if
err
:=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
resetAt
);
err
!=
nil
{
if
err
:=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
resetAt
);
err
!=
nil
{
log
.
Printf
(
"SetR
ate
L
imit
ed
failed
for
account
%d: %v
"
,
account
.
ID
,
err
)
s
log
.
Warn
(
"r
ate
_l
imit
_set_
failed
"
,
"
account
_id
"
,
account
.
ID
,
"error"
,
err
)
}
}
return
return
}
}
...
@@ -333,7 +413,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
...
@@ -333,7 +413,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
// 标记限流状态
// 标记限流状态
if
err
:=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
resetAt
);
err
!=
nil
{
if
err
:=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
resetAt
);
err
!=
nil
{
log
.
Printf
(
"SetR
ate
L
imit
ed
failed
for
account
%d: %v
"
,
account
.
ID
,
err
)
s
log
.
Warn
(
"r
ate
_l
imit
_set_
failed
"
,
"
account
_id
"
,
account
.
ID
,
"error"
,
err
)
return
return
}
}
...
@@ -341,10 +421,10 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
...
@@ -341,10 +421,10 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
windowEnd
:=
resetAt
windowEnd
:=
resetAt
windowStart
:=
resetAt
.
Add
(
-
5
*
time
.
Hour
)
windowStart
:=
resetAt
.
Add
(
-
5
*
time
.
Hour
)
if
err
:=
s
.
accountRepo
.
UpdateSessionWindow
(
ctx
,
account
.
ID
,
&
windowStart
,
&
windowEnd
,
"rejected"
);
err
!=
nil
{
if
err
:=
s
.
accountRepo
.
UpdateSessionWindow
(
ctx
,
account
.
ID
,
&
windowStart
,
&
windowEnd
,
"rejected"
);
err
!=
nil
{
log
.
Printf
(
"U
pdate
S
ession
W
indow
failed
for
account
%d: %v
"
,
account
.
ID
,
err
)
s
log
.
Warn
(
"rate_limit_u
pdate
_s
ession
_w
indow
_
failed
"
,
"
account
_id
"
,
account
.
ID
,
"error"
,
err
)
}
}
log
.
Printf
(
"
A
ccount
%d
rate
limited
until %v
"
,
account
.
ID
,
resetAt
)
s
log
.
Info
(
"
a
ccount
_
rate
_
limited
"
,
"account_id
"
,
account
.
ID
,
"reset_at"
,
resetAt
)
}
}
// handle529 处理529过载错误
// handle529 处理529过载错误
...
@@ -357,11 +437,11 @@ func (s *RateLimitService) handle529(ctx context.Context, account *Account) {
...
@@ -357,11 +437,11 @@ func (s *RateLimitService) handle529(ctx context.Context, account *Account) {
until
:=
time
.
Now
()
.
Add
(
time
.
Duration
(
cooldownMinutes
)
*
time
.
Minute
)
until
:=
time
.
Now
()
.
Add
(
time
.
Duration
(
cooldownMinutes
)
*
time
.
Minute
)
if
err
:=
s
.
accountRepo
.
SetOverloaded
(
ctx
,
account
.
ID
,
until
);
err
!=
nil
{
if
err
:=
s
.
accountRepo
.
SetOverloaded
(
ctx
,
account
.
ID
,
until
);
err
!=
nil
{
log
.
Printf
(
"SetO
verload
ed
failed
for
account
%d: %v
"
,
account
.
ID
,
err
)
s
log
.
Warn
(
"o
verload
_set_
failed
"
,
"
account
_id
"
,
account
.
ID
,
"error"
,
err
)
return
return
}
}
log
.
Printf
(
"
A
ccount
%d
overloaded
until %v
"
,
account
.
ID
,
until
)
s
log
.
Info
(
"
a
ccount
_
overloaded
"
,
"account_id
"
,
account
.
ID
,
"until"
,
until
)
}
}
// UpdateSessionWindow 从成功响应更新5h窗口状态
// UpdateSessionWindow 从成功响应更新5h窗口状态
...
@@ -384,17 +464,17 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc
...
@@ -384,17 +464,17 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc
end
:=
start
.
Add
(
5
*
time
.
Hour
)
end
:=
start
.
Add
(
5
*
time
.
Hour
)
windowStart
=
&
start
windowStart
=
&
start
windowEnd
=
&
end
windowEnd
=
&
end
log
.
Printf
(
"
A
ccount
%d: initializing 5h window from %v to %v (status: %s)
"
,
account
.
ID
,
start
,
end
,
status
)
s
log
.
Info
(
"
a
ccount
_session_window_initialized"
,
"account_id
"
,
account
.
ID
,
"window_
start
"
,
start
,
"window_end"
,
end
,
"status"
,
status
)
}
}
if
err
:=
s
.
accountRepo
.
UpdateSessionWindow
(
ctx
,
account
.
ID
,
windowStart
,
windowEnd
,
status
);
err
!=
nil
{
if
err
:=
s
.
accountRepo
.
UpdateSessionWindow
(
ctx
,
account
.
ID
,
windowStart
,
windowEnd
,
status
);
err
!=
nil
{
log
.
Printf
(
"UpdateS
ession
W
indow
failed
for
account
%d: %v
"
,
account
.
ID
,
err
)
s
log
.
Warn
(
"s
ession
_w
indow
_update_
failed
"
,
"
account
_id
"
,
account
.
ID
,
"error"
,
err
)
}
}
// 如果状态为allowed且之前有限流,说明窗口已重置,清除限流状态
// 如果状态为allowed且之前有限流,说明窗口已重置,清除限流状态
if
status
==
"allowed"
&&
account
.
IsRateLimited
()
{
if
status
==
"allowed"
&&
account
.
IsRateLimited
()
{
if
err
:=
s
.
ClearRateLimit
(
ctx
,
account
.
ID
);
err
!=
nil
{
if
err
:=
s
.
ClearRateLimit
(
ctx
,
account
.
ID
);
err
!=
nil
{
log
.
Printf
(
"ClearR
ate
L
imit
failed
for
account
%d: %v
"
,
account
.
ID
,
err
)
s
log
.
Warn
(
"r
ate
_l
imit
_clear_
failed
"
,
"
account
_id
"
,
account
.
ID
,
"error"
,
err
)
}
}
}
}
}
}
...
@@ -413,7 +493,7 @@ func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID
...
@@ -413,7 +493,7 @@ func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID
}
}
if
s
.
tempUnschedCache
!=
nil
{
if
s
.
tempUnschedCache
!=
nil
{
if
err
:=
s
.
tempUnschedCache
.
DeleteTempUnsched
(
ctx
,
accountID
);
err
!=
nil
{
if
err
:=
s
.
tempUnschedCache
.
DeleteTempUnsched
(
ctx
,
accountID
);
err
!=
nil
{
log
.
Printf
(
"DeleteTempUnsched
failed
for
account
%d: %v
"
,
accountID
,
err
)
s
log
.
Warn
(
"temp_unsched_cache_delete_
failed
"
,
"
account
_id
"
,
accountID
,
"error"
,
err
)
}
}
}
}
return
nil
return
nil
...
@@ -460,7 +540,7 @@ func (s *RateLimitService) GetTempUnschedStatus(ctx context.Context, accountID i
...
@@ -460,7 +540,7 @@ func (s *RateLimitService) GetTempUnschedStatus(ctx context.Context, accountID i
if
s
.
tempUnschedCache
!=
nil
{
if
s
.
tempUnschedCache
!=
nil
{
if
err
:=
s
.
tempUnschedCache
.
SetTempUnsched
(
ctx
,
accountID
,
state
);
err
!=
nil
{
if
err
:=
s
.
tempUnschedCache
.
SetTempUnsched
(
ctx
,
accountID
,
state
);
err
!=
nil
{
log
.
Printf
(
"SetT
emp
U
nsched
failed
for
account
%d: %v
"
,
accountID
,
err
)
s
log
.
Warn
(
"t
emp
_u
nsched
_cache_set_
failed
"
,
"
account
_id
"
,
accountID
,
"error"
,
err
)
}
}
}
}
...
@@ -563,17 +643,17 @@ func (s *RateLimitService) triggerTempUnschedulable(ctx context.Context, account
...
@@ -563,17 +643,17 @@ func (s *RateLimitService) triggerTempUnschedulable(ctx context.Context, account
}
}
if
err
:=
s
.
accountRepo
.
SetTempUnschedulable
(
ctx
,
account
.
ID
,
until
,
reason
);
err
!=
nil
{
if
err
:=
s
.
accountRepo
.
SetTempUnschedulable
(
ctx
,
account
.
ID
,
until
,
reason
);
err
!=
nil
{
log
.
Printf
(
"SetT
emp
U
nsched
ulable
failed
for
account
%d: %v
"
,
account
.
ID
,
err
)
s
log
.
Warn
(
"t
emp
_u
nsched
_set_
failed
"
,
"
account
_id
"
,
account
.
ID
,
"error"
,
err
)
return
false
return
false
}
}
if
s
.
tempUnschedCache
!=
nil
{
if
s
.
tempUnschedCache
!=
nil
{
if
err
:=
s
.
tempUnschedCache
.
SetTempUnsched
(
ctx
,
account
.
ID
,
state
);
err
!=
nil
{
if
err
:=
s
.
tempUnschedCache
.
SetTempUnsched
(
ctx
,
account
.
ID
,
state
);
err
!=
nil
{
log
.
Printf
(
"SetT
emp
U
nsched
cache
failed
for
account
%d: %v
"
,
account
.
ID
,
err
)
s
log
.
Warn
(
"t
emp
_u
nsched
_
cache
_set_
failed
"
,
"
account
_id
"
,
account
.
ID
,
"error"
,
err
)
}
}
}
}
log
.
Printf
(
"
A
ccount
%d
temp
unschedulable
until %v (rule %d, code %d)
"
,
account
.
ID
,
until
,
ruleIndex
,
statusCode
)
s
log
.
Info
(
"
a
ccount
_
temp
_
unschedulable
"
,
"account_id
"
,
account
.
ID
,
"
until
"
,
until
,
"rule_index"
,
ruleIndex
,
"status_code"
,
statusCode
)
return
true
return
true
}
}
...
@@ -597,13 +677,13 @@ func (s *RateLimitService) HandleStreamTimeout(ctx context.Context, account *Acc
...
@@ -597,13 +677,13 @@ func (s *RateLimitService) HandleStreamTimeout(ctx context.Context, account *Acc
// 获取系统设置
// 获取系统设置
if
s
.
settingService
==
nil
{
if
s
.
settingService
==
nil
{
log
.
Printf
(
"[S
tream
T
imeout
]
setting
S
ervice
not configured, skipping timeout handling for
account
%
d"
,
account
.
ID
)
s
log
.
Warn
(
"s
tream
_t
imeout
_
setting
_s
ervice
_missing"
,
"
account
_i
d"
,
account
.
ID
)
return
false
return
false
}
}
settings
,
err
:=
s
.
settingService
.
GetStreamTimeoutSettings
(
ctx
)
settings
,
err
:=
s
.
settingService
.
GetStreamTimeoutSettings
(
ctx
)
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Printf
(
"[S
tream
T
imeout
] Failed to
get
settings
: %v
"
,
err
)
s
log
.
Warn
(
"s
tream
_t
imeout
_
get
_
settings
_failed"
,
"account_id"
,
account
.
ID
,
"error
"
,
err
)
return
false
return
false
}
}
...
@@ -620,14 +700,13 @@ func (s *RateLimitService) HandleStreamTimeout(ctx context.Context, account *Acc
...
@@ -620,14 +700,13 @@ func (s *RateLimitService) HandleStreamTimeout(ctx context.Context, account *Acc
if
s
.
timeoutCounterCache
!=
nil
{
if
s
.
timeoutCounterCache
!=
nil
{
count
,
err
=
s
.
timeoutCounterCache
.
IncrementTimeoutCount
(
ctx
,
account
.
ID
,
settings
.
ThresholdWindowMinutes
)
count
,
err
=
s
.
timeoutCounterCache
.
IncrementTimeoutCount
(
ctx
,
account
.
ID
,
settings
.
ThresholdWindowMinutes
)
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Printf
(
"[S
tream
T
imeout
] Failed to increment timeout count for account %d: %v
"
,
account
.
ID
,
err
)
s
log
.
Warn
(
"s
tream
_t
imeout
_increment_count_failed"
,
"account_id
"
,
account
.
ID
,
"error"
,
err
)
// 继续处理,使用 count=1
// 继续处理,使用 count=1
count
=
1
count
=
1
}
}
}
}
log
.
Printf
(
"[StreamTimeout] Account %d timeout count: %d/%d (window: %d min, model: %s)"
,
slog
.
Info
(
"stream_timeout_count"
,
"account_id"
,
account
.
ID
,
"count"
,
count
,
"threshold"
,
settings
.
ThresholdCount
,
"window_minutes"
,
settings
.
ThresholdWindowMinutes
,
"model"
,
model
)
account
.
ID
,
count
,
settings
.
ThresholdCount
,
settings
.
ThresholdWindowMinutes
,
model
)
// 检查是否达到阈值
// 检查是否达到阈值
if
count
<
int64
(
settings
.
ThresholdCount
)
{
if
count
<
int64
(
settings
.
ThresholdCount
)
{
...
@@ -668,24 +747,24 @@ func (s *RateLimitService) triggerStreamTimeoutTempUnsched(ctx context.Context,
...
@@ -668,24 +747,24 @@ func (s *RateLimitService) triggerStreamTimeoutTempUnsched(ctx context.Context,
}
}
if
err
:=
s
.
accountRepo
.
SetTempUnschedulable
(
ctx
,
account
.
ID
,
until
,
reason
);
err
!=
nil
{
if
err
:=
s
.
accountRepo
.
SetTempUnschedulable
(
ctx
,
account
.
ID
,
until
,
reason
);
err
!=
nil
{
log
.
Printf
(
"[S
tream
T
imeout
] SetT
emp
U
nsched
ulable
failed
for
account
%d: %v
"
,
account
.
ID
,
err
)
s
log
.
Warn
(
"s
tream
_t
imeout
_set_t
emp
_u
nsched
_
failed
"
,
"
account
_id
"
,
account
.
ID
,
"error"
,
err
)
return
false
return
false
}
}
if
s
.
tempUnschedCache
!=
nil
{
if
s
.
tempUnschedCache
!=
nil
{
if
err
:=
s
.
tempUnschedCache
.
SetTempUnsched
(
ctx
,
account
.
ID
,
state
);
err
!=
nil
{
if
err
:=
s
.
tempUnschedCache
.
SetTempUnsched
(
ctx
,
account
.
ID
,
state
);
err
!=
nil
{
log
.
Printf
(
"[S
tream
T
imeout
] SetT
emp
U
nsched
cache
failed
for
account
%d: %v
"
,
account
.
ID
,
err
)
s
log
.
Warn
(
"s
tream
_t
imeout
_set_t
emp
_u
nsched
_
cache
_
failed
"
,
"
account
_id
"
,
account
.
ID
,
"error"
,
err
)
}
}
}
}
// 重置超时计数
// 重置超时计数
if
s
.
timeoutCounterCache
!=
nil
{
if
s
.
timeoutCounterCache
!=
nil
{
if
err
:=
s
.
timeoutCounterCache
.
ResetTimeoutCount
(
ctx
,
account
.
ID
);
err
!=
nil
{
if
err
:=
s
.
timeoutCounterCache
.
ResetTimeoutCount
(
ctx
,
account
.
ID
);
err
!=
nil
{
log
.
Printf
(
"[S
tream
T
imeout
] R
eset
TimeoutC
ount
failed
for
account
%d: %v
"
,
account
.
ID
,
err
)
s
log
.
Warn
(
"s
tream
_t
imeout
_r
eset
_c
ount
_
failed
"
,
"
account
_id
"
,
account
.
ID
,
"error"
,
err
)
}
}
}
}
log
.
Printf
(
"[S
tream
T
imeout
] Account %d marked as
temp
unschedulable
until %v (model: %s)
"
,
account
.
ID
,
until
,
model
)
s
log
.
Info
(
"s
tream
_t
imeout
_
temp
_
unschedulable
"
,
"account_id
"
,
account
.
ID
,
"
until
"
,
until
,
"model"
,
model
)
return
true
return
true
}
}
...
@@ -694,17 +773,17 @@ func (s *RateLimitService) triggerStreamTimeoutError(ctx context.Context, accoun
...
@@ -694,17 +773,17 @@ func (s *RateLimitService) triggerStreamTimeoutError(ctx context.Context, accoun
errorMsg
:=
"Stream data interval timeout (repeated failures) for model: "
+
model
errorMsg
:=
"Stream data interval timeout (repeated failures) for model: "
+
model
if
err
:=
s
.
accountRepo
.
SetError
(
ctx
,
account
.
ID
,
errorMsg
);
err
!=
nil
{
if
err
:=
s
.
accountRepo
.
SetError
(
ctx
,
account
.
ID
,
errorMsg
);
err
!=
nil
{
log
.
Printf
(
"[S
tream
T
imeout
] SetE
rror
failed
for
account
%d: %v
"
,
account
.
ID
,
err
)
s
log
.
Warn
(
"s
tream
_t
imeout
_set_e
rror
_
failed
"
,
"
account
_id
"
,
account
.
ID
,
"error"
,
err
)
return
false
return
false
}
}
// 重置超时计数
// 重置超时计数
if
s
.
timeoutCounterCache
!=
nil
{
if
s
.
timeoutCounterCache
!=
nil
{
if
err
:=
s
.
timeoutCounterCache
.
ResetTimeoutCount
(
ctx
,
account
.
ID
);
err
!=
nil
{
if
err
:=
s
.
timeoutCounterCache
.
ResetTimeoutCount
(
ctx
,
account
.
ID
);
err
!=
nil
{
log
.
Printf
(
"[S
tream
T
imeout
] R
eset
TimeoutC
ount
failed
for
account
%d: %v
"
,
account
.
ID
,
err
)
s
log
.
Warn
(
"s
tream
_t
imeout
_r
eset
_c
ount
_
failed
"
,
"
account
_id
"
,
account
.
ID
,
"error"
,
err
)
}
}
}
}
log
.
Printf
(
"[S
tream
T
imeout
] A
ccount
%d marked as error (model: %s)
"
,
account
.
ID
,
model
)
s
log
.
Warn
(
"s
tream
_t
imeout
_a
ccount
_error"
,
"account_id
"
,
account
.
ID
,
"model"
,
model
)
return
true
return
true
}
}
backend/internal/service/ratelimit_service_401_test.go
0 → 100644
View file @
daf10907
//go:build unit
package
service
import
(
"context"
"errors"
"net/http"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type
rateLimitAccountRepoStub
struct
{
mockAccountRepoForGemini
tempCalls
int
tempUntil
time
.
Time
tempReason
string
setErrorCalls
int
}
func
(
r
*
rateLimitAccountRepoStub
)
SetTempUnschedulable
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
,
reason
string
)
error
{
r
.
tempCalls
++
r
.
tempUntil
=
until
r
.
tempReason
=
reason
return
nil
}
func
(
r
*
rateLimitAccountRepoStub
)
SetError
(
ctx
context
.
Context
,
id
int64
,
errorMsg
string
)
error
{
r
.
setErrorCalls
++
return
nil
}
type
tokenCacheInvalidatorRecorder
struct
{
accounts
[]
*
Account
err
error
}
func
(
r
*
tokenCacheInvalidatorRecorder
)
InvalidateToken
(
ctx
context
.
Context
,
account
*
Account
)
error
{
r
.
accounts
=
append
(
r
.
accounts
,
account
)
return
r
.
err
}
func
TestRateLimitService_HandleUpstreamError_OAuth401TempUnschedulable
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
platform
string
}{
{
name
:
"gemini"
,
platform
:
PlatformGemini
},
{
name
:
"antigravity"
,
platform
:
PlatformAntigravity
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
rateLimitAccountRepoStub
{}
invalidator
:=
&
tokenCacheInvalidatorRecorder
{}
service
:=
NewRateLimitService
(
repo
,
nil
,
&
config
.
Config
{},
nil
,
nil
)
service
.
SetTokenCacheInvalidator
(
invalidator
)
account
:=
&
Account
{
ID
:
100
,
Platform
:
tt
.
platform
,
Type
:
AccountTypeOAuth
,
}
start
:=
time
.
Now
()
shouldDisable
:=
service
.
HandleUpstreamError
(
context
.
Background
(),
account
,
401
,
http
.
Header
{},
[]
byte
(
"unauthorized"
))
require
.
True
(
t
,
shouldDisable
)
require
.
Equal
(
t
,
1
,
repo
.
tempCalls
)
require
.
Equal
(
t
,
0
,
repo
.
setErrorCalls
)
require
.
Len
(
t
,
invalidator
.
accounts
,
1
)
require
.
WithinDuration
(
t
,
start
.
Add
(
5
*
time
.
Minute
),
repo
.
tempUntil
,
10
*
time
.
Second
)
require
.
NotEmpty
(
t
,
repo
.
tempReason
)
})
}
}
func
TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError
(
t
*
testing
.
T
)
{
repo
:=
&
rateLimitAccountRepoStub
{}
invalidator
:=
&
tokenCacheInvalidatorRecorder
{
err
:
errors
.
New
(
"boom"
)}
service
:=
NewRateLimitService
(
repo
,
nil
,
&
config
.
Config
{},
nil
,
nil
)
service
.
SetTokenCacheInvalidator
(
invalidator
)
account
:=
&
Account
{
ID
:
101
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
,
}
shouldDisable
:=
service
.
HandleUpstreamError
(
context
.
Background
(),
account
,
401
,
http
.
Header
{},
[]
byte
(
"unauthorized"
))
require
.
True
(
t
,
shouldDisable
)
require
.
Equal
(
t
,
1
,
repo
.
tempCalls
)
require
.
Equal
(
t
,
0
,
repo
.
setErrorCalls
)
require
.
Len
(
t
,
invalidator
.
accounts
,
1
)
}
func
TestRateLimitService_HandleUpstreamError_OAuth401CustomRule
(
t
*
testing
.
T
)
{
repo
:=
&
rateLimitAccountRepoStub
{}
invalidator
:=
&
tokenCacheInvalidatorRecorder
{}
service
:=
NewRateLimitService
(
repo
,
nil
,
&
config
.
Config
{},
nil
,
nil
)
service
.
SetTokenCacheInvalidator
(
invalidator
)
account
:=
&
Account
{
ID
:
103
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"temp_unschedulable_enabled"
:
true
,
"temp_unschedulable_rules"
:
[]
any
{
map
[
string
]
any
{
"error_code"
:
401
,
"keywords"
:
[]
any
{
"unauthorized"
},
"duration_minutes"
:
30
,
"description"
:
"custom rule"
,
},
},
},
}
start
:=
time
.
Now
()
shouldDisable
:=
service
.
HandleUpstreamError
(
context
.
Background
(),
account
,
401
,
http
.
Header
{},
[]
byte
(
"unauthorized"
))
require
.
True
(
t
,
shouldDisable
)
require
.
Equal
(
t
,
1
,
repo
.
tempCalls
)
require
.
Equal
(
t
,
0
,
repo
.
setErrorCalls
)
require
.
Len
(
t
,
invalidator
.
accounts
,
1
)
require
.
WithinDuration
(
t
,
start
.
Add
(
30
*
time
.
Minute
),
repo
.
tempUntil
,
10
*
time
.
Second
)
}
func
TestRateLimitService_HandleUpstreamError_NonOAuth401
(
t
*
testing
.
T
)
{
repo
:=
&
rateLimitAccountRepoStub
{}
invalidator
:=
&
tokenCacheInvalidatorRecorder
{}
service
:=
NewRateLimitService
(
repo
,
nil
,
&
config
.
Config
{},
nil
,
nil
)
service
.
SetTokenCacheInvalidator
(
invalidator
)
account
:=
&
Account
{
ID
:
102
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
}
shouldDisable
:=
service
.
HandleUpstreamError
(
context
.
Background
(),
account
,
401
,
http
.
Header
{},
[]
byte
(
"unauthorized"
))
require
.
True
(
t
,
shouldDisable
)
require
.
Equal
(
t
,
0
,
repo
.
tempCalls
)
require
.
Equal
(
t
,
1
,
repo
.
setErrorCalls
)
require
.
Empty
(
t
,
invalidator
.
accounts
)
}
// TestRateLimitService_HandleOAuth401_NilAccount 测试 account 为 nil 的情况
func
TestRateLimitService_HandleOAuth401_NilAccount
(
t
*
testing
.
T
)
{
repo
:=
&
rateLimitAccountRepoStub
{}
service
:=
NewRateLimitService
(
repo
,
nil
,
&
config
.
Config
{},
nil
,
nil
)
result
:=
service
.
handleOAuth401TempUnschedulable
(
context
.
Background
(),
nil
,
"error"
)
require
.
False
(
t
,
result
)
require
.
Equal
(
t
,
0
,
repo
.
tempCalls
)
}
// TestRateLimitService_HandleOAuth401_NilInvalidator 测试 tokenCacheInvalidator 为 nil 的情况
func
TestRateLimitService_HandleOAuth401_NilInvalidator
(
t
*
testing
.
T
)
{
repo
:=
&
rateLimitAccountRepoStub
{}
service
:=
NewRateLimitService
(
repo
,
nil
,
&
config
.
Config
{},
nil
,
nil
)
// 不设置 tokenCacheInvalidator
account
:=
&
Account
{
ID
:
200
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
,
}
result
:=
service
.
handleOAuth401TempUnschedulable
(
context
.
Background
(),
account
,
"error"
)
require
.
True
(
t
,
result
)
require
.
Equal
(
t
,
1
,
repo
.
tempCalls
)
}
// TestRateLimitService_HandleOAuth401_SetTempUnschedulableFailed 测试 SetTempUnschedulable 失败的情况
func
TestRateLimitService_HandleOAuth401_SetTempUnschedulableFailed
(
t
*
testing
.
T
)
{
repo
:=
&
rateLimitAccountRepoStubWithError
{
setTempErr
:
errors
.
New
(
"db error"
),
}
invalidator
:=
&
tokenCacheInvalidatorRecorder
{}
service
:=
NewRateLimitService
(
repo
,
nil
,
&
config
.
Config
{},
nil
,
nil
)
service
.
SetTokenCacheInvalidator
(
invalidator
)
account
:=
&
Account
{
ID
:
201
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
,
}
result
:=
service
.
handleOAuth401TempUnschedulable
(
context
.
Background
(),
account
,
"error"
)
require
.
False
(
t
,
result
)
// 失败应返回 false
require
.
Len
(
t
,
invalidator
.
accounts
,
1
)
// 但 invalidator 仍然被调用
}
// rateLimitAccountRepoStubWithError 支持返回错误的 stub
type
rateLimitAccountRepoStubWithError
struct
{
mockAccountRepoForGemini
setTempErr
error
setErrorCalls
int
}
func
(
r
*
rateLimitAccountRepoStubWithError
)
SetTempUnschedulable
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
,
reason
string
)
error
{
return
r
.
setTempErr
}
func
(
r
*
rateLimitAccountRepoStubWithError
)
SetError
(
ctx
context
.
Context
,
id
int64
,
errorMsg
string
)
error
{
r
.
setErrorCalls
++
return
nil
}
// TestRateLimitService_HandleOAuth401_WithTempUnschedCache 测试 tempUnschedCache 存在的情况
func
TestRateLimitService_HandleOAuth401_WithTempUnschedCache
(
t
*
testing
.
T
)
{
repo
:=
&
rateLimitAccountRepoStub
{}
invalidator
:=
&
tokenCacheInvalidatorRecorder
{}
tempCache
:=
&
tempUnschedCacheStub
{}
service
:=
NewRateLimitService
(
repo
,
nil
,
&
config
.
Config
{},
nil
,
tempCache
)
service
.
SetTokenCacheInvalidator
(
invalidator
)
account
:=
&
Account
{
ID
:
202
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
,
}
result
:=
service
.
handleOAuth401TempUnschedulable
(
context
.
Background
(),
account
,
"error"
)
require
.
True
(
t
,
result
)
require
.
Equal
(
t
,
1
,
repo
.
tempCalls
)
require
.
Equal
(
t
,
1
,
tempCache
.
setCalls
)
}
// TestRateLimitService_HandleOAuth401_TempUnschedCacheError 测试 tempUnschedCache 设置失败的情况
func
TestRateLimitService_HandleOAuth401_TempUnschedCacheError
(
t
*
testing
.
T
)
{
repo
:=
&
rateLimitAccountRepoStub
{}
invalidator
:=
&
tokenCacheInvalidatorRecorder
{}
tempCache
:=
&
tempUnschedCacheStub
{
setErr
:
errors
.
New
(
"cache error"
)}
service
:=
NewRateLimitService
(
repo
,
nil
,
&
config
.
Config
{},
nil
,
tempCache
)
service
.
SetTokenCacheInvalidator
(
invalidator
)
account
:=
&
Account
{
ID
:
203
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
,
}
result
:=
service
.
handleOAuth401TempUnschedulable
(
context
.
Background
(),
account
,
"error"
)
require
.
True
(
t
,
result
)
// 缓存错误不影响主流程
require
.
Equal
(
t
,
1
,
repo
.
tempCalls
)
}
// tempUnschedCacheStub 用于测试的 TempUnschedCache stub
type
tempUnschedCacheStub
struct
{
setCalls
int
setErr
error
}
func
(
c
*
tempUnschedCacheStub
)
GetTempUnsched
(
ctx
context
.
Context
,
accountID
int64
)
(
*
TempUnschedState
,
error
)
{
return
nil
,
nil
}
func
(
c
*
tempUnschedCacheStub
)
SetTempUnsched
(
ctx
context
.
Context
,
accountID
int64
,
state
*
TempUnschedState
)
error
{
c
.
setCalls
++
return
c
.
setErr
}
func
(
c
*
tempUnschedCacheStub
)
DeleteTempUnsched
(
ctx
context
.
Context
,
accountID
int64
)
error
{
return
nil
}
// TestRateLimitService_OAuth401Cooldown 测试 oauth401Cooldown 函数
func
TestRateLimitService_OAuth401Cooldown
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
cfg
*
config
.
Config
expected
time
.
Duration
}{
{
name
:
"default_when_config_zero"
,
cfg
:
&
config
.
Config
{
RateLimit
:
config
.
RateLimitConfig
{
OAuth401CooldownMinutes
:
0
}},
expected
:
5
*
time
.
Minute
,
},
{
name
:
"custom_cooldown_10_minutes"
,
cfg
:
&
config
.
Config
{
RateLimit
:
config
.
RateLimitConfig
{
OAuth401CooldownMinutes
:
10
}},
expected
:
10
*
time
.
Minute
,
},
{
name
:
"custom_cooldown_1_minute"
,
cfg
:
&
config
.
Config
{
RateLimit
:
config
.
RateLimitConfig
{
OAuth401CooldownMinutes
:
1
}},
expected
:
1
*
time
.
Minute
,
},
{
name
:
"negative_value_uses_default"
,
cfg
:
&
config
.
Config
{
RateLimit
:
config
.
RateLimitConfig
{
OAuth401CooldownMinutes
:
-
5
}},
expected
:
5
*
time
.
Minute
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
service
:=
NewRateLimitService
(
nil
,
nil
,
tt
.
cfg
,
nil
,
nil
)
result
:=
service
.
oauth401Cooldown
()
require
.
Equal
(
t
,
tt
.
expected
,
result
)
})
}
}
// TestRateLimitService_OAuth401Cooldown_NilConfig 测试 cfg 为 nil 的情况
func
TestRateLimitService_OAuth401Cooldown_NilConfig
(
t
*
testing
.
T
)
{
service
:=
&
RateLimitService
{
cfg
:
nil
}
result
:=
service
.
oauth401Cooldown
()
require
.
Equal
(
t
,
5
*
time
.
Minute
,
result
)
}
// TestRateLimitService_HandleOAuth401_WithCustomCooldown 测试自定义 cooldown 配置
func
TestRateLimitService_HandleOAuth401_WithCustomCooldown
(
t
*
testing
.
T
)
{
repo
:=
&
rateLimitAccountRepoStub
{}
cfg
:=
&
config
.
Config
{
RateLimit
:
config
.
RateLimitConfig
{
OAuth401CooldownMinutes
:
15
,
},
}
service
:=
NewRateLimitService
(
repo
,
nil
,
cfg
,
nil
,
nil
)
account
:=
&
Account
{
ID
:
204
,
Platform
:
PlatformAntigravity
,
Type
:
AccountTypeOAuth
,
}
start
:=
time
.
Now
()
result
:=
service
.
handleOAuth401TempUnschedulable
(
context
.
Background
(),
account
,
"error"
)
require
.
True
(
t
,
result
)
require
.
WithinDuration
(
t
,
start
.
Add
(
15
*
time
.
Minute
),
repo
.
tempUntil
,
10
*
time
.
Second
)
}
// TestRateLimitService_HandleOAuth401_EmptyUpstreamMsg 测试 upstreamMsg 为空的情况
func
TestRateLimitService_HandleOAuth401_EmptyUpstreamMsg
(
t
*
testing
.
T
)
{
repo
:=
&
rateLimitAccountRepoStub
{}
service
:=
NewRateLimitService
(
repo
,
nil
,
&
config
.
Config
{},
nil
,
nil
)
account
:=
&
Account
{
ID
:
205
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
,
}
result
:=
service
.
handleOAuth401TempUnschedulable
(
context
.
Background
(),
account
,
""
)
require
.
True
(
t
,
result
)
require
.
Contains
(
t
,
repo
.
tempReason
,
"Authentication failed (401)"
)
}
backend/internal/service/token_cache_invalidator.go
0 → 100644
View file @
daf10907
package
service
import
"context"
type
TokenCacheInvalidator
interface
{
InvalidateToken
(
ctx
context
.
Context
,
account
*
Account
)
error
}
type
CompositeTokenCacheInvalidator
struct
{
geminiCache
GeminiTokenCache
}
func
NewCompositeTokenCacheInvalidator
(
geminiCache
GeminiTokenCache
)
*
CompositeTokenCacheInvalidator
{
return
&
CompositeTokenCacheInvalidator
{
geminiCache
:
geminiCache
,
}
}
func
(
c
*
CompositeTokenCacheInvalidator
)
InvalidateToken
(
ctx
context
.
Context
,
account
*
Account
)
error
{
if
c
==
nil
||
c
.
geminiCache
==
nil
||
account
==
nil
{
return
nil
}
if
account
.
Type
!=
AccountTypeOAuth
{
return
nil
}
switch
account
.
Platform
{
case
PlatformGemini
:
return
c
.
geminiCache
.
DeleteAccessToken
(
ctx
,
GeminiTokenCacheKey
(
account
))
case
PlatformAntigravity
:
return
c
.
geminiCache
.
DeleteAccessToken
(
ctx
,
AntigravityTokenCacheKey
(
account
))
default
:
return
nil
}
}
backend/internal/service/token_cache_invalidator_test.go
0 → 100644
View file @
daf10907
//go:build unit
package
service
import
(
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
)
type
geminiTokenCacheStub
struct
{
deletedKeys
[]
string
deleteErr
error
}
func
(
s
*
geminiTokenCacheStub
)
GetAccessToken
(
ctx
context
.
Context
,
cacheKey
string
)
(
string
,
error
)
{
return
""
,
nil
}
func
(
s
*
geminiTokenCacheStub
)
SetAccessToken
(
ctx
context
.
Context
,
cacheKey
string
,
token
string
,
ttl
time
.
Duration
)
error
{
return
nil
}
func
(
s
*
geminiTokenCacheStub
)
DeleteAccessToken
(
ctx
context
.
Context
,
cacheKey
string
)
error
{
s
.
deletedKeys
=
append
(
s
.
deletedKeys
,
cacheKey
)
return
s
.
deleteErr
}
func
(
s
*
geminiTokenCacheStub
)
AcquireRefreshLock
(
ctx
context
.
Context
,
cacheKey
string
,
ttl
time
.
Duration
)
(
bool
,
error
)
{
return
true
,
nil
}
func
(
s
*
geminiTokenCacheStub
)
ReleaseRefreshLock
(
ctx
context
.
Context
,
cacheKey
string
)
error
{
return
nil
}
func
TestCompositeTokenCacheInvalidator_Gemini
(
t
*
testing
.
T
)
{
cache
:=
&
geminiTokenCacheStub
{}
invalidator
:=
NewCompositeTokenCacheInvalidator
(
cache
)
account
:=
&
Account
{
ID
:
10
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"project_id"
:
"project-x"
,
},
}
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
[]
string
{
"project-x"
},
cache
.
deletedKeys
)
}
func
TestCompositeTokenCacheInvalidator_Antigravity
(
t
*
testing
.
T
)
{
cache
:=
&
geminiTokenCacheStub
{}
invalidator
:=
NewCompositeTokenCacheInvalidator
(
cache
)
account
:=
&
Account
{
ID
:
99
,
Platform
:
PlatformAntigravity
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"project_id"
:
"ag-project"
,
},
}
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
[]
string
{
"ag:ag-project"
},
cache
.
deletedKeys
)
}
func
TestCompositeTokenCacheInvalidator_SkipNonOAuth
(
t
*
testing
.
T
)
{
cache
:=
&
geminiTokenCacheStub
{}
invalidator
:=
NewCompositeTokenCacheInvalidator
(
cache
)
account
:=
&
Account
{
ID
:
1
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeAPIKey
,
}
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Empty
(
t
,
cache
.
deletedKeys
)
}
func
TestCompositeTokenCacheInvalidator_NilCache
(
t
*
testing
.
T
)
{
invalidator
:=
NewCompositeTokenCacheInvalidator
(
nil
)
account
:=
&
Account
{
ID
:
2
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
,
}
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
}
backend/internal/service/token_cache_key_test.go
0 → 100644
View file @
daf10907
//go:build unit
package
service
import
(
"testing"
"github.com/stretchr/testify/require"
)
func
TestGeminiTokenCacheKey
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
account
*
Account
expected
string
}{
{
name
:
"with_project_id"
,
account
:
&
Account
{
ID
:
100
,
Credentials
:
map
[
string
]
any
{
"project_id"
:
"my-project-123"
,
},
},
expected
:
"my-project-123"
,
},
{
name
:
"project_id_with_whitespace"
,
account
:
&
Account
{
ID
:
101
,
Credentials
:
map
[
string
]
any
{
"project_id"
:
" project-with-spaces "
,
},
},
expected
:
"project-with-spaces"
,
},
{
name
:
"empty_project_id_fallback_to_account_id"
,
account
:
&
Account
{
ID
:
102
,
Credentials
:
map
[
string
]
any
{
"project_id"
:
""
,
},
},
expected
:
"account:102"
,
},
{
name
:
"whitespace_only_project_id_fallback_to_account_id"
,
account
:
&
Account
{
ID
:
103
,
Credentials
:
map
[
string
]
any
{
"project_id"
:
" "
,
},
},
expected
:
"account:103"
,
},
{
name
:
"no_project_id_key_fallback_to_account_id"
,
account
:
&
Account
{
ID
:
104
,
Credentials
:
map
[
string
]
any
{},
},
expected
:
"account:104"
,
},
{
name
:
"nil_credentials_fallback_to_account_id"
,
account
:
&
Account
{
ID
:
105
,
Credentials
:
nil
,
},
expected
:
"account:105"
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
result
:=
GeminiTokenCacheKey
(
tt
.
account
)
require
.
Equal
(
t
,
tt
.
expected
,
result
)
})
}
}
func
TestAntigravityTokenCacheKey
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
account
*
Account
expected
string
}{
{
name
:
"with_project_id"
,
account
:
&
Account
{
ID
:
200
,
Credentials
:
map
[
string
]
any
{
"project_id"
:
"ag-project-456"
,
},
},
expected
:
"ag:ag-project-456"
,
},
{
name
:
"project_id_with_whitespace"
,
account
:
&
Account
{
ID
:
201
,
Credentials
:
map
[
string
]
any
{
"project_id"
:
" ag-project-spaces "
,
},
},
expected
:
"ag:ag-project-spaces"
,
},
{
name
:
"empty_project_id_fallback_to_account_id"
,
account
:
&
Account
{
ID
:
202
,
Credentials
:
map
[
string
]
any
{
"project_id"
:
""
,
},
},
expected
:
"ag:account:202"
,
},
{
name
:
"whitespace_only_project_id_fallback_to_account_id"
,
account
:
&
Account
{
ID
:
203
,
Credentials
:
map
[
string
]
any
{
"project_id"
:
" "
,
},
},
expected
:
"ag:account:203"
,
},
{
name
:
"no_project_id_key_fallback_to_account_id"
,
account
:
&
Account
{
ID
:
204
,
Credentials
:
map
[
string
]
any
{},
},
expected
:
"ag:account:204"
,
},
{
name
:
"nil_credentials_fallback_to_account_id"
,
account
:
&
Account
{
ID
:
205
,
Credentials
:
nil
,
},
expected
:
"ag:account:205"
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
result
:=
AntigravityTokenCacheKey
(
tt
.
account
)
require
.
Equal
(
t
,
tt
.
expected
,
result
)
})
}
}
backend/internal/service/token_refresh_service.go
View file @
daf10907
...
@@ -17,6 +17,7 @@ type TokenRefreshService struct {
...
@@ -17,6 +17,7 @@ type TokenRefreshService struct {
accountRepo
AccountRepository
accountRepo
AccountRepository
refreshers
[]
TokenRefresher
refreshers
[]
TokenRefresher
cfg
*
config
.
TokenRefreshConfig
cfg
*
config
.
TokenRefreshConfig
cacheInvalidator
TokenCacheInvalidator
stopCh
chan
struct
{}
stopCh
chan
struct
{}
wg
sync
.
WaitGroup
wg
sync
.
WaitGroup
...
@@ -29,11 +30,13 @@ func NewTokenRefreshService(
...
@@ -29,11 +30,13 @@ func NewTokenRefreshService(
openaiOAuthService
*
OpenAIOAuthService
,
openaiOAuthService
*
OpenAIOAuthService
,
geminiOAuthService
*
GeminiOAuthService
,
geminiOAuthService
*
GeminiOAuthService
,
antigravityOAuthService
*
AntigravityOAuthService
,
antigravityOAuthService
*
AntigravityOAuthService
,
cacheInvalidator
TokenCacheInvalidator
,
cfg
*
config
.
Config
,
cfg
*
config
.
Config
,
)
*
TokenRefreshService
{
)
*
TokenRefreshService
{
s
:=
&
TokenRefreshService
{
s
:=
&
TokenRefreshService
{
accountRepo
:
accountRepo
,
accountRepo
:
accountRepo
,
cfg
:
&
cfg
.
TokenRefresh
,
cfg
:
&
cfg
.
TokenRefresh
,
cacheInvalidator
:
cacheInvalidator
,
stopCh
:
make
(
chan
struct
{}),
stopCh
:
make
(
chan
struct
{}),
}
}
...
@@ -169,6 +172,14 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
...
@@ -169,6 +172,14 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
if
err
:=
s
.
accountRepo
.
Update
(
ctx
,
account
);
err
!=
nil
{
if
err
:=
s
.
accountRepo
.
Update
(
ctx
,
account
);
err
!=
nil
{
return
fmt
.
Errorf
(
"failed to save credentials: %w"
,
err
)
return
fmt
.
Errorf
(
"failed to save credentials: %w"
,
err
)
}
}
if
s
.
cacheInvalidator
!=
nil
&&
account
.
Type
==
AccountTypeOAuth
&&
(
account
.
Platform
==
PlatformGemini
||
account
.
Platform
==
PlatformAntigravity
)
{
if
err
:=
s
.
cacheInvalidator
.
InvalidateToken
(
ctx
,
account
);
err
!=
nil
{
log
.
Printf
(
"[TokenRefresh] Failed to invalidate token cache for account %d: %v"
,
account
.
ID
,
err
)
}
else
{
log
.
Printf
(
"[TokenRefresh] Token cache invalidated for account %d"
,
account
.
ID
)
}
}
return
nil
return
nil
}
}
...
...
backend/internal/service/token_refresh_service_test.go
0 → 100644
View file @
daf10907
//go:build unit
package
service
import
(
"context"
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type
tokenRefreshAccountRepo
struct
{
mockAccountRepoForGemini
updateCalls
int
setErrorCalls
int
lastAccount
*
Account
updateErr
error
}
func
(
r
*
tokenRefreshAccountRepo
)
Update
(
ctx
context
.
Context
,
account
*
Account
)
error
{
r
.
updateCalls
++
r
.
lastAccount
=
account
return
r
.
updateErr
}
func
(
r
*
tokenRefreshAccountRepo
)
SetError
(
ctx
context
.
Context
,
id
int64
,
errorMsg
string
)
error
{
r
.
setErrorCalls
++
return
nil
}
type
tokenCacheInvalidatorStub
struct
{
calls
int
err
error
}
func
(
s
*
tokenCacheInvalidatorStub
)
InvalidateToken
(
ctx
context
.
Context
,
account
*
Account
)
error
{
s
.
calls
++
return
s
.
err
}
type
tokenRefresherStub
struct
{
credentials
map
[
string
]
any
err
error
}
func
(
r
*
tokenRefresherStub
)
CanRefresh
(
account
*
Account
)
bool
{
return
true
}
func
(
r
*
tokenRefresherStub
)
NeedsRefresh
(
account
*
Account
,
refreshWindowDuration
time
.
Duration
)
bool
{
return
true
}
func
(
r
*
tokenRefresherStub
)
Refresh
(
ctx
context
.
Context
,
account
*
Account
)
(
map
[
string
]
any
,
error
)
{
if
r
.
err
!=
nil
{
return
nil
,
r
.
err
}
return
r
.
credentials
,
nil
}
func
TestTokenRefreshService_RefreshWithRetry_InvalidatesCache
(
t
*
testing
.
T
)
{
repo
:=
&
tokenRefreshAccountRepo
{}
invalidator
:=
&
tokenCacheInvalidatorStub
{}
cfg
:=
&
config
.
Config
{
TokenRefresh
:
config
.
TokenRefreshConfig
{
MaxRetries
:
1
,
RetryBackoffSeconds
:
0
,
},
}
service
:=
NewTokenRefreshService
(
repo
,
nil
,
nil
,
nil
,
nil
,
invalidator
,
cfg
)
account
:=
&
Account
{
ID
:
5
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
,
}
refresher
:=
&
tokenRefresherStub
{
credentials
:
map
[
string
]
any
{
"access_token"
:
"new-token"
,
},
}
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
repo
.
updateCalls
)
require
.
Equal
(
t
,
1
,
invalidator
.
calls
)
require
.
Equal
(
t
,
"new-token"
,
account
.
GetCredential
(
"access_token"
))
}
func
TestTokenRefreshService_RefreshWithRetry_InvalidatorErrorIgnored
(
t
*
testing
.
T
)
{
repo
:=
&
tokenRefreshAccountRepo
{}
invalidator
:=
&
tokenCacheInvalidatorStub
{
err
:
errors
.
New
(
"invalidate failed"
)}
cfg
:=
&
config
.
Config
{
TokenRefresh
:
config
.
TokenRefreshConfig
{
MaxRetries
:
1
,
RetryBackoffSeconds
:
0
,
},
}
service
:=
NewTokenRefreshService
(
repo
,
nil
,
nil
,
nil
,
nil
,
invalidator
,
cfg
)
account
:=
&
Account
{
ID
:
6
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
,
}
refresher
:=
&
tokenRefresherStub
{
credentials
:
map
[
string
]
any
{
"access_token"
:
"token"
,
},
}
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
repo
.
updateCalls
)
require
.
Equal
(
t
,
1
,
invalidator
.
calls
)
}
func
TestTokenRefreshService_RefreshWithRetry_NilInvalidator
(
t
*
testing
.
T
)
{
repo
:=
&
tokenRefreshAccountRepo
{}
cfg
:=
&
config
.
Config
{
TokenRefresh
:
config
.
TokenRefreshConfig
{
MaxRetries
:
1
,
RetryBackoffSeconds
:
0
,
},
}
service
:=
NewTokenRefreshService
(
repo
,
nil
,
nil
,
nil
,
nil
,
nil
,
cfg
)
account
:=
&
Account
{
ID
:
7
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
,
}
refresher
:=
&
tokenRefresherStub
{
credentials
:
map
[
string
]
any
{
"access_token"
:
"token"
,
},
}
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
repo
.
updateCalls
)
}
// TestTokenRefreshService_RefreshWithRetry_Antigravity 测试 Antigravity 平台的缓存失效
func
TestTokenRefreshService_RefreshWithRetry_Antigravity
(
t
*
testing
.
T
)
{
repo
:=
&
tokenRefreshAccountRepo
{}
invalidator
:=
&
tokenCacheInvalidatorStub
{}
cfg
:=
&
config
.
Config
{
TokenRefresh
:
config
.
TokenRefreshConfig
{
MaxRetries
:
1
,
RetryBackoffSeconds
:
0
,
},
}
service
:=
NewTokenRefreshService
(
repo
,
nil
,
nil
,
nil
,
nil
,
invalidator
,
cfg
)
account
:=
&
Account
{
ID
:
8
,
Platform
:
PlatformAntigravity
,
Type
:
AccountTypeOAuth
,
}
refresher
:=
&
tokenRefresherStub
{
credentials
:
map
[
string
]
any
{
"access_token"
:
"ag-token"
,
},
}
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
repo
.
updateCalls
)
require
.
Equal
(
t
,
1
,
invalidator
.
calls
)
// Antigravity 也应触发缓存失效
}
// TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount 测试非 OAuth 账号不触发缓存失效
func
TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount
(
t
*
testing
.
T
)
{
repo
:=
&
tokenRefreshAccountRepo
{}
invalidator
:=
&
tokenCacheInvalidatorStub
{}
cfg
:=
&
config
.
Config
{
TokenRefresh
:
config
.
TokenRefreshConfig
{
MaxRetries
:
1
,
RetryBackoffSeconds
:
0
,
},
}
service
:=
NewTokenRefreshService
(
repo
,
nil
,
nil
,
nil
,
nil
,
invalidator
,
cfg
)
account
:=
&
Account
{
ID
:
9
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeAPIKey
,
// 非 OAuth
}
refresher
:=
&
tokenRefresherStub
{
credentials
:
map
[
string
]
any
{
"access_token"
:
"token"
,
},
}
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
repo
.
updateCalls
)
require
.
Equal
(
t
,
0
,
invalidator
.
calls
)
// 非 OAuth 不触发缓存失效
}
// TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth 测试其他平台的 OAuth 账号不触发缓存失效
func
TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth
(
t
*
testing
.
T
)
{
repo
:=
&
tokenRefreshAccountRepo
{}
invalidator
:=
&
tokenCacheInvalidatorStub
{}
cfg
:=
&
config
.
Config
{
TokenRefresh
:
config
.
TokenRefreshConfig
{
MaxRetries
:
1
,
RetryBackoffSeconds
:
0
,
},
}
service
:=
NewTokenRefreshService
(
repo
,
nil
,
nil
,
nil
,
nil
,
invalidator
,
cfg
)
account
:=
&
Account
{
ID
:
10
,
Platform
:
PlatformOpenAI
,
// 其他平台
Type
:
AccountTypeOAuth
,
}
refresher
:=
&
tokenRefresherStub
{
credentials
:
map
[
string
]
any
{
"access_token"
:
"token"
,
},
}
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
repo
.
updateCalls
)
require
.
Equal
(
t
,
0
,
invalidator
.
calls
)
// 其他平台不触发缓存失效
}
// TestTokenRefreshService_RefreshWithRetry_UpdateFailed 测试更新失败的情况
func
TestTokenRefreshService_RefreshWithRetry_UpdateFailed
(
t
*
testing
.
T
)
{
repo
:=
&
tokenRefreshAccountRepo
{
updateErr
:
errors
.
New
(
"update failed"
)}
invalidator
:=
&
tokenCacheInvalidatorStub
{}
cfg
:=
&
config
.
Config
{
TokenRefresh
:
config
.
TokenRefreshConfig
{
MaxRetries
:
1
,
RetryBackoffSeconds
:
0
,
},
}
service
:=
NewTokenRefreshService
(
repo
,
nil
,
nil
,
nil
,
nil
,
invalidator
,
cfg
)
account
:=
&
Account
{
ID
:
11
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
,
}
refresher
:=
&
tokenRefresherStub
{
credentials
:
map
[
string
]
any
{
"access_token"
:
"token"
,
},
}
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"failed to save credentials"
)
require
.
Equal
(
t
,
1
,
repo
.
updateCalls
)
require
.
Equal
(
t
,
0
,
invalidator
.
calls
)
// 更新失败时不应触发缓存失效
}
// TestTokenRefreshService_RefreshWithRetry_RefreshFailed 测试刷新失败的情况
func
TestTokenRefreshService_RefreshWithRetry_RefreshFailed
(
t
*
testing
.
T
)
{
repo
:=
&
tokenRefreshAccountRepo
{}
invalidator
:=
&
tokenCacheInvalidatorStub
{}
cfg
:=
&
config
.
Config
{
TokenRefresh
:
config
.
TokenRefreshConfig
{
MaxRetries
:
2
,
RetryBackoffSeconds
:
0
,
},
}
service
:=
NewTokenRefreshService
(
repo
,
nil
,
nil
,
nil
,
nil
,
invalidator
,
cfg
)
account
:=
&
Account
{
ID
:
12
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
,
}
refresher
:=
&
tokenRefresherStub
{
err
:
errors
.
New
(
"refresh failed"
),
}
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
)
require
.
Error
(
t
,
err
)
require
.
Equal
(
t
,
0
,
repo
.
updateCalls
)
// 刷新失败不应更新
require
.
Equal
(
t
,
0
,
invalidator
.
calls
)
// 刷新失败不应触发缓存失效
require
.
Equal
(
t
,
1
,
repo
.
setErrorCalls
)
// 应设置错误状态
}
// TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed 测试 Antigravity 刷新失败不设置错误状态
func
TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed
(
t
*
testing
.
T
)
{
repo
:=
&
tokenRefreshAccountRepo
{}
invalidator
:=
&
tokenCacheInvalidatorStub
{}
cfg
:=
&
config
.
Config
{
TokenRefresh
:
config
.
TokenRefreshConfig
{
MaxRetries
:
1
,
RetryBackoffSeconds
:
0
,
},
}
service
:=
NewTokenRefreshService
(
repo
,
nil
,
nil
,
nil
,
nil
,
invalidator
,
cfg
)
account
:=
&
Account
{
ID
:
13
,
Platform
:
PlatformAntigravity
,
Type
:
AccountTypeOAuth
,
}
refresher
:=
&
tokenRefresherStub
{
err
:
errors
.
New
(
"network error"
),
// 可重试错误
}
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
)
require
.
Error
(
t
,
err
)
require
.
Equal
(
t
,
0
,
repo
.
updateCalls
)
require
.
Equal
(
t
,
0
,
invalidator
.
calls
)
require
.
Equal
(
t
,
0
,
repo
.
setErrorCalls
)
// Antigravity 可重试错误不设置错误状态
}
// TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError 测试 Antigravity 不可重试错误
func
TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError
(
t
*
testing
.
T
)
{
repo
:=
&
tokenRefreshAccountRepo
{}
invalidator
:=
&
tokenCacheInvalidatorStub
{}
cfg
:=
&
config
.
Config
{
TokenRefresh
:
config
.
TokenRefreshConfig
{
MaxRetries
:
3
,
RetryBackoffSeconds
:
0
,
},
}
service
:=
NewTokenRefreshService
(
repo
,
nil
,
nil
,
nil
,
nil
,
invalidator
,
cfg
)
account
:=
&
Account
{
ID
:
14
,
Platform
:
PlatformAntigravity
,
Type
:
AccountTypeOAuth
,
}
refresher
:=
&
tokenRefresherStub
{
err
:
errors
.
New
(
"invalid_grant: token revoked"
),
// 不可重试错误
}
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
)
require
.
Error
(
t
,
err
)
require
.
Equal
(
t
,
0
,
repo
.
updateCalls
)
require
.
Equal
(
t
,
0
,
invalidator
.
calls
)
require
.
Equal
(
t
,
1
,
repo
.
setErrorCalls
)
// 不可重试错误应设置错误状态
}
// TestIsNonRetryableRefreshError 测试不可重试错误判断
func
TestIsNonRetryableRefreshError
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
err
error
expected
bool
}{
{
name
:
"nil_error"
,
err
:
nil
,
expected
:
false
},
{
name
:
"network_error"
,
err
:
errors
.
New
(
"network timeout"
),
expected
:
false
},
{
name
:
"invalid_grant"
,
err
:
errors
.
New
(
"invalid_grant"
),
expected
:
true
},
{
name
:
"invalid_client"
,
err
:
errors
.
New
(
"invalid_client"
),
expected
:
true
},
{
name
:
"unauthorized_client"
,
err
:
errors
.
New
(
"unauthorized_client"
),
expected
:
true
},
{
name
:
"access_denied"
,
err
:
errors
.
New
(
"access_denied"
),
expected
:
true
},
{
name
:
"invalid_grant_with_desc"
,
err
:
errors
.
New
(
"Error: invalid_grant - token revoked"
),
expected
:
true
},
{
name
:
"case_insensitive"
,
err
:
errors
.
New
(
"INVALID_GRANT"
),
expected
:
true
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
result
:=
isNonRetryableRefreshError
(
tt
.
err
)
require
.
Equal
(
t
,
tt
.
expected
,
result
)
})
}
}
backend/internal/service/wire.go
View file @
daf10907
...
@@ -42,9 +42,10 @@ func ProvideTokenRefreshService(
...
@@ -42,9 +42,10 @@ func ProvideTokenRefreshService(
openaiOAuthService
*
OpenAIOAuthService
,
openaiOAuthService
*
OpenAIOAuthService
,
geminiOAuthService
*
GeminiOAuthService
,
geminiOAuthService
*
GeminiOAuthService
,
antigravityOAuthService
*
AntigravityOAuthService
,
antigravityOAuthService
*
AntigravityOAuthService
,
cacheInvalidator
TokenCacheInvalidator
,
cfg
*
config
.
Config
,
cfg
*
config
.
Config
,
)
*
TokenRefreshService
{
)
*
TokenRefreshService
{
svc
:=
NewTokenRefreshService
(
accountRepo
,
oauthService
,
openaiOAuthService
,
geminiOAuthService
,
antigravityOAuthService
,
cfg
)
svc
:=
NewTokenRefreshService
(
accountRepo
,
oauthService
,
openaiOAuthService
,
geminiOAuthService
,
antigravityOAuthService
,
cacheInvalidator
,
cfg
)
svc
.
Start
()
svc
.
Start
()
return
svc
return
svc
}
}
...
@@ -108,10 +109,12 @@ func ProvideRateLimitService(
...
@@ -108,10 +109,12 @@ func ProvideRateLimitService(
tempUnschedCache
TempUnschedCache
,
tempUnschedCache
TempUnschedCache
,
timeoutCounterCache
TimeoutCounterCache
,
timeoutCounterCache
TimeoutCounterCache
,
settingService
*
SettingService
,
settingService
*
SettingService
,
tokenCacheInvalidator
TokenCacheInvalidator
,
)
*
RateLimitService
{
)
*
RateLimitService
{
svc
:=
NewRateLimitService
(
accountRepo
,
usageRepo
,
cfg
,
geminiQuotaService
,
tempUnschedCache
)
svc
:=
NewRateLimitService
(
accountRepo
,
usageRepo
,
cfg
,
geminiQuotaService
,
tempUnschedCache
)
svc
.
SetTimeoutCounterCache
(
timeoutCounterCache
)
svc
.
SetTimeoutCounterCache
(
timeoutCounterCache
)
svc
.
SetSettingService
(
settingService
)
svc
.
SetSettingService
(
settingService
)
svc
.
SetTokenCacheInvalidator
(
tokenCacheInvalidator
)
return
svc
return
svc
}
}
...
@@ -210,6 +213,7 @@ var ProviderSet = wire.NewSet(
...
@@ -210,6 +213,7 @@ var ProviderSet = wire.NewSet(
NewOpenAIOAuthService
,
NewOpenAIOAuthService
,
NewGeminiOAuthService
,
NewGeminiOAuthService
,
NewGeminiQuotaService
,
NewGeminiQuotaService
,
NewCompositeTokenCacheInvalidator
,
NewAntigravityOAuthService
,
NewAntigravityOAuthService
,
NewGeminiTokenProvider
,
NewGeminiTokenProvider
,
NewGeminiMessagesCompatService
,
NewGeminiMessagesCompatService
,
...
...
config.yaml
View file @
daf10907
...
@@ -387,6 +387,9 @@ rate_limit:
...
@@ -387,6 +387,9 @@ rate_limit:
# Cooldown time (in minutes) when upstream returns 529 (overloaded)
# Cooldown time (in minutes) when upstream returns 529 (overloaded)
# 上游返回 529(过载)时的冷却时间(分钟)
# 上游返回 529(过载)时的冷却时间(分钟)
overload_cooldown_minutes
:
10
overload_cooldown_minutes
:
10
# Cooldown time (in minutes) for OAuth 401 temporary unschedulable
# OAuth 401 临时不可调度冷却时间(分钟)
oauth_401_cooldown_minutes
:
5
# =============================================================================
# =============================================================================
# Pricing Data Source (Optional)
# Pricing Data Source (Optional)
...
...
deploy/.env.example
View file @
daf10907
...
@@ -69,6 +69,17 @@ JWT_EXPIRE_HOUR=24
...
@@ -69,6 +69,17 @@ JWT_EXPIRE_HOUR=24
# Leave unset to use default ./config.yaml
# Leave unset to use default ./config.yaml
#CONFIG_FILE=./config.yaml
#CONFIG_FILE=./config.yaml
# -----------------------------------------------------------------------------
# Rate Limiting (Optional)
# 速率限制(可选)
# -----------------------------------------------------------------------------
# Cooldown time (in minutes) when upstream returns 529 (overloaded)
# 上游返回 529(过载)时的冷却时间(分钟)
RATE_LIMIT_OVERLOAD_COOLDOWN_MINUTES=10
# Cooldown time (in minutes) for OAuth 401 temporary unschedulable
# OAuth 401 临时不可调度冷却时间(分钟)
RATE_LIMIT_OAUTH_401_COOLDOWN_MINUTES=5
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
# Gateway Scheduling (Optional)
# Gateway Scheduling (Optional)
# 调度缓存与受控回源配置(缓存就绪且命中时不读 DB)
# 调度缓存与受控回源配置(缓存就绪且命中时不读 DB)
...
...
deploy/config.example.yaml
View file @
daf10907
...
@@ -429,6 +429,9 @@ rate_limit:
...
@@ -429,6 +429,9 @@ rate_limit:
# Cooldown time (in minutes) when upstream returns 529 (overloaded)
# Cooldown time (in minutes) when upstream returns 529 (overloaded)
# 上游返回 529(过载)时的冷却时间(分钟)
# 上游返回 529(过载)时的冷却时间(分钟)
overload_cooldown_minutes
:
10
overload_cooldown_minutes
:
10
# Cooldown time (in minutes) for OAuth 401 temporary unschedulable
# OAuth 401 临时不可调度冷却时间(分钟)
oauth_401_cooldown_minutes
:
5
# =============================================================================
# =============================================================================
# Pricing Data Source (Optional)
# Pricing Data Source (Optional)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment