Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
b36f3db9
Unverified
Commit
b36f3db9
authored
Jan 15, 2026
by
Wesley Liddick
Committed by
GitHub
Jan 15, 2026
Browse files
Merge pull request #300 from mt21625457/main
feat(网关): 引入 OpenAI/Claude OAuth token 缓存
parents
5f890e85
f862ddc9
Changes
18
Expand all
Hide whitespace changes
Inline
Side-by-side
.gitignore
View file @
b36f3db9
...
@@ -129,3 +129,4 @@ deploy/docker-compose.override.yml
...
@@ -129,3 +129,4 @@ deploy/docker-compose.override.yml
.gocache/
.gocache/
vite.config.js
vite.config.js
docs/*
docs/*
.serena/
\ No newline at end of file
backend/cmd/server/wire_gen.go
View file @
b36f3db9
...
@@ -100,8 +100,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
...
@@ -100,8 +100,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
tempUnschedCache
:=
repository
.
NewTempUnschedCache
(
redisClient
)
tempUnschedCache
:=
repository
.
NewTempUnschedCache
(
redisClient
)
timeoutCounterCache
:=
repository
.
NewTimeoutCounterCache
(
redisClient
)
timeoutCounterCache
:=
repository
.
NewTimeoutCounterCache
(
redisClient
)
geminiTokenCache
:=
repository
.
NewGeminiTokenCache
(
redisClient
)
geminiTokenCache
:=
repository
.
NewGeminiTokenCache
(
redisClient
)
t
okenCacheInvalidator
:=
service
.
NewCompositeTokenCacheInvalidator
(
geminiTokenCache
)
compositeT
okenCacheInvalidator
:=
service
.
NewCompositeTokenCacheInvalidator
(
geminiTokenCache
)
rateLimitService
:=
service
.
ProvideRateLimitService
(
accountRepository
,
usageLogRepository
,
configConfig
,
geminiQuotaService
,
tempUnschedCache
,
timeoutCounterCache
,
settingService
,
t
okenCacheInvalidator
)
rateLimitService
:=
service
.
ProvideRateLimitService
(
accountRepository
,
usageLogRepository
,
configConfig
,
geminiQuotaService
,
tempUnschedCache
,
timeoutCounterCache
,
settingService
,
compositeT
okenCacheInvalidator
)
claudeUsageFetcher
:=
repository
.
NewClaudeUsageFetcher
()
claudeUsageFetcher
:=
repository
.
NewClaudeUsageFetcher
()
antigravityQuotaFetcher
:=
service
.
NewAntigravityQuotaFetcher
(
proxyRepository
)
antigravityQuotaFetcher
:=
service
.
NewAntigravityQuotaFetcher
(
proxyRepository
)
usageCache
:=
service
.
NewUsageCache
()
usageCache
:=
service
.
NewUsageCache
()
...
@@ -136,8 +136,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
...
@@ -136,8 +136,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
identityCache
:=
repository
.
NewIdentityCache
(
redisClient
)
identityCache
:=
repository
.
NewIdentityCache
(
redisClient
)
identityService
:=
service
.
NewIdentityService
(
identityCache
)
identityService
:=
service
.
NewIdentityService
(
identityCache
)
deferredService
:=
service
.
ProvideDeferredService
(
accountRepository
,
timingWheelService
)
deferredService
:=
service
.
ProvideDeferredService
(
accountRepository
,
timingWheelService
)
gatewayService
:=
service
.
NewGatewayService
(
accountRepository
,
groupRepository
,
usageLogRepository
,
userRepository
,
userSubscriptionRepository
,
gatewayCache
,
configConfig
,
schedulerSnapshotService
,
concurrencyService
,
billingService
,
rateLimitService
,
billingCacheService
,
identityService
,
httpUpstream
,
deferredService
)
claudeTokenProvider
:=
service
.
NewClaudeTokenProvider
(
accountRepository
,
geminiTokenCache
,
oAuthService
)
openAIGatewayService
:=
service
.
NewOpenAIGatewayService
(
accountRepository
,
usageLogRepository
,
userRepository
,
userSubscriptionRepository
,
gatewayCache
,
configConfig
,
schedulerSnapshotService
,
concurrencyService
,
billingService
,
rateLimitService
,
billingCacheService
,
httpUpstream
,
deferredService
)
gatewayService
:=
service
.
NewGatewayService
(
accountRepository
,
groupRepository
,
usageLogRepository
,
userRepository
,
userSubscriptionRepository
,
gatewayCache
,
configConfig
,
schedulerSnapshotService
,
concurrencyService
,
billingService
,
rateLimitService
,
billingCacheService
,
identityService
,
httpUpstream
,
deferredService
,
claudeTokenProvider
)
openAITokenProvider
:=
service
.
NewOpenAITokenProvider
(
accountRepository
,
geminiTokenCache
,
openAIOAuthService
)
openAIGatewayService
:=
service
.
NewOpenAIGatewayService
(
accountRepository
,
usageLogRepository
,
userRepository
,
userSubscriptionRepository
,
gatewayCache
,
configConfig
,
schedulerSnapshotService
,
concurrencyService
,
billingService
,
rateLimitService
,
billingCacheService
,
httpUpstream
,
deferredService
,
openAITokenProvider
)
geminiMessagesCompatService
:=
service
.
NewGeminiMessagesCompatService
(
accountRepository
,
groupRepository
,
gatewayCache
,
schedulerSnapshotService
,
geminiTokenProvider
,
rateLimitService
,
httpUpstream
,
antigravityGatewayService
,
configConfig
)
geminiMessagesCompatService
:=
service
.
NewGeminiMessagesCompatService
(
accountRepository
,
groupRepository
,
gatewayCache
,
schedulerSnapshotService
,
geminiTokenProvider
,
rateLimitService
,
httpUpstream
,
antigravityGatewayService
,
configConfig
)
opsService
:=
service
.
NewOpsService
(
opsRepository
,
settingRepository
,
configConfig
,
accountRepository
,
concurrencyService
,
gatewayService
,
openAIGatewayService
,
geminiMessagesCompatService
,
antigravityGatewayService
)
opsService
:=
service
.
NewOpsService
(
opsRepository
,
settingRepository
,
configConfig
,
accountRepository
,
concurrencyService
,
gatewayService
,
openAIGatewayService
,
geminiMessagesCompatService
,
antigravityGatewayService
)
settingHandler
:=
admin
.
NewSettingHandler
(
settingService
,
emailService
,
turnstileService
,
opsService
)
settingHandler
:=
admin
.
NewSettingHandler
(
settingService
,
emailService
,
turnstileService
,
opsService
)
...
@@ -168,7 +170,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
...
@@ -168,7 +170,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsAlertEvaluatorService
:=
service
.
ProvideOpsAlertEvaluatorService
(
opsService
,
opsRepository
,
emailService
,
redisClient
,
configConfig
)
opsAlertEvaluatorService
:=
service
.
ProvideOpsAlertEvaluatorService
(
opsService
,
opsRepository
,
emailService
,
redisClient
,
configConfig
)
opsCleanupService
:=
service
.
ProvideOpsCleanupService
(
opsRepository
,
db
,
redisClient
,
configConfig
)
opsCleanupService
:=
service
.
ProvideOpsCleanupService
(
opsRepository
,
db
,
redisClient
,
configConfig
)
opsScheduledReportService
:=
service
.
ProvideOpsScheduledReportService
(
opsService
,
userService
,
emailService
,
redisClient
,
configConfig
)
opsScheduledReportService
:=
service
.
ProvideOpsScheduledReportService
(
opsService
,
userService
,
emailService
,
redisClient
,
configConfig
)
tokenRefreshService
:=
service
.
ProvideTokenRefreshService
(
accountRepository
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
,
antigravityOAuthService
,
t
okenCacheInvalidator
,
configConfig
)
tokenRefreshService
:=
service
.
ProvideTokenRefreshService
(
accountRepository
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
,
antigravityOAuthService
,
compositeT
okenCacheInvalidator
,
configConfig
)
accountExpiryService
:=
service
.
ProvideAccountExpiryService
(
accountRepository
)
accountExpiryService
:=
service
.
ProvideAccountExpiryService
(
accountRepository
)
v
:=
provideCleanup
(
client
,
redisClient
,
opsMetricsCollector
,
opsAggregationService
,
opsAlertEvaluatorService
,
opsCleanupService
,
opsScheduledReportService
,
schedulerSnapshotService
,
tokenRefreshService
,
accountExpiryService
,
pricingService
,
emailQueueService
,
billingCacheService
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
,
antigravityOAuthService
)
v
:=
provideCleanup
(
client
,
redisClient
,
opsMetricsCollector
,
opsAggregationService
,
opsAlertEvaluatorService
,
opsCleanupService
,
opsScheduledReportService
,
schedulerSnapshotService
,
tokenRefreshService
,
accountExpiryService
,
pricingService
,
emailQueueService
,
billingCacheService
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
,
antigravityOAuthService
)
application
:=
&
Application
{
application
:=
&
Application
{
...
...
backend/internal/repository/gemini_token_cache.go
View file @
b36f3db9
...
@@ -11,8 +11,8 @@ import (
...
@@ -11,8 +11,8 @@ import (
)
)
const
(
const
(
gemini
TokenKeyPrefix
=
"
gemini
:token:"
oauth
TokenKeyPrefix
=
"
oauth
:token:"
gemini
RefreshLockKeyPrefix
=
"
gemini
:refresh_lock:"
oauth
RefreshLockKeyPrefix
=
"
oauth
:refresh_lock:"
)
)
type
geminiTokenCache
struct
{
type
geminiTokenCache
struct
{
...
@@ -24,26 +24,26 @@ func NewGeminiTokenCache(rdb *redis.Client) service.GeminiTokenCache {
...
@@ -24,26 +24,26 @@ func NewGeminiTokenCache(rdb *redis.Client) service.GeminiTokenCache {
}
}
func
(
c
*
geminiTokenCache
)
GetAccessToken
(
ctx
context
.
Context
,
cacheKey
string
)
(
string
,
error
)
{
func
(
c
*
geminiTokenCache
)
GetAccessToken
(
ctx
context
.
Context
,
cacheKey
string
)
(
string
,
error
)
{
key
:=
fmt
.
Sprintf
(
"%s%s"
,
gemini
TokenKeyPrefix
,
cacheKey
)
key
:=
fmt
.
Sprintf
(
"%s%s"
,
oauth
TokenKeyPrefix
,
cacheKey
)
return
c
.
rdb
.
Get
(
ctx
,
key
)
.
Result
()
return
c
.
rdb
.
Get
(
ctx
,
key
)
.
Result
()
}
}
func
(
c
*
geminiTokenCache
)
SetAccessToken
(
ctx
context
.
Context
,
cacheKey
string
,
token
string
,
ttl
time
.
Duration
)
error
{
func
(
c
*
geminiTokenCache
)
SetAccessToken
(
ctx
context
.
Context
,
cacheKey
string
,
token
string
,
ttl
time
.
Duration
)
error
{
key
:=
fmt
.
Sprintf
(
"%s%s"
,
gemini
TokenKeyPrefix
,
cacheKey
)
key
:=
fmt
.
Sprintf
(
"%s%s"
,
oauth
TokenKeyPrefix
,
cacheKey
)
return
c
.
rdb
.
Set
(
ctx
,
key
,
token
,
ttl
)
.
Err
()
return
c
.
rdb
.
Set
(
ctx
,
key
,
token
,
ttl
)
.
Err
()
}
}
func
(
c
*
geminiTokenCache
)
DeleteAccessToken
(
ctx
context
.
Context
,
cacheKey
string
)
error
{
func
(
c
*
geminiTokenCache
)
DeleteAccessToken
(
ctx
context
.
Context
,
cacheKey
string
)
error
{
key
:=
fmt
.
Sprintf
(
"%s%s"
,
gemini
TokenKeyPrefix
,
cacheKey
)
key
:=
fmt
.
Sprintf
(
"%s%s"
,
oauth
TokenKeyPrefix
,
cacheKey
)
return
c
.
rdb
.
Del
(
ctx
,
key
)
.
Err
()
return
c
.
rdb
.
Del
(
ctx
,
key
)
.
Err
()
}
}
func
(
c
*
geminiTokenCache
)
AcquireRefreshLock
(
ctx
context
.
Context
,
cacheKey
string
,
ttl
time
.
Duration
)
(
bool
,
error
)
{
func
(
c
*
geminiTokenCache
)
AcquireRefreshLock
(
ctx
context
.
Context
,
cacheKey
string
,
ttl
time
.
Duration
)
(
bool
,
error
)
{
key
:=
fmt
.
Sprintf
(
"%s%s"
,
gemini
RefreshLockKeyPrefix
,
cacheKey
)
key
:=
fmt
.
Sprintf
(
"%s%s"
,
oauth
RefreshLockKeyPrefix
,
cacheKey
)
return
c
.
rdb
.
SetNX
(
ctx
,
key
,
1
,
ttl
)
.
Result
()
return
c
.
rdb
.
SetNX
(
ctx
,
key
,
1
,
ttl
)
.
Result
()
}
}
func
(
c
*
geminiTokenCache
)
ReleaseRefreshLock
(
ctx
context
.
Context
,
cacheKey
string
)
error
{
func
(
c
*
geminiTokenCache
)
ReleaseRefreshLock
(
ctx
context
.
Context
,
cacheKey
string
)
error
{
key
:=
fmt
.
Sprintf
(
"%s%s"
,
gemini
RefreshLockKeyPrefix
,
cacheKey
)
key
:=
fmt
.
Sprintf
(
"%s%s"
,
oauth
RefreshLockKeyPrefix
,
cacheKey
)
return
c
.
rdb
.
Del
(
ctx
,
key
)
.
Err
()
return
c
.
rdb
.
Del
(
ctx
,
key
)
.
Err
()
}
}
backend/internal/service/claude_token_provider.go
0 → 100644
View file @
b36f3db9
package
service
import
(
"context"
"errors"
"log/slog"
"strconv"
"strings"
"time"
)
const
(
claudeTokenRefreshSkew
=
3
*
time
.
Minute
claudeTokenCacheSkew
=
5
*
time
.
Minute
claudeLockWaitTime
=
200
*
time
.
Millisecond
)
// ClaudeTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
type
ClaudeTokenCache
=
GeminiTokenCache
// ClaudeTokenProvider 管理 Claude (Anthropic) OAuth 账户的 access_token
type
ClaudeTokenProvider
struct
{
accountRepo
AccountRepository
tokenCache
ClaudeTokenCache
oauthService
*
OAuthService
}
func
NewClaudeTokenProvider
(
accountRepo
AccountRepository
,
tokenCache
ClaudeTokenCache
,
oauthService
*
OAuthService
,
)
*
ClaudeTokenProvider
{
return
&
ClaudeTokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
tokenCache
,
oauthService
:
oauthService
,
}
}
// GetAccessToken 获取有效的 access_token
func
(
p
*
ClaudeTokenProvider
)
GetAccessToken
(
ctx
context
.
Context
,
account
*
Account
)
(
string
,
error
)
{
if
account
==
nil
{
return
""
,
errors
.
New
(
"account is nil"
)
}
if
account
.
Platform
!=
PlatformAnthropic
||
account
.
Type
!=
AccountTypeOAuth
{
return
""
,
errors
.
New
(
"not an anthropic oauth account"
)
}
cacheKey
:=
ClaudeTokenCacheKey
(
account
)
// 1. 先尝试缓存
if
p
.
tokenCache
!=
nil
{
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
strings
.
TrimSpace
(
token
)
!=
""
{
slog
.
Debug
(
"claude_token_cache_hit"
,
"account_id"
,
account
.
ID
)
return
token
,
nil
}
else
if
err
!=
nil
{
slog
.
Warn
(
"claude_token_cache_get_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
}
slog
.
Debug
(
"claude_token_cache_miss"
,
"account_id"
,
account
.
ID
)
// 2. 如果即将过期则刷新
expiresAt
:=
account
.
GetCredentialAsTime
(
"expires_at"
)
needsRefresh
:=
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
claudeTokenRefreshSkew
refreshFailed
:=
false
if
needsRefresh
&&
p
.
tokenCache
!=
nil
{
locked
,
lockErr
:=
p
.
tokenCache
.
AcquireRefreshLock
(
ctx
,
cacheKey
,
30
*
time
.
Second
)
if
lockErr
==
nil
&&
locked
{
defer
func
()
{
_
=
p
.
tokenCache
.
ReleaseRefreshLock
(
ctx
,
cacheKey
)
}()
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
strings
.
TrimSpace
(
token
)
!=
""
{
return
token
,
nil
}
// 从数据库获取最新账户信息
fresh
,
err
:=
p
.
accountRepo
.
GetByID
(
ctx
,
account
.
ID
)
if
err
==
nil
&&
fresh
!=
nil
{
account
=
fresh
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
if
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
claudeTokenRefreshSkew
{
if
p
.
oauthService
==
nil
{
slog
.
Warn
(
"claude_oauth_service_not_configured"
,
"account_id"
,
account
.
ID
)
refreshFailed
=
true
// 无法刷新,标记失败
}
else
{
tokenInfo
,
err
:=
p
.
oauthService
.
RefreshAccountToken
(
ctx
,
account
)
if
err
!=
nil
{
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
slog
.
Warn
(
"claude_token_refresh_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
refreshFailed
=
true
// 刷新失败,标记以使用短 TTL
}
else
{
// 构建新 credentials,保留原有字段
newCredentials
:=
make
(
map
[
string
]
any
)
for
k
,
v
:=
range
account
.
Credentials
{
newCredentials
[
k
]
=
v
}
newCredentials
[
"access_token"
]
=
tokenInfo
.
AccessToken
newCredentials
[
"token_type"
]
=
tokenInfo
.
TokenType
newCredentials
[
"expires_in"
]
=
strconv
.
FormatInt
(
tokenInfo
.
ExpiresIn
,
10
)
newCredentials
[
"expires_at"
]
=
strconv
.
FormatInt
(
tokenInfo
.
ExpiresAt
,
10
)
if
tokenInfo
.
RefreshToken
!=
""
{
newCredentials
[
"refresh_token"
]
=
tokenInfo
.
RefreshToken
}
if
tokenInfo
.
Scope
!=
""
{
newCredentials
[
"scope"
]
=
tokenInfo
.
Scope
}
account
.
Credentials
=
newCredentials
if
updateErr
:=
p
.
accountRepo
.
Update
(
ctx
,
account
);
updateErr
!=
nil
{
slog
.
Error
(
"claude_token_provider_update_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
updateErr
)
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
}
}
}
else
if
lockErr
!=
nil
{
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
slog
.
Warn
(
"claude_token_lock_failed_degraded_refresh"
,
"account_id"
,
account
.
ID
,
"error"
,
lockErr
)
// 检查 ctx 是否已取消
if
ctx
.
Err
()
!=
nil
{
return
""
,
ctx
.
Err
()
}
// 从数据库获取最新账户信息
if
p
.
accountRepo
!=
nil
{
fresh
,
err
:=
p
.
accountRepo
.
GetByID
(
ctx
,
account
.
ID
)
if
err
==
nil
&&
fresh
!=
nil
{
account
=
fresh
}
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
if
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
claudeTokenRefreshSkew
{
if
p
.
oauthService
==
nil
{
slog
.
Warn
(
"claude_oauth_service_not_configured"
,
"account_id"
,
account
.
ID
)
refreshFailed
=
true
}
else
{
tokenInfo
,
err
:=
p
.
oauthService
.
RefreshAccountToken
(
ctx
,
account
)
if
err
!=
nil
{
slog
.
Warn
(
"claude_token_refresh_failed_degraded"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
refreshFailed
=
true
}
else
{
// 构建新 credentials,保留原有字段
newCredentials
:=
make
(
map
[
string
]
any
)
for
k
,
v
:=
range
account
.
Credentials
{
newCredentials
[
k
]
=
v
}
newCredentials
[
"access_token"
]
=
tokenInfo
.
AccessToken
newCredentials
[
"token_type"
]
=
tokenInfo
.
TokenType
newCredentials
[
"expires_in"
]
=
strconv
.
FormatInt
(
tokenInfo
.
ExpiresIn
,
10
)
newCredentials
[
"expires_at"
]
=
strconv
.
FormatInt
(
tokenInfo
.
ExpiresAt
,
10
)
if
tokenInfo
.
RefreshToken
!=
""
{
newCredentials
[
"refresh_token"
]
=
tokenInfo
.
RefreshToken
}
if
tokenInfo
.
Scope
!=
""
{
newCredentials
[
"scope"
]
=
tokenInfo
.
Scope
}
account
.
Credentials
=
newCredentials
if
updateErr
:=
p
.
accountRepo
.
Update
(
ctx
,
account
);
updateErr
!=
nil
{
slog
.
Error
(
"claude_token_provider_update_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
updateErr
)
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
}
}
}
else
{
// 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存
time
.
Sleep
(
claudeLockWaitTime
)
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
strings
.
TrimSpace
(
token
)
!=
""
{
slog
.
Debug
(
"claude_token_cache_hit_after_wait"
,
"account_id"
,
account
.
ID
)
return
token
,
nil
}
}
}
accessToken
:=
account
.
GetCredential
(
"access_token"
)
if
strings
.
TrimSpace
(
accessToken
)
==
""
{
return
""
,
errors
.
New
(
"access_token not found in credentials"
)
}
// 3. 存入缓存
if
p
.
tokenCache
!=
nil
{
ttl
:=
30
*
time
.
Minute
if
refreshFailed
{
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
ttl
=
time
.
Minute
slog
.
Debug
(
"claude_token_cache_short_ttl"
,
"account_id"
,
account
.
ID
,
"reason"
,
"refresh_failed"
)
}
else
if
expiresAt
!=
nil
{
until
:=
time
.
Until
(
*
expiresAt
)
switch
{
case
until
>
claudeTokenCacheSkew
:
ttl
=
until
-
claudeTokenCacheSkew
case
until
>
0
:
ttl
=
until
default
:
ttl
=
time
.
Minute
}
}
if
err
:=
p
.
tokenCache
.
SetAccessToken
(
ctx
,
cacheKey
,
accessToken
,
ttl
);
err
!=
nil
{
slog
.
Warn
(
"claude_token_cache_set_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
}
return
accessToken
,
nil
}
backend/internal/service/claude_token_provider_test.go
0 → 100644
View file @
b36f3db9
This diff is collapsed.
Click to expand it.
backend/internal/service/gateway_service.go
View file @
b36f3db9
...
@@ -159,6 +159,7 @@ type GatewayService struct {
...
@@ -159,6 +159,7 @@ type GatewayService struct {
httpUpstream
HTTPUpstream
httpUpstream
HTTPUpstream
deferredService
*
DeferredService
deferredService
*
DeferredService
concurrencyService
*
ConcurrencyService
concurrencyService
*
ConcurrencyService
claudeTokenProvider
*
ClaudeTokenProvider
}
}
// NewGatewayService creates a new GatewayService
// NewGatewayService creates a new GatewayService
...
@@ -178,6 +179,7 @@ func NewGatewayService(
...
@@ -178,6 +179,7 @@ func NewGatewayService(
identityService
*
IdentityService
,
identityService
*
IdentityService
,
httpUpstream
HTTPUpstream
,
httpUpstream
HTTPUpstream
,
deferredService
*
DeferredService
,
deferredService
*
DeferredService
,
claudeTokenProvider
*
ClaudeTokenProvider
,
)
*
GatewayService
{
)
*
GatewayService
{
return
&
GatewayService
{
return
&
GatewayService
{
accountRepo
:
accountRepo
,
accountRepo
:
accountRepo
,
...
@@ -195,6 +197,7 @@ func NewGatewayService(
...
@@ -195,6 +197,7 @@ func NewGatewayService(
identityService
:
identityService
,
identityService
:
identityService
,
httpUpstream
:
httpUpstream
,
httpUpstream
:
httpUpstream
,
deferredService
:
deferredService
,
deferredService
:
deferredService
,
claudeTokenProvider
:
claudeTokenProvider
,
}
}
}
}
...
@@ -1079,6 +1082,16 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
...
@@ -1079,6 +1082,16 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
}
}
func
(
s
*
GatewayService
)
getOAuthToken
(
ctx
context
.
Context
,
account
*
Account
)
(
string
,
string
,
error
)
{
func
(
s
*
GatewayService
)
getOAuthToken
(
ctx
context
.
Context
,
account
*
Account
)
(
string
,
string
,
error
)
{
// 对于 Anthropic OAuth 账号,使用 ClaudeTokenProvider 获取缓存的 token
if
account
.
Platform
==
PlatformAnthropic
&&
account
.
Type
==
AccountTypeOAuth
&&
s
.
claudeTokenProvider
!=
nil
{
accessToken
,
err
:=
s
.
claudeTokenProvider
.
GetAccessToken
(
ctx
,
account
)
if
err
!=
nil
{
return
""
,
""
,
err
}
return
accessToken
,
"oauth"
,
nil
}
// 其他情况(Gemini 有自己的 TokenProvider,setup-token 类型等)直接从账号读取
accessToken
:=
account
.
GetCredential
(
"access_token"
)
accessToken
:=
account
.
GetCredential
(
"access_token"
)
if
accessToken
==
""
{
if
accessToken
==
""
{
return
""
,
""
,
errors
.
New
(
"access_token not found in credentials"
)
return
""
,
""
,
errors
.
New
(
"access_token not found in credentials"
)
...
...
backend/internal/service/gemini_token_provider.go
View file @
b36f3db9
...
@@ -154,7 +154,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
...
@@ -154,7 +154,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
func
GeminiTokenCacheKey
(
account
*
Account
)
string
{
func
GeminiTokenCacheKey
(
account
*
Account
)
string
{
projectID
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"project_id"
))
projectID
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"project_id"
))
if
projectID
!=
""
{
if
projectID
!=
""
{
return
projectID
return
"gemini:"
+
projectID
}
}
return
"account:"
+
strconv
.
FormatInt
(
account
.
ID
,
10
)
return
"
gemini:
account:"
+
strconv
.
FormatInt
(
account
.
ID
,
10
)
}
}
backend/internal/service/openai_gateway_service.go
View file @
b36f3db9
...
@@ -93,6 +93,7 @@ type OpenAIGatewayService struct {
...
@@ -93,6 +93,7 @@ type OpenAIGatewayService struct {
billingCacheService
*
BillingCacheService
billingCacheService
*
BillingCacheService
httpUpstream
HTTPUpstream
httpUpstream
HTTPUpstream
deferredService
*
DeferredService
deferredService
*
DeferredService
openAITokenProvider
*
OpenAITokenProvider
}
}
// NewOpenAIGatewayService creates a new OpenAIGatewayService
// NewOpenAIGatewayService creates a new OpenAIGatewayService
...
@@ -110,6 +111,7 @@ func NewOpenAIGatewayService(
...
@@ -110,6 +111,7 @@ func NewOpenAIGatewayService(
billingCacheService
*
BillingCacheService
,
billingCacheService
*
BillingCacheService
,
httpUpstream
HTTPUpstream
,
httpUpstream
HTTPUpstream
,
deferredService
*
DeferredService
,
deferredService
*
DeferredService
,
openAITokenProvider
*
OpenAITokenProvider
,
)
*
OpenAIGatewayService
{
)
*
OpenAIGatewayService
{
return
&
OpenAIGatewayService
{
return
&
OpenAIGatewayService
{
accountRepo
:
accountRepo
,
accountRepo
:
accountRepo
,
...
@@ -125,6 +127,7 @@ func NewOpenAIGatewayService(
...
@@ -125,6 +127,7 @@ func NewOpenAIGatewayService(
billingCacheService
:
billingCacheService
,
billingCacheService
:
billingCacheService
,
httpUpstream
:
httpUpstream
,
httpUpstream
:
httpUpstream
,
deferredService
:
deferredService
,
deferredService
:
deferredService
,
openAITokenProvider
:
openAITokenProvider
,
}
}
}
}
...
@@ -503,6 +506,15 @@ func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig
...
@@ -503,6 +506,15 @@ func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig
func
(
s
*
OpenAIGatewayService
)
GetAccessToken
(
ctx
context
.
Context
,
account
*
Account
)
(
string
,
string
,
error
)
{
func
(
s
*
OpenAIGatewayService
)
GetAccessToken
(
ctx
context
.
Context
,
account
*
Account
)
(
string
,
string
,
error
)
{
switch
account
.
Type
{
switch
account
.
Type
{
case
AccountTypeOAuth
:
case
AccountTypeOAuth
:
// 使用 TokenProvider 获取缓存的 token
if
s
.
openAITokenProvider
!=
nil
{
accessToken
,
err
:=
s
.
openAITokenProvider
.
GetAccessToken
(
ctx
,
account
)
if
err
!=
nil
{
return
""
,
""
,
err
}
return
accessToken
,
"oauth"
,
nil
}
// 降级:TokenProvider 未配置时直接从账号读取
accessToken
:=
account
.
GetOpenAIAccessToken
()
accessToken
:=
account
.
GetOpenAIAccessToken
()
if
accessToken
==
""
{
if
accessToken
==
""
{
return
""
,
""
,
errors
.
New
(
"access_token not found in credentials"
)
return
""
,
""
,
errors
.
New
(
"access_token not found in credentials"
)
...
...
backend/internal/service/openai_token_provider.go
0 → 100644
View file @
b36f3db9
package
service
import
(
"context"
"errors"
"log/slog"
"strings"
"time"
)
const
(
openAITokenRefreshSkew
=
3
*
time
.
Minute
openAITokenCacheSkew
=
5
*
time
.
Minute
openAILockWaitTime
=
200
*
time
.
Millisecond
)
// OpenAITokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
type
OpenAITokenCache
=
GeminiTokenCache
// OpenAITokenProvider 管理 OpenAI OAuth 账户的 access_token
type
OpenAITokenProvider
struct
{
accountRepo
AccountRepository
tokenCache
OpenAITokenCache
openAIOAuthService
*
OpenAIOAuthService
}
func
NewOpenAITokenProvider
(
accountRepo
AccountRepository
,
tokenCache
OpenAITokenCache
,
openAIOAuthService
*
OpenAIOAuthService
,
)
*
OpenAITokenProvider
{
return
&
OpenAITokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
tokenCache
,
openAIOAuthService
:
openAIOAuthService
,
}
}
// GetAccessToken 获取有效的 access_token
func
(
p
*
OpenAITokenProvider
)
GetAccessToken
(
ctx
context
.
Context
,
account
*
Account
)
(
string
,
error
)
{
if
account
==
nil
{
return
""
,
errors
.
New
(
"account is nil"
)
}
if
account
.
Platform
!=
PlatformOpenAI
||
account
.
Type
!=
AccountTypeOAuth
{
return
""
,
errors
.
New
(
"not an openai oauth account"
)
}
cacheKey
:=
OpenAITokenCacheKey
(
account
)
// 1. 先尝试缓存
if
p
.
tokenCache
!=
nil
{
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
strings
.
TrimSpace
(
token
)
!=
""
{
slog
.
Debug
(
"openai_token_cache_hit"
,
"account_id"
,
account
.
ID
)
return
token
,
nil
}
else
if
err
!=
nil
{
slog
.
Warn
(
"openai_token_cache_get_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
}
slog
.
Debug
(
"openai_token_cache_miss"
,
"account_id"
,
account
.
ID
)
// 2. 如果即将过期则刷新
expiresAt
:=
account
.
GetCredentialAsTime
(
"expires_at"
)
needsRefresh
:=
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
openAITokenRefreshSkew
refreshFailed
:=
false
if
needsRefresh
&&
p
.
tokenCache
!=
nil
{
locked
,
lockErr
:=
p
.
tokenCache
.
AcquireRefreshLock
(
ctx
,
cacheKey
,
30
*
time
.
Second
)
if
lockErr
==
nil
&&
locked
{
defer
func
()
{
_
=
p
.
tokenCache
.
ReleaseRefreshLock
(
ctx
,
cacheKey
)
}()
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
strings
.
TrimSpace
(
token
)
!=
""
{
return
token
,
nil
}
// 从数据库获取最新账户信息
fresh
,
err
:=
p
.
accountRepo
.
GetByID
(
ctx
,
account
.
ID
)
if
err
==
nil
&&
fresh
!=
nil
{
account
=
fresh
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
if
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
openAITokenRefreshSkew
{
if
p
.
openAIOAuthService
==
nil
{
slog
.
Warn
(
"openai_oauth_service_not_configured"
,
"account_id"
,
account
.
ID
)
refreshFailed
=
true
// 无法刷新,标记失败
}
else
{
tokenInfo
,
err
:=
p
.
openAIOAuthService
.
RefreshAccountToken
(
ctx
,
account
)
if
err
!=
nil
{
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
slog
.
Warn
(
"openai_token_refresh_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
refreshFailed
=
true
// 刷新失败,标记以使用短 TTL
}
else
{
newCredentials
:=
p
.
openAIOAuthService
.
BuildAccountCredentials
(
tokenInfo
)
for
k
,
v
:=
range
account
.
Credentials
{
if
_
,
exists
:=
newCredentials
[
k
];
!
exists
{
newCredentials
[
k
]
=
v
}
}
account
.
Credentials
=
newCredentials
if
updateErr
:=
p
.
accountRepo
.
Update
(
ctx
,
account
);
updateErr
!=
nil
{
slog
.
Error
(
"openai_token_provider_update_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
updateErr
)
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
}
}
}
else
if
lockErr
!=
nil
{
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
slog
.
Warn
(
"openai_token_lock_failed_degraded_refresh"
,
"account_id"
,
account
.
ID
,
"error"
,
lockErr
)
// 检查 ctx 是否已取消
if
ctx
.
Err
()
!=
nil
{
return
""
,
ctx
.
Err
()
}
// 从数据库获取最新账户信息
if
p
.
accountRepo
!=
nil
{
fresh
,
err
:=
p
.
accountRepo
.
GetByID
(
ctx
,
account
.
ID
)
if
err
==
nil
&&
fresh
!=
nil
{
account
=
fresh
}
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
if
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
openAITokenRefreshSkew
{
if
p
.
openAIOAuthService
==
nil
{
slog
.
Warn
(
"openai_oauth_service_not_configured"
,
"account_id"
,
account
.
ID
)
refreshFailed
=
true
}
else
{
tokenInfo
,
err
:=
p
.
openAIOAuthService
.
RefreshAccountToken
(
ctx
,
account
)
if
err
!=
nil
{
slog
.
Warn
(
"openai_token_refresh_failed_degraded"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
refreshFailed
=
true
}
else
{
newCredentials
:=
p
.
openAIOAuthService
.
BuildAccountCredentials
(
tokenInfo
)
for
k
,
v
:=
range
account
.
Credentials
{
if
_
,
exists
:=
newCredentials
[
k
];
!
exists
{
newCredentials
[
k
]
=
v
}
}
account
.
Credentials
=
newCredentials
if
updateErr
:=
p
.
accountRepo
.
Update
(
ctx
,
account
);
updateErr
!=
nil
{
slog
.
Error
(
"openai_token_provider_update_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
updateErr
)
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
}
}
}
else
{
// 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存
time
.
Sleep
(
openAILockWaitTime
)
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
strings
.
TrimSpace
(
token
)
!=
""
{
slog
.
Debug
(
"openai_token_cache_hit_after_wait"
,
"account_id"
,
account
.
ID
)
return
token
,
nil
}
}
}
accessToken
:=
account
.
GetOpenAIAccessToken
()
if
strings
.
TrimSpace
(
accessToken
)
==
""
{
return
""
,
errors
.
New
(
"access_token not found in credentials"
)
}
// 3. 存入缓存
if
p
.
tokenCache
!=
nil
{
ttl
:=
30
*
time
.
Minute
if
refreshFailed
{
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
ttl
=
time
.
Minute
slog
.
Debug
(
"openai_token_cache_short_ttl"
,
"account_id"
,
account
.
ID
,
"reason"
,
"refresh_failed"
)
}
else
if
expiresAt
!=
nil
{
until
:=
time
.
Until
(
*
expiresAt
)
switch
{
case
until
>
openAITokenCacheSkew
:
ttl
=
until
-
openAITokenCacheSkew
case
until
>
0
:
ttl
=
until
default
:
ttl
=
time
.
Minute
}
}
if
err
:=
p
.
tokenCache
.
SetAccessToken
(
ctx
,
cacheKey
,
accessToken
,
ttl
);
err
!=
nil
{
slog
.
Warn
(
"openai_token_cache_set_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
}
return
accessToken
,
nil
}
backend/internal/service/openai_token_provider_test.go
0 → 100644
View file @
b36f3db9
This diff is collapsed.
Click to expand it.
backend/internal/service/ratelimit_service.go
View file @
b36f3db9
...
@@ -85,13 +85,24 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
...
@@ -85,13 +85,24 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
switch
statusCode
{
switch
statusCode
{
case
401
:
case
401
:
if
account
.
Type
==
AccountTypeOAuth
&&
// 对所有 OAuth 账号在 401 错误时调用缓存失效并强制下次刷新
(
account
.
Platform
==
PlatformAntigravity
||
account
.
Platform
==
PlatformGemini
)
{
if
account
.
Type
==
AccountTypeOAuth
{
// 1. 失效缓存
if
s
.
tokenCacheInvalidator
!=
nil
{
if
s
.
tokenCacheInvalidator
!=
nil
{
if
err
:=
s
.
tokenCacheInvalidator
.
InvalidateToken
(
ctx
,
account
);
err
!=
nil
{
if
err
:=
s
.
tokenCacheInvalidator
.
InvalidateToken
(
ctx
,
account
);
err
!=
nil
{
slog
.
Warn
(
"oauth_401_invalidate_cache_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
slog
.
Warn
(
"oauth_401_invalidate_cache_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
}
}
}
// 2. 设置 expires_at 为当前时间,强制下次请求刷新 token
if
account
.
Credentials
==
nil
{
account
.
Credentials
=
make
(
map
[
string
]
any
)
}
account
.
Credentials
[
"expires_at"
]
=
time
.
Now
()
.
Format
(
time
.
RFC3339
)
if
err
:=
s
.
accountRepo
.
Update
(
ctx
,
account
);
err
!=
nil
{
slog
.
Warn
(
"oauth_401_force_refresh_update_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
else
{
slog
.
Info
(
"oauth_401_force_refresh_set"
,
"account_id"
,
account
.
ID
,
"platform"
,
account
.
Platform
)
}
}
}
msg
:=
"Authentication failed (401): invalid or expired credentials"
msg
:=
"Authentication failed (401): invalid or expired credentials"
if
upstreamMsg
!=
""
{
if
upstreamMsg
!=
""
{
...
...
backend/internal/service/token_cache_invalidator.go
View file @
b36f3db9
...
@@ -7,29 +7,35 @@ type TokenCacheInvalidator interface {
...
@@ -7,29 +7,35 @@ type TokenCacheInvalidator interface {
}
}
type
CompositeTokenCacheInvalidator
struct
{
type
CompositeTokenCacheInvalidator
struct
{
geminiC
ache
GeminiTokenCache
c
ache
GeminiTokenCache
// 统一使用一个缓存接口,通过缓存键前缀区分平台
}
}
func
NewCompositeTokenCacheInvalidator
(
geminiC
ache
GeminiTokenCache
)
*
CompositeTokenCacheInvalidator
{
func
NewCompositeTokenCacheInvalidator
(
c
ache
GeminiTokenCache
)
*
CompositeTokenCacheInvalidator
{
return
&
CompositeTokenCacheInvalidator
{
return
&
CompositeTokenCacheInvalidator
{
geminiCache
:
geminiC
ache
,
cache
:
c
ache
,
}
}
}
}
func
(
c
*
CompositeTokenCacheInvalidator
)
InvalidateToken
(
ctx
context
.
Context
,
account
*
Account
)
error
{
func
(
c
*
CompositeTokenCacheInvalidator
)
InvalidateToken
(
ctx
context
.
Context
,
account
*
Account
)
error
{
if
c
==
nil
||
c
.
geminiC
ache
==
nil
||
account
==
nil
{
if
c
==
nil
||
c
.
c
ache
==
nil
||
account
==
nil
{
return
nil
return
nil
}
}
if
account
.
Type
!=
AccountTypeOAuth
{
if
account
.
Type
!=
AccountTypeOAuth
{
return
nil
return
nil
}
}
var
cacheKey
string
switch
account
.
Platform
{
switch
account
.
Platform
{
case
PlatformGemini
:
case
PlatformGemini
:
return
c
.
geminiCache
.
DeleteAccessToken
(
ctx
,
GeminiTokenCacheKey
(
account
)
)
cacheKey
=
GeminiTokenCacheKey
(
account
)
case
PlatformAntigravity
:
case
PlatformAntigravity
:
return
c
.
geminiCache
.
DeleteAccessToken
(
ctx
,
AntigravityTokenCacheKey
(
account
))
cacheKey
=
AntigravityTokenCacheKey
(
account
)
case
PlatformOpenAI
:
cacheKey
=
OpenAITokenCacheKey
(
account
)
case
PlatformAnthropic
:
cacheKey
=
ClaudeTokenCacheKey
(
account
)
default
:
default
:
return
nil
return
nil
}
}
return
c
.
cache
.
DeleteAccessToken
(
ctx
,
cacheKey
)
}
}
backend/internal/service/token_cache_invalidator_test.go
View file @
b36f3db9
...
@@ -4,6 +4,7 @@ package service
...
@@ -4,6 +4,7 @@ package service
import
(
import
(
"context"
"context"
"errors"
"testing"
"testing"
"time"
"time"
...
@@ -50,7 +51,7 @@ func TestCompositeTokenCacheInvalidator_Gemini(t *testing.T) {
...
@@ -50,7 +51,7 @@ func TestCompositeTokenCacheInvalidator_Gemini(t *testing.T) {
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
account
)
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
[]
string
{
"project-x"
},
cache
.
deletedKeys
)
require
.
Equal
(
t
,
[]
string
{
"
gemini:
project-x"
},
cache
.
deletedKeys
)
}
}
func
TestCompositeTokenCacheInvalidator_Antigravity
(
t
*
testing
.
T
)
{
func
TestCompositeTokenCacheInvalidator_Antigravity
(
t
*
testing
.
T
)
{
...
@@ -70,13 +71,99 @@ func TestCompositeTokenCacheInvalidator_Antigravity(t *testing.T) {
...
@@ -70,13 +71,99 @@ func TestCompositeTokenCacheInvalidator_Antigravity(t *testing.T) {
require
.
Equal
(
t
,
[]
string
{
"ag:ag-project"
},
cache
.
deletedKeys
)
require
.
Equal
(
t
,
[]
string
{
"ag:ag-project"
},
cache
.
deletedKeys
)
}
}
func
TestCompositeTokenCacheInvalidator_OpenAI
(
t
*
testing
.
T
)
{
cache
:=
&
geminiTokenCacheStub
{}
invalidator
:=
NewCompositeTokenCacheInvalidator
(
cache
)
account
:=
&
Account
{
ID
:
500
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"openai-token"
,
},
}
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
[]
string
{
"openai:account:500"
},
cache
.
deletedKeys
)
}
func
TestCompositeTokenCacheInvalidator_Claude
(
t
*
testing
.
T
)
{
cache
:=
&
geminiTokenCacheStub
{}
invalidator
:=
NewCompositeTokenCacheInvalidator
(
cache
)
account
:=
&
Account
{
ID
:
600
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"claude-token"
,
},
}
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
[]
string
{
"claude:account:600"
},
cache
.
deletedKeys
)
}
func
TestCompositeTokenCacheInvalidator_SkipNonOAuth
(
t
*
testing
.
T
)
{
func
TestCompositeTokenCacheInvalidator_SkipNonOAuth
(
t
*
testing
.
T
)
{
cache
:=
&
geminiTokenCacheStub
{}
cache
:=
&
geminiTokenCacheStub
{}
invalidator
:=
NewCompositeTokenCacheInvalidator
(
cache
)
invalidator
:=
NewCompositeTokenCacheInvalidator
(
cache
)
tests
:=
[]
struct
{
name
string
account
*
Account
}{
{
name
:
"gemini_api_key"
,
account
:
&
Account
{
ID
:
1
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeAPIKey
,
},
},
{
name
:
"openai_api_key"
,
account
:
&
Account
{
ID
:
2
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
},
},
{
name
:
"claude_api_key"
,
account
:
&
Account
{
ID
:
3
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
},
},
{
name
:
"claude_setup_token"
,
account
:
&
Account
{
ID
:
4
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeSetupToken
,
},
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
cache
.
deletedKeys
=
nil
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
tt
.
account
)
require
.
NoError
(
t
,
err
)
require
.
Empty
(
t
,
cache
.
deletedKeys
)
})
}
}
func
TestCompositeTokenCacheInvalidator_SkipUnsupportedPlatform
(
t
*
testing
.
T
)
{
cache
:=
&
geminiTokenCacheStub
{}
invalidator
:=
NewCompositeTokenCacheInvalidator
(
cache
)
account
:=
&
Account
{
account
:=
&
Account
{
ID
:
1
,
ID
:
1
00
,
Platform
:
P
latform
Gemini
,
Platform
:
"unknown-p
latform
"
,
Type
:
AccountType
APIKey
,
Type
:
AccountType
OAuth
,
}
}
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
account
)
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
account
)
...
@@ -95,3 +182,87 @@ func TestCompositeTokenCacheInvalidator_NilCache(t *testing.T) {
...
@@ -95,3 +182,87 @@ func TestCompositeTokenCacheInvalidator_NilCache(t *testing.T) {
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
account
)
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
}
}
func
TestCompositeTokenCacheInvalidator_NilAccount
(
t
*
testing
.
T
)
{
cache
:=
&
geminiTokenCacheStub
{}
invalidator
:=
NewCompositeTokenCacheInvalidator
(
cache
)
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
nil
)
require
.
NoError
(
t
,
err
)
require
.
Empty
(
t
,
cache
.
deletedKeys
)
}
func
TestCompositeTokenCacheInvalidator_NilInvalidator
(
t
*
testing
.
T
)
{
var
invalidator
*
CompositeTokenCacheInvalidator
account
:=
&
Account
{
ID
:
5
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
,
}
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
}
func
TestCompositeTokenCacheInvalidator_DeleteError
(
t
*
testing
.
T
)
{
expectedErr
:=
errors
.
New
(
"redis connection failed"
)
cache
:=
&
geminiTokenCacheStub
{
deleteErr
:
expectedErr
}
invalidator
:=
NewCompositeTokenCacheInvalidator
(
cache
)
tests
:=
[]
struct
{
name
string
account
*
Account
}{
{
name
:
"openai_delete_error"
,
account
:
&
Account
{
ID
:
700
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
},
},
{
name
:
"claude_delete_error"
,
account
:
&
Account
{
ID
:
800
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
},
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
tt
.
account
)
require
.
Error
(
t
,
err
)
require
.
Equal
(
t
,
expectedErr
,
err
)
})
}
}
func
TestCompositeTokenCacheInvalidator_AllPlatformsIntegration
(
t
*
testing
.
T
)
{
// 测试所有平台的缓存键生成和删除
cache
:=
&
geminiTokenCacheStub
{}
invalidator
:=
NewCompositeTokenCacheInvalidator
(
cache
)
accounts
:=
[]
*
Account
{
{
ID
:
1
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"project_id"
:
"gemini-proj"
}},
{
ID
:
2
,
Platform
:
PlatformAntigravity
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"project_id"
:
"ag-proj"
}},
{
ID
:
3
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
},
{
ID
:
4
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
},
}
expectedKeys
:=
[]
string
{
"gemini:gemini-proj"
,
"ag:ag-proj"
,
"openai:account:3"
,
"claude:account:4"
,
}
for
_
,
acc
:=
range
accounts
{
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
acc
)
require
.
NoError
(
t
,
err
)
}
require
.
Equal
(
t
,
expectedKeys
,
cache
.
deletedKeys
)
}
backend/internal/service/token_cache_key.go
0 → 100644
View file @
b36f3db9
package
service
import
"strconv"
// OpenAITokenCacheKey 生成 OpenAI OAuth 账号的缓存键
// 格式: "openai:account:{account_id}"
func
OpenAITokenCacheKey
(
account
*
Account
)
string
{
return
"openai:account:"
+
strconv
.
FormatInt
(
account
.
ID
,
10
)
}
// ClaudeTokenCacheKey 生成 Claude (Anthropic) OAuth 账号的缓存键
// 格式: "claude:account:{account_id}"
func
ClaudeTokenCacheKey
(
account
*
Account
)
string
{
return
"claude:account:"
+
strconv
.
FormatInt
(
account
.
ID
,
10
)
}
backend/internal/service/token_cache_key_test.go
View file @
b36f3db9
...
@@ -22,7 +22,7 @@ func TestGeminiTokenCacheKey(t *testing.T) {
...
@@ -22,7 +22,7 @@ func TestGeminiTokenCacheKey(t *testing.T) {
"project_id"
:
"my-project-123"
,
"project_id"
:
"my-project-123"
,
},
},
},
},
expected
:
"my-project-123"
,
expected
:
"
gemini:
my-project-123"
,
},
},
{
{
name
:
"project_id_with_whitespace"
,
name
:
"project_id_with_whitespace"
,
...
@@ -32,7 +32,7 @@ func TestGeminiTokenCacheKey(t *testing.T) {
...
@@ -32,7 +32,7 @@ func TestGeminiTokenCacheKey(t *testing.T) {
"project_id"
:
" project-with-spaces "
,
"project_id"
:
" project-with-spaces "
,
},
},
},
},
expected
:
"project-with-spaces"
,
expected
:
"
gemini:
project-with-spaces"
,
},
},
{
{
name
:
"empty_project_id_fallback_to_account_id"
,
name
:
"empty_project_id_fallback_to_account_id"
,
...
@@ -42,7 +42,7 @@ func TestGeminiTokenCacheKey(t *testing.T) {
...
@@ -42,7 +42,7 @@ func TestGeminiTokenCacheKey(t *testing.T) {
"project_id"
:
""
,
"project_id"
:
""
,
},
},
},
},
expected
:
"account:102"
,
expected
:
"
gemini:
account:102"
,
},
},
{
{
name
:
"whitespace_only_project_id_fallback_to_account_id"
,
name
:
"whitespace_only_project_id_fallback_to_account_id"
,
...
@@ -52,7 +52,7 @@ func TestGeminiTokenCacheKey(t *testing.T) {
...
@@ -52,7 +52,7 @@ func TestGeminiTokenCacheKey(t *testing.T) {
"project_id"
:
" "
,
"project_id"
:
" "
,
},
},
},
},
expected
:
"account:103"
,
expected
:
"
gemini:
account:103"
,
},
},
{
{
name
:
"no_project_id_key_fallback_to_account_id"
,
name
:
"no_project_id_key_fallback_to_account_id"
,
...
@@ -60,7 +60,7 @@ func TestGeminiTokenCacheKey(t *testing.T) {
...
@@ -60,7 +60,7 @@ func TestGeminiTokenCacheKey(t *testing.T) {
ID
:
104
,
ID
:
104
,
Credentials
:
map
[
string
]
any
{},
Credentials
:
map
[
string
]
any
{},
},
},
expected
:
"account:104"
,
expected
:
"
gemini:
account:104"
,
},
},
{
{
name
:
"nil_credentials_fallback_to_account_id"
,
name
:
"nil_credentials_fallback_to_account_id"
,
...
@@ -68,7 +68,7 @@ func TestGeminiTokenCacheKey(t *testing.T) {
...
@@ -68,7 +68,7 @@ func TestGeminiTokenCacheKey(t *testing.T) {
ID
:
105
,
ID
:
105
,
Credentials
:
nil
,
Credentials
:
nil
,
},
},
expected
:
"account:105"
,
expected
:
"
gemini:
account:105"
,
},
},
}
}
...
@@ -151,3 +151,109 @@ func TestAntigravityTokenCacheKey(t *testing.T) {
...
@@ -151,3 +151,109 @@ func TestAntigravityTokenCacheKey(t *testing.T) {
})
})
}
}
}
}
func
TestOpenAITokenCacheKey
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
account
*
Account
expected
string
}{
{
name
:
"basic_account"
,
account
:
&
Account
{
ID
:
300
,
},
expected
:
"openai:account:300"
,
},
{
name
:
"account_with_credentials"
,
account
:
&
Account
{
ID
:
301
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"test-token"
,
},
},
expected
:
"openai:account:301"
,
},
{
name
:
"account_id_zero"
,
account
:
&
Account
{
ID
:
0
,
},
expected
:
"openai:account:0"
,
},
{
name
:
"large_account_id"
,
account
:
&
Account
{
ID
:
9999999999
,
},
expected
:
"openai:account:9999999999"
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
result
:=
OpenAITokenCacheKey
(
tt
.
account
)
require
.
Equal
(
t
,
tt
.
expected
,
result
)
})
}
}
func
TestClaudeTokenCacheKey
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
account
*
Account
expected
string
}{
{
name
:
"basic_account"
,
account
:
&
Account
{
ID
:
400
,
},
expected
:
"claude:account:400"
,
},
{
name
:
"account_with_credentials"
,
account
:
&
Account
{
ID
:
401
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"claude-token"
,
},
},
expected
:
"claude:account:401"
,
},
{
name
:
"account_id_zero"
,
account
:
&
Account
{
ID
:
0
,
},
expected
:
"claude:account:0"
,
},
{
name
:
"large_account_id"
,
account
:
&
Account
{
ID
:
9999999999
,
},
expected
:
"claude:account:9999999999"
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
result
:=
ClaudeTokenCacheKey
(
tt
.
account
)
require
.
Equal
(
t
,
tt
.
expected
,
result
)
})
}
}
func
TestCacheKeyUniqueness
(
t
*
testing
.
T
)
{
// 确保不同平台的缓存键不会冲突
account
:=
&
Account
{
ID
:
123
}
openaiKey
:=
OpenAITokenCacheKey
(
account
)
claudeKey
:=
ClaudeTokenCacheKey
(
account
)
require
.
NotEqual
(
t
,
openaiKey
,
claudeKey
,
"OpenAI and Claude cache keys should be different"
)
require
.
Contains
(
t
,
openaiKey
,
"openai:"
)
require
.
Contains
(
t
,
claudeKey
,
"claude:"
)
}
backend/internal/service/token_refresh_service.go
View file @
b36f3db9
...
@@ -172,8 +172,8 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
...
@@ -172,8 +172,8 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
if
err
:=
s
.
accountRepo
.
Update
(
ctx
,
account
);
err
!=
nil
{
if
err
:=
s
.
accountRepo
.
Update
(
ctx
,
account
);
err
!=
nil
{
return
fmt
.
Errorf
(
"failed to save credentials: %w"
,
err
)
return
fmt
.
Errorf
(
"failed to save credentials: %w"
,
err
)
}
}
if
s
.
cacheInvalidator
!=
nil
&&
account
.
Type
==
AccountTypeOAuth
&&
// 对所有 OAuth 账号调用缓存失效(InvalidateToken 内部根据平台判断是否需要处理)
(
account
.
Pl
at
f
or
m
==
PlatformGemini
||
account
.
Platform
==
PlatformAntigravity
)
{
if
s
.
cacheInvalid
ator
!=
nil
&&
account
.
Type
==
AccountTypeOAuth
{
if
err
:=
s
.
cacheInvalidator
.
InvalidateToken
(
ctx
,
account
);
err
!=
nil
{
if
err
:=
s
.
cacheInvalidator
.
InvalidateToken
(
ctx
,
account
);
err
!=
nil
{
log
.
Printf
(
"[TokenRefresh] Failed to invalidate token cache for account %d: %v"
,
account
.
ID
,
err
)
log
.
Printf
(
"[TokenRefresh] Failed to invalidate token cache for account %d: %v"
,
account
.
ID
,
err
)
}
else
{
}
else
{
...
...
backend/internal/service/token_refresh_service_test.go
View file @
b36f3db9
...
@@ -197,7 +197,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount(t *testing.T) {
...
@@ -197,7 +197,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount(t *testing.T) {
require
.
Equal
(
t
,
0
,
invalidator
.
calls
)
// 非 OAuth 不触发缓存失效
require
.
Equal
(
t
,
0
,
invalidator
.
calls
)
// 非 OAuth 不触发缓存失效
}
}
// TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth 测试
其他平台的
OAuth
账号不
触发缓存失效
// TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth 测试
所有
OAuth
平台都
触发缓存失效
func
TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth
(
t
*
testing
.
T
)
{
func
TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth
(
t
*
testing
.
T
)
{
repo
:=
&
tokenRefreshAccountRepo
{}
repo
:=
&
tokenRefreshAccountRepo
{}
invalidator
:=
&
tokenCacheInvalidatorStub
{}
invalidator
:=
&
tokenCacheInvalidatorStub
{}
...
@@ -210,7 +210,7 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
...
@@ -210,7 +210,7 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
service
:=
NewTokenRefreshService
(
repo
,
nil
,
nil
,
nil
,
nil
,
invalidator
,
cfg
)
service
:=
NewTokenRefreshService
(
repo
,
nil
,
nil
,
nil
,
nil
,
invalidator
,
cfg
)
account
:=
&
Account
{
account
:=
&
Account
{
ID
:
10
,
ID
:
10
,
Platform
:
PlatformOpenAI
,
//
其他平台
Platform
:
PlatformOpenAI
,
//
OpenAI OAuth 账户
Type
:
AccountTypeOAuth
,
Type
:
AccountTypeOAuth
,
}
}
refresher
:=
&
tokenRefresherStub
{
refresher
:=
&
tokenRefresherStub
{
...
@@ -222,7 +222,7 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
...
@@ -222,7 +222,7 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
)
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
repo
.
updateCalls
)
require
.
Equal
(
t
,
1
,
repo
.
updateCalls
)
require
.
Equal
(
t
,
0
,
invalidator
.
calls
)
//
其他平台不
触发缓存失效
require
.
Equal
(
t
,
1
,
invalidator
.
calls
)
//
所有 OAuth 账户刷新后
触发缓存失效
}
}
// TestTokenRefreshService_RefreshWithRetry_UpdateFailed 测试更新失败的情况
// TestTokenRefreshService_RefreshWithRetry_UpdateFailed 测试更新失败的情况
...
...
backend/internal/service/wire.go
View file @
b36f3db9
...
@@ -214,10 +214,13 @@ var ProviderSet = wire.NewSet(
...
@@ -214,10 +214,13 @@ var ProviderSet = wire.NewSet(
NewGeminiOAuthService
,
NewGeminiOAuthService
,
NewGeminiQuotaService
,
NewGeminiQuotaService
,
NewCompositeTokenCacheInvalidator
,
NewCompositeTokenCacheInvalidator
,
wire
.
Bind
(
new
(
TokenCacheInvalidator
),
new
(
*
CompositeTokenCacheInvalidator
)),
NewAntigravityOAuthService
,
NewAntigravityOAuthService
,
NewGeminiTokenProvider
,
NewGeminiTokenProvider
,
NewGeminiMessagesCompatService
,
NewGeminiMessagesCompatService
,
NewAntigravityTokenProvider
,
NewAntigravityTokenProvider
,
NewOpenAITokenProvider
,
NewClaudeTokenProvider
,
NewAntigravityGatewayService
,
NewAntigravityGatewayService
,
ProvideRateLimitService
,
ProvideRateLimitService
,
NewAccountUsageService
,
NewAccountUsageService
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment