Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
ec82c37d
Unverified
Commit
ec82c37d
authored
Mar 16, 2026
by
Wesley Liddick
Committed by
GitHub
Mar 16, 2026
Browse files
Merge pull request #1042 from touwaeriol/feat/unified-oauth-refresh-api
feat: unified OAuth token refresh API with distributed locking
parents
d3a9f5bb
044d3a01
Changes
14
Show whitespace changes
Inline
Side-by-side
backend/cmd/server/wire_gen.go
View file @
ec82c37d
...
...
@@ -124,6 +124,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
tempUnschedCache
:=
repository
.
NewTempUnschedCache
(
redisClient
)
timeoutCounterCache
:=
repository
.
NewTimeoutCounterCache
(
redisClient
)
geminiTokenCache
:=
repository
.
NewGeminiTokenCache
(
redisClient
)
oauthRefreshAPI
:=
service
.
NewOAuthRefreshAPI
(
accountRepository
,
geminiTokenCache
)
compositeTokenCacheInvalidator
:=
service
.
NewCompositeTokenCacheInvalidator
(
geminiTokenCache
)
rateLimitService
:=
service
.
ProvideRateLimitService
(
accountRepository
,
usageLogRepository
,
configConfig
,
geminiQuotaService
,
tempUnschedCache
,
timeoutCounterCache
,
settingService
,
compositeTokenCacheInvalidator
)
httpUpstream
:=
repository
.
NewHTTPUpstream
(
configConfig
)
...
...
@@ -132,11 +133,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
usageCache
:=
service
.
NewUsageCache
()
identityCache
:=
repository
.
NewIdentityCache
(
redisClient
)
accountUsageService
:=
service
.
NewAccountUsageService
(
accountRepository
,
usageLogRepository
,
claudeUsageFetcher
,
geminiQuotaService
,
antigravityQuotaFetcher
,
usageCache
,
identityCache
)
geminiTokenProvider
:=
service
.
New
GeminiTokenProvider
(
accountRepository
,
geminiTokenCache
,
geminiOAuthService
)
geminiTokenProvider
:=
service
.
Provide
GeminiTokenProvider
(
accountRepository
,
geminiTokenCache
,
geminiOAuthService
,
oauthRefreshAPI
)
gatewayCache
:=
repository
.
NewGatewayCache
(
redisClient
)
schedulerOutboxRepository
:=
repository
.
NewSchedulerOutboxRepository
(
db
)
schedulerSnapshotService
:=
service
.
ProvideSchedulerSnapshotService
(
schedulerCache
,
schedulerOutboxRepository
,
accountRepository
,
groupRepository
,
configConfig
)
antigravityTokenProvider
:=
service
.
New
AntigravityTokenProvider
(
accountRepository
,
geminiTokenCache
,
antigravityOAuthService
)
antigravityTokenProvider
:=
service
.
Provide
AntigravityTokenProvider
(
accountRepository
,
geminiTokenCache
,
antigravityOAuthService
,
oauthRefreshAPI
)
antigravityGatewayService
:=
service
.
NewAntigravityGatewayService
(
accountRepository
,
gatewayCache
,
schedulerSnapshotService
,
antigravityTokenProvider
,
rateLimitService
,
httpUpstream
,
settingService
)
accountTestService
:=
service
.
NewAccountTestService
(
accountRepository
,
geminiTokenProvider
,
antigravityGatewayService
,
httpUpstream
,
configConfig
)
crsSyncService
:=
service
.
NewCRSSyncService
(
accountRepository
,
proxyRepository
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
,
configConfig
)
...
...
@@ -166,10 +167,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
billingService
:=
service
.
NewBillingService
(
configConfig
,
pricingService
)
identityService
:=
service
.
NewIdentityService
(
identityCache
)
deferredService
:=
service
.
ProvideDeferredService
(
accountRepository
,
timingWheelService
)
claudeTokenProvider
:=
service
.
New
ClaudeTokenProvider
(
accountRepository
,
geminiTokenCache
,
oAuthService
)
claudeTokenProvider
:=
service
.
Provide
ClaudeTokenProvider
(
accountRepository
,
geminiTokenCache
,
oAuthService
,
oauthRefreshAPI
)
digestSessionStore
:=
service
.
NewDigestSessionStore
()
gatewayService
:=
service
.
NewGatewayService
(
accountRepository
,
groupRepository
,
usageLogRepository
,
usageBillingRepository
,
userRepository
,
userSubscriptionRepository
,
userGroupRateRepository
,
gatewayCache
,
configConfig
,
schedulerSnapshotService
,
concurrencyService
,
billingService
,
rateLimitService
,
billingCacheService
,
identityService
,
httpUpstream
,
deferredService
,
claudeTokenProvider
,
sessionLimitCache
,
rpmCache
,
digestSessionStore
,
settingService
)
openAITokenProvider
:=
service
.
New
OpenAITokenProvider
(
accountRepository
,
geminiTokenCache
,
openAIOAuthService
)
openAITokenProvider
:=
service
.
Provide
OpenAITokenProvider
(
accountRepository
,
geminiTokenCache
,
openAIOAuthService
,
oauthRefreshAPI
)
openAIGatewayService
:=
service
.
NewOpenAIGatewayService
(
accountRepository
,
usageLogRepository
,
usageBillingRepository
,
userRepository
,
userSubscriptionRepository
,
userGroupRateRepository
,
gatewayCache
,
configConfig
,
schedulerSnapshotService
,
concurrencyService
,
billingService
,
rateLimitService
,
billingCacheService
,
httpUpstream
,
deferredService
,
openAITokenProvider
)
geminiMessagesCompatService
:=
service
.
NewGeminiMessagesCompatService
(
accountRepository
,
groupRepository
,
gatewayCache
,
schedulerSnapshotService
,
geminiTokenProvider
,
rateLimitService
,
httpUpstream
,
antigravityGatewayService
,
configConfig
)
opsSystemLogSink
:=
service
.
ProvideOpsSystemLogSink
(
opsRepository
)
...
...
@@ -232,7 +233,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsCleanupService
:=
service
.
ProvideOpsCleanupService
(
opsRepository
,
db
,
redisClient
,
configConfig
)
opsScheduledReportService
:=
service
.
ProvideOpsScheduledReportService
(
opsService
,
userService
,
emailService
,
redisClient
,
configConfig
)
soraMediaCleanupService
:=
service
.
ProvideSoraMediaCleanupService
(
soraMediaStorage
,
configConfig
)
tokenRefreshService
:=
service
.
ProvideTokenRefreshService
(
accountRepository
,
soraAccountRepository
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
,
antigravityOAuthService
,
compositeTokenCacheInvalidator
,
schedulerCache
,
configConfig
,
tempUnschedCache
,
privacyClientFactory
,
proxyRepository
)
tokenRefreshService
:=
service
.
ProvideTokenRefreshService
(
accountRepository
,
soraAccountRepository
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
,
antigravityOAuthService
,
compositeTokenCacheInvalidator
,
schedulerCache
,
configConfig
,
tempUnschedCache
,
privacyClientFactory
,
proxyRepository
,
oauthRefreshAPI
)
accountExpiryService
:=
service
.
ProvideAccountExpiryService
(
accountRepository
)
subscriptionExpiryService
:=
service
.
ProvideSubscriptionExpiryService
(
userSubscriptionRepository
)
scheduledTestRunnerService
:=
service
.
ProvideScheduledTestRunnerService
(
scheduledTestPlanRepository
,
scheduledTestService
,
accountTestService
,
rateLimitService
,
configConfig
)
...
...
backend/internal/service/antigravity_token_provider.go
View file @
ec82c37d
...
...
@@ -3,7 +3,6 @@ package service
import
(
"context"
"errors"
"log"
"log/slog"
"strconv"
"strings"
...
...
@@ -17,15 +16,18 @@ const (
antigravityBackfillCooldown
=
5
*
time
.
Minute
)
// AntigravityTokenCache
T
oken
缓存接口(复用 GeminiTokenCache 接口定义)
// AntigravityTokenCache
t
oken
cache interface.
type
AntigravityTokenCache
=
GeminiTokenCache
// AntigravityTokenProvider
管理 Antigravity 账户的 access_token
// AntigravityTokenProvider
manages access_token for antigravity accounts.
type
AntigravityTokenProvider
struct
{
accountRepo
AccountRepository
tokenCache
AntigravityTokenCache
antigravityOAuthService
*
AntigravityOAuthService
backfillCooldown
sync
.
Map
// key: int64 (account.ID) → value: time.Time
backfillCooldown
sync
.
Map
// key: accountID -> last attempt time
refreshAPI
*
OAuthRefreshAPI
executor
OAuthRefreshExecutor
refreshPolicy
ProviderRefreshPolicy
}
func
NewAntigravityTokenProvider
(
...
...
@@ -37,10 +39,22 @@ func NewAntigravityTokenProvider(
accountRepo
:
accountRepo
,
tokenCache
:
tokenCache
,
antigravityOAuthService
:
antigravityOAuthService
,
refreshPolicy
:
AntigravityProviderRefreshPolicy
(),
}
}
// GetAccessToken 获取有效的 access_token
// SetRefreshAPI injects unified OAuth refresh API and executor.
func
(
p
*
AntigravityTokenProvider
)
SetRefreshAPI
(
api
*
OAuthRefreshAPI
,
executor
OAuthRefreshExecutor
)
{
p
.
refreshAPI
=
api
p
.
executor
=
executor
}
// SetRefreshPolicy injects caller-side refresh policy.
func
(
p
*
AntigravityTokenProvider
)
SetRefreshPolicy
(
policy
ProviderRefreshPolicy
)
{
p
.
refreshPolicy
=
policy
}
// GetAccessToken returns a valid access_token.
func
(
p
*
AntigravityTokenProvider
)
GetAccessToken
(
ctx
context
.
Context
,
account
*
Account
)
(
string
,
error
)
{
if
account
==
nil
{
return
""
,
errors
.
New
(
"account is nil"
)
...
...
@@ -48,7 +62,8 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
if
account
.
Platform
!=
PlatformAntigravity
{
return
""
,
errors
.
New
(
"not an antigravity account"
)
}
// upstream 类型:直接从 credentials 读取 api_key,不走 OAuth 刷新流程
// upstream accounts use static api_key and never refresh oauth token.
if
account
.
Type
==
AccountTypeUpstream
{
apiKey
:=
account
.
GetCredential
(
"api_key"
)
if
apiKey
==
""
{
...
...
@@ -62,46 +77,38 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
cacheKey
:=
AntigravityTokenCacheKey
(
account
)
// 1
. 先尝试缓存
// 1
) Try cache first.
if
p
.
tokenCache
!=
nil
{
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
strings
.
TrimSpace
(
token
)
!=
""
{
return
token
,
nil
}
}
// 2
. 如果即将过期则刷新
// 2
) Refresh if needed (pre-expiry skew).
expiresAt
:=
account
.
GetCredentialAsTime
(
"expires_at"
)
needsRefresh
:=
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
antigravityTokenRefreshSkew
if
needsRefresh
&&
p
.
tokenCache
!=
nil
{
locked
,
err
:=
p
.
tokenCache
.
AcquireRefreshLock
(
ctx
,
cacheKey
,
30
*
time
.
Second
)
if
err
==
nil
&&
locked
{
defer
func
()
{
_
=
p
.
tokenCache
.
ReleaseRefreshLock
(
ctx
,
cacheKey
)
}()
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
strings
.
TrimSpace
(
token
)
!=
""
{
return
token
,
nil
}
// 从数据库获取最新账户信息
fresh
,
err
:=
p
.
accountRepo
.
GetByID
(
ctx
,
account
.
ID
)
if
err
==
nil
&&
fresh
!=
nil
{
account
=
fresh
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
if
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
antigravityTokenRefreshSkew
{
if
p
.
antigravityOAuthService
==
nil
{
return
""
,
errors
.
New
(
"antigravity oauth service not configured"
)
}
tokenInfo
,
err
:=
p
.
antigravityOAuthService
.
RefreshAccountToken
(
ctx
,
account
)
if
needsRefresh
&&
p
.
refreshAPI
!=
nil
&&
p
.
executor
!=
nil
{
result
,
err
:=
p
.
refreshAPI
.
RefreshIfNeeded
(
ctx
,
account
,
p
.
executor
,
antigravityTokenRefreshSkew
)
if
err
!=
nil
{
if
p
.
refreshPolicy
.
OnRefreshError
==
ProviderRefreshErrorReturn
{
return
""
,
err
}
p
.
mergeCredentials
(
account
,
tokenInfo
)
if
updateErr
:=
p
.
accountRepo
.
Update
(
ctx
,
account
);
updateErr
!=
nil
{
log
.
Printf
(
"[AntigravityTokenProvider] Failed to update account credentials: %v"
,
updateErr
)
}
else
if
result
.
LockHeld
{
if
p
.
refreshPolicy
.
OnLockHeld
==
ProviderLockHeldWaitForCache
&&
p
.
tokenCache
!=
nil
{
if
token
,
cacheErr
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
cacheErr
==
nil
&&
strings
.
TrimSpace
(
token
)
!=
""
{
return
token
,
nil
}
}
// default policy: continue with existing token.
}
else
{
account
=
result
.
Account
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
}
else
if
needsRefresh
&&
p
.
tokenCache
!=
nil
{
// Backward-compatible test path when refreshAPI is not injected.
locked
,
err
:=
p
.
tokenCache
.
AcquireRefreshLock
(
ctx
,
cacheKey
,
30
*
time
.
Second
)
if
err
==
nil
&&
locked
{
defer
func
()
{
_
=
p
.
tokenCache
.
ReleaseRefreshLock
(
ctx
,
cacheKey
)
}()
}
}
...
...
@@ -110,32 +117,31 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
return
""
,
errors
.
New
(
"access_token not found in credentials"
)
}
// 如果账号还没有 project_id,尝试在线补齐,避免请求 daily/sandbox 时出现
// "Invalid project resource name projects/"。
// 仅调用 loadProjectIDWithRetry,不刷新 OAuth token;带冷却机制防止频繁重试。
// Backfill project_id online when missing, with cooldown to avoid hammering.
if
strings
.
TrimSpace
(
account
.
GetCredential
(
"project_id"
))
==
""
&&
p
.
antigravityOAuthService
!=
nil
{
if
p
.
shouldAttemptBackfill
(
account
.
ID
)
{
p
.
markBackfillAttempted
(
account
.
ID
)
if
projectID
,
err
:=
p
.
antigravityOAuthService
.
FillProjectID
(
ctx
,
account
,
accessToken
);
err
==
nil
&&
projectID
!=
""
{
account
.
Credentials
[
"project_id"
]
=
projectID
if
updateErr
:=
p
.
accountRepo
.
Update
(
ctx
,
account
);
updateErr
!=
nil
{
log
.
Printf
(
"[AntigravityTokenProvider] project_id 补齐持久化失败: %v"
,
updateErr
)
slog
.
Warn
(
"antigravity_project_id_backfill_persist_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
updateErr
,
)
}
}
}
}
// 3
. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
// 3
) Populate cache with TTL.
if
p
.
tokenCache
!=
nil
{
latestAccount
,
isStale
:=
CheckTokenVersion
(
ctx
,
account
,
p
.
accountRepo
)
if
isStale
&&
latestAccount
!=
nil
{
// 版本过时,使用 DB 中的最新 token
slog
.
Debug
(
"antigravity_token_version_stale_use_latest"
,
"account_id"
,
account
.
ID
)
accessToken
=
latestAccount
.
GetCredential
(
"access_token"
)
if
strings
.
TrimSpace
(
accessToken
)
==
""
{
return
""
,
errors
.
New
(
"access_token not found after version check"
)
}
// 不写入缓存,让下次请求重新处理
}
else
{
ttl
:=
30
*
time
.
Minute
if
expiresAt
!=
nil
{
...
...
@@ -156,18 +162,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
return
accessToken
,
nil
}
// mergeCredentials 将 tokenInfo 构建的凭证合并到 account 中,保留原有未覆盖的字段
func
(
p
*
AntigravityTokenProvider
)
mergeCredentials
(
account
*
Account
,
tokenInfo
*
AntigravityTokenInfo
)
{
newCredentials
:=
p
.
antigravityOAuthService
.
BuildAccountCredentials
(
tokenInfo
)
for
k
,
v
:=
range
account
.
Credentials
{
if
_
,
exists
:=
newCredentials
[
k
];
!
exists
{
newCredentials
[
k
]
=
v
}
}
account
.
Credentials
=
newCredentials
}
// shouldAttemptBackfill 检查是否应该尝试补齐 project_id(冷却期内不重复尝试)
// shouldAttemptBackfill checks backfill cooldown.
func
(
p
*
AntigravityTokenProvider
)
shouldAttemptBackfill
(
accountID
int64
)
bool
{
if
v
,
ok
:=
p
.
backfillCooldown
.
Load
(
accountID
);
ok
{
if
lastAttempt
,
ok
:=
v
.
(
time
.
Time
);
ok
{
...
...
backend/internal/service/antigravity_token_refresher.go
View file @
ec82c37d
...
...
@@ -25,6 +25,11 @@ func NewAntigravityTokenRefresher(antigravityOAuthService *AntigravityOAuthServi
}
}
// CacheKey 返回用于分布式锁的缓存键
func
(
r
*
AntigravityTokenRefresher
)
CacheKey
(
account
*
Account
)
string
{
return
AntigravityTokenCacheKey
(
account
)
}
// CanRefresh 检查是否可以刷新此账户
func
(
r
*
AntigravityTokenRefresher
)
CanRefresh
(
account
*
Account
)
bool
{
return
account
.
Platform
==
PlatformAntigravity
&&
account
.
Type
==
AccountTypeOAuth
...
...
@@ -58,11 +63,7 @@ func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Accoun
newCredentials
:=
r
.
antigravityOAuthService
.
BuildAccountCredentials
(
tokenInfo
)
// 合并旧的 credentials,保留新 credentials 中不存在的字段
for
k
,
v
:=
range
account
.
Credentials
{
if
_
,
exists
:=
newCredentials
[
k
];
!
exists
{
newCredentials
[
k
]
=
v
}
}
newCredentials
=
MergeCredentials
(
account
.
Credentials
,
newCredentials
)
// 特殊处理 project_id:如果新值为空但旧值非空,保留旧值
// 这确保了即使 LoadCodeAssist 失败,project_id 也不会丢失
...
...
backend/internal/service/claude_token_provider.go
View file @
ec82c37d
...
...
@@ -4,7 +4,6 @@ import (
"context"
"errors"
"log/slog"
"strconv"
"strings"
"time"
)
...
...
@@ -15,14 +14,17 @@ const (
claudeLockWaitTime
=
200
*
time
.
Millisecond
)
// ClaudeTokenCache
T
oken
缓存接口(复用 GeminiTokenCache 接口定义)
// ClaudeTokenCache
t
oken
cache interface.
type
ClaudeTokenCache
=
GeminiTokenCache
// ClaudeTokenProvider
管理 Claude (Anthropic) OAuth 账户的 access_token
// ClaudeTokenProvider
manages access_token for Claude OAuth accounts.
type
ClaudeTokenProvider
struct
{
accountRepo
AccountRepository
tokenCache
ClaudeTokenCache
oauthService
*
OAuthService
refreshAPI
*
OAuthRefreshAPI
executor
OAuthRefreshExecutor
refreshPolicy
ProviderRefreshPolicy
}
func
NewClaudeTokenProvider
(
...
...
@@ -34,10 +36,22 @@ func NewClaudeTokenProvider(
accountRepo
:
accountRepo
,
tokenCache
:
tokenCache
,
oauthService
:
oauthService
,
refreshPolicy
:
ClaudeProviderRefreshPolicy
(),
}
}
// GetAccessToken 获取有效的 access_token
// SetRefreshAPI injects unified OAuth refresh API and executor.
func
(
p
*
ClaudeTokenProvider
)
SetRefreshAPI
(
api
*
OAuthRefreshAPI
,
executor
OAuthRefreshExecutor
)
{
p
.
refreshAPI
=
api
p
.
executor
=
executor
}
// SetRefreshPolicy injects caller-side refresh policy.
func
(
p
*
ClaudeTokenProvider
)
SetRefreshPolicy
(
policy
ProviderRefreshPolicy
)
{
p
.
refreshPolicy
=
policy
}
// GetAccessToken returns a valid access_token.
func
(
p
*
ClaudeTokenProvider
)
GetAccessToken
(
ctx
context
.
Context
,
account
*
Account
)
(
string
,
error
)
{
if
account
==
nil
{
return
""
,
errors
.
New
(
"account is nil"
)
...
...
@@ -48,7 +62,7 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
cacheKey
:=
ClaudeTokenCacheKey
(
account
)
// 1
. 先尝试缓存
// 1
) Try cache first.
if
p
.
tokenCache
!=
nil
{
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
strings
.
TrimSpace
(
token
)
!=
""
{
slog
.
Debug
(
"claude_token_cache_hit"
,
"account_id"
,
account
.
ID
)
...
...
@@ -60,114 +74,39 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
slog
.
Debug
(
"claude_token_cache_miss"
,
"account_id"
,
account
.
ID
)
// 2
. 如果即将过期则刷新
// 2
) Refresh if needed (pre-expiry skew).
expiresAt
:=
account
.
GetCredentialAsTime
(
"expires_at"
)
needsRefresh
:=
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
claudeTokenRefreshSkew
refreshFailed
:=
false
if
needsRefresh
&&
p
.
tokenCache
!=
nil
{
locked
,
lockErr
:=
p
.
tokenCache
.
AcquireRefreshLock
(
ctx
,
cacheKey
,
30
*
time
.
Second
)
if
lockErr
==
nil
&&
locked
{
defer
func
()
{
_
=
p
.
tokenCache
.
ReleaseRefreshLock
(
ctx
,
cacheKey
)
}()
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
strings
.
TrimSpace
(
token
)
!=
""
{
return
token
,
nil
}
// 从数据库获取最新账户信息
fresh
,
err
:=
p
.
accountRepo
.
GetByID
(
ctx
,
account
.
ID
)
if
err
==
nil
&&
fresh
!=
nil
{
account
=
fresh
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
if
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
claudeTokenRefreshSkew
{
if
p
.
oauthService
==
nil
{
slog
.
Warn
(
"claude_oauth_service_not_configured"
,
"account_id"
,
account
.
ID
)
refreshFailed
=
true
// 无法刷新,标记失败
}
else
{
tokenInfo
,
err
:=
p
.
oauthService
.
RefreshAccountToken
(
ctx
,
account
)
if
needsRefresh
&&
p
.
refreshAPI
!=
nil
&&
p
.
executor
!=
nil
{
result
,
err
:=
p
.
refreshAPI
.
RefreshIfNeeded
(
ctx
,
account
,
p
.
executor
,
claudeTokenRefreshSkew
)
if
err
!=
nil
{
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
slog
.
Warn
(
"claude_token_refresh_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
refreshFailed
=
true
// 刷新失败,标记以使用短 TTL
}
else
{
// 构建新 credentials,保留原有字段
newCredentials
:=
make
(
map
[
string
]
any
)
for
k
,
v
:=
range
account
.
Credentials
{
newCredentials
[
k
]
=
v
}
newCredentials
[
"access_token"
]
=
tokenInfo
.
AccessToken
newCredentials
[
"token_type"
]
=
tokenInfo
.
TokenType
newCredentials
[
"expires_in"
]
=
strconv
.
FormatInt
(
tokenInfo
.
ExpiresIn
,
10
)
newCredentials
[
"expires_at"
]
=
strconv
.
FormatInt
(
tokenInfo
.
ExpiresAt
,
10
)
if
tokenInfo
.
RefreshToken
!=
""
{
newCredentials
[
"refresh_token"
]
=
tokenInfo
.
RefreshToken
}
if
tokenInfo
.
Scope
!=
""
{
newCredentials
[
"scope"
]
=
tokenInfo
.
Scope
}
account
.
Credentials
=
newCredentials
if
updateErr
:=
p
.
accountRepo
.
Update
(
ctx
,
account
);
updateErr
!=
nil
{
slog
.
Error
(
"claude_token_provider_update_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
updateErr
)
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
}
}
}
else
if
lockErr
!=
nil
{
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
slog
.
Warn
(
"claude_token_lock_failed_degraded_refresh"
,
"account_id"
,
account
.
ID
,
"error"
,
lockErr
)
// 检查 ctx 是否已取消
if
ctx
.
Err
()
!=
nil
{
return
""
,
ctx
.
Err
()
}
// 从数据库获取最新账户信息
if
p
.
accountRepo
!=
nil
{
fresh
,
err
:=
p
.
accountRepo
.
GetByID
(
ctx
,
account
.
ID
)
if
err
==
nil
&&
fresh
!=
nil
{
account
=
fresh
}
if
p
.
refreshPolicy
.
OnRefreshError
==
ProviderRefreshErrorReturn
{
return
""
,
err
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
if
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
claudeTokenRefreshSkew
{
if
p
.
oauthService
==
nil
{
slog
.
Warn
(
"claude_oauth_service_not_configured"
,
"account_id"
,
account
.
ID
)
refreshFailed
=
true
}
else
{
tokenInfo
,
err
:=
p
.
oauthService
.
RefreshAccountToken
(
ctx
,
account
)
if
err
!=
nil
{
slog
.
Warn
(
"claude_token_refresh_failed_degraded"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
slog
.
Warn
(
"claude_token_refresh_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
refreshFailed
=
true
}
else
{
// 构建新 credentials,保留原有字段
newCredentials
:=
make
(
map
[
string
]
any
)
for
k
,
v
:=
range
account
.
Credentials
{
newCredentials
[
k
]
=
v
}
newCredentials
[
"access_token"
]
=
tokenInfo
.
AccessToken
newCredentials
[
"token_type"
]
=
tokenInfo
.
TokenType
newCredentials
[
"expires_in"
]
=
strconv
.
FormatInt
(
tokenInfo
.
ExpiresIn
,
10
)
newCredentials
[
"expires_at"
]
=
strconv
.
FormatInt
(
tokenInfo
.
ExpiresAt
,
10
)
if
tokenInfo
.
RefreshToken
!=
""
{
newCredentials
[
"refresh_token"
]
=
tokenInfo
.
RefreshToken
}
if
tokenInfo
.
Scope
!=
""
{
newCredentials
[
"scope"
]
=
tokenInfo
.
Scope
}
else
if
result
.
LockHeld
{
if
p
.
refreshPolicy
.
OnLockHeld
==
ProviderLockHeldWaitForCache
&&
p
.
tokenCache
!=
nil
{
time
.
Sleep
(
claudeLockWaitTime
)
if
token
,
cacheErr
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
cacheErr
==
nil
&&
strings
.
TrimSpace
(
token
)
!=
""
{
slog
.
Debug
(
"claude_token_cache_hit_after_wait"
,
"account_id"
,
account
.
ID
)
return
token
,
nil
}
account
.
Credentials
=
newCredentials
if
updateErr
:=
p
.
accountRepo
.
Update
(
ctx
,
account
);
updateErr
!=
nil
{
slog
.
Error
(
"claude_token_provider_update_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
updateErr
)
}
}
else
{
account
=
result
.
Account
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
}
}
}
else
if
needsRefresh
&&
p
.
tokenCache
!=
nil
{
// Backward-compatible test path when refreshAPI is not injected.
locked
,
lockErr
:=
p
.
tokenCache
.
AcquireRefreshLock
(
ctx
,
cacheKey
,
30
*
time
.
Second
)
if
lockErr
==
nil
&&
locked
{
defer
func
()
{
_
=
p
.
tokenCache
.
ReleaseRefreshLock
(
ctx
,
cacheKey
)
}()
}
else
if
lockErr
!=
nil
{
slog
.
Warn
(
"claude_token_lock_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
lockErr
)
}
else
{
// 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存
time
.
Sleep
(
claudeLockWaitTime
)
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
strings
.
TrimSpace
(
token
)
!=
""
{
slog
.
Debug
(
"claude_token_cache_hit_after_wait"
,
"account_id"
,
account
.
ID
)
...
...
@@ -181,22 +120,23 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
return
""
,
errors
.
New
(
"access_token not found in credentials"
)
}
// 3
. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
// 3
) Populate cache with TTL.
if
p
.
tokenCache
!=
nil
{
latestAccount
,
isStale
:=
CheckTokenVersion
(
ctx
,
account
,
p
.
accountRepo
)
if
isStale
&&
latestAccount
!=
nil
{
// 版本过时,使用 DB 中的最新 token
slog
.
Debug
(
"claude_token_version_stale_use_latest"
,
"account_id"
,
account
.
ID
)
accessToken
=
latestAccount
.
GetCredential
(
"access_token"
)
if
strings
.
TrimSpace
(
accessToken
)
==
""
{
return
""
,
errors
.
New
(
"access_token not found after version check"
)
}
// 不写入缓存,让下次请求重新处理
}
else
{
ttl
:=
30
*
time
.
Minute
if
refreshFailed
{
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
if
p
.
refreshPolicy
.
FailureTTL
>
0
{
ttl
=
p
.
refreshPolicy
.
FailureTTL
}
else
{
ttl
=
time
.
Minute
}
slog
.
Debug
(
"claude_token_cache_short_ttl"
,
"account_id"
,
account
.
ID
,
"reason"
,
"refresh_failed"
)
}
else
if
expiresAt
!=
nil
{
until
:=
time
.
Until
(
*
expiresAt
)
...
...
backend/internal/service/gemini_token_provider.go
View file @
ec82c37d
...
...
@@ -15,10 +15,14 @@ const (
geminiTokenCacheSkew
=
5
*
time
.
Minute
)
// GeminiTokenProvider manages access_token for Gemini OAuth accounts.
type
GeminiTokenProvider
struct
{
accountRepo
AccountRepository
tokenCache
GeminiTokenCache
geminiOAuthService
*
GeminiOAuthService
refreshAPI
*
OAuthRefreshAPI
executor
OAuthRefreshExecutor
refreshPolicy
ProviderRefreshPolicy
}
func
NewGeminiTokenProvider
(
...
...
@@ -30,9 +34,21 @@ func NewGeminiTokenProvider(
accountRepo
:
accountRepo
,
tokenCache
:
tokenCache
,
geminiOAuthService
:
geminiOAuthService
,
refreshPolicy
:
GeminiProviderRefreshPolicy
(),
}
}
// SetRefreshAPI injects unified OAuth refresh API and executor.
func
(
p
*
GeminiTokenProvider
)
SetRefreshAPI
(
api
*
OAuthRefreshAPI
,
executor
OAuthRefreshExecutor
)
{
p
.
refreshAPI
=
api
p
.
executor
=
executor
}
// SetRefreshPolicy injects caller-side refresh policy.
func
(
p
*
GeminiTokenProvider
)
SetRefreshPolicy
(
policy
ProviderRefreshPolicy
)
{
p
.
refreshPolicy
=
policy
}
func
(
p
*
GeminiTokenProvider
)
GetAccessToken
(
ctx
context
.
Context
,
account
*
Account
)
(
string
,
error
)
{
if
account
==
nil
{
return
""
,
errors
.
New
(
"account is nil"
)
...
...
@@ -53,39 +69,31 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
// 2) Refresh if needed (pre-expiry skew).
expiresAt
:=
account
.
GetCredentialAsTime
(
"expires_at"
)
needsRefresh
:=
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
geminiTokenRefreshSkew
if
needsRefresh
&&
p
.
tokenCache
!=
nil
{
locked
,
err
:=
p
.
tokenCache
.
AcquireRefreshLock
(
ctx
,
cacheKey
,
30
*
time
.
Second
)
if
err
==
nil
&&
locked
{
defer
func
()
{
_
=
p
.
tokenCache
.
ReleaseRefreshLock
(
ctx
,
cacheKey
)
}()
// Re-check after lock (another worker may have refreshed).
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
strings
.
TrimSpace
(
token
)
!=
""
{
return
token
,
nil
}
fresh
,
err
:=
p
.
accountRepo
.
GetByID
(
ctx
,
account
.
ID
)
if
err
==
nil
&&
fresh
!=
nil
{
account
=
fresh
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
if
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
geminiTokenRefreshSkew
{
if
p
.
geminiOAuthService
==
nil
{
return
""
,
errors
.
New
(
"gemini oauth service not configured"
)
}
tokenInfo
,
err
:=
p
.
geminiOAuthService
.
RefreshAccountToken
(
ctx
,
account
)
if
needsRefresh
&&
p
.
refreshAPI
!=
nil
&&
p
.
executor
!=
nil
{
result
,
err
:=
p
.
refreshAPI
.
RefreshIfNeeded
(
ctx
,
account
,
p
.
executor
,
geminiTokenRefreshSkew
)
if
err
!=
nil
{
if
p
.
refreshPolicy
.
OnRefreshError
==
ProviderRefreshErrorReturn
{
return
""
,
err
}
newCredentials
:=
p
.
geminiOAuthService
.
BuildAccountCredentials
(
tokenInfo
)
for
k
,
v
:=
range
account
.
Credentials
{
if
_
,
exists
:=
newCredentials
[
k
];
!
exists
{
newCredentials
[
k
]
=
v
}
else
if
result
.
LockHeld
{
if
p
.
refreshPolicy
.
OnLockHeld
==
ProviderLockHeldWaitForCache
&&
p
.
tokenCache
!=
nil
{
if
token
,
cacheErr
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
cacheErr
==
nil
&&
strings
.
TrimSpace
(
token
)
!=
""
{
return
token
,
nil
}
}
account
.
Credentials
=
newCredentials
_
=
p
.
accountRepo
.
Update
(
ctx
,
account
)
slog
.
Debug
(
"gemini_token_lock_held_use_old"
,
"account_id"
,
account
.
ID
)
}
else
{
account
=
result
.
Account
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
}
else
if
needsRefresh
&&
p
.
tokenCache
!=
nil
{
// Backward-compatible test path when refreshAPI is not injected.
locked
,
lockErr
:=
p
.
tokenCache
.
AcquireRefreshLock
(
ctx
,
cacheKey
,
30
*
time
.
Second
)
if
lockErr
==
nil
&&
locked
{
defer
func
()
{
_
=
p
.
tokenCache
.
ReleaseRefreshLock
(
ctx
,
cacheKey
)
}()
}
else
if
lockErr
!=
nil
{
slog
.
Warn
(
"gemini_token_lock_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
lockErr
)
}
}
...
...
@@ -95,15 +103,14 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
}
// project_id is optional now:
// - If present: will use Code Assist API (requires project_id)
// - If absent: will use AI Studio API with OAuth token (like regular API key mode)
// Auto-detect project_id only if explicitly enabled via a credential flag
// - If present: use Code Assist API (requires project_id)
// - If absent: use AI Studio API with OAuth token.
projectID
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"project_id"
))
autoDetectProjectID
:=
account
.
GetCredential
(
"auto_detect_project_id"
)
==
"true"
if
projectID
==
""
&&
autoDetectProjectID
{
if
p
.
geminiOAuthService
==
nil
{
return
accessToken
,
nil
// Fallback to AI Studio API mode
return
accessToken
,
nil
}
var
proxyURL
string
...
...
@@ -132,17 +139,15 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
}
}
// 3) Populate cache with TTL
(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
// 3) Populate cache with TTL
.
if
p
.
tokenCache
!=
nil
{
latestAccount
,
isStale
:=
CheckTokenVersion
(
ctx
,
account
,
p
.
accountRepo
)
if
isStale
&&
latestAccount
!=
nil
{
// 版本过时,使用 DB 中的最新 token
slog
.
Debug
(
"gemini_token_version_stale_use_latest"
,
"account_id"
,
account
.
ID
)
accessToken
=
latestAccount
.
GetCredential
(
"access_token"
)
if
strings
.
TrimSpace
(
accessToken
)
==
""
{
return
""
,
errors
.
New
(
"access_token not found after version check"
)
}
// 不写入缓存,让下次请求重新处理
}
else
{
ttl
:=
30
*
time
.
Minute
if
expiresAt
!=
nil
{
...
...
backend/internal/service/gemini_token_refresher.go
View file @
ec82c37d
...
...
@@ -13,6 +13,11 @@ func NewGeminiTokenRefresher(geminiOAuthService *GeminiOAuthService) *GeminiToke
return
&
GeminiTokenRefresher
{
geminiOAuthService
:
geminiOAuthService
}
}
// CacheKey 返回用于分布式锁的缓存键
func
(
r
*
GeminiTokenRefresher
)
CacheKey
(
account
*
Account
)
string
{
return
GeminiTokenCacheKey
(
account
)
}
func
(
r
*
GeminiTokenRefresher
)
CanRefresh
(
account
*
Account
)
bool
{
return
account
.
Platform
==
PlatformGemini
&&
account
.
Type
==
AccountTypeOAuth
}
...
...
@@ -35,11 +40,7 @@ func (r *GeminiTokenRefresher) Refresh(ctx context.Context, account *Account) (m
}
newCredentials
:=
r
.
geminiOAuthService
.
BuildAccountCredentials
(
tokenInfo
)
for
k
,
v
:=
range
account
.
Credentials
{
if
_
,
exists
:=
newCredentials
[
k
];
!
exists
{
newCredentials
[
k
]
=
v
}
}
newCredentials
=
MergeCredentials
(
account
.
Credentials
,
newCredentials
)
return
newCredentials
,
nil
}
backend/internal/service/oauth_refresh_api.go
0 → 100644
View file @
ec82c37d
package
service
import
(
"context"
"fmt"
"log/slog"
"strconv"
"time"
)
// OAuthRefreshExecutor 各平台实现的 OAuth 刷新执行器
// TokenRefresher 接口的超集:增加了 CacheKey 方法用于分布式锁
type
OAuthRefreshExecutor
interface
{
TokenRefresher
// CacheKey 返回用于分布式锁的缓存键(与 TokenProvider 使用的一致)
CacheKey
(
account
*
Account
)
string
}
const
refreshLockTTL
=
30
*
time
.
Second
// OAuthRefreshResult 统一刷新结果
type
OAuthRefreshResult
struct
{
Refreshed
bool
// 实际执行了刷新
NewCredentials
map
[
string
]
any
// 刷新后的 credentials(nil 表示未刷新)
Account
*
Account
// 从 DB 重新读取的最新 account
LockHeld
bool
// 锁被其他 worker 持有(未执行刷新)
}
// OAuthRefreshAPI 统一的 OAuth Token 刷新入口
// 封装分布式锁、DB 重读、已刷新检查等通用逻辑
type
OAuthRefreshAPI
struct
{
accountRepo
AccountRepository
tokenCache
GeminiTokenCache
// 可选,nil = 无锁
}
// NewOAuthRefreshAPI 创建统一刷新 API
func
NewOAuthRefreshAPI
(
accountRepo
AccountRepository
,
tokenCache
GeminiTokenCache
)
*
OAuthRefreshAPI
{
return
&
OAuthRefreshAPI
{
accountRepo
:
accountRepo
,
tokenCache
:
tokenCache
,
}
}
// RefreshIfNeeded 在分布式锁保护下按需刷新 OAuth token
//
// 流程:
// 1. 获取分布式锁
// 2. 从 DB 重读最新 account(防止使用过时的 refresh_token)
// 3. 二次检查是否仍需刷新
// 4. 调用 executor.Refresh() 执行平台特定刷新逻辑
// 5. 设置 _token_version + 更新 DB
// 6. 释放锁
func
(
api
*
OAuthRefreshAPI
)
RefreshIfNeeded
(
ctx
context
.
Context
,
account
*
Account
,
executor
OAuthRefreshExecutor
,
refreshWindow
time
.
Duration
,
)
(
*
OAuthRefreshResult
,
error
)
{
cacheKey
:=
executor
.
CacheKey
(
account
)
// 1. 获取分布式锁
lockAcquired
:=
false
if
api
.
tokenCache
!=
nil
{
acquired
,
lockErr
:=
api
.
tokenCache
.
AcquireRefreshLock
(
ctx
,
cacheKey
,
refreshLockTTL
)
if
lockErr
!=
nil
{
// Redis 错误,降级为无锁刷新
slog
.
Warn
(
"oauth_refresh_lock_failed_degraded"
,
"account_id"
,
account
.
ID
,
"cache_key"
,
cacheKey
,
"error"
,
lockErr
,
)
}
else
if
!
acquired
{
// 锁被其他 worker 持有
return
&
OAuthRefreshResult
{
LockHeld
:
true
},
nil
}
else
{
lockAcquired
=
true
defer
func
()
{
_
=
api
.
tokenCache
.
ReleaseRefreshLock
(
ctx
,
cacheKey
)
}()
}
}
// 2. 从 DB 重读最新 account(锁保护下,确保使用最新的 refresh_token)
freshAccount
,
err
:=
api
.
accountRepo
.
GetByID
(
ctx
,
account
.
ID
)
if
err
!=
nil
{
slog
.
Warn
(
"oauth_refresh_db_reread_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
,
)
// 降级使用传入的 account
freshAccount
=
account
}
else
if
freshAccount
==
nil
{
freshAccount
=
account
}
// 3. 二次检查是否仍需刷新(另一条路径可能已刷新)
if
!
executor
.
NeedsRefresh
(
freshAccount
,
refreshWindow
)
{
return
&
OAuthRefreshResult
{
Account
:
freshAccount
,
},
nil
}
// 4. 执行平台特定刷新逻辑
newCredentials
,
refreshErr
:=
executor
.
Refresh
(
ctx
,
freshAccount
)
if
refreshErr
!=
nil
{
return
nil
,
refreshErr
}
// 5. 设置版本号 + 更新 DB
if
newCredentials
!=
nil
{
newCredentials
[
"_token_version"
]
=
time
.
Now
()
.
UnixMilli
()
freshAccount
.
Credentials
=
newCredentials
if
updateErr
:=
api
.
accountRepo
.
Update
(
ctx
,
freshAccount
);
updateErr
!=
nil
{
slog
.
Error
(
"oauth_refresh_update_failed"
,
"account_id"
,
freshAccount
.
ID
,
"error"
,
updateErr
,
)
return
nil
,
fmt
.
Errorf
(
"oauth refresh succeeded but DB update failed: %w"
,
updateErr
)
}
}
_
=
lockAcquired
// suppress unused warning when tokenCache is nil
return
&
OAuthRefreshResult
{
Refreshed
:
true
,
NewCredentials
:
newCredentials
,
Account
:
freshAccount
,
},
nil
}
// MergeCredentials 将旧 credentials 中不存在于新 map 的字段保留到新 map 中
func
MergeCredentials
(
oldCreds
,
newCreds
map
[
string
]
any
)
map
[
string
]
any
{
if
newCreds
==
nil
{
newCreds
=
make
(
map
[
string
]
any
)
}
for
k
,
v
:=
range
oldCreds
{
if
_
,
exists
:=
newCreds
[
k
];
!
exists
{
newCreds
[
k
]
=
v
}
}
return
newCreds
}
// BuildClaudeAccountCredentials 为 Claude 平台构建 OAuth credentials map
// 消除 Claude 平台没有 BuildAccountCredentials 方法的问题
func
BuildClaudeAccountCredentials
(
tokenInfo
*
TokenInfo
)
map
[
string
]
any
{
creds
:=
map
[
string
]
any
{
"access_token"
:
tokenInfo
.
AccessToken
,
"token_type"
:
tokenInfo
.
TokenType
,
"expires_in"
:
strconv
.
FormatInt
(
tokenInfo
.
ExpiresIn
,
10
),
"expires_at"
:
strconv
.
FormatInt
(
tokenInfo
.
ExpiresAt
,
10
),
}
if
tokenInfo
.
RefreshToken
!=
""
{
creds
[
"refresh_token"
]
=
tokenInfo
.
RefreshToken
}
if
tokenInfo
.
Scope
!=
""
{
creds
[
"scope"
]
=
tokenInfo
.
Scope
}
return
creds
}
backend/internal/service/oauth_refresh_api_test.go
0 → 100644
View file @
ec82c37d
//go:build unit
package
service
import
(
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// ---------- mock helpers ----------
// refreshAPIAccountRepo implements AccountRepository for OAuthRefreshAPI tests.
type
refreshAPIAccountRepo
struct
{
mockAccountRepoForGemini
account
*
Account
// returned by GetByID
getByIDErr
error
updateErr
error
updateCalls
int
}
func
(
r
*
refreshAPIAccountRepo
)
GetByID
(
_
context
.
Context
,
_
int64
)
(
*
Account
,
error
)
{
if
r
.
getByIDErr
!=
nil
{
return
nil
,
r
.
getByIDErr
}
return
r
.
account
,
nil
}
func
(
r
*
refreshAPIAccountRepo
)
Update
(
_
context
.
Context
,
_
*
Account
)
error
{
r
.
updateCalls
++
return
r
.
updateErr
}
// refreshAPIExecutorStub implements OAuthRefreshExecutor for tests.
type
refreshAPIExecutorStub
struct
{
needsRefresh
bool
credentials
map
[
string
]
any
err
error
refreshCalls
int
}
func
(
e
*
refreshAPIExecutorStub
)
CanRefresh
(
_
*
Account
)
bool
{
return
true
}
func
(
e
*
refreshAPIExecutorStub
)
NeedsRefresh
(
_
*
Account
,
_
time
.
Duration
)
bool
{
return
e
.
needsRefresh
}
func
(
e
*
refreshAPIExecutorStub
)
Refresh
(
_
context
.
Context
,
_
*
Account
)
(
map
[
string
]
any
,
error
)
{
e
.
refreshCalls
++
if
e
.
err
!=
nil
{
return
nil
,
e
.
err
}
return
e
.
credentials
,
nil
}
func
(
e
*
refreshAPIExecutorStub
)
CacheKey
(
account
*
Account
)
string
{
return
"test:api:"
+
account
.
Platform
}
// refreshAPICacheStub implements GeminiTokenCache for OAuthRefreshAPI tests.
type
refreshAPICacheStub
struct
{
lockResult
bool
lockErr
error
releaseCalls
int
}
func
(
c
*
refreshAPICacheStub
)
GetAccessToken
(
context
.
Context
,
string
)
(
string
,
error
)
{
return
""
,
nil
}
func
(
c
*
refreshAPICacheStub
)
SetAccessToken
(
context
.
Context
,
string
,
string
,
time
.
Duration
)
error
{
return
nil
}
func
(
c
*
refreshAPICacheStub
)
DeleteAccessToken
(
context
.
Context
,
string
)
error
{
return
nil
}
func
(
c
*
refreshAPICacheStub
)
AcquireRefreshLock
(
context
.
Context
,
string
,
time
.
Duration
)
(
bool
,
error
)
{
return
c
.
lockResult
,
c
.
lockErr
}
func
(
c
*
refreshAPICacheStub
)
ReleaseRefreshLock
(
context
.
Context
,
string
)
error
{
c
.
releaseCalls
++
return
nil
}
// ========== RefreshIfNeeded tests ==========
func
TestRefreshIfNeeded_Success
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
}
repo
:=
&
refreshAPIAccountRepo
{
account
:
account
}
cache
:=
&
refreshAPICacheStub
{
lockResult
:
true
}
executor
:=
&
refreshAPIExecutorStub
{
needsRefresh
:
true
,
credentials
:
map
[
string
]
any
{
"access_token"
:
"new-token"
},
}
api
:=
NewOAuthRefreshAPI
(
repo
,
cache
)
result
,
err
:=
api
.
RefreshIfNeeded
(
context
.
Background
(),
account
,
executor
,
3
*
time
.
Minute
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
result
.
Refreshed
)
require
.
NotNil
(
t
,
result
.
NewCredentials
)
require
.
Equal
(
t
,
"new-token"
,
result
.
NewCredentials
[
"access_token"
])
require
.
NotNil
(
t
,
result
.
NewCredentials
[
"_token_version"
])
// version stamp set
require
.
Equal
(
t
,
1
,
repo
.
updateCalls
)
// DB updated
require
.
Equal
(
t
,
1
,
cache
.
releaseCalls
)
// lock released
require
.
Equal
(
t
,
1
,
executor
.
refreshCalls
)
}
func
TestRefreshIfNeeded_LockHeld
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
ID
:
2
,
Platform
:
PlatformAnthropic
}
repo
:=
&
refreshAPIAccountRepo
{
account
:
account
}
cache
:=
&
refreshAPICacheStub
{
lockResult
:
false
}
// lock not acquired
executor
:=
&
refreshAPIExecutorStub
{
needsRefresh
:
true
}
api
:=
NewOAuthRefreshAPI
(
repo
,
cache
)
result
,
err
:=
api
.
RefreshIfNeeded
(
context
.
Background
(),
account
,
executor
,
3
*
time
.
Minute
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
result
.
LockHeld
)
require
.
False
(
t
,
result
.
Refreshed
)
require
.
Equal
(
t
,
0
,
repo
.
updateCalls
)
require
.
Equal
(
t
,
0
,
executor
.
refreshCalls
)
}
func
TestRefreshIfNeeded_LockErrorDegrades
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
ID
:
3
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
}
repo
:=
&
refreshAPIAccountRepo
{
account
:
account
}
cache
:=
&
refreshAPICacheStub
{
lockErr
:
errors
.
New
(
"redis down"
)}
// lock error
executor
:=
&
refreshAPIExecutorStub
{
needsRefresh
:
true
,
credentials
:
map
[
string
]
any
{
"access_token"
:
"degraded-token"
},
}
api
:=
NewOAuthRefreshAPI
(
repo
,
cache
)
result
,
err
:=
api
.
RefreshIfNeeded
(
context
.
Background
(),
account
,
executor
,
3
*
time
.
Minute
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
result
.
Refreshed
)
// still refreshed (degraded mode)
require
.
Equal
(
t
,
1
,
repo
.
updateCalls
)
// DB updated
require
.
Equal
(
t
,
0
,
cache
.
releaseCalls
)
// no lock to release
require
.
Equal
(
t
,
1
,
executor
.
refreshCalls
)
}
func
TestRefreshIfNeeded_NoCacheNoLock
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
ID
:
4
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
}
repo
:=
&
refreshAPIAccountRepo
{
account
:
account
}
executor
:=
&
refreshAPIExecutorStub
{
needsRefresh
:
true
,
credentials
:
map
[
string
]
any
{
"access_token"
:
"no-cache-token"
},
}
api
:=
NewOAuthRefreshAPI
(
repo
,
nil
)
// no cache = no lock
result
,
err
:=
api
.
RefreshIfNeeded
(
context
.
Background
(),
account
,
executor
,
3
*
time
.
Minute
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
result
.
Refreshed
)
require
.
Equal
(
t
,
1
,
repo
.
updateCalls
)
}
func
TestRefreshIfNeeded_AlreadyRefreshed
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
ID
:
5
,
Platform
:
PlatformAnthropic
}
repo
:=
&
refreshAPIAccountRepo
{
account
:
account
}
cache
:=
&
refreshAPICacheStub
{
lockResult
:
true
}
executor
:=
&
refreshAPIExecutorStub
{
needsRefresh
:
false
}
// already refreshed
api
:=
NewOAuthRefreshAPI
(
repo
,
cache
)
result
,
err
:=
api
.
RefreshIfNeeded
(
context
.
Background
(),
account
,
executor
,
3
*
time
.
Minute
)
require
.
NoError
(
t
,
err
)
require
.
False
(
t
,
result
.
Refreshed
)
require
.
False
(
t
,
result
.
LockHeld
)
require
.
NotNil
(
t
,
result
.
Account
)
// returns fresh account
require
.
Equal
(
t
,
0
,
repo
.
updateCalls
)
require
.
Equal
(
t
,
0
,
executor
.
refreshCalls
)
}
func
TestRefreshIfNeeded_RefreshError
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
ID
:
6
,
Platform
:
PlatformAnthropic
}
repo
:=
&
refreshAPIAccountRepo
{
account
:
account
}
cache
:=
&
refreshAPICacheStub
{
lockResult
:
true
}
executor
:=
&
refreshAPIExecutorStub
{
needsRefresh
:
true
,
err
:
errors
.
New
(
"invalid_grant: token revoked"
),
}
api
:=
NewOAuthRefreshAPI
(
repo
,
cache
)
result
,
err
:=
api
.
RefreshIfNeeded
(
context
.
Background
(),
account
,
executor
,
3
*
time
.
Minute
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
result
)
require
.
Contains
(
t
,
err
.
Error
(),
"invalid_grant"
)
require
.
Equal
(
t
,
0
,
repo
.
updateCalls
)
// no DB update on refresh error
require
.
Equal
(
t
,
1
,
cache
.
releaseCalls
)
// lock still released via defer
}
func
TestRefreshIfNeeded_DBUpdateError
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
ID
:
7
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
}
repo
:=
&
refreshAPIAccountRepo
{
account
:
account
,
updateErr
:
errors
.
New
(
"db connection lost"
),
}
cache
:=
&
refreshAPICacheStub
{
lockResult
:
true
}
executor
:=
&
refreshAPIExecutorStub
{
needsRefresh
:
true
,
credentials
:
map
[
string
]
any
{
"access_token"
:
"token"
},
}
api
:=
NewOAuthRefreshAPI
(
repo
,
cache
)
result
,
err
:=
api
.
RefreshIfNeeded
(
context
.
Background
(),
account
,
executor
,
3
*
time
.
Minute
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
result
)
require
.
Contains
(
t
,
err
.
Error
(),
"DB update failed"
)
require
.
Equal
(
t
,
1
,
repo
.
updateCalls
)
// attempted
}
func
TestRefreshIfNeeded_DBRereadFails
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
ID
:
8
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
}
repo
:=
&
refreshAPIAccountRepo
{
account
:
nil
,
// GetByID returns nil
getByIDErr
:
errors
.
New
(
"db timeout"
),
}
cache
:=
&
refreshAPICacheStub
{
lockResult
:
true
}
executor
:=
&
refreshAPIExecutorStub
{
needsRefresh
:
true
,
credentials
:
map
[
string
]
any
{
"access_token"
:
"fallback-token"
},
}
api
:=
NewOAuthRefreshAPI
(
repo
,
cache
)
result
,
err
:=
api
.
RefreshIfNeeded
(
context
.
Background
(),
account
,
executor
,
3
*
time
.
Minute
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
result
.
Refreshed
)
require
.
Equal
(
t
,
1
,
executor
.
refreshCalls
)
// still refreshes using passed-in account
}
func
TestRefreshIfNeeded_NilCredentials
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
ID
:
9
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
}
repo
:=
&
refreshAPIAccountRepo
{
account
:
account
}
cache
:=
&
refreshAPICacheStub
{
lockResult
:
true
}
executor
:=
&
refreshAPIExecutorStub
{
needsRefresh
:
true
,
credentials
:
nil
,
// Refresh returns nil credentials
}
api
:=
NewOAuthRefreshAPI
(
repo
,
cache
)
result
,
err
:=
api
.
RefreshIfNeeded
(
context
.
Background
(),
account
,
executor
,
3
*
time
.
Minute
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
result
.
Refreshed
)
require
.
Nil
(
t
,
result
.
NewCredentials
)
require
.
Equal
(
t
,
0
,
repo
.
updateCalls
)
// no DB update when credentials are nil
}
// ========== MergeCredentials tests ==========
func
TestMergeCredentials_Basic
(
t
*
testing
.
T
)
{
old
:=
map
[
string
]
any
{
"a"
:
"1"
,
"b"
:
"2"
,
"c"
:
"3"
}
new
:=
map
[
string
]
any
{
"a"
:
"new"
,
"d"
:
"4"
}
result
:=
MergeCredentials
(
old
,
new
)
require
.
Equal
(
t
,
"new"
,
result
[
"a"
])
// new value preserved
require
.
Equal
(
t
,
"2"
,
result
[
"b"
])
// old value kept
require
.
Equal
(
t
,
"3"
,
result
[
"c"
])
// old value kept
require
.
Equal
(
t
,
"4"
,
result
[
"d"
])
// new value preserved
}
func
TestMergeCredentials_NilNew
(
t
*
testing
.
T
)
{
old
:=
map
[
string
]
any
{
"a"
:
"1"
}
result
:=
MergeCredentials
(
old
,
nil
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
"1"
,
result
[
"a"
])
}
func
TestMergeCredentials_NilOld
(
t
*
testing
.
T
)
{
new
:=
map
[
string
]
any
{
"a"
:
"1"
}
result
:=
MergeCredentials
(
nil
,
new
)
require
.
Equal
(
t
,
"1"
,
result
[
"a"
])
}
func
TestMergeCredentials_BothNil
(
t
*
testing
.
T
)
{
result
:=
MergeCredentials
(
nil
,
nil
)
require
.
NotNil
(
t
,
result
)
require
.
Empty
(
t
,
result
)
}
func
TestMergeCredentials_NewOverridesOld
(
t
*
testing
.
T
)
{
old
:=
map
[
string
]
any
{
"access_token"
:
"old-token"
,
"refresh_token"
:
"old-refresh"
}
new
:=
map
[
string
]
any
{
"access_token"
:
"new-token"
}
result
:=
MergeCredentials
(
old
,
new
)
require
.
Equal
(
t
,
"new-token"
,
result
[
"access_token"
])
// overridden
require
.
Equal
(
t
,
"old-refresh"
,
result
[
"refresh_token"
])
// preserved
}
// ========== BuildClaudeAccountCredentials tests ==========
func
TestBuildClaudeAccountCredentials_Full
(
t
*
testing
.
T
)
{
tokenInfo
:=
&
TokenInfo
{
AccessToken
:
"at-123"
,
TokenType
:
"Bearer"
,
ExpiresIn
:
3600
,
ExpiresAt
:
1700000000
,
RefreshToken
:
"rt-456"
,
Scope
:
"openid"
,
}
creds
:=
BuildClaudeAccountCredentials
(
tokenInfo
)
require
.
Equal
(
t
,
"at-123"
,
creds
[
"access_token"
])
require
.
Equal
(
t
,
"Bearer"
,
creds
[
"token_type"
])
require
.
Equal
(
t
,
"3600"
,
creds
[
"expires_in"
])
require
.
Equal
(
t
,
"1700000000"
,
creds
[
"expires_at"
])
require
.
Equal
(
t
,
"rt-456"
,
creds
[
"refresh_token"
])
require
.
Equal
(
t
,
"openid"
,
creds
[
"scope"
])
}
func
TestBuildClaudeAccountCredentials_Minimal
(
t
*
testing
.
T
)
{
tokenInfo
:=
&
TokenInfo
{
AccessToken
:
"at-789"
,
TokenType
:
"Bearer"
,
ExpiresIn
:
7200
,
ExpiresAt
:
1700003600
,
}
creds
:=
BuildClaudeAccountCredentials
(
tokenInfo
)
require
.
Equal
(
t
,
"at-789"
,
creds
[
"access_token"
])
require
.
Equal
(
t
,
"Bearer"
,
creds
[
"token_type"
])
require
.
Equal
(
t
,
"7200"
,
creds
[
"expires_in"
])
require
.
Equal
(
t
,
"1700003600"
,
creds
[
"expires_at"
])
_
,
hasRefresh
:=
creds
[
"refresh_token"
]
_
,
hasScope
:=
creds
[
"scope"
]
require
.
False
(
t
,
hasRefresh
,
"refresh_token should not be set when empty"
)
require
.
False
(
t
,
hasScope
,
"scope should not be set when empty"
)
}
// ========== BackgroundRefreshPolicy tests ==========
func
TestBackgroundRefreshPolicy_DefaultSkips
(
t
*
testing
.
T
)
{
p
:=
DefaultBackgroundRefreshPolicy
()
require
.
ErrorIs
(
t
,
p
.
handleLockHeld
(),
errRefreshSkipped
)
require
.
ErrorIs
(
t
,
p
.
handleAlreadyRefreshed
(),
errRefreshSkipped
)
}
func
TestBackgroundRefreshPolicy_SuccessOverride
(
t
*
testing
.
T
)
{
p
:=
BackgroundRefreshPolicy
{
OnLockHeld
:
BackgroundSkipAsSuccess
,
OnAlreadyRefresh
:
BackgroundSkipAsSuccess
,
}
require
.
NoError
(
t
,
p
.
handleLockHeld
())
require
.
NoError
(
t
,
p
.
handleAlreadyRefreshed
())
}
// ========== ProviderRefreshPolicy tests ==========
func
TestClaudeProviderRefreshPolicy
(
t
*
testing
.
T
)
{
p
:=
ClaudeProviderRefreshPolicy
()
require
.
Equal
(
t
,
ProviderRefreshErrorUseExistingToken
,
p
.
OnRefreshError
)
require
.
Equal
(
t
,
ProviderLockHeldWaitForCache
,
p
.
OnLockHeld
)
require
.
Equal
(
t
,
time
.
Minute
,
p
.
FailureTTL
)
}
func
TestOpenAIProviderRefreshPolicy
(
t
*
testing
.
T
)
{
p
:=
OpenAIProviderRefreshPolicy
()
require
.
Equal
(
t
,
ProviderRefreshErrorUseExistingToken
,
p
.
OnRefreshError
)
require
.
Equal
(
t
,
ProviderLockHeldWaitForCache
,
p
.
OnLockHeld
)
require
.
Equal
(
t
,
time
.
Minute
,
p
.
FailureTTL
)
}
func
TestGeminiProviderRefreshPolicy
(
t
*
testing
.
T
)
{
p
:=
GeminiProviderRefreshPolicy
()
require
.
Equal
(
t
,
ProviderRefreshErrorReturn
,
p
.
OnRefreshError
)
require
.
Equal
(
t
,
ProviderLockHeldUseExistingToken
,
p
.
OnLockHeld
)
require
.
Equal
(
t
,
time
.
Duration
(
0
),
p
.
FailureTTL
)
}
func
TestAntigravityProviderRefreshPolicy
(
t
*
testing
.
T
)
{
p
:=
AntigravityProviderRefreshPolicy
()
require
.
Equal
(
t
,
ProviderRefreshErrorReturn
,
p
.
OnRefreshError
)
require
.
Equal
(
t
,
ProviderLockHeldUseExistingToken
,
p
.
OnLockHeld
)
require
.
Equal
(
t
,
time
.
Duration
(
0
),
p
.
FailureTTL
)
}
backend/internal/service/openai_token_provider.go
View file @
ec82c37d
...
...
@@ -20,7 +20,7 @@ const (
openAILockWarnThresholdMs
=
250
)
// OpenAITokenRuntimeMetrics
表示 OpenAI token 刷新与锁竞争保护指标快照。
// OpenAITokenRuntimeMetrics
is a snapshot of refresh and lock contention metrics.
type
OpenAITokenRuntimeMetrics
struct
{
RefreshRequests
int64
RefreshSuccess
int64
...
...
@@ -72,15 +72,18 @@ func (m *openAITokenRuntimeMetricsStore) touchNow() {
m
.
lastObservedUnixMs
.
Store
(
time
.
Now
()
.
UnixMilli
())
}
// OpenAITokenCache
T
oken
缓存接口(复用 GeminiTokenCache 接口定义)
// OpenAITokenCache
t
oken
cache interface.
type
OpenAITokenCache
=
GeminiTokenCache
// OpenAITokenProvider
管理
OpenAI OAuth
账户的 access_token
// OpenAITokenProvider
manages access_token for
OpenAI
/Sora
OAuth
accounts.
type
OpenAITokenProvider
struct
{
accountRepo
AccountRepository
tokenCache
OpenAITokenCache
openAIOAuthService
*
OpenAIOAuthService
metrics
*
openAITokenRuntimeMetricsStore
refreshAPI
*
OAuthRefreshAPI
executor
OAuthRefreshExecutor
refreshPolicy
ProviderRefreshPolicy
}
func
NewOpenAITokenProvider
(
...
...
@@ -93,9 +96,21 @@ func NewOpenAITokenProvider(
tokenCache
:
tokenCache
,
openAIOAuthService
:
openAIOAuthService
,
metrics
:
&
openAITokenRuntimeMetricsStore
{},
refreshPolicy
:
OpenAIProviderRefreshPolicy
(),
}
}
// SetRefreshAPI injects unified OAuth refresh API and executor.
func
(
p
*
OpenAITokenProvider
)
SetRefreshAPI
(
api
*
OAuthRefreshAPI
,
executor
OAuthRefreshExecutor
)
{
p
.
refreshAPI
=
api
p
.
executor
=
executor
}
// SetRefreshPolicy injects caller-side refresh policy.
func
(
p
*
OpenAITokenProvider
)
SetRefreshPolicy
(
policy
ProviderRefreshPolicy
)
{
p
.
refreshPolicy
=
policy
}
func
(
p
*
OpenAITokenProvider
)
SnapshotRuntimeMetrics
()
OpenAITokenRuntimeMetrics
{
if
p
==
nil
{
return
OpenAITokenRuntimeMetrics
{}
...
...
@@ -110,7 +125,7 @@ func (p *OpenAITokenProvider) ensureMetrics() {
}
}
// GetAccessToken
获取有效的
access_token
// GetAccessToken
returns a valid
access_token
.
func
(
p
*
OpenAITokenProvider
)
GetAccessToken
(
ctx
context
.
Context
,
account
*
Account
)
(
string
,
error
)
{
p
.
ensureMetrics
()
if
account
==
nil
{
...
...
@@ -122,7 +137,7 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
cacheKey
:=
OpenAITokenCacheKey
(
account
)
// 1
. 先尝试缓存
// 1
) Try cache first.
if
p
.
tokenCache
!=
nil
{
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
strings
.
TrimSpace
(
token
)
!=
""
{
slog
.
Debug
(
"openai_token_cache_hit"
,
"account_id"
,
account
.
ID
)
...
...
@@ -134,114 +149,62 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
slog
.
Debug
(
"openai_token_cache_miss"
,
"account_id"
,
account
.
ID
)
// 2
. 如果即将过期则刷新
// 2
) Refresh if needed (pre-expiry skew).
expiresAt
:=
account
.
GetCredentialAsTime
(
"expires_at"
)
needsRefresh
:=
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
openAITokenRefreshSkew
refreshFailed
:=
false
if
needsRefresh
&&
p
.
tokenCache
!=
nil
{
if
needsRefresh
&&
p
.
refreshAPI
!=
nil
&&
p
.
executor
!=
nil
{
p
.
metrics
.
refreshRequests
.
Add
(
1
)
p
.
metrics
.
touchNow
()
locked
,
lockErr
:=
p
.
tokenCache
.
AcquireRefreshLock
(
ctx
,
cacheKey
,
30
*
time
.
Second
)
if
lockErr
==
nil
&&
locked
{
defer
func
()
{
_
=
p
.
tokenCache
.
ReleaseRefreshLock
(
ctx
,
cacheKey
)
}()
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
strings
.
TrimSpace
(
token
)
!=
""
{
return
token
,
nil
}
// 从数据库获取最新账户信息
fresh
,
err
:=
p
.
accountRepo
.
GetByID
(
ctx
,
account
.
ID
)
if
err
==
nil
&&
fresh
!=
nil
{
account
=
fresh
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
if
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
openAITokenRefreshSkew
{
// Sora accounts skip OpenAI OAuth refresh and keep existing token path.
if
account
.
Platform
==
PlatformSora
{
slog
.
Debug
(
"openai_token_refresh_skipped_for_sora"
,
"account_id"
,
account
.
ID
)
// Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。
refreshFailed
=
true
}
else
if
p
.
openAIOAuthService
==
nil
{
slog
.
Warn
(
"openai_oauth_service_not_configured"
,
"account_id"
,
account
.
ID
)
p
.
metrics
.
refreshFailure
.
Add
(
1
)
refreshFailed
=
true
// 无法刷新,标记失败
}
else
{
tokenInfo
,
err
:=
p
.
openAIOAuthService
.
RefreshAccountToken
(
ctx
,
account
)
result
,
err
:=
p
.
refreshAPI
.
RefreshIfNeeded
(
ctx
,
account
,
p
.
executor
,
openAITokenRefreshSkew
)
if
err
!=
nil
{
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
if
p
.
refreshPolicy
.
OnRefreshError
==
ProviderRefreshErrorReturn
{
return
""
,
err
}
slog
.
Warn
(
"openai_token_refresh_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
p
.
metrics
.
refreshFailure
.
Add
(
1
)
refreshFailed
=
true
// 刷新失败,标记以使用短 TTL
}
else
{
p
.
metrics
.
refreshSuccess
.
Add
(
1
)
newCredentials
:=
p
.
openAIOAuthService
.
BuildAccountCredentials
(
tokenInfo
)
for
k
,
v
:=
range
account
.
Credentials
{
if
_
,
exists
:=
newCredentials
[
k
];
!
exists
{
newCredentials
[
k
]
=
v
}
}
account
.
Credentials
=
newCredentials
if
updateErr
:=
p
.
accountRepo
.
Update
(
ctx
,
account
);
updateErr
!=
nil
{
slog
.
Error
(
"openai_token_provider_update_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
updateErr
)
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
}
}
}
else
if
lockErr
!=
nil
{
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
p
.
metrics
.
lockAcquireFailure
.
Add
(
1
)
refreshFailed
=
true
}
else
if
result
.
LockHeld
{
if
p
.
refreshPolicy
.
OnLockHeld
==
ProviderLockHeldWaitForCache
{
p
.
metrics
.
lockContention
.
Add
(
1
)
p
.
metrics
.
touchNow
()
slog
.
Warn
(
"openai_token_lock_failed_degraded_refresh"
,
"account_id"
,
account
.
ID
,
"error"
,
lockErr
)
// 检查 ctx 是否已取消
if
ctx
.
Err
()
!=
nil
{
return
""
,
ctx
.
Err
()
token
,
waitErr
:=
p
.
waitForTokenAfterLockRace
(
ctx
,
cacheKey
)
if
waitErr
!=
nil
{
return
""
,
waitErr
}
// 从数据库获取最新账户信息
if
p
.
accountRepo
!=
nil
{
fresh
,
err
:=
p
.
accountRepo
.
GetByID
(
ctx
,
account
.
ID
)
if
err
==
nil
&&
fresh
!=
nil
{
account
=
fresh
if
strings
.
TrimSpace
(
token
)
!=
""
{
slog
.
Debug
(
"openai_token_cache_hit_after_wait"
,
"account_id"
,
account
.
ID
)
return
token
,
nil
}
}
}
else
if
result
.
Refreshed
{
p
.
metrics
.
refreshSuccess
.
Add
(
1
)
account
=
result
.
Account
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
if
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
openAITokenRefreshSkew
{
if
account
.
Platform
==
PlatformSora
{
slog
.
Debug
(
"openai_token_refresh_skipped_for_sora_degraded"
,
"account_id"
,
account
.
ID
)
// Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。
refreshFailed
=
true
}
else
if
p
.
openAIOAuthService
==
nil
{
slog
.
Warn
(
"openai_oauth_service_not_configured"
,
"account_id"
,
account
.
ID
)
p
.
metrics
.
refreshFailure
.
Add
(
1
)
refreshFailed
=
true
}
else
{
tokenInfo
,
err
:=
p
.
openAIOAuthService
.
RefreshAccountToken
(
ctx
,
account
)
if
err
!=
nil
{
slog
.
Warn
(
"openai_token_refresh_failed_degraded"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
p
.
metrics
.
refreshFailure
.
Add
(
1
)
refreshFailed
=
true
}
else
{
p
.
metrics
.
refreshSuccess
.
Add
(
1
)
newCredentials
:=
p
.
openAIOAuthService
.
BuildAccountCredentials
(
tokenInfo
)
for
k
,
v
:=
range
account
.
Credentials
{
if
_
,
exists
:=
newCredentials
[
k
];
!
exists
{
newCredentials
[
k
]
=
v
}
}
account
.
Credentials
=
newCredentials
if
updateErr
:=
p
.
accountRepo
.
Update
(
ctx
,
account
);
updateErr
!=
nil
{
slog
.
Error
(
"openai_token_provider_update_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
updateErr
)
}
account
=
result
.
Account
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
}
}
}
else
if
needsRefresh
&&
p
.
tokenCache
!=
nil
{
// Backward-compatible test path when refreshAPI is not injected.
p
.
metrics
.
refreshRequests
.
Add
(
1
)
p
.
metrics
.
touchNow
()
locked
,
lockErr
:=
p
.
tokenCache
.
AcquireRefreshLock
(
ctx
,
cacheKey
,
30
*
time
.
Second
)
if
lockErr
==
nil
&&
locked
{
defer
func
()
{
_
=
p
.
tokenCache
.
ReleaseRefreshLock
(
ctx
,
cacheKey
)
}()
}
else
if
lockErr
!=
nil
{
p
.
metrics
.
lockAcquireFailure
.
Add
(
1
)
p
.
metrics
.
touchNow
()
slog
.
Warn
(
"openai_token_lock_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
lockErr
)
}
else
{
// 锁被其他 worker 持有:使用短轮询+jitter,降低固定等待导致的尾延迟台阶。
p
.
metrics
.
lockContention
.
Add
(
1
)
p
.
metrics
.
touchNow
()
token
,
waitErr
:=
p
.
waitForTokenAfterLockRace
(
ctx
,
cacheKey
)
...
...
@@ -260,22 +223,23 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
return
""
,
errors
.
New
(
"access_token not found in credentials"
)
}
// 3
. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
// 3
) Populate cache with TTL.
if
p
.
tokenCache
!=
nil
{
latestAccount
,
isStale
:=
CheckTokenVersion
(
ctx
,
account
,
p
.
accountRepo
)
if
isStale
&&
latestAccount
!=
nil
{
// 版本过时,使用 DB 中的最新 token
slog
.
Debug
(
"openai_token_version_stale_use_latest"
,
"account_id"
,
account
.
ID
)
accessToken
=
latestAccount
.
GetOpenAIAccessToken
()
if
strings
.
TrimSpace
(
accessToken
)
==
""
{
return
""
,
errors
.
New
(
"access_token not found after version check"
)
}
// 不写入缓存,让下次请求重新处理
}
else
{
ttl
:=
30
*
time
.
Minute
if
refreshFailed
{
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
if
p
.
refreshPolicy
.
FailureTTL
>
0
{
ttl
=
p
.
refreshPolicy
.
FailureTTL
}
else
{
ttl
=
time
.
Minute
}
slog
.
Debug
(
"openai_token_cache_short_ttl"
,
"account_id"
,
account
.
ID
,
"reason"
,
"refresh_failed"
)
}
else
if
expiresAt
!=
nil
{
until
:=
time
.
Until
(
*
expiresAt
)
...
...
backend/internal/service/refresh_policy.go
0 → 100644
View file @
ec82c37d
package
service
import
"time"
// ProviderRefreshErrorAction 定义 provider 在刷新失败时的处理动作。
type
ProviderRefreshErrorAction
int
const
(
// ProviderRefreshErrorReturn 失败即返回错误(不降级旧 token)。
ProviderRefreshErrorReturn
ProviderRefreshErrorAction
=
iota
// ProviderRefreshErrorUseExistingToken 失败后继续使用现有 token。
ProviderRefreshErrorUseExistingToken
)
// ProviderLockHeldAction 定义 provider 在刷新锁被占用时的处理动作。
type
ProviderLockHeldAction
int
const
(
// ProviderLockHeldUseExistingToken 直接使用现有 token。
ProviderLockHeldUseExistingToken
ProviderLockHeldAction
=
iota
// ProviderLockHeldWaitForCache 等待后重试缓存读取。
ProviderLockHeldWaitForCache
)
// ProviderRefreshPolicy 描述 provider 的平台差异策略。
type
ProviderRefreshPolicy
struct
{
OnRefreshError
ProviderRefreshErrorAction
OnLockHeld
ProviderLockHeldAction
FailureTTL
time
.
Duration
}
func
ClaudeProviderRefreshPolicy
()
ProviderRefreshPolicy
{
return
ProviderRefreshPolicy
{
OnRefreshError
:
ProviderRefreshErrorUseExistingToken
,
OnLockHeld
:
ProviderLockHeldWaitForCache
,
FailureTTL
:
time
.
Minute
,
}
}
func
OpenAIProviderRefreshPolicy
()
ProviderRefreshPolicy
{
return
ProviderRefreshPolicy
{
OnRefreshError
:
ProviderRefreshErrorUseExistingToken
,
OnLockHeld
:
ProviderLockHeldWaitForCache
,
FailureTTL
:
time
.
Minute
,
}
}
func
GeminiProviderRefreshPolicy
()
ProviderRefreshPolicy
{
return
ProviderRefreshPolicy
{
OnRefreshError
:
ProviderRefreshErrorReturn
,
OnLockHeld
:
ProviderLockHeldUseExistingToken
,
FailureTTL
:
0
,
}
}
func
AntigravityProviderRefreshPolicy
()
ProviderRefreshPolicy
{
return
ProviderRefreshPolicy
{
OnRefreshError
:
ProviderRefreshErrorReturn
,
OnLockHeld
:
ProviderLockHeldUseExistingToken
,
FailureTTL
:
0
,
}
}
// BackgroundSkipAction 定义后台刷新服务在“未实际刷新”场景的计数方式。
type
BackgroundSkipAction
int
const
(
// BackgroundSkipAsSkipped 计入 skipped(保持当前默认行为)。
BackgroundSkipAsSkipped
BackgroundSkipAction
=
iota
// BackgroundSkipAsSuccess 计入 success(仅用于兼容旧统计口径时可选)。
BackgroundSkipAsSuccess
)
// BackgroundRefreshPolicy 描述后台刷新服务的调用侧策略。
type
BackgroundRefreshPolicy
struct
{
OnLockHeld
BackgroundSkipAction
OnAlreadyRefresh
BackgroundSkipAction
}
func
DefaultBackgroundRefreshPolicy
()
BackgroundRefreshPolicy
{
return
BackgroundRefreshPolicy
{
OnLockHeld
:
BackgroundSkipAsSkipped
,
OnAlreadyRefresh
:
BackgroundSkipAsSkipped
,
}
}
func
(
p
BackgroundRefreshPolicy
)
handleLockHeld
()
error
{
if
p
.
OnLockHeld
==
BackgroundSkipAsSuccess
{
return
nil
}
return
errRefreshSkipped
}
func
(
p
BackgroundRefreshPolicy
)
handleAlreadyRefreshed
()
error
{
if
p
.
OnAlreadyRefresh
==
BackgroundSkipAsSuccess
{
return
nil
}
return
errRefreshSkipped
}
backend/internal/service/token_refresh_service.go
View file @
ec82c37d
...
...
@@ -2,6 +2,7 @@ package service
import
(
"context"
"errors"
"fmt"
"log/slog"
"strings"
...
...
@@ -16,10 +17,13 @@ import (
type
TokenRefreshService
struct
{
accountRepo
AccountRepository
refreshers
[]
TokenRefresher
executors
[]
OAuthRefreshExecutor
// 与 refreshers 一一对应的 executor(带 CacheKey)
refreshPolicy
BackgroundRefreshPolicy
cfg
*
config
.
TokenRefreshConfig
cacheInvalidator
TokenCacheInvalidator
schedulerCache
SchedulerCache
// 用于同步更新调度器缓存,解决 token 刷新后缓存不一致问题
tempUnschedCache
TempUnschedCache
// 用于清除 Redis 中的临时不可调度缓存
refreshAPI
*
OAuthRefreshAPI
// 统一刷新 API
// OpenAI privacy: 刷新成功后检查并设置 training opt-out
privacyClientFactory
PrivacyClientFactory
...
...
@@ -43,6 +47,7 @@ func NewTokenRefreshService(
)
*
TokenRefreshService
{
s
:=
&
TokenRefreshService
{
accountRepo
:
accountRepo
,
refreshPolicy
:
DefaultBackgroundRefreshPolicy
(),
cfg
:
&
cfg
.
TokenRefresh
,
cacheInvalidator
:
cacheInvalidator
,
schedulerCache
:
schedulerCache
,
...
...
@@ -53,12 +58,24 @@ func NewTokenRefreshService(
openAIRefresher
:=
NewOpenAITokenRefresher
(
openaiOAuthService
,
accountRepo
)
openAIRefresher
.
SetSyncLinkedSoraAccounts
(
cfg
.
TokenRefresh
.
SyncLinkedSoraAccounts
)
// 注册平台特定的刷新器
claudeRefresher
:=
NewClaudeTokenRefresher
(
oauthService
)
geminiRefresher
:=
NewGeminiTokenRefresher
(
geminiOAuthService
)
agRefresher
:=
NewAntigravityTokenRefresher
(
antigravityOAuthService
)
// 注册平台特定的刷新器(TokenRefresher 接口)
s
.
refreshers
=
[]
TokenRefresher
{
NewClaudeTokenRefresher
(
oauthService
),
claudeRefresher
,
openAIRefresher
,
geminiRefresher
,
agRefresher
,
}
// 注册对应的 OAuthRefreshExecutor(带 CacheKey 方法)
s
.
executors
=
[]
OAuthRefreshExecutor
{
claudeRefresher
,
openAIRefresher
,
NewG
emini
Token
Refresher
(
geminiOAuthService
)
,
NewAntigravityTokenRefresher
(
antigravityOAuthService
)
,
g
eminiRefresher
,
agRefresher
,
}
return
s
...
...
@@ -82,6 +99,16 @@ func (s *TokenRefreshService) SetPrivacyDeps(factory PrivacyClientFactory, proxy
s
.
proxyRepo
=
proxyRepo
}
// SetRefreshAPI 注入统一的 OAuth 刷新 API
func
(
s
*
TokenRefreshService
)
SetRefreshAPI
(
api
*
OAuthRefreshAPI
)
{
s
.
refreshAPI
=
api
}
// SetRefreshPolicy 注入后台刷新调用侧策略(用于显式化平台/场景差异行为)。
func
(
s
*
TokenRefreshService
)
SetRefreshPolicy
(
policy
BackgroundRefreshPolicy
)
{
s
.
refreshPolicy
=
policy
}
// Start 启动后台刷新服务
func
(
s
*
TokenRefreshService
)
Start
()
{
if
!
s
.
cfg
.
Enabled
{
...
...
@@ -148,13 +175,13 @@ func (s *TokenRefreshService) processRefresh() {
totalAccounts
:=
len
(
accounts
)
oauthAccounts
:=
0
// 可刷新的OAuth账号数
needsRefresh
:=
0
// 需要刷新的账号数
refreshed
,
failed
:=
0
,
0
refreshed
,
failed
,
skipped
:=
0
,
0
,
0
for
i
:=
range
accounts
{
account
:=
&
accounts
[
i
]
// 遍历所有刷新器,找到能处理此账号的
for
_
,
refresher
:=
range
s
.
refreshers
{
for
idx
,
refresher
:=
range
s
.
refreshers
{
if
!
refresher
.
CanRefresh
(
account
)
{
continue
}
...
...
@@ -168,14 +195,24 @@ func (s *TokenRefreshService) processRefresh() {
needsRefresh
++
// 获取对应的 executor
var
executor
OAuthRefreshExecutor
if
idx
<
len
(
s
.
executors
)
{
executor
=
s
.
executors
[
idx
]
}
// 执行刷新
if
err
:=
s
.
refreshWithRetry
(
ctx
,
account
,
refresher
);
err
!=
nil
{
if
err
:=
s
.
refreshWithRetry
(
ctx
,
account
,
refresher
,
executor
,
refreshWindow
);
err
!=
nil
{
if
errors
.
Is
(
err
,
errRefreshSkipped
)
{
skipped
++
}
else
{
slog
.
Warn
(
"token_refresh.account_refresh_failed"
,
"account_id"
,
account
.
ID
,
"account_name"
,
account
.
Name
,
"error"
,
err
,
)
failed
++
}
}
else
{
slog
.
Info
(
"token_refresh.account_refreshed"
,
"account_id"
,
account
.
ID
,
...
...
@@ -193,13 +230,14 @@ func (s *TokenRefreshService) processRefresh() {
if
needsRefresh
==
0
&&
failed
==
0
{
slog
.
Debug
(
"token_refresh.cycle_completed"
,
"total"
,
totalAccounts
,
"oauth"
,
oauthAccounts
,
"needs_refresh"
,
needsRefresh
,
"refreshed"
,
refreshed
,
"failed"
,
failed
)
"needs_refresh"
,
needsRefresh
,
"refreshed"
,
refreshed
,
"skipped"
,
skipped
,
"failed"
,
failed
)
}
else
{
slog
.
Info
(
"token_refresh.cycle_completed"
,
"total"
,
totalAccounts
,
"oauth"
,
oauthAccounts
,
"needs_refresh"
,
needsRefresh
,
"refreshed"
,
refreshed
,
"skipped"
,
skipped
,
"failed"
,
failed
,
)
}
...
...
@@ -212,25 +250,86 @@ func (s *TokenRefreshService) listActiveAccounts(ctx context.Context) ([]Account
}
// refreshWithRetry 带重试的刷新
func
(
s
*
TokenRefreshService
)
refreshWithRetry
(
ctx
context
.
Context
,
account
*
Account
,
refresher
TokenRefresher
)
error
{
func
(
s
*
TokenRefreshService
)
refreshWithRetry
(
ctx
context
.
Context
,
account
*
Account
,
refresher
TokenRefresher
,
executor
OAuthRefreshExecutor
,
refreshWindow
time
.
Duration
)
error
{
var
lastErr
error
for
attempt
:=
1
;
attempt
<=
s
.
cfg
.
MaxRetries
;
attempt
++
{
newCredentials
,
err
:=
refresher
.
Refresh
(
ctx
,
account
)
// 如果有新凭证,先更新(即使有错误也要保存 token)
var
newCredentials
map
[
string
]
any
var
err
error
// 优先使用统一 API(带分布式锁 + DB 重读保护)
if
s
.
refreshAPI
!=
nil
&&
executor
!=
nil
{
result
,
refreshErr
:=
s
.
refreshAPI
.
RefreshIfNeeded
(
ctx
,
account
,
executor
,
refreshWindow
)
if
refreshErr
!=
nil
{
err
=
refreshErr
}
else
if
result
.
LockHeld
{
// 锁被其他 worker 持有,由调用侧策略决定如何计数
return
s
.
refreshPolicy
.
handleLockHeld
()
}
else
if
!
result
.
Refreshed
{
// 已被其他路径刷新,由调用侧策略决定如何计数
return
s
.
refreshPolicy
.
handleAlreadyRefreshed
()
}
else
{
account
=
result
.
Account
_
=
result
.
NewCredentials
// 统一 API 已设置 _token_version 并更新 DB,无需重复操作
}
}
else
{
// 降级:直接调用 refresher(兼容旧路径)
newCredentials
,
err
=
refresher
.
Refresh
(
ctx
,
account
)
if
newCredentials
!=
nil
{
// 记录刷新版本时间戳,用于解决缓存一致性问题
// TokenProvider 写入缓存前会检查此版本,如果版本已更新则跳过写入
newCredentials
[
"_token_version"
]
=
time
.
Now
()
.
UnixMilli
()
account
.
Credentials
=
newCredentials
if
saveErr
:=
s
.
accountRepo
.
Update
(
ctx
,
account
);
saveErr
!=
nil
{
return
fmt
.
Errorf
(
"failed to save credentials: %w"
,
saveErr
)
}
}
}
if
err
==
nil
{
s
.
postRefreshActions
(
ctx
,
account
)
return
nil
}
// 不可重试错误(invalid_grant/invalid_client 等)直接标记 error 状态并返回
if
isNonRetryableRefreshError
(
err
)
{
errorMsg
:=
fmt
.
Sprintf
(
"Token refresh failed (non-retryable): %v"
,
err
)
if
setErr
:=
s
.
accountRepo
.
SetError
(
ctx
,
account
.
ID
,
errorMsg
);
setErr
!=
nil
{
slog
.
Error
(
"token_refresh.set_error_status_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
setErr
,
)
}
return
err
}
lastErr
=
err
slog
.
Warn
(
"token_refresh.retry_attempt_failed"
,
"account_id"
,
account
.
ID
,
"attempt"
,
attempt
,
"max_retries"
,
s
.
cfg
.
MaxRetries
,
"error"
,
err
,
)
// 如果还有重试机会,等待后重试
if
attempt
<
s
.
cfg
.
MaxRetries
{
// 指数退避:2^(attempt-1) * baseSeconds
backoff
:=
time
.
Duration
(
s
.
cfg
.
RetryBackoffSeconds
)
*
time
.
Second
*
time
.
Duration
(
1
<<
(
attempt
-
1
))
time
.
Sleep
(
backoff
)
}
}
// 可重试错误耗尽:仅记录日志,不标记 error(可能是临时网络问题,下个周期继续重试)
slog
.
Warn
(
"token_refresh.retry_exhausted"
,
"account_id"
,
account
.
ID
,
"platform"
,
account
.
Platform
,
"max_retries"
,
s
.
cfg
.
MaxRetries
,
"error"
,
lastErr
,
)
return
lastErr
}
// postRefreshActions 刷新成功后的后续动作(清除错误状态、缓存失效、调度器同步等)
func
(
s
*
TokenRefreshService
)
postRefreshActions
(
ctx
context
.
Context
,
account
*
Account
)
{
// Antigravity 账户:如果之前是因为缺少 project_id 而标记为 error,现在成功获取到了,清除错误状态
if
account
.
Platform
==
PlatformAntigravity
&&
account
.
Status
==
StatusError
&&
...
...
@@ -276,7 +375,6 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
}
}
// 同步更新调度器缓存,确保调度获取的 Account 对象包含最新的 credentials
// 这解决了 token 刷新后调度器缓存数据不一致的问题(#445)
if
s
.
schedulerCache
!=
nil
{
if
err
:=
s
.
schedulerCache
.
SetAccount
(
ctx
,
account
);
err
!=
nil
{
slog
.
Warn
(
"token_refresh.sync_scheduler_cache_failed"
,
...
...
@@ -289,48 +387,11 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
}
// OpenAI OAuth: 刷新成功后,检查是否已设置 privacy_mode,未设置则尝试关闭训练数据共享
s
.
ensureOpenAIPrivacy
(
ctx
,
account
)
return
nil
}
// 不可重试错误(invalid_grant/invalid_client 等)直接标记 error 状态并返回
if
isNonRetryableRefreshError
(
err
)
{
errorMsg
:=
fmt
.
Sprintf
(
"Token refresh failed (non-retryable): %v"
,
err
)
if
setErr
:=
s
.
accountRepo
.
SetError
(
ctx
,
account
.
ID
,
errorMsg
);
setErr
!=
nil
{
slog
.
Error
(
"token_refresh.set_error_status_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
setErr
,
)
}
return
err
}
lastErr
=
err
slog
.
Warn
(
"token_refresh.retry_attempt_failed"
,
"account_id"
,
account
.
ID
,
"attempt"
,
attempt
,
"max_retries"
,
s
.
cfg
.
MaxRetries
,
"error"
,
err
,
)
// 如果还有重试机会,等待后重试
if
attempt
<
s
.
cfg
.
MaxRetries
{
// 指数退避:2^(attempt-1) * baseSeconds
backoff
:=
time
.
Duration
(
s
.
cfg
.
RetryBackoffSeconds
)
*
time
.
Second
*
time
.
Duration
(
1
<<
(
attempt
-
1
))
time
.
Sleep
(
backoff
)
}
}
// 可重试错误耗尽:仅记录日志,不标记 error(可能是临时网络问题,下个周期继续重试)
slog
.
Warn
(
"token_refresh.retry_exhausted"
,
"account_id"
,
account
.
ID
,
"platform"
,
account
.
Platform
,
"max_retries"
,
s
.
cfg
.
MaxRetries
,
"error"
,
lastErr
,
)
return
lastErr
}
// errRefreshSkipped 表示刷新被跳过(锁竞争或已被其他路径刷新),不计入 failed 或 refreshed
var
errRefreshSkipped
=
fmt
.
Errorf
(
"refresh skipped"
)
// isNonRetryableRefreshError 判断是否为不可重试的刷新错误
// 这些错误通常表示凭证已失效或配置确实缺失,需要用户重新授权
// 注意:missing_project_id 错误只在真正缺失(从未获取过)时返回,临时获取失败不会返回此错误
...
...
backend/internal/service/token_refresh_service_test.go
View file @
ec82c37d
...
...
@@ -84,6 +84,10 @@ func (r *tokenRefresherStub) Refresh(ctx context.Context, account *Account) (map
return
r
.
credentials
,
nil
}
func
(
r
*
tokenRefresherStub
)
CacheKey
(
account
*
Account
)
string
{
return
"test:stub:"
+
account
.
Platform
}
func
TestTokenRefreshService_RefreshWithRetry_InvalidatesCache
(
t
*
testing
.
T
)
{
repo
:=
&
tokenRefreshAccountRepo
{}
invalidator
:=
&
tokenCacheInvalidatorStub
{}
...
...
@@ -105,7 +109,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
},
}
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
)
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
,
refresher
,
time
.
Hour
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
repo
.
updateCalls
)
require
.
Equal
(
t
,
1
,
invalidator
.
calls
)
...
...
@@ -133,7 +137,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatorErrorIgnored(t *testing
},
}
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
)
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
,
refresher
,
time
.
Hour
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
repo
.
updateCalls
)
require
.
Equal
(
t
,
1
,
invalidator
.
calls
)
...
...
@@ -159,7 +163,7 @@ func TestTokenRefreshService_RefreshWithRetry_NilInvalidator(t *testing.T) {
},
}
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
)
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
,
refresher
,
time
.
Hour
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
repo
.
updateCalls
)
}
...
...
@@ -186,7 +190,7 @@ func TestTokenRefreshService_RefreshWithRetry_Antigravity(t *testing.T) {
},
}
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
)
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
,
refresher
,
time
.
Hour
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
repo
.
updateCalls
)
require
.
Equal
(
t
,
1
,
invalidator
.
calls
)
// Antigravity 也应触发缓存失效
...
...
@@ -214,7 +218,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount(t *testing.T) {
},
}
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
)
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
,
refresher
,
time
.
Hour
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
repo
.
updateCalls
)
require
.
Equal
(
t
,
0
,
invalidator
.
calls
)
// 非 OAuth 不触发缓存失效
...
...
@@ -242,7 +246,7 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
},
}
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
)
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
,
refresher
,
time
.
Hour
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
repo
.
updateCalls
)
require
.
Equal
(
t
,
1
,
invalidator
.
calls
)
// 所有 OAuth 账户刷新后触发缓存失效
...
...
@@ -270,7 +274,7 @@ func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) {
},
}
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
)
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
,
refresher
,
time
.
Hour
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"failed to save credentials"
)
require
.
Equal
(
t
,
1
,
repo
.
updateCalls
)
...
...
@@ -297,7 +301,7 @@ func TestTokenRefreshService_RefreshWithRetry_RefreshFailed(t *testing.T) {
err
:
errors
.
New
(
"refresh failed"
),
}
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
)
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
,
refresher
,
time
.
Hour
)
require
.
Error
(
t
,
err
)
require
.
Equal
(
t
,
0
,
repo
.
updateCalls
)
// 刷新失败不应更新
require
.
Equal
(
t
,
0
,
invalidator
.
calls
)
// 刷新失败不应触发缓存失效
...
...
@@ -324,7 +328,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed(t *testin
err
:
errors
.
New
(
"network error"
),
// 可重试错误
}
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
)
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
,
refresher
,
time
.
Hour
)
require
.
Error
(
t
,
err
)
require
.
Equal
(
t
,
0
,
repo
.
updateCalls
)
require
.
Equal
(
t
,
0
,
invalidator
.
calls
)
...
...
@@ -351,7 +355,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError(t *te
err
:
errors
.
New
(
"invalid_grant: token revoked"
),
// 不可重试错误
}
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
)
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
,
refresher
,
time
.
Hour
)
require
.
Error
(
t
,
err
)
require
.
Equal
(
t
,
0
,
repo
.
updateCalls
)
require
.
Equal
(
t
,
0
,
invalidator
.
calls
)
...
...
@@ -383,7 +387,7 @@ func TestTokenRefreshService_RefreshWithRetry_ClearsTempUnschedulable(t *testing
},
}
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
)
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
,
refresher
,
time
.
Hour
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
repo
.
updateCalls
)
require
.
Equal
(
t
,
1
,
repo
.
clearTempCalls
)
// DB 清除
...
...
@@ -422,7 +426,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonRetryableErrorAllPlatforms(t *t
err
:
errors
.
New
(
"invalid_grant: token revoked"
),
}
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
)
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
,
refresher
,
time
.
Hour
)
require
.
Error
(
t
,
err
)
require
.
Equal
(
t
,
1
,
repo
.
setErrorCalls
)
// 所有平台不可重试错误都应 SetError
})
...
...
@@ -453,3 +457,212 @@ func TestIsNonRetryableRefreshError(t *testing.T) {
})
}
}
// ========== Path A (refreshAPI) 测试用例 ==========
// mockTokenCacheForRefreshAPI 用于 Path A 测试的 GeminiTokenCache mock
type
mockTokenCacheForRefreshAPI
struct
{
lockResult
bool
lockErr
error
releaseCalls
int
}
func
(
m
*
mockTokenCacheForRefreshAPI
)
GetAccessToken
(
_
context
.
Context
,
_
string
)
(
string
,
error
)
{
return
""
,
errors
.
New
(
"not cached"
)
}
func
(
m
*
mockTokenCacheForRefreshAPI
)
SetAccessToken
(
_
context
.
Context
,
_
string
,
_
string
,
_
time
.
Duration
)
error
{
return
nil
}
func
(
m
*
mockTokenCacheForRefreshAPI
)
DeleteAccessToken
(
_
context
.
Context
,
_
string
)
error
{
return
nil
}
func
(
m
*
mockTokenCacheForRefreshAPI
)
AcquireRefreshLock
(
_
context
.
Context
,
_
string
,
_
time
.
Duration
)
(
bool
,
error
)
{
return
m
.
lockResult
,
m
.
lockErr
}
func
(
m
*
mockTokenCacheForRefreshAPI
)
ReleaseRefreshLock
(
_
context
.
Context
,
_
string
)
error
{
m
.
releaseCalls
++
return
nil
}
// buildPathAService 构建注入了 refreshAPI 的 service(Path A 测试辅助)
func
buildPathAService
(
repo
*
tokenRefreshAccountRepo
,
cache
GeminiTokenCache
,
invalidator
TokenCacheInvalidator
)
(
*
TokenRefreshService
,
*
tokenRefresherStub
)
{
cfg
:=
&
config
.
Config
{
TokenRefresh
:
config
.
TokenRefreshConfig
{
MaxRetries
:
1
,
RetryBackoffSeconds
:
0
,
},
}
service
:=
NewTokenRefreshService
(
repo
,
nil
,
nil
,
nil
,
nil
,
invalidator
,
nil
,
cfg
,
nil
)
refreshAPI
:=
NewOAuthRefreshAPI
(
repo
,
cache
)
service
.
SetRefreshAPI
(
refreshAPI
)
refresher
:=
&
tokenRefresherStub
{
credentials
:
map
[
string
]
any
{
"access_token"
:
"refreshed-token"
,
},
}
return
service
,
refresher
}
// TestPathA_Success 统一 API 路径正常成功:刷新 + DB 更新 + postRefreshActions
func
TestPathA_Success
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
ID
:
100
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
,
}
repo
:=
&
tokenRefreshAccountRepo
{}
repo
.
accountsByID
=
map
[
int64
]
*
Account
{
account
.
ID
:
account
}
invalidator
:=
&
tokenCacheInvalidatorStub
{}
cache
:=
&
mockTokenCacheForRefreshAPI
{
lockResult
:
true
}
service
,
refresher
:=
buildPathAService
(
repo
,
cache
,
invalidator
)
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
,
refresher
,
time
.
Hour
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
repo
.
updateCalls
)
// DB 更新被调用
require
.
Equal
(
t
,
1
,
invalidator
.
calls
)
// 缓存失效被调用
require
.
Equal
(
t
,
1
,
cache
.
releaseCalls
)
// 锁被释放
}
// TestPathA_LockHeld 锁被其他 worker 持有 → 返回 errRefreshSkipped
func
TestPathA_LockHeld
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
ID
:
101
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
,
}
repo
:=
&
tokenRefreshAccountRepo
{}
invalidator
:=
&
tokenCacheInvalidatorStub
{}
cache
:=
&
mockTokenCacheForRefreshAPI
{
lockResult
:
false
}
// 锁获取失败(被占)
service
,
refresher
:=
buildPathAService
(
repo
,
cache
,
invalidator
)
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
,
refresher
,
time
.
Hour
)
require
.
ErrorIs
(
t
,
err
,
errRefreshSkipped
)
require
.
Equal
(
t
,
0
,
repo
.
updateCalls
)
// 不应更新 DB
require
.
Equal
(
t
,
0
,
invalidator
.
calls
)
// 不应触发缓存失效
}
// TestPathA_AlreadyRefreshed 二次检查发现已被其他路径刷新 → 返回 errRefreshSkipped
func
TestPathA_AlreadyRefreshed
(
t
*
testing
.
T
)
{
// NeedsRefresh 返回 false → RefreshIfNeeded 返回 {Refreshed: false}
account
:=
&
Account
{
ID
:
102
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
,
}
repo
:=
&
tokenRefreshAccountRepo
{}
repo
.
accountsByID
=
map
[
int64
]
*
Account
{
account
.
ID
:
account
}
invalidator
:=
&
tokenCacheInvalidatorStub
{}
cache
:=
&
mockTokenCacheForRefreshAPI
{
lockResult
:
true
}
service
,
_
:=
buildPathAService
(
repo
,
cache
,
invalidator
)
// 使用一个 NeedsRefresh 返回 false 的 stub
noRefreshNeeded
:=
&
tokenRefresherStub
{
credentials
:
map
[
string
]
any
{
"access_token"
:
"token"
},
}
// 覆盖 NeedsRefresh 行为 — 我们需要一个新的 stub 类型
alwaysFreshStub
:=
&
alwaysFreshRefresherStub
{}
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
noRefreshNeeded
,
alwaysFreshStub
,
time
.
Hour
)
require
.
ErrorIs
(
t
,
err
,
errRefreshSkipped
)
require
.
Equal
(
t
,
0
,
repo
.
updateCalls
)
require
.
Equal
(
t
,
0
,
invalidator
.
calls
)
}
// alwaysFreshRefresherStub 二次检查时认为不需要刷新(模拟已被其他路径刷新)
type
alwaysFreshRefresherStub
struct
{}
func
(
r
*
alwaysFreshRefresherStub
)
CanRefresh
(
_
*
Account
)
bool
{
return
true
}
func
(
r
*
alwaysFreshRefresherStub
)
NeedsRefresh
(
_
*
Account
,
_
time
.
Duration
)
bool
{
return
false
}
func
(
r
*
alwaysFreshRefresherStub
)
Refresh
(
_
context
.
Context
,
_
*
Account
)
(
map
[
string
]
any
,
error
)
{
return
nil
,
errors
.
New
(
"should not be called"
)
}
func
(
r
*
alwaysFreshRefresherStub
)
CacheKey
(
account
*
Account
)
string
{
return
"test:fresh:"
+
account
.
Platform
}
// TestPathA_NonRetryableError 统一 API 路径返回不可重试错误 → SetError
func
TestPathA_NonRetryableError
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
ID
:
103
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
,
}
repo
:=
&
tokenRefreshAccountRepo
{}
repo
.
accountsByID
=
map
[
int64
]
*
Account
{
account
.
ID
:
account
}
invalidator
:=
&
tokenCacheInvalidatorStub
{}
cache
:=
&
mockTokenCacheForRefreshAPI
{
lockResult
:
true
}
service
,
_
:=
buildPathAService
(
repo
,
cache
,
invalidator
)
refresher
:=
&
tokenRefresherStub
{
err
:
errors
.
New
(
"invalid_grant: token revoked"
),
}
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
,
refresher
,
time
.
Hour
)
require
.
Error
(
t
,
err
)
require
.
Equal
(
t
,
1
,
repo
.
setErrorCalls
)
// 应标记 error 状态
require
.
Equal
(
t
,
0
,
repo
.
updateCalls
)
// 不应更新 credentials
require
.
Equal
(
t
,
0
,
invalidator
.
calls
)
// 不应触发缓存失效
}
// TestPathA_RetryableErrorExhausted 统一 API 路径可重试错误耗尽 → 不标记 error
func
TestPathA_RetryableErrorExhausted
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
ID
:
104
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
,
}
repo
:=
&
tokenRefreshAccountRepo
{}
repo
.
accountsByID
=
map
[
int64
]
*
Account
{
account
.
ID
:
account
}
invalidator
:=
&
tokenCacheInvalidatorStub
{}
cache
:=
&
mockTokenCacheForRefreshAPI
{
lockResult
:
true
}
cfg
:=
&
config
.
Config
{
TokenRefresh
:
config
.
TokenRefreshConfig
{
MaxRetries
:
2
,
RetryBackoffSeconds
:
0
,
},
}
service
:=
NewTokenRefreshService
(
repo
,
nil
,
nil
,
nil
,
nil
,
invalidator
,
nil
,
cfg
,
nil
)
refreshAPI
:=
NewOAuthRefreshAPI
(
repo
,
cache
)
service
.
SetRefreshAPI
(
refreshAPI
)
refresher
:=
&
tokenRefresherStub
{
err
:
errors
.
New
(
"network timeout"
),
}
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
,
refresher
,
time
.
Hour
)
require
.
Error
(
t
,
err
)
require
.
Equal
(
t
,
0
,
repo
.
setErrorCalls
)
// 可重试错误不标记 error
require
.
Equal
(
t
,
0
,
repo
.
updateCalls
)
// 刷新失败不应更新
require
.
Equal
(
t
,
0
,
invalidator
.
calls
)
// 不应触发缓存失效
}
// TestPathA_DBUpdateFailed 统一 API 路径 DB 更新失败 → 返回 error,不执行 postRefreshActions
func
TestPathA_DBUpdateFailed
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
ID
:
105
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
,
}
repo
:=
&
tokenRefreshAccountRepo
{
updateErr
:
errors
.
New
(
"db connection lost"
)}
repo
.
accountsByID
=
map
[
int64
]
*
Account
{
account
.
ID
:
account
}
invalidator
:=
&
tokenCacheInvalidatorStub
{}
cache
:=
&
mockTokenCacheForRefreshAPI
{
lockResult
:
true
}
service
,
refresher
:=
buildPathAService
(
repo
,
cache
,
invalidator
)
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
,
refresher
,
time
.
Hour
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"DB update failed"
)
require
.
Equal
(
t
,
1
,
repo
.
updateCalls
)
// DB 更新被尝试
require
.
Equal
(
t
,
0
,
invalidator
.
calls
)
// DB 失败时不应触发缓存失效
}
backend/internal/service/token_refresher.go
View file @
ec82c37d
...
...
@@ -3,7 +3,6 @@ package service
import
(
"context"
"log"
"strconv"
"time"
)
...
...
@@ -33,6 +32,11 @@ func NewClaudeTokenRefresher(oauthService *OAuthService) *ClaudeTokenRefresher {
}
}
// CacheKey 返回用于分布式锁的缓存键
func
(
r
*
ClaudeTokenRefresher
)
CacheKey
(
account
*
Account
)
string
{
return
ClaudeTokenCacheKey
(
account
)
}
// CanRefresh 检查是否能处理此账号
// 只处理 anthropic 平台的 oauth 类型账号
// setup-token 虽然也是OAuth,但有效期1年,不需要频繁刷新
...
...
@@ -59,24 +63,8 @@ func (r *ClaudeTokenRefresher) Refresh(ctx context.Context, account *Account) (m
return
nil
,
err
}
// 保留现有credentials中的所有字段
newCredentials
:=
make
(
map
[
string
]
any
)
for
k
,
v
:=
range
account
.
Credentials
{
newCredentials
[
k
]
=
v
}
// 只更新token相关字段
// 注意:expires_at 和 expires_in 必须存为字符串,因为 GetCredential 只返回 string 类型
newCredentials
[
"access_token"
]
=
tokenInfo
.
AccessToken
newCredentials
[
"token_type"
]
=
tokenInfo
.
TokenType
newCredentials
[
"expires_in"
]
=
strconv
.
FormatInt
(
tokenInfo
.
ExpiresIn
,
10
)
newCredentials
[
"expires_at"
]
=
strconv
.
FormatInt
(
tokenInfo
.
ExpiresAt
,
10
)
if
tokenInfo
.
RefreshToken
!=
""
{
newCredentials
[
"refresh_token"
]
=
tokenInfo
.
RefreshToken
}
if
tokenInfo
.
Scope
!=
""
{
newCredentials
[
"scope"
]
=
tokenInfo
.
Scope
}
newCredentials
:=
BuildClaudeAccountCredentials
(
tokenInfo
)
newCredentials
=
MergeCredentials
(
account
.
Credentials
,
newCredentials
)
return
newCredentials
,
nil
}
...
...
@@ -97,6 +85,11 @@ func NewOpenAITokenRefresher(openaiOAuthService *OpenAIOAuthService, accountRepo
}
}
// CacheKey 返回用于分布式锁的缓存键
func
(
r
*
OpenAITokenRefresher
)
CacheKey
(
account
*
Account
)
string
{
return
OpenAITokenCacheKey
(
account
)
}
// SetSoraAccountRepo 设置 Sora 账号扩展表仓储
// 用于在 Token 刷新时同步更新 sora_accounts 表
// 如果未设置,syncLinkedSoraAccounts 只会更新 accounts.credentials
...
...
@@ -137,13 +130,7 @@ func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (m
// 使用服务提供的方法构建新凭证,并保留原有字段
newCredentials
:=
r
.
openaiOAuthService
.
BuildAccountCredentials
(
tokenInfo
)
// 保留原有credentials中非token相关字段
for
k
,
v
:=
range
account
.
Credentials
{
if
_
,
exists
:=
newCredentials
[
k
];
!
exists
{
newCredentials
[
k
]
=
v
}
}
newCredentials
=
MergeCredentials
(
account
.
Credentials
,
newCredentials
)
// 异步同步关联的 Sora 账号(不阻塞主流程)
if
r
.
accountRepo
!=
nil
&&
r
.
syncLinkedSora
{
...
...
backend/internal/service/wire.go
View file @
ec82c37d
...
...
@@ -51,16 +51,77 @@ func ProvideTokenRefreshService(
tempUnschedCache
TempUnschedCache
,
privacyClientFactory
PrivacyClientFactory
,
proxyRepo
ProxyRepository
,
refreshAPI
*
OAuthRefreshAPI
,
)
*
TokenRefreshService
{
svc
:=
NewTokenRefreshService
(
accountRepo
,
oauthService
,
openaiOAuthService
,
geminiOAuthService
,
antigravityOAuthService
,
cacheInvalidator
,
schedulerCache
,
cfg
,
tempUnschedCache
)
// 注入 Sora 账号扩展表仓储,用于 OpenAI Token 刷新时同步 sora_accounts 表
svc
.
SetSoraAccountRepo
(
soraAccountRepo
)
// 注入 OpenAI privacy opt-out 依赖
svc
.
SetPrivacyDeps
(
privacyClientFactory
,
proxyRepo
)
// 注入统一 OAuth 刷新 API(消除 TokenRefreshService 与 TokenProvider 之间的竞争条件)
svc
.
SetRefreshAPI
(
refreshAPI
)
// 调用侧显式注入后台刷新策略,避免策略漂移
svc
.
SetRefreshPolicy
(
DefaultBackgroundRefreshPolicy
())
svc
.
Start
()
return
svc
}
// ProvideClaudeTokenProvider creates ClaudeTokenProvider with OAuthRefreshAPI injection
func
ProvideClaudeTokenProvider
(
accountRepo
AccountRepository
,
tokenCache
GeminiTokenCache
,
oauthService
*
OAuthService
,
refreshAPI
*
OAuthRefreshAPI
,
)
*
ClaudeTokenProvider
{
p
:=
NewClaudeTokenProvider
(
accountRepo
,
tokenCache
,
oauthService
)
executor
:=
NewClaudeTokenRefresher
(
oauthService
)
p
.
SetRefreshAPI
(
refreshAPI
,
executor
)
p
.
SetRefreshPolicy
(
ClaudeProviderRefreshPolicy
())
return
p
}
// ProvideOpenAITokenProvider creates OpenAITokenProvider with OAuthRefreshAPI injection
func
ProvideOpenAITokenProvider
(
accountRepo
AccountRepository
,
tokenCache
GeminiTokenCache
,
openaiOAuthService
*
OpenAIOAuthService
,
refreshAPI
*
OAuthRefreshAPI
,
)
*
OpenAITokenProvider
{
p
:=
NewOpenAITokenProvider
(
accountRepo
,
tokenCache
,
openaiOAuthService
)
executor
:=
NewOpenAITokenRefresher
(
openaiOAuthService
,
accountRepo
)
p
.
SetRefreshAPI
(
refreshAPI
,
executor
)
p
.
SetRefreshPolicy
(
OpenAIProviderRefreshPolicy
())
return
p
}
// ProvideGeminiTokenProvider creates GeminiTokenProvider with OAuthRefreshAPI injection
func
ProvideGeminiTokenProvider
(
accountRepo
AccountRepository
,
tokenCache
GeminiTokenCache
,
geminiOAuthService
*
GeminiOAuthService
,
refreshAPI
*
OAuthRefreshAPI
,
)
*
GeminiTokenProvider
{
p
:=
NewGeminiTokenProvider
(
accountRepo
,
tokenCache
,
geminiOAuthService
)
executor
:=
NewGeminiTokenRefresher
(
geminiOAuthService
)
p
.
SetRefreshAPI
(
refreshAPI
,
executor
)
p
.
SetRefreshPolicy
(
GeminiProviderRefreshPolicy
())
return
p
}
// ProvideAntigravityTokenProvider creates AntigravityTokenProvider with OAuthRefreshAPI injection
func
ProvideAntigravityTokenProvider
(
accountRepo
AccountRepository
,
tokenCache
GeminiTokenCache
,
antigravityOAuthService
*
AntigravityOAuthService
,
refreshAPI
*
OAuthRefreshAPI
,
)
*
AntigravityTokenProvider
{
p
:=
NewAntigravityTokenProvider
(
accountRepo
,
tokenCache
,
antigravityOAuthService
)
executor
:=
NewAntigravityTokenRefresher
(
antigravityOAuthService
)
p
.
SetRefreshAPI
(
refreshAPI
,
executor
)
p
.
SetRefreshPolicy
(
AntigravityProviderRefreshPolicy
())
return
p
}
// ProvideDashboardAggregationService 创建并启动仪表盘聚合服务
func
ProvideDashboardAggregationService
(
repo
DashboardAggregationRepository
,
timingWheel
*
TimingWheelService
,
cfg
*
config
.
Config
)
*
DashboardAggregationService
{
svc
:=
NewDashboardAggregationService
(
repo
,
timingWheel
,
cfg
)
...
...
@@ -375,11 +436,12 @@ var ProviderSet = wire.NewSet(
NewCompositeTokenCacheInvalidator
,
wire
.
Bind
(
new
(
TokenCacheInvalidator
),
new
(
*
CompositeTokenCacheInvalidator
)),
NewAntigravityOAuthService
,
NewGeminiTokenProvider
,
NewOAuthRefreshAPI
,
ProvideGeminiTokenProvider
,
NewGeminiMessagesCompatService
,
New
AntigravityTokenProvider
,
New
OpenAITokenProvider
,
New
ClaudeTokenProvider
,
Provide
AntigravityTokenProvider
,
Provide
OpenAITokenProvider
,
Provide
ClaudeTokenProvider
,
NewAntigravityGatewayService
,
ProvideRateLimitService
,
NewAccountUsageService
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment