Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
3ba4d535
Commit
3ba4d535
authored
Jan 15, 2026
by
yangjianbo
Browse files
Merge branch 'dev'
parents
a65fd9de
5b37e9ae
Changes
18
Hide whitespace changes
Inline
Side-by-side
.gitignore
View file @
3ba4d535
...
...
@@ -127,3 +127,4 @@ deploy/docker-compose.override.yml
.gocache/
vite.config.js
docs/*
.serena/
\ No newline at end of file
backend/cmd/server/wire_gen.go
View file @
3ba4d535
...
...
@@ -100,8 +100,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
tempUnschedCache
:=
repository
.
NewTempUnschedCache
(
redisClient
)
timeoutCounterCache
:=
repository
.
NewTimeoutCounterCache
(
redisClient
)
geminiTokenCache
:=
repository
.
NewGeminiTokenCache
(
redisClient
)
t
okenCacheInvalidator
:=
service
.
NewCompositeTokenCacheInvalidator
(
geminiTokenCache
)
rateLimitService
:=
service
.
ProvideRateLimitService
(
accountRepository
,
usageLogRepository
,
configConfig
,
geminiQuotaService
,
tempUnschedCache
,
timeoutCounterCache
,
settingService
,
t
okenCacheInvalidator
)
compositeT
okenCacheInvalidator
:=
service
.
NewCompositeTokenCacheInvalidator
(
geminiTokenCache
)
rateLimitService
:=
service
.
ProvideRateLimitService
(
accountRepository
,
usageLogRepository
,
configConfig
,
geminiQuotaService
,
tempUnschedCache
,
timeoutCounterCache
,
settingService
,
compositeT
okenCacheInvalidator
)
claudeUsageFetcher
:=
repository
.
NewClaudeUsageFetcher
()
antigravityQuotaFetcher
:=
service
.
NewAntigravityQuotaFetcher
(
proxyRepository
)
usageCache
:=
service
.
NewUsageCache
()
...
...
@@ -136,8 +136,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
identityCache
:=
repository
.
NewIdentityCache
(
redisClient
)
identityService
:=
service
.
NewIdentityService
(
identityCache
)
deferredService
:=
service
.
ProvideDeferredService
(
accountRepository
,
timingWheelService
)
gatewayService
:=
service
.
NewGatewayService
(
accountRepository
,
groupRepository
,
usageLogRepository
,
userRepository
,
userSubscriptionRepository
,
gatewayCache
,
configConfig
,
schedulerSnapshotService
,
concurrencyService
,
billingService
,
rateLimitService
,
billingCacheService
,
identityService
,
httpUpstream
,
deferredService
)
openAIGatewayService
:=
service
.
NewOpenAIGatewayService
(
accountRepository
,
usageLogRepository
,
userRepository
,
userSubscriptionRepository
,
gatewayCache
,
configConfig
,
schedulerSnapshotService
,
concurrencyService
,
billingService
,
rateLimitService
,
billingCacheService
,
httpUpstream
,
deferredService
)
claudeTokenProvider
:=
service
.
NewClaudeTokenProvider
(
accountRepository
,
geminiTokenCache
,
oAuthService
)
gatewayService
:=
service
.
NewGatewayService
(
accountRepository
,
groupRepository
,
usageLogRepository
,
userRepository
,
userSubscriptionRepository
,
gatewayCache
,
configConfig
,
schedulerSnapshotService
,
concurrencyService
,
billingService
,
rateLimitService
,
billingCacheService
,
identityService
,
httpUpstream
,
deferredService
,
claudeTokenProvider
)
openAITokenProvider
:=
service
.
NewOpenAITokenProvider
(
accountRepository
,
geminiTokenCache
,
openAIOAuthService
)
openAIGatewayService
:=
service
.
NewOpenAIGatewayService
(
accountRepository
,
usageLogRepository
,
userRepository
,
userSubscriptionRepository
,
gatewayCache
,
configConfig
,
schedulerSnapshotService
,
concurrencyService
,
billingService
,
rateLimitService
,
billingCacheService
,
httpUpstream
,
deferredService
,
openAITokenProvider
)
geminiMessagesCompatService
:=
service
.
NewGeminiMessagesCompatService
(
accountRepository
,
groupRepository
,
gatewayCache
,
schedulerSnapshotService
,
geminiTokenProvider
,
rateLimitService
,
httpUpstream
,
antigravityGatewayService
,
configConfig
)
opsService
:=
service
.
NewOpsService
(
opsRepository
,
settingRepository
,
configConfig
,
accountRepository
,
concurrencyService
,
gatewayService
,
openAIGatewayService
,
geminiMessagesCompatService
,
antigravityGatewayService
)
settingHandler
:=
admin
.
NewSettingHandler
(
settingService
,
emailService
,
turnstileService
,
opsService
)
...
...
@@ -168,7 +170,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsAlertEvaluatorService
:=
service
.
ProvideOpsAlertEvaluatorService
(
opsService
,
opsRepository
,
emailService
,
redisClient
,
configConfig
)
opsCleanupService
:=
service
.
ProvideOpsCleanupService
(
opsRepository
,
db
,
redisClient
,
configConfig
)
opsScheduledReportService
:=
service
.
ProvideOpsScheduledReportService
(
opsService
,
userService
,
emailService
,
redisClient
,
configConfig
)
tokenRefreshService
:=
service
.
ProvideTokenRefreshService
(
accountRepository
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
,
antigravityOAuthService
,
t
okenCacheInvalidator
,
configConfig
)
tokenRefreshService
:=
service
.
ProvideTokenRefreshService
(
accountRepository
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
,
antigravityOAuthService
,
compositeT
okenCacheInvalidator
,
configConfig
)
accountExpiryService
:=
service
.
ProvideAccountExpiryService
(
accountRepository
)
v
:=
provideCleanup
(
client
,
redisClient
,
opsMetricsCollector
,
opsAggregationService
,
opsAlertEvaluatorService
,
opsCleanupService
,
opsScheduledReportService
,
schedulerSnapshotService
,
tokenRefreshService
,
accountExpiryService
,
pricingService
,
emailQueueService
,
billingCacheService
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
,
antigravityOAuthService
)
application
:=
&
Application
{
...
...
backend/internal/repository/gemini_token_cache.go
View file @
3ba4d535
...
...
@@ -11,8 +11,8 @@ import (
)
const
(
gemini
TokenKeyPrefix
=
"
gemini
:token:"
gemini
RefreshLockKeyPrefix
=
"
gemini
:refresh_lock:"
oauth
TokenKeyPrefix
=
"
oauth
:token:"
oauth
RefreshLockKeyPrefix
=
"
oauth
:refresh_lock:"
)
type
geminiTokenCache
struct
{
...
...
@@ -24,26 +24,26 @@ func NewGeminiTokenCache(rdb *redis.Client) service.GeminiTokenCache {
}
func
(
c
*
geminiTokenCache
)
GetAccessToken
(
ctx
context
.
Context
,
cacheKey
string
)
(
string
,
error
)
{
key
:=
fmt
.
Sprintf
(
"%s%s"
,
gemini
TokenKeyPrefix
,
cacheKey
)
key
:=
fmt
.
Sprintf
(
"%s%s"
,
oauth
TokenKeyPrefix
,
cacheKey
)
return
c
.
rdb
.
Get
(
ctx
,
key
)
.
Result
()
}
func
(
c
*
geminiTokenCache
)
SetAccessToken
(
ctx
context
.
Context
,
cacheKey
string
,
token
string
,
ttl
time
.
Duration
)
error
{
key
:=
fmt
.
Sprintf
(
"%s%s"
,
gemini
TokenKeyPrefix
,
cacheKey
)
key
:=
fmt
.
Sprintf
(
"%s%s"
,
oauth
TokenKeyPrefix
,
cacheKey
)
return
c
.
rdb
.
Set
(
ctx
,
key
,
token
,
ttl
)
.
Err
()
}
func
(
c
*
geminiTokenCache
)
DeleteAccessToken
(
ctx
context
.
Context
,
cacheKey
string
)
error
{
key
:=
fmt
.
Sprintf
(
"%s%s"
,
gemini
TokenKeyPrefix
,
cacheKey
)
key
:=
fmt
.
Sprintf
(
"%s%s"
,
oauth
TokenKeyPrefix
,
cacheKey
)
return
c
.
rdb
.
Del
(
ctx
,
key
)
.
Err
()
}
func
(
c
*
geminiTokenCache
)
AcquireRefreshLock
(
ctx
context
.
Context
,
cacheKey
string
,
ttl
time
.
Duration
)
(
bool
,
error
)
{
key
:=
fmt
.
Sprintf
(
"%s%s"
,
gemini
RefreshLockKeyPrefix
,
cacheKey
)
key
:=
fmt
.
Sprintf
(
"%s%s"
,
oauth
RefreshLockKeyPrefix
,
cacheKey
)
return
c
.
rdb
.
SetNX
(
ctx
,
key
,
1
,
ttl
)
.
Result
()
}
func
(
c
*
geminiTokenCache
)
ReleaseRefreshLock
(
ctx
context
.
Context
,
cacheKey
string
)
error
{
key
:=
fmt
.
Sprintf
(
"%s%s"
,
gemini
RefreshLockKeyPrefix
,
cacheKey
)
key
:=
fmt
.
Sprintf
(
"%s%s"
,
oauth
RefreshLockKeyPrefix
,
cacheKey
)
return
c
.
rdb
.
Del
(
ctx
,
key
)
.
Err
()
}
backend/internal/service/claude_token_provider.go
0 → 100644
View file @
3ba4d535
package
service
import
(
"context"
"errors"
"log/slog"
"strconv"
"strings"
"time"
)
const
(
claudeTokenRefreshSkew
=
3
*
time
.
Minute
claudeTokenCacheSkew
=
5
*
time
.
Minute
claudeLockWaitTime
=
200
*
time
.
Millisecond
)
// ClaudeTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
type
ClaudeTokenCache
=
GeminiTokenCache
// ClaudeTokenProvider 管理 Claude (Anthropic) OAuth 账户的 access_token
type
ClaudeTokenProvider
struct
{
accountRepo
AccountRepository
tokenCache
ClaudeTokenCache
oauthService
*
OAuthService
}
func
NewClaudeTokenProvider
(
accountRepo
AccountRepository
,
tokenCache
ClaudeTokenCache
,
oauthService
*
OAuthService
,
)
*
ClaudeTokenProvider
{
return
&
ClaudeTokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
tokenCache
,
oauthService
:
oauthService
,
}
}
// GetAccessToken 获取有效的 access_token
func
(
p
*
ClaudeTokenProvider
)
GetAccessToken
(
ctx
context
.
Context
,
account
*
Account
)
(
string
,
error
)
{
if
account
==
nil
{
return
""
,
errors
.
New
(
"account is nil"
)
}
if
account
.
Platform
!=
PlatformAnthropic
||
account
.
Type
!=
AccountTypeOAuth
{
return
""
,
errors
.
New
(
"not an anthropic oauth account"
)
}
cacheKey
:=
ClaudeTokenCacheKey
(
account
)
// 1. 先尝试缓存
if
p
.
tokenCache
!=
nil
{
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
strings
.
TrimSpace
(
token
)
!=
""
{
slog
.
Debug
(
"claude_token_cache_hit"
,
"account_id"
,
account
.
ID
)
return
token
,
nil
}
else
if
err
!=
nil
{
slog
.
Warn
(
"claude_token_cache_get_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
}
slog
.
Debug
(
"claude_token_cache_miss"
,
"account_id"
,
account
.
ID
)
// 2. 如果即将过期则刷新
expiresAt
:=
account
.
GetCredentialAsTime
(
"expires_at"
)
needsRefresh
:=
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
claudeTokenRefreshSkew
refreshFailed
:=
false
if
needsRefresh
&&
p
.
tokenCache
!=
nil
{
locked
,
lockErr
:=
p
.
tokenCache
.
AcquireRefreshLock
(
ctx
,
cacheKey
,
30
*
time
.
Second
)
if
lockErr
==
nil
&&
locked
{
defer
func
()
{
_
=
p
.
tokenCache
.
ReleaseRefreshLock
(
ctx
,
cacheKey
)
}()
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
strings
.
TrimSpace
(
token
)
!=
""
{
return
token
,
nil
}
// 从数据库获取最新账户信息
fresh
,
err
:=
p
.
accountRepo
.
GetByID
(
ctx
,
account
.
ID
)
if
err
==
nil
&&
fresh
!=
nil
{
account
=
fresh
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
if
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
claudeTokenRefreshSkew
{
if
p
.
oauthService
==
nil
{
slog
.
Warn
(
"claude_oauth_service_not_configured"
,
"account_id"
,
account
.
ID
)
refreshFailed
=
true
// 无法刷新,标记失败
}
else
{
tokenInfo
,
err
:=
p
.
oauthService
.
RefreshAccountToken
(
ctx
,
account
)
if
err
!=
nil
{
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
slog
.
Warn
(
"claude_token_refresh_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
refreshFailed
=
true
// 刷新失败,标记以使用短 TTL
}
else
{
// 构建新 credentials,保留原有字段
newCredentials
:=
make
(
map
[
string
]
any
)
for
k
,
v
:=
range
account
.
Credentials
{
newCredentials
[
k
]
=
v
}
newCredentials
[
"access_token"
]
=
tokenInfo
.
AccessToken
newCredentials
[
"token_type"
]
=
tokenInfo
.
TokenType
newCredentials
[
"expires_in"
]
=
strconv
.
FormatInt
(
tokenInfo
.
ExpiresIn
,
10
)
newCredentials
[
"expires_at"
]
=
strconv
.
FormatInt
(
tokenInfo
.
ExpiresAt
,
10
)
if
tokenInfo
.
RefreshToken
!=
""
{
newCredentials
[
"refresh_token"
]
=
tokenInfo
.
RefreshToken
}
if
tokenInfo
.
Scope
!=
""
{
newCredentials
[
"scope"
]
=
tokenInfo
.
Scope
}
account
.
Credentials
=
newCredentials
if
updateErr
:=
p
.
accountRepo
.
Update
(
ctx
,
account
);
updateErr
!=
nil
{
slog
.
Error
(
"claude_token_provider_update_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
updateErr
)
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
}
}
}
else
if
lockErr
!=
nil
{
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
slog
.
Warn
(
"claude_token_lock_failed_degraded_refresh"
,
"account_id"
,
account
.
ID
,
"error"
,
lockErr
)
// 检查 ctx 是否已取消
if
ctx
.
Err
()
!=
nil
{
return
""
,
ctx
.
Err
()
}
// 从数据库获取最新账户信息
if
p
.
accountRepo
!=
nil
{
fresh
,
err
:=
p
.
accountRepo
.
GetByID
(
ctx
,
account
.
ID
)
if
err
==
nil
&&
fresh
!=
nil
{
account
=
fresh
}
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
if
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
claudeTokenRefreshSkew
{
if
p
.
oauthService
==
nil
{
slog
.
Warn
(
"claude_oauth_service_not_configured"
,
"account_id"
,
account
.
ID
)
refreshFailed
=
true
}
else
{
tokenInfo
,
err
:=
p
.
oauthService
.
RefreshAccountToken
(
ctx
,
account
)
if
err
!=
nil
{
slog
.
Warn
(
"claude_token_refresh_failed_degraded"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
refreshFailed
=
true
}
else
{
// 构建新 credentials,保留原有字段
newCredentials
:=
make
(
map
[
string
]
any
)
for
k
,
v
:=
range
account
.
Credentials
{
newCredentials
[
k
]
=
v
}
newCredentials
[
"access_token"
]
=
tokenInfo
.
AccessToken
newCredentials
[
"token_type"
]
=
tokenInfo
.
TokenType
newCredentials
[
"expires_in"
]
=
strconv
.
FormatInt
(
tokenInfo
.
ExpiresIn
,
10
)
newCredentials
[
"expires_at"
]
=
strconv
.
FormatInt
(
tokenInfo
.
ExpiresAt
,
10
)
if
tokenInfo
.
RefreshToken
!=
""
{
newCredentials
[
"refresh_token"
]
=
tokenInfo
.
RefreshToken
}
if
tokenInfo
.
Scope
!=
""
{
newCredentials
[
"scope"
]
=
tokenInfo
.
Scope
}
account
.
Credentials
=
newCredentials
if
updateErr
:=
p
.
accountRepo
.
Update
(
ctx
,
account
);
updateErr
!=
nil
{
slog
.
Error
(
"claude_token_provider_update_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
updateErr
)
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
}
}
}
else
{
// 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存
time
.
Sleep
(
claudeLockWaitTime
)
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
strings
.
TrimSpace
(
token
)
!=
""
{
slog
.
Debug
(
"claude_token_cache_hit_after_wait"
,
"account_id"
,
account
.
ID
)
return
token
,
nil
}
}
}
accessToken
:=
account
.
GetCredential
(
"access_token"
)
if
strings
.
TrimSpace
(
accessToken
)
==
""
{
return
""
,
errors
.
New
(
"access_token not found in credentials"
)
}
// 3. 存入缓存
if
p
.
tokenCache
!=
nil
{
ttl
:=
30
*
time
.
Minute
if
refreshFailed
{
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
ttl
=
time
.
Minute
slog
.
Debug
(
"claude_token_cache_short_ttl"
,
"account_id"
,
account
.
ID
,
"reason"
,
"refresh_failed"
)
}
else
if
expiresAt
!=
nil
{
until
:=
time
.
Until
(
*
expiresAt
)
switch
{
case
until
>
claudeTokenCacheSkew
:
ttl
=
until
-
claudeTokenCacheSkew
case
until
>
0
:
ttl
=
until
default
:
ttl
=
time
.
Minute
}
}
if
err
:=
p
.
tokenCache
.
SetAccessToken
(
ctx
,
cacheKey
,
accessToken
,
ttl
);
err
!=
nil
{
slog
.
Warn
(
"claude_token_cache_set_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
}
return
accessToken
,
nil
}
backend/internal/service/claude_token_provider_test.go
0 → 100644
View file @
3ba4d535
//go:build unit
package
service
import
(
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// claudeTokenCacheStub implements ClaudeTokenCache for testing
type
claudeTokenCacheStub
struct
{
mu
sync
.
Mutex
tokens
map
[
string
]
string
getErr
error
setErr
error
deleteErr
error
lockAcquired
bool
lockErr
error
releaseLockErr
error
getCalled
int32
setCalled
int32
lockCalled
int32
unlockCalled
int32
simulateLockRace
bool
}
func
newClaudeTokenCacheStub
()
*
claudeTokenCacheStub
{
return
&
claudeTokenCacheStub
{
tokens
:
make
(
map
[
string
]
string
),
lockAcquired
:
true
,
}
}
func
(
s
*
claudeTokenCacheStub
)
GetAccessToken
(
ctx
context
.
Context
,
cacheKey
string
)
(
string
,
error
)
{
atomic
.
AddInt32
(
&
s
.
getCalled
,
1
)
if
s
.
getErr
!=
nil
{
return
""
,
s
.
getErr
}
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
return
s
.
tokens
[
cacheKey
],
nil
}
func
(
s
*
claudeTokenCacheStub
)
SetAccessToken
(
ctx
context
.
Context
,
cacheKey
string
,
token
string
,
ttl
time
.
Duration
)
error
{
atomic
.
AddInt32
(
&
s
.
setCalled
,
1
)
if
s
.
setErr
!=
nil
{
return
s
.
setErr
}
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
s
.
tokens
[
cacheKey
]
=
token
return
nil
}
func
(
s
*
claudeTokenCacheStub
)
DeleteAccessToken
(
ctx
context
.
Context
,
cacheKey
string
)
error
{
if
s
.
deleteErr
!=
nil
{
return
s
.
deleteErr
}
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
delete
(
s
.
tokens
,
cacheKey
)
return
nil
}
func
(
s
*
claudeTokenCacheStub
)
AcquireRefreshLock
(
ctx
context
.
Context
,
cacheKey
string
,
ttl
time
.
Duration
)
(
bool
,
error
)
{
atomic
.
AddInt32
(
&
s
.
lockCalled
,
1
)
if
s
.
lockErr
!=
nil
{
return
false
,
s
.
lockErr
}
if
s
.
simulateLockRace
{
return
false
,
nil
}
return
s
.
lockAcquired
,
nil
}
func
(
s
*
claudeTokenCacheStub
)
ReleaseRefreshLock
(
ctx
context
.
Context
,
cacheKey
string
)
error
{
atomic
.
AddInt32
(
&
s
.
unlockCalled
,
1
)
return
s
.
releaseLockErr
}
// claudeAccountRepoStub is a minimal stub implementing only the methods used by ClaudeTokenProvider
type
claudeAccountRepoStub
struct
{
account
*
Account
getErr
error
updateErr
error
getCalled
int32
updateCalled
int32
}
func
(
r
*
claudeAccountRepoStub
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Account
,
error
)
{
atomic
.
AddInt32
(
&
r
.
getCalled
,
1
)
if
r
.
getErr
!=
nil
{
return
nil
,
r
.
getErr
}
return
r
.
account
,
nil
}
func
(
r
*
claudeAccountRepoStub
)
Update
(
ctx
context
.
Context
,
account
*
Account
)
error
{
atomic
.
AddInt32
(
&
r
.
updateCalled
,
1
)
if
r
.
updateErr
!=
nil
{
return
r
.
updateErr
}
r
.
account
=
account
return
nil
}
// claudeOAuthServiceStub implements OAuthService methods for testing
type
claudeOAuthServiceStub
struct
{
tokenInfo
*
TokenInfo
refreshErr
error
refreshCalled
int32
}
func
(
s
*
claudeOAuthServiceStub
)
RefreshAccountToken
(
ctx
context
.
Context
,
account
*
Account
)
(
*
TokenInfo
,
error
)
{
atomic
.
AddInt32
(
&
s
.
refreshCalled
,
1
)
if
s
.
refreshErr
!=
nil
{
return
nil
,
s
.
refreshErr
}
return
s
.
tokenInfo
,
nil
}
// testClaudeTokenProvider is a test version that uses the stub OAuth service
type
testClaudeTokenProvider
struct
{
accountRepo
*
claudeAccountRepoStub
tokenCache
*
claudeTokenCacheStub
oauthService
*
claudeOAuthServiceStub
}
func
(
p
*
testClaudeTokenProvider
)
GetAccessToken
(
ctx
context
.
Context
,
account
*
Account
)
(
string
,
error
)
{
if
account
==
nil
{
return
""
,
errors
.
New
(
"account is nil"
)
}
if
account
.
Platform
!=
PlatformAnthropic
||
account
.
Type
!=
AccountTypeOAuth
{
return
""
,
errors
.
New
(
"not an anthropic oauth account"
)
}
cacheKey
:=
ClaudeTokenCacheKey
(
account
)
// 1. Check cache
if
p
.
tokenCache
!=
nil
{
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
token
!=
""
{
return
token
,
nil
}
}
// 2. Check if refresh needed
expiresAt
:=
account
.
GetCredentialAsTime
(
"expires_at"
)
needsRefresh
:=
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
claudeTokenRefreshSkew
refreshFailed
:=
false
if
needsRefresh
&&
p
.
tokenCache
!=
nil
{
locked
,
err
:=
p
.
tokenCache
.
AcquireRefreshLock
(
ctx
,
cacheKey
,
30
*
time
.
Second
)
if
err
==
nil
&&
locked
{
defer
func
()
{
_
=
p
.
tokenCache
.
ReleaseRefreshLock
(
ctx
,
cacheKey
)
}()
// Check cache again after acquiring lock
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
token
!=
""
{
return
token
,
nil
}
// Get fresh account from DB
fresh
,
err
:=
p
.
accountRepo
.
GetByID
(
ctx
,
account
.
ID
)
if
err
==
nil
&&
fresh
!=
nil
{
account
=
fresh
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
if
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
claudeTokenRefreshSkew
{
if
p
.
oauthService
==
nil
{
refreshFailed
=
true
// 无法刷新,标记失败
}
else
{
tokenInfo
,
err
:=
p
.
oauthService
.
RefreshAccountToken
(
ctx
,
account
)
if
err
!=
nil
{
refreshFailed
=
true
// 刷新失败,标记以使用短 TTL
}
else
{
// Build new credentials
newCredentials
:=
make
(
map
[
string
]
any
)
for
k
,
v
:=
range
account
.
Credentials
{
newCredentials
[
k
]
=
v
}
newCredentials
[
"access_token"
]
=
tokenInfo
.
AccessToken
newCredentials
[
"token_type"
]
=
tokenInfo
.
TokenType
newCredentials
[
"expires_at"
]
=
time
.
Now
()
.
Add
(
time
.
Duration
(
tokenInfo
.
ExpiresIn
)
*
time
.
Second
)
.
Format
(
time
.
RFC3339
)
if
tokenInfo
.
RefreshToken
!=
""
{
newCredentials
[
"refresh_token"
]
=
tokenInfo
.
RefreshToken
}
account
.
Credentials
=
newCredentials
_
=
p
.
accountRepo
.
Update
(
ctx
,
account
)
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
}
}
}
else
if
p
.
tokenCache
.
simulateLockRace
{
// Wait and retry cache
time
.
Sleep
(
10
*
time
.
Millisecond
)
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
token
!=
""
{
return
token
,
nil
}
}
}
accessToken
:=
account
.
GetCredential
(
"access_token"
)
if
accessToken
==
""
{
return
""
,
errors
.
New
(
"access_token not found in credentials"
)
}
// 3. Store in cache
if
p
.
tokenCache
!=
nil
{
ttl
:=
30
*
time
.
Minute
if
refreshFailed
{
ttl
=
time
.
Minute
// 刷新失败时使用短 TTL
}
else
if
expiresAt
!=
nil
{
until
:=
time
.
Until
(
*
expiresAt
)
if
until
>
claudeTokenCacheSkew
{
ttl
=
until
-
claudeTokenCacheSkew
}
else
if
until
>
0
{
ttl
=
until
}
else
{
ttl
=
time
.
Minute
}
}
_
=
p
.
tokenCache
.
SetAccessToken
(
ctx
,
cacheKey
,
accessToken
,
ttl
)
}
return
accessToken
,
nil
}
func
TestClaudeTokenProvider_CacheHit
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
account
:=
&
Account
{
ID
:
100
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"db-token"
,
},
}
cacheKey
:=
ClaudeTokenCacheKey
(
account
)
cache
.
tokens
[
cacheKey
]
=
"cached-token"
provider
:=
NewClaudeTokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"cached-token"
,
token
)
require
.
Equal
(
t
,
int32
(
1
),
atomic
.
LoadInt32
(
&
cache
.
getCalled
))
require
.
Equal
(
t
,
int32
(
0
),
atomic
.
LoadInt32
(
&
cache
.
setCalled
))
}
func
TestClaudeTokenProvider_CacheMiss_FromCredentials
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
// Token expires in far future, no refresh needed
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
101
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"credential-token"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewClaudeTokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"credential-token"
,
token
)
// Should have stored in cache
cacheKey
:=
ClaudeTokenCacheKey
(
account
)
require
.
Equal
(
t
,
"credential-token"
,
cache
.
tokens
[
cacheKey
])
}
func
TestClaudeTokenProvider_TokenRefresh
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
accountRepo
:=
&
claudeAccountRepoStub
{}
oauthService
:=
&
claudeOAuthServiceStub
{
tokenInfo
:
&
TokenInfo
{
AccessToken
:
"refreshed-token"
,
RefreshToken
:
"new-refresh-token"
,
TokenType
:
"Bearer"
,
ExpiresIn
:
3600
,
ExpiresAt
:
time
.
Now
()
.
Add
(
time
.
Hour
)
.
Unix
(),
},
}
// Token expires soon (within refresh skew)
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
102
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"old-token"
,
"refresh_token"
:
"old-refresh-token"
,
"expires_at"
:
expiresAt
,
},
}
accountRepo
.
account
=
account
provider
:=
&
testClaudeTokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
cache
,
oauthService
:
oauthService
,
}
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"refreshed-token"
,
token
)
require
.
Equal
(
t
,
int32
(
1
),
atomic
.
LoadInt32
(
&
oauthService
.
refreshCalled
))
}
func
TestClaudeTokenProvider_LockRaceCondition
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
cache
.
simulateLockRace
=
true
accountRepo
:=
&
claudeAccountRepoStub
{}
// Token expires soon
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
103
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"race-token"
,
"expires_at"
:
expiresAt
,
},
}
accountRepo
.
account
=
account
// Simulate another worker already refreshed and cached
cacheKey
:=
ClaudeTokenCacheKey
(
account
)
go
func
()
{
time
.
Sleep
(
5
*
time
.
Millisecond
)
cache
.
mu
.
Lock
()
cache
.
tokens
[
cacheKey
]
=
"winner-token"
cache
.
mu
.
Unlock
()
}()
provider
:=
&
testClaudeTokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
cache
,
}
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
NotEmpty
(
t
,
token
)
}
func
TestClaudeTokenProvider_NilAccount
(
t
*
testing
.
T
)
{
provider
:=
NewClaudeTokenProvider
(
nil
,
nil
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
nil
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"account is nil"
)
require
.
Empty
(
t
,
token
)
}
func
TestClaudeTokenProvider_WrongPlatform
(
t
*
testing
.
T
)
{
provider
:=
NewClaudeTokenProvider
(
nil
,
nil
,
nil
)
account
:=
&
Account
{
ID
:
104
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
}
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"not an anthropic oauth account"
)
require
.
Empty
(
t
,
token
)
}
func
TestClaudeTokenProvider_WrongAccountType
(
t
*
testing
.
T
)
{
provider
:=
NewClaudeTokenProvider
(
nil
,
nil
,
nil
)
account
:=
&
Account
{
ID
:
105
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
}
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"not an anthropic oauth account"
)
require
.
Empty
(
t
,
token
)
}
func
TestClaudeTokenProvider_SetupTokenType
(
t
*
testing
.
T
)
{
provider
:=
NewClaudeTokenProvider
(
nil
,
nil
,
nil
)
account
:=
&
Account
{
ID
:
106
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeSetupToken
,
}
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"not an anthropic oauth account"
)
require
.
Empty
(
t
,
token
)
}
func
TestClaudeTokenProvider_NilCache
(
t
*
testing
.
T
)
{
// Token doesn't need refresh
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
107
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"nocache-token"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewClaudeTokenProvider
(
nil
,
nil
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"nocache-token"
,
token
)
}
func
TestClaudeTokenProvider_CacheGetError
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
cache
.
getErr
=
errors
.
New
(
"redis connection failed"
)
// Token doesn't need refresh
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
108
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"fallback-token"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewClaudeTokenProvider
(
nil
,
cache
,
nil
)
// Should gracefully degrade and return from credentials
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"fallback-token"
,
token
)
}
func
TestClaudeTokenProvider_CacheSetError
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
cache
.
setErr
=
errors
.
New
(
"redis write failed"
)
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
109
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"still-works-token"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewClaudeTokenProvider
(
nil
,
cache
,
nil
)
// Should still work even if cache set fails
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"still-works-token"
,
token
)
}
func
TestClaudeTokenProvider_MissingAccessToken
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
110
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"expires_at"
:
expiresAt
,
// missing access_token
},
}
provider
:=
NewClaudeTokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"access_token not found"
)
require
.
Empty
(
t
,
token
)
}
func
TestClaudeTokenProvider_RefreshError
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
accountRepo
:=
&
claudeAccountRepoStub
{}
oauthService
:=
&
claudeOAuthServiceStub
{
refreshErr
:
errors
.
New
(
"oauth refresh failed"
),
}
// Token expires soon
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
111
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"old-token"
,
"refresh_token"
:
"old-refresh-token"
,
"expires_at"
:
expiresAt
,
},
}
accountRepo
.
account
=
account
provider
:=
&
testClaudeTokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
cache
,
oauthService
:
oauthService
,
}
// Now with fallback behavior, should return existing token even if refresh fails
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"old-token"
,
token
)
// Fallback to existing token
}
func
TestClaudeTokenProvider_OAuthServiceNotConfigured
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
accountRepo
:=
&
claudeAccountRepoStub
{}
// Token expires soon
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
112
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"old-token"
,
"expires_at"
:
expiresAt
,
},
}
accountRepo
.
account
=
account
provider
:=
&
testClaudeTokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
cache
,
oauthService
:
nil
,
// not configured
}
// Now with fallback behavior, should return existing token even if oauth service not configured
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"old-token"
,
token
)
// Fallback to existing token
}
func
TestClaudeTokenProvider_TTLCalculation
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
expiresIn
time
.
Duration
}{
{
name
:
"far_future_expiry"
,
expiresIn
:
1
*
time
.
Hour
,
},
{
name
:
"medium_expiry"
,
expiresIn
:
10
*
time
.
Minute
,
},
{
name
:
"near_expiry"
,
expiresIn
:
6
*
time
.
Minute
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
expiresAt
:=
time
.
Now
()
.
Add
(
tt
.
expiresIn
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
200
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"test-token"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewClaudeTokenProvider
(
nil
,
cache
,
nil
)
_
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
// Verify token was cached
cacheKey
:=
ClaudeTokenCacheKey
(
account
)
require
.
Equal
(
t
,
"test-token"
,
cache
.
tokens
[
cacheKey
])
})
}
}
func
TestClaudeTokenProvider_AccountRepoGetError
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
accountRepo
:=
&
claudeAccountRepoStub
{
getErr
:
errors
.
New
(
"db connection failed"
),
}
oauthService
:=
&
claudeOAuthServiceStub
{
tokenInfo
:
&
TokenInfo
{
AccessToken
:
"refreshed-token"
,
RefreshToken
:
"new-refresh"
,
TokenType
:
"Bearer"
,
ExpiresIn
:
3600
,
},
}
// Token expires soon
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
113
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"old-token"
,
"refresh_token"
:
"old-refresh"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
&
testClaudeTokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
cache
,
oauthService
:
oauthService
,
}
// Should still work, just using the passed-in account
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"refreshed-token"
,
token
)
}
func
TestClaudeTokenProvider_AccountUpdateError
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
accountRepo
:=
&
claudeAccountRepoStub
{
updateErr
:
errors
.
New
(
"db write failed"
),
}
oauthService
:=
&
claudeOAuthServiceStub
{
tokenInfo
:
&
TokenInfo
{
AccessToken
:
"refreshed-token"
,
RefreshToken
:
"new-refresh"
,
TokenType
:
"Bearer"
,
ExpiresIn
:
3600
,
},
}
// Token expires soon
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
114
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"old-token"
,
"refresh_token"
:
"old-refresh"
,
"expires_at"
:
expiresAt
,
},
}
accountRepo
.
account
=
account
provider
:=
&
testClaudeTokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
cache
,
oauthService
:
oauthService
,
}
// Should still return token even if update fails
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"refreshed-token"
,
token
)
}
func
TestClaudeTokenProvider_RefreshPreservesExistingCredentials
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
accountRepo
:=
&
claudeAccountRepoStub
{}
oauthService
:=
&
claudeOAuthServiceStub
{
tokenInfo
:
&
TokenInfo
{
AccessToken
:
"new-access-token"
,
RefreshToken
:
"new-refresh-token"
,
TokenType
:
"Bearer"
,
ExpiresIn
:
3600
,
},
}
// Token expires soon
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
115
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"old-access-token"
,
"refresh_token"
:
"old-refresh-token"
,
"expires_at"
:
expiresAt
,
"custom_field"
:
"should-be-preserved"
,
"organization"
:
"test-org"
,
},
}
accountRepo
.
account
=
account
provider
:=
&
testClaudeTokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
cache
,
oauthService
:
oauthService
,
}
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"new-access-token"
,
token
)
// Verify existing fields are preserved
require
.
Equal
(
t
,
"should-be-preserved"
,
accountRepo
.
account
.
Credentials
[
"custom_field"
])
require
.
Equal
(
t
,
"test-org"
,
accountRepo
.
account
.
Credentials
[
"organization"
])
// Verify new fields are updated
require
.
Equal
(
t
,
"new-access-token"
,
accountRepo
.
account
.
Credentials
[
"access_token"
])
require
.
Equal
(
t
,
"new-refresh-token"
,
accountRepo
.
account
.
Credentials
[
"refresh_token"
])
}
func
TestClaudeTokenProvider_DoubleCheckCacheAfterLock
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
accountRepo
:=
&
claudeAccountRepoStub
{}
oauthService
:=
&
claudeOAuthServiceStub
{
tokenInfo
:
&
TokenInfo
{
AccessToken
:
"refreshed-token"
,
RefreshToken
:
"new-refresh"
,
TokenType
:
"Bearer"
,
ExpiresIn
:
3600
,
},
}
// Token expires soon
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
116
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"old-token"
,
"expires_at"
:
expiresAt
,
},
}
accountRepo
.
account
=
account
cacheKey
:=
ClaudeTokenCacheKey
(
account
)
// After lock is acquired, cache should have the token (simulating another worker)
go
func
()
{
time
.
Sleep
(
5
*
time
.
Millisecond
)
cache
.
mu
.
Lock
()
cache
.
tokens
[
cacheKey
]
=
"cached-by-other-worker"
cache
.
mu
.
Unlock
()
}()
provider
:=
&
testClaudeTokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
cache
,
oauthService
:
oauthService
,
}
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
NotEmpty
(
t
,
token
)
}
// Tests for real provider - to increase coverage
func
TestClaudeTokenProvider_Real_LockFailedWait
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
cache
.
lockAcquired
=
false
// Lock acquisition fails
// Token expires soon (within refresh skew) to trigger lock attempt
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
300
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"fallback-token"
,
"expires_at"
:
expiresAt
,
},
}
// Set token in cache after lock wait period (simulate other worker refreshing)
cacheKey
:=
ClaudeTokenCacheKey
(
account
)
go
func
()
{
time
.
Sleep
(
100
*
time
.
Millisecond
)
cache
.
mu
.
Lock
()
cache
.
tokens
[
cacheKey
]
=
"refreshed-by-other"
cache
.
mu
.
Unlock
()
}()
provider
:=
NewClaudeTokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
NotEmpty
(
t
,
token
)
}
func
TestClaudeTokenProvider_Real_CacheHitAfterWait
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
cache
.
lockAcquired
=
false
// Lock acquisition fails
// Token expires soon
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
301
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"original-token"
,
"expires_at"
:
expiresAt
,
},
}
cacheKey
:=
ClaudeTokenCacheKey
(
account
)
// Set token in cache immediately after wait starts
go
func
()
{
time
.
Sleep
(
50
*
time
.
Millisecond
)
cache
.
mu
.
Lock
()
cache
.
tokens
[
cacheKey
]
=
"winner-token"
cache
.
mu
.
Unlock
()
}()
provider
:=
NewClaudeTokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
NotEmpty
(
t
,
token
)
}
func
TestClaudeTokenProvider_Real_NoExpiresAt
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
cache
.
lockAcquired
=
false
// Prevent entering refresh logic
// Token with nil expires_at (no expiry set)
account
:=
&
Account
{
ID
:
302
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"no-expiry-token"
,
},
}
// After lock wait, return token from credentials
provider
:=
NewClaudeTokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"no-expiry-token"
,
token
)
}
func
TestClaudeTokenProvider_Real_WhitespaceToken
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
cacheKey
:=
"claude:account:303"
cache
.
tokens
[
cacheKey
]
=
" "
// Whitespace only - should be treated as empty
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
303
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"real-token"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewClaudeTokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"real-token"
,
token
)
}
func
TestClaudeTokenProvider_Real_EmptyCredentialToken
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
304
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
" "
,
// Whitespace only
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewClaudeTokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"access_token not found"
)
require
.
Empty
(
t
,
token
)
}
func
TestClaudeTokenProvider_Real_LockError
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
cache
.
lockErr
=
errors
.
New
(
"redis lock failed"
)
// Token expires soon (within refresh skew)
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
305
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"fallback-on-lock-error"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewClaudeTokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"fallback-on-lock-error"
,
token
)
}
func
TestClaudeTokenProvider_Real_NilCredentials
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
306
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"expires_at"
:
expiresAt
,
// No access_token
},
}
provider
:=
NewClaudeTokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"access_token not found"
)
require
.
Empty
(
t
,
token
)
}
backend/internal/service/gateway_service.go
View file @
3ba4d535
...
...
@@ -144,21 +144,22 @@ func (e *UpstreamFailoverError) Error() string {
// GatewayService handles API gateway operations
type
GatewayService
struct
{
accountRepo
AccountRepository
groupRepo
GroupRepository
usageLogRepo
UsageLogRepository
userRepo
UserRepository
userSubRepo
UserSubscriptionRepository
cache
GatewayCache
cfg
*
config
.
Config
schedulerSnapshot
*
SchedulerSnapshotService
billingService
*
BillingService
rateLimitService
*
RateLimitService
billingCacheService
*
BillingCacheService
identityService
*
IdentityService
httpUpstream
HTTPUpstream
deferredService
*
DeferredService
concurrencyService
*
ConcurrencyService
accountRepo
AccountRepository
groupRepo
GroupRepository
usageLogRepo
UsageLogRepository
userRepo
UserRepository
userSubRepo
UserSubscriptionRepository
cache
GatewayCache
cfg
*
config
.
Config
schedulerSnapshot
*
SchedulerSnapshotService
billingService
*
BillingService
rateLimitService
*
RateLimitService
billingCacheService
*
BillingCacheService
identityService
*
IdentityService
httpUpstream
HTTPUpstream
deferredService
*
DeferredService
concurrencyService
*
ConcurrencyService
claudeTokenProvider
*
ClaudeTokenProvider
}
// NewGatewayService creates a new GatewayService
...
...
@@ -178,23 +179,25 @@ func NewGatewayService(
identityService
*
IdentityService
,
httpUpstream
HTTPUpstream
,
deferredService
*
DeferredService
,
claudeTokenProvider
*
ClaudeTokenProvider
,
)
*
GatewayService
{
return
&
GatewayService
{
accountRepo
:
accountRepo
,
groupRepo
:
groupRepo
,
usageLogRepo
:
usageLogRepo
,
userRepo
:
userRepo
,
userSubRepo
:
userSubRepo
,
cache
:
cache
,
cfg
:
cfg
,
schedulerSnapshot
:
schedulerSnapshot
,
concurrencyService
:
concurrencyService
,
billingService
:
billingService
,
rateLimitService
:
rateLimitService
,
billingCacheService
:
billingCacheService
,
identityService
:
identityService
,
httpUpstream
:
httpUpstream
,
deferredService
:
deferredService
,
accountRepo
:
accountRepo
,
groupRepo
:
groupRepo
,
usageLogRepo
:
usageLogRepo
,
userRepo
:
userRepo
,
userSubRepo
:
userSubRepo
,
cache
:
cache
,
cfg
:
cfg
,
schedulerSnapshot
:
schedulerSnapshot
,
concurrencyService
:
concurrencyService
,
billingService
:
billingService
,
rateLimitService
:
rateLimitService
,
billingCacheService
:
billingCacheService
,
identityService
:
identityService
,
httpUpstream
:
httpUpstream
,
deferredService
:
deferredService
,
claudeTokenProvider
:
claudeTokenProvider
,
}
}
...
...
@@ -1079,6 +1082,16 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
}
func
(
s
*
GatewayService
)
getOAuthToken
(
ctx
context
.
Context
,
account
*
Account
)
(
string
,
string
,
error
)
{
// 对于 Anthropic OAuth 账号,使用 ClaudeTokenProvider 获取缓存的 token
if
account
.
Platform
==
PlatformAnthropic
&&
account
.
Type
==
AccountTypeOAuth
&&
s
.
claudeTokenProvider
!=
nil
{
accessToken
,
err
:=
s
.
claudeTokenProvider
.
GetAccessToken
(
ctx
,
account
)
if
err
!=
nil
{
return
""
,
""
,
err
}
return
accessToken
,
"oauth"
,
nil
}
// 其他情况(Gemini 有自己的 TokenProvider,setup-token 类型等)直接从账号读取
accessToken
:=
account
.
GetCredential
(
"access_token"
)
if
accessToken
==
""
{
return
""
,
""
,
errors
.
New
(
"access_token not found in credentials"
)
...
...
backend/internal/service/gemini_token_provider.go
View file @
3ba4d535
...
...
@@ -154,7 +154,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
func
GeminiTokenCacheKey
(
account
*
Account
)
string
{
projectID
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"project_id"
))
if
projectID
!=
""
{
return
projectID
return
"gemini:"
+
projectID
}
return
"account:"
+
strconv
.
FormatInt
(
account
.
ID
,
10
)
return
"
gemini:
account:"
+
strconv
.
FormatInt
(
account
.
ID
,
10
)
}
backend/internal/service/openai_gateway_service.go
View file @
3ba4d535
...
...
@@ -80,19 +80,20 @@ type OpenAIForwardResult struct {
// OpenAIGatewayService handles OpenAI API gateway operations
type
OpenAIGatewayService
struct
{
accountRepo
AccountRepository
usageLogRepo
UsageLogRepository
userRepo
UserRepository
userSubRepo
UserSubscriptionRepository
cache
GatewayCache
cfg
*
config
.
Config
schedulerSnapshot
*
SchedulerSnapshotService
concurrencyService
*
ConcurrencyService
billingService
*
BillingService
rateLimitService
*
RateLimitService
billingCacheService
*
BillingCacheService
httpUpstream
HTTPUpstream
deferredService
*
DeferredService
accountRepo
AccountRepository
usageLogRepo
UsageLogRepository
userRepo
UserRepository
userSubRepo
UserSubscriptionRepository
cache
GatewayCache
cfg
*
config
.
Config
schedulerSnapshot
*
SchedulerSnapshotService
concurrencyService
*
ConcurrencyService
billingService
*
BillingService
rateLimitService
*
RateLimitService
billingCacheService
*
BillingCacheService
httpUpstream
HTTPUpstream
deferredService
*
DeferredService
openAITokenProvider
*
OpenAITokenProvider
}
// NewOpenAIGatewayService creates a new OpenAIGatewayService
...
...
@@ -110,21 +111,23 @@ func NewOpenAIGatewayService(
billingCacheService
*
BillingCacheService
,
httpUpstream
HTTPUpstream
,
deferredService
*
DeferredService
,
openAITokenProvider
*
OpenAITokenProvider
,
)
*
OpenAIGatewayService
{
return
&
OpenAIGatewayService
{
accountRepo
:
accountRepo
,
usageLogRepo
:
usageLogRepo
,
userRepo
:
userRepo
,
userSubRepo
:
userSubRepo
,
cache
:
cache
,
cfg
:
cfg
,
schedulerSnapshot
:
schedulerSnapshot
,
concurrencyService
:
concurrencyService
,
billingService
:
billingService
,
rateLimitService
:
rateLimitService
,
billingCacheService
:
billingCacheService
,
httpUpstream
:
httpUpstream
,
deferredService
:
deferredService
,
accountRepo
:
accountRepo
,
usageLogRepo
:
usageLogRepo
,
userRepo
:
userRepo
,
userSubRepo
:
userSubRepo
,
cache
:
cache
,
cfg
:
cfg
,
schedulerSnapshot
:
schedulerSnapshot
,
concurrencyService
:
concurrencyService
,
billingService
:
billingService
,
rateLimitService
:
rateLimitService
,
billingCacheService
:
billingCacheService
,
httpUpstream
:
httpUpstream
,
deferredService
:
deferredService
,
openAITokenProvider
:
openAITokenProvider
,
}
}
...
...
@@ -503,6 +506,15 @@ func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig
func
(
s
*
OpenAIGatewayService
)
GetAccessToken
(
ctx
context
.
Context
,
account
*
Account
)
(
string
,
string
,
error
)
{
switch
account
.
Type
{
case
AccountTypeOAuth
:
// 使用 TokenProvider 获取缓存的 token
if
s
.
openAITokenProvider
!=
nil
{
accessToken
,
err
:=
s
.
openAITokenProvider
.
GetAccessToken
(
ctx
,
account
)
if
err
!=
nil
{
return
""
,
""
,
err
}
return
accessToken
,
"oauth"
,
nil
}
// 降级:TokenProvider 未配置时直接从账号读取
accessToken
:=
account
.
GetOpenAIAccessToken
()
if
accessToken
==
""
{
return
""
,
""
,
errors
.
New
(
"access_token not found in credentials"
)
...
...
backend/internal/service/openai_token_provider.go
0 → 100644
View file @
3ba4d535
package
service
import
(
"context"
"errors"
"log/slog"
"strings"
"time"
)
const
(
openAITokenRefreshSkew
=
3
*
time
.
Minute
openAITokenCacheSkew
=
5
*
time
.
Minute
openAILockWaitTime
=
200
*
time
.
Millisecond
)
// OpenAITokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
type
OpenAITokenCache
=
GeminiTokenCache
// OpenAITokenProvider 管理 OpenAI OAuth 账户的 access_token
type
OpenAITokenProvider
struct
{
accountRepo
AccountRepository
tokenCache
OpenAITokenCache
openAIOAuthService
*
OpenAIOAuthService
}
func
NewOpenAITokenProvider
(
accountRepo
AccountRepository
,
tokenCache
OpenAITokenCache
,
openAIOAuthService
*
OpenAIOAuthService
,
)
*
OpenAITokenProvider
{
return
&
OpenAITokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
tokenCache
,
openAIOAuthService
:
openAIOAuthService
,
}
}
// GetAccessToken 获取有效的 access_token
func
(
p
*
OpenAITokenProvider
)
GetAccessToken
(
ctx
context
.
Context
,
account
*
Account
)
(
string
,
error
)
{
if
account
==
nil
{
return
""
,
errors
.
New
(
"account is nil"
)
}
if
account
.
Platform
!=
PlatformOpenAI
||
account
.
Type
!=
AccountTypeOAuth
{
return
""
,
errors
.
New
(
"not an openai oauth account"
)
}
cacheKey
:=
OpenAITokenCacheKey
(
account
)
// 1. 先尝试缓存
if
p
.
tokenCache
!=
nil
{
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
strings
.
TrimSpace
(
token
)
!=
""
{
slog
.
Debug
(
"openai_token_cache_hit"
,
"account_id"
,
account
.
ID
)
return
token
,
nil
}
else
if
err
!=
nil
{
slog
.
Warn
(
"openai_token_cache_get_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
}
slog
.
Debug
(
"openai_token_cache_miss"
,
"account_id"
,
account
.
ID
)
// 2. 如果即将过期则刷新
expiresAt
:=
account
.
GetCredentialAsTime
(
"expires_at"
)
needsRefresh
:=
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
openAITokenRefreshSkew
refreshFailed
:=
false
if
needsRefresh
&&
p
.
tokenCache
!=
nil
{
locked
,
lockErr
:=
p
.
tokenCache
.
AcquireRefreshLock
(
ctx
,
cacheKey
,
30
*
time
.
Second
)
if
lockErr
==
nil
&&
locked
{
defer
func
()
{
_
=
p
.
tokenCache
.
ReleaseRefreshLock
(
ctx
,
cacheKey
)
}()
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
strings
.
TrimSpace
(
token
)
!=
""
{
return
token
,
nil
}
// 从数据库获取最新账户信息
fresh
,
err
:=
p
.
accountRepo
.
GetByID
(
ctx
,
account
.
ID
)
if
err
==
nil
&&
fresh
!=
nil
{
account
=
fresh
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
if
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
openAITokenRefreshSkew
{
if
p
.
openAIOAuthService
==
nil
{
slog
.
Warn
(
"openai_oauth_service_not_configured"
,
"account_id"
,
account
.
ID
)
refreshFailed
=
true
// 无法刷新,标记失败
}
else
{
tokenInfo
,
err
:=
p
.
openAIOAuthService
.
RefreshAccountToken
(
ctx
,
account
)
if
err
!=
nil
{
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
slog
.
Warn
(
"openai_token_refresh_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
refreshFailed
=
true
// 刷新失败,标记以使用短 TTL
}
else
{
newCredentials
:=
p
.
openAIOAuthService
.
BuildAccountCredentials
(
tokenInfo
)
for
k
,
v
:=
range
account
.
Credentials
{
if
_
,
exists
:=
newCredentials
[
k
];
!
exists
{
newCredentials
[
k
]
=
v
}
}
account
.
Credentials
=
newCredentials
if
updateErr
:=
p
.
accountRepo
.
Update
(
ctx
,
account
);
updateErr
!=
nil
{
slog
.
Error
(
"openai_token_provider_update_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
updateErr
)
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
}
}
}
else
if
lockErr
!=
nil
{
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
slog
.
Warn
(
"openai_token_lock_failed_degraded_refresh"
,
"account_id"
,
account
.
ID
,
"error"
,
lockErr
)
// 检查 ctx 是否已取消
if
ctx
.
Err
()
!=
nil
{
return
""
,
ctx
.
Err
()
}
// 从数据库获取最新账户信息
if
p
.
accountRepo
!=
nil
{
fresh
,
err
:=
p
.
accountRepo
.
GetByID
(
ctx
,
account
.
ID
)
if
err
==
nil
&&
fresh
!=
nil
{
account
=
fresh
}
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
if
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
openAITokenRefreshSkew
{
if
p
.
openAIOAuthService
==
nil
{
slog
.
Warn
(
"openai_oauth_service_not_configured"
,
"account_id"
,
account
.
ID
)
refreshFailed
=
true
}
else
{
tokenInfo
,
err
:=
p
.
openAIOAuthService
.
RefreshAccountToken
(
ctx
,
account
)
if
err
!=
nil
{
slog
.
Warn
(
"openai_token_refresh_failed_degraded"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
refreshFailed
=
true
}
else
{
newCredentials
:=
p
.
openAIOAuthService
.
BuildAccountCredentials
(
tokenInfo
)
for
k
,
v
:=
range
account
.
Credentials
{
if
_
,
exists
:=
newCredentials
[
k
];
!
exists
{
newCredentials
[
k
]
=
v
}
}
account
.
Credentials
=
newCredentials
if
updateErr
:=
p
.
accountRepo
.
Update
(
ctx
,
account
);
updateErr
!=
nil
{
slog
.
Error
(
"openai_token_provider_update_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
updateErr
)
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
}
}
}
else
{
// 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存
time
.
Sleep
(
openAILockWaitTime
)
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
strings
.
TrimSpace
(
token
)
!=
""
{
slog
.
Debug
(
"openai_token_cache_hit_after_wait"
,
"account_id"
,
account
.
ID
)
return
token
,
nil
}
}
}
accessToken
:=
account
.
GetOpenAIAccessToken
()
if
strings
.
TrimSpace
(
accessToken
)
==
""
{
return
""
,
errors
.
New
(
"access_token not found in credentials"
)
}
// 3. 存入缓存
if
p
.
tokenCache
!=
nil
{
ttl
:=
30
*
time
.
Minute
if
refreshFailed
{
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
ttl
=
time
.
Minute
slog
.
Debug
(
"openai_token_cache_short_ttl"
,
"account_id"
,
account
.
ID
,
"reason"
,
"refresh_failed"
)
}
else
if
expiresAt
!=
nil
{
until
:=
time
.
Until
(
*
expiresAt
)
switch
{
case
until
>
openAITokenCacheSkew
:
ttl
=
until
-
openAITokenCacheSkew
case
until
>
0
:
ttl
=
until
default
:
ttl
=
time
.
Minute
}
}
if
err
:=
p
.
tokenCache
.
SetAccessToken
(
ctx
,
cacheKey
,
accessToken
,
ttl
);
err
!=
nil
{
slog
.
Warn
(
"openai_token_cache_set_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
}
return
accessToken
,
nil
}
backend/internal/service/openai_token_provider_test.go
0 → 100644
View file @
3ba4d535
//go:build unit
package
service
import
(
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// openAITokenCacheStub implements OpenAITokenCache for testing
type
openAITokenCacheStub
struct
{
mu
sync
.
Mutex
tokens
map
[
string
]
string
getErr
error
setErr
error
deleteErr
error
lockAcquired
bool
lockErr
error
releaseLockErr
error
getCalled
int32
setCalled
int32
lockCalled
int32
unlockCalled
int32
simulateLockRace
bool
}
func
newOpenAITokenCacheStub
()
*
openAITokenCacheStub
{
return
&
openAITokenCacheStub
{
tokens
:
make
(
map
[
string
]
string
),
lockAcquired
:
true
,
}
}
func
(
s
*
openAITokenCacheStub
)
GetAccessToken
(
ctx
context
.
Context
,
cacheKey
string
)
(
string
,
error
)
{
atomic
.
AddInt32
(
&
s
.
getCalled
,
1
)
if
s
.
getErr
!=
nil
{
return
""
,
s
.
getErr
}
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
return
s
.
tokens
[
cacheKey
],
nil
}
func
(
s
*
openAITokenCacheStub
)
SetAccessToken
(
ctx
context
.
Context
,
cacheKey
string
,
token
string
,
ttl
time
.
Duration
)
error
{
atomic
.
AddInt32
(
&
s
.
setCalled
,
1
)
if
s
.
setErr
!=
nil
{
return
s
.
setErr
}
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
s
.
tokens
[
cacheKey
]
=
token
return
nil
}
func
(
s
*
openAITokenCacheStub
)
DeleteAccessToken
(
ctx
context
.
Context
,
cacheKey
string
)
error
{
if
s
.
deleteErr
!=
nil
{
return
s
.
deleteErr
}
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
delete
(
s
.
tokens
,
cacheKey
)
return
nil
}
func
(
s
*
openAITokenCacheStub
)
AcquireRefreshLock
(
ctx
context
.
Context
,
cacheKey
string
,
ttl
time
.
Duration
)
(
bool
,
error
)
{
atomic
.
AddInt32
(
&
s
.
lockCalled
,
1
)
if
s
.
lockErr
!=
nil
{
return
false
,
s
.
lockErr
}
if
s
.
simulateLockRace
{
return
false
,
nil
}
return
s
.
lockAcquired
,
nil
}
func
(
s
*
openAITokenCacheStub
)
ReleaseRefreshLock
(
ctx
context
.
Context
,
cacheKey
string
)
error
{
atomic
.
AddInt32
(
&
s
.
unlockCalled
,
1
)
return
s
.
releaseLockErr
}
// openAIAccountRepoStub is a minimal stub implementing only the methods used by OpenAITokenProvider
type
openAIAccountRepoStub
struct
{
account
*
Account
getErr
error
updateErr
error
getCalled
int32
updateCalled
int32
}
func
(
r
*
openAIAccountRepoStub
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Account
,
error
)
{
atomic
.
AddInt32
(
&
r
.
getCalled
,
1
)
if
r
.
getErr
!=
nil
{
return
nil
,
r
.
getErr
}
return
r
.
account
,
nil
}
func
(
r
*
openAIAccountRepoStub
)
Update
(
ctx
context
.
Context
,
account
*
Account
)
error
{
atomic
.
AddInt32
(
&
r
.
updateCalled
,
1
)
if
r
.
updateErr
!=
nil
{
return
r
.
updateErr
}
r
.
account
=
account
return
nil
}
// openAIOAuthServiceStub implements OpenAIOAuthService methods for testing
type
openAIOAuthServiceStub
struct
{
tokenInfo
*
OpenAITokenInfo
refreshErr
error
refreshCalled
int32
}
func
(
s
*
openAIOAuthServiceStub
)
RefreshAccountToken
(
ctx
context
.
Context
,
account
*
Account
)
(
*
OpenAITokenInfo
,
error
)
{
atomic
.
AddInt32
(
&
s
.
refreshCalled
,
1
)
if
s
.
refreshErr
!=
nil
{
return
nil
,
s
.
refreshErr
}
return
s
.
tokenInfo
,
nil
}
func
(
s
*
openAIOAuthServiceStub
)
BuildAccountCredentials
(
info
*
OpenAITokenInfo
)
map
[
string
]
any
{
now
:=
time
.
Now
()
return
map
[
string
]
any
{
"access_token"
:
info
.
AccessToken
,
"refresh_token"
:
info
.
RefreshToken
,
"expires_at"
:
now
.
Add
(
time
.
Duration
(
info
.
ExpiresIn
)
*
time
.
Second
)
.
Format
(
time
.
RFC3339
),
}
}
func
TestOpenAITokenProvider_CacheHit
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
account
:=
&
Account
{
ID
:
100
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"db-token"
,
},
}
cacheKey
:=
OpenAITokenCacheKey
(
account
)
cache
.
tokens
[
cacheKey
]
=
"cached-token"
provider
:=
NewOpenAITokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"cached-token"
,
token
)
require
.
Equal
(
t
,
int32
(
1
),
atomic
.
LoadInt32
(
&
cache
.
getCalled
))
require
.
Equal
(
t
,
int32
(
0
),
atomic
.
LoadInt32
(
&
cache
.
setCalled
))
}
func
TestOpenAITokenProvider_CacheMiss_FromCredentials
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
// Token expires in far future, no refresh needed
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
101
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"credential-token"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewOpenAITokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"credential-token"
,
token
)
// Should have stored in cache
cacheKey
:=
OpenAITokenCacheKey
(
account
)
require
.
Equal
(
t
,
"credential-token"
,
cache
.
tokens
[
cacheKey
])
}
func
TestOpenAITokenProvider_TokenRefresh
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
accountRepo
:=
&
openAIAccountRepoStub
{}
oauthService
:=
&
openAIOAuthServiceStub
{
tokenInfo
:
&
OpenAITokenInfo
{
AccessToken
:
"refreshed-token"
,
RefreshToken
:
"new-refresh-token"
,
ExpiresIn
:
3600
,
},
}
// Token expires soon (within refresh skew)
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
102
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"old-token"
,
"refresh_token"
:
"old-refresh-token"
,
"expires_at"
:
expiresAt
,
},
}
accountRepo
.
account
=
account
// We need to directly test with the stub - create a custom provider
customProvider
:=
&
testOpenAITokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
cache
,
oauthService
:
oauthService
,
}
token
,
err
:=
customProvider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"refreshed-token"
,
token
)
require
.
Equal
(
t
,
int32
(
1
),
atomic
.
LoadInt32
(
&
oauthService
.
refreshCalled
))
}
// testOpenAITokenProvider is a test version that uses the stub OAuth service
type
testOpenAITokenProvider
struct
{
accountRepo
*
openAIAccountRepoStub
tokenCache
*
openAITokenCacheStub
oauthService
*
openAIOAuthServiceStub
}
func
(
p
*
testOpenAITokenProvider
)
GetAccessToken
(
ctx
context
.
Context
,
account
*
Account
)
(
string
,
error
)
{
if
account
==
nil
{
return
""
,
errors
.
New
(
"account is nil"
)
}
if
account
.
Platform
!=
PlatformOpenAI
||
account
.
Type
!=
AccountTypeOAuth
{
return
""
,
errors
.
New
(
"not an openai oauth account"
)
}
cacheKey
:=
OpenAITokenCacheKey
(
account
)
// 1. Check cache
if
p
.
tokenCache
!=
nil
{
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
token
!=
""
{
return
token
,
nil
}
}
// 2. Check if refresh needed
expiresAt
:=
account
.
GetCredentialAsTime
(
"expires_at"
)
needsRefresh
:=
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
openAITokenRefreshSkew
refreshFailed
:=
false
if
needsRefresh
&&
p
.
tokenCache
!=
nil
{
locked
,
err
:=
p
.
tokenCache
.
AcquireRefreshLock
(
ctx
,
cacheKey
,
30
*
time
.
Second
)
if
err
==
nil
&&
locked
{
defer
func
()
{
_
=
p
.
tokenCache
.
ReleaseRefreshLock
(
ctx
,
cacheKey
)
}()
// Check cache again after acquiring lock
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
token
!=
""
{
return
token
,
nil
}
// Get fresh account from DB
fresh
,
err
:=
p
.
accountRepo
.
GetByID
(
ctx
,
account
.
ID
)
if
err
==
nil
&&
fresh
!=
nil
{
account
=
fresh
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
if
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
openAITokenRefreshSkew
{
if
p
.
oauthService
==
nil
{
refreshFailed
=
true
// 无法刷新,标记失败
}
else
{
tokenInfo
,
err
:=
p
.
oauthService
.
RefreshAccountToken
(
ctx
,
account
)
if
err
!=
nil
{
refreshFailed
=
true
// 刷新失败,标记以使用短 TTL
}
else
{
newCredentials
:=
p
.
oauthService
.
BuildAccountCredentials
(
tokenInfo
)
for
k
,
v
:=
range
account
.
Credentials
{
if
_
,
exists
:=
newCredentials
[
k
];
!
exists
{
newCredentials
[
k
]
=
v
}
}
account
.
Credentials
=
newCredentials
_
=
p
.
accountRepo
.
Update
(
ctx
,
account
)
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
}
}
}
else
if
p
.
tokenCache
.
simulateLockRace
{
// Wait and retry cache
time
.
Sleep
(
10
*
time
.
Millisecond
)
// Short wait for test
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
token
!=
""
{
return
token
,
nil
}
}
}
accessToken
:=
account
.
GetOpenAIAccessToken
()
if
accessToken
==
""
{
return
""
,
errors
.
New
(
"access_token not found in credentials"
)
}
// 3. Store in cache
if
p
.
tokenCache
!=
nil
{
ttl
:=
30
*
time
.
Minute
if
refreshFailed
{
ttl
=
time
.
Minute
// 刷新失败时使用短 TTL
}
else
if
expiresAt
!=
nil
{
until
:=
time
.
Until
(
*
expiresAt
)
if
until
>
openAITokenCacheSkew
{
ttl
=
until
-
openAITokenCacheSkew
}
else
if
until
>
0
{
ttl
=
until
}
else
{
ttl
=
time
.
Minute
}
}
_
=
p
.
tokenCache
.
SetAccessToken
(
ctx
,
cacheKey
,
accessToken
,
ttl
)
}
return
accessToken
,
nil
}
func
TestOpenAITokenProvider_LockRaceCondition
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
cache
.
simulateLockRace
=
true
accountRepo
:=
&
openAIAccountRepoStub
{}
// Token expires soon
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
103
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"race-token"
,
"expires_at"
:
expiresAt
,
},
}
accountRepo
.
account
=
account
// Simulate another worker already refreshed and cached
cacheKey
:=
OpenAITokenCacheKey
(
account
)
go
func
()
{
time
.
Sleep
(
5
*
time
.
Millisecond
)
cache
.
mu
.
Lock
()
cache
.
tokens
[
cacheKey
]
=
"winner-token"
cache
.
mu
.
Unlock
()
}()
provider
:=
&
testOpenAITokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
cache
,
}
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
// Should get the token set by the "winner" or the original
require
.
NotEmpty
(
t
,
token
)
}
func
TestOpenAITokenProvider_NilAccount
(
t
*
testing
.
T
)
{
provider
:=
NewOpenAITokenProvider
(
nil
,
nil
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
nil
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"account is nil"
)
require
.
Empty
(
t
,
token
)
}
func
TestOpenAITokenProvider_WrongPlatform
(
t
*
testing
.
T
)
{
provider
:=
NewOpenAITokenProvider
(
nil
,
nil
,
nil
)
account
:=
&
Account
{
ID
:
104
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
,
}
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"not an openai oauth account"
)
require
.
Empty
(
t
,
token
)
}
func
TestOpenAITokenProvider_WrongAccountType
(
t
*
testing
.
T
)
{
provider
:=
NewOpenAITokenProvider
(
nil
,
nil
,
nil
)
account
:=
&
Account
{
ID
:
105
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
}
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"not an openai oauth account"
)
require
.
Empty
(
t
,
token
)
}
func
TestOpenAITokenProvider_NilCache
(
t
*
testing
.
T
)
{
// Token doesn't need refresh
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
106
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"nocache-token"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewOpenAITokenProvider
(
nil
,
nil
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"nocache-token"
,
token
)
}
func
TestOpenAITokenProvider_CacheGetError
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
cache
.
getErr
=
errors
.
New
(
"redis connection failed"
)
// Token doesn't need refresh
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
107
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"fallback-token"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewOpenAITokenProvider
(
nil
,
cache
,
nil
)
// Should gracefully degrade and return from credentials
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"fallback-token"
,
token
)
}
func
TestOpenAITokenProvider_CacheSetError
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
cache
.
setErr
=
errors
.
New
(
"redis write failed"
)
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
108
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"still-works-token"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewOpenAITokenProvider
(
nil
,
cache
,
nil
)
// Should still work even if cache set fails
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"still-works-token"
,
token
)
}
func
TestOpenAITokenProvider_MissingAccessToken
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
109
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"expires_at"
:
expiresAt
,
// missing access_token
},
}
provider
:=
NewOpenAITokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"access_token not found"
)
require
.
Empty
(
t
,
token
)
}
func
TestOpenAITokenProvider_RefreshError
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
accountRepo
:=
&
openAIAccountRepoStub
{}
oauthService
:=
&
openAIOAuthServiceStub
{
refreshErr
:
errors
.
New
(
"oauth refresh failed"
),
}
// Token expires soon
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
110
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"old-token"
,
"refresh_token"
:
"old-refresh-token"
,
"expires_at"
:
expiresAt
,
},
}
accountRepo
.
account
=
account
provider
:=
&
testOpenAITokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
cache
,
oauthService
:
oauthService
,
}
// Now with fallback behavior, should return existing token even if refresh fails
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"old-token"
,
token
)
// Fallback to existing token
}
func
TestOpenAITokenProvider_OAuthServiceNotConfigured
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
accountRepo
:=
&
openAIAccountRepoStub
{}
// Token expires soon
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
111
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"old-token"
,
"expires_at"
:
expiresAt
,
},
}
accountRepo
.
account
=
account
provider
:=
&
testOpenAITokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
cache
,
oauthService
:
nil
,
// not configured
}
// Now with fallback behavior, should return existing token even if oauth service not configured
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"old-token"
,
token
)
// Fallback to existing token
}
func
TestOpenAITokenProvider_TTLCalculation
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
expiresIn
time
.
Duration
}{
{
name
:
"far_future_expiry"
,
expiresIn
:
1
*
time
.
Hour
,
},
{
name
:
"medium_expiry"
,
expiresIn
:
10
*
time
.
Minute
,
},
{
name
:
"near_expiry"
,
expiresIn
:
6
*
time
.
Minute
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
expiresAt
:=
time
.
Now
()
.
Add
(
tt
.
expiresIn
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
200
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"test-token"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewOpenAITokenProvider
(
nil
,
cache
,
nil
)
_
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
// Verify token was cached
cacheKey
:=
OpenAITokenCacheKey
(
account
)
require
.
Equal
(
t
,
"test-token"
,
cache
.
tokens
[
cacheKey
])
})
}
}
func
TestOpenAITokenProvider_DoubleCheckAfterLock
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
accountRepo
:=
&
openAIAccountRepoStub
{}
oauthService
:=
&
openAIOAuthServiceStub
{
tokenInfo
:
&
OpenAITokenInfo
{
AccessToken
:
"refreshed-token"
,
RefreshToken
:
"new-refresh"
,
ExpiresIn
:
3600
,
},
}
// Token expires soon
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
112
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"old-token"
,
"expires_at"
:
expiresAt
,
},
}
accountRepo
.
account
=
account
cacheKey
:=
OpenAITokenCacheKey
(
account
)
// Simulate: first GetAccessToken returns empty, but after lock acquired, cache has token
originalGet
:=
int32
(
0
)
cache
.
tokens
[
cacheKey
]
=
""
// Empty initially
provider
:=
&
testOpenAITokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
cache
,
oauthService
:
oauthService
,
}
// In a goroutine, set the cached token after a small delay (simulating race)
go
func
()
{
time
.
Sleep
(
5
*
time
.
Millisecond
)
cache
.
mu
.
Lock
()
cache
.
tokens
[
cacheKey
]
=
"cached-by-other"
cache
.
mu
.
Unlock
()
}()
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
// Should get either the refreshed token or the cached one
require
.
NotEmpty
(
t
,
token
)
_
=
originalGet
// Suppress unused warning
}
// Tests for real provider - to increase coverage
func
TestOpenAITokenProvider_Real_LockFailedWait
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
cache
.
lockAcquired
=
false
// Lock acquisition fails
// Token expires soon (within refresh skew) to trigger lock attempt
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
200
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"fallback-token"
,
"expires_at"
:
expiresAt
,
},
}
// Set token in cache after lock wait period (simulate other worker refreshing)
cacheKey
:=
OpenAITokenCacheKey
(
account
)
go
func
()
{
time
.
Sleep
(
100
*
time
.
Millisecond
)
cache
.
mu
.
Lock
()
cache
.
tokens
[
cacheKey
]
=
"refreshed-by-other"
cache
.
mu
.
Unlock
()
}()
provider
:=
NewOpenAITokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
// Should get either the fallback token or the refreshed one
require
.
NotEmpty
(
t
,
token
)
}
func
TestOpenAITokenProvider_Real_CacheHitAfterWait
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
cache
.
lockAcquired
=
false
// Lock acquisition fails
// Token expires soon
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
201
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"original-token"
,
"expires_at"
:
expiresAt
,
},
}
cacheKey
:=
OpenAITokenCacheKey
(
account
)
// Set token in cache immediately after wait starts
go
func
()
{
time
.
Sleep
(
50
*
time
.
Millisecond
)
cache
.
mu
.
Lock
()
cache
.
tokens
[
cacheKey
]
=
"winner-token"
cache
.
mu
.
Unlock
()
}()
provider
:=
NewOpenAITokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
NotEmpty
(
t
,
token
)
}
func
TestOpenAITokenProvider_Real_ExpiredWithoutRefreshToken
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
cache
.
lockAcquired
=
false
// Prevent entering refresh logic
// Token with nil expires_at (no expiry set) - should use credentials
account
:=
&
Account
{
ID
:
202
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"no-expiry-token"
,
},
}
provider
:=
NewOpenAITokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
// Without OAuth service, refresh will fail but token should be returned from credentials
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"no-expiry-token"
,
token
)
}
func
TestOpenAITokenProvider_Real_WhitespaceToken
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
cacheKey
:=
"openai:account:203"
cache
.
tokens
[
cacheKey
]
=
" "
// Whitespace only - should be treated as empty
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
203
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"real-token"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewOpenAITokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"real-token"
,
token
)
// Should fall back to credentials
}
func
TestOpenAITokenProvider_Real_LockError
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
cache
.
lockErr
=
errors
.
New
(
"redis lock failed"
)
// Token expires soon (within refresh skew)
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
204
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"fallback-on-lock-error"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewOpenAITokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"fallback-on-lock-error"
,
token
)
}
func
TestOpenAITokenProvider_Real_WhitespaceCredentialToken
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
205
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
" "
,
// Whitespace only
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewOpenAITokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"access_token not found"
)
require
.
Empty
(
t
,
token
)
}
func
TestOpenAITokenProvider_Real_NilCredentials
(
t
*
testing
.
T
)
{
cache
:=
newOpenAITokenCacheStub
()
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
206
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"expires_at"
:
expiresAt
,
// No access_token
},
}
provider
:=
NewOpenAITokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"access_token not found"
)
require
.
Empty
(
t
,
token
)
}
backend/internal/service/ratelimit_service.go
View file @
3ba4d535
...
...
@@ -85,13 +85,24 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
switch
statusCode
{
case
401
:
if
account
.
Type
==
AccountTypeOAuth
&&
(
account
.
Platform
==
PlatformAntigravity
||
account
.
Platform
==
PlatformGemini
)
{
// 对所有 OAuth 账号在 401 错误时调用缓存失效并强制下次刷新
if
account
.
Type
==
AccountTypeOAuth
{
// 1. 失效缓存
if
s
.
tokenCacheInvalidator
!=
nil
{
if
err
:=
s
.
tokenCacheInvalidator
.
InvalidateToken
(
ctx
,
account
);
err
!=
nil
{
slog
.
Warn
(
"oauth_401_invalidate_cache_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
}
// 2. 设置 expires_at 为当前时间,强制下次请求刷新 token
if
account
.
Credentials
==
nil
{
account
.
Credentials
=
make
(
map
[
string
]
any
)
}
account
.
Credentials
[
"expires_at"
]
=
time
.
Now
()
.
Format
(
time
.
RFC3339
)
if
err
:=
s
.
accountRepo
.
Update
(
ctx
,
account
);
err
!=
nil
{
slog
.
Warn
(
"oauth_401_force_refresh_update_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
else
{
slog
.
Info
(
"oauth_401_force_refresh_set"
,
"account_id"
,
account
.
ID
,
"platform"
,
account
.
Platform
)
}
}
msg
:=
"Authentication failed (401): invalid or expired credentials"
if
upstreamMsg
!=
""
{
...
...
backend/internal/service/token_cache_invalidator.go
View file @
3ba4d535
...
...
@@ -7,29 +7,35 @@ type TokenCacheInvalidator interface {
}
type
CompositeTokenCacheInvalidator
struct
{
geminiC
ache
GeminiTokenCache
c
ache
GeminiTokenCache
// 统一使用一个缓存接口,通过缓存键前缀区分平台
}
func
NewCompositeTokenCacheInvalidator
(
geminiC
ache
GeminiTokenCache
)
*
CompositeTokenCacheInvalidator
{
func
NewCompositeTokenCacheInvalidator
(
c
ache
GeminiTokenCache
)
*
CompositeTokenCacheInvalidator
{
return
&
CompositeTokenCacheInvalidator
{
geminiCache
:
geminiC
ache
,
cache
:
c
ache
,
}
}
func
(
c
*
CompositeTokenCacheInvalidator
)
InvalidateToken
(
ctx
context
.
Context
,
account
*
Account
)
error
{
if
c
==
nil
||
c
.
geminiC
ache
==
nil
||
account
==
nil
{
if
c
==
nil
||
c
.
c
ache
==
nil
||
account
==
nil
{
return
nil
}
if
account
.
Type
!=
AccountTypeOAuth
{
return
nil
}
var
cacheKey
string
switch
account
.
Platform
{
case
PlatformGemini
:
return
c
.
geminiCache
.
DeleteAccessToken
(
ctx
,
GeminiTokenCacheKey
(
account
)
)
cacheKey
=
GeminiTokenCacheKey
(
account
)
case
PlatformAntigravity
:
return
c
.
geminiCache
.
DeleteAccessToken
(
ctx
,
AntigravityTokenCacheKey
(
account
))
cacheKey
=
AntigravityTokenCacheKey
(
account
)
case
PlatformOpenAI
:
cacheKey
=
OpenAITokenCacheKey
(
account
)
case
PlatformAnthropic
:
cacheKey
=
ClaudeTokenCacheKey
(
account
)
default
:
return
nil
}
return
c
.
cache
.
DeleteAccessToken
(
ctx
,
cacheKey
)
}
backend/internal/service/token_cache_invalidator_test.go
View file @
3ba4d535
...
...
@@ -4,6 +4,7 @@ package service
import
(
"context"
"errors"
"testing"
"time"
...
...
@@ -50,7 +51,7 @@ func TestCompositeTokenCacheInvalidator_Gemini(t *testing.T) {
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
[]
string
{
"project-x"
},
cache
.
deletedKeys
)
require
.
Equal
(
t
,
[]
string
{
"
gemini:
project-x"
},
cache
.
deletedKeys
)
}
func
TestCompositeTokenCacheInvalidator_Antigravity
(
t
*
testing
.
T
)
{
...
...
@@ -70,13 +71,99 @@ func TestCompositeTokenCacheInvalidator_Antigravity(t *testing.T) {
require
.
Equal
(
t
,
[]
string
{
"ag:ag-project"
},
cache
.
deletedKeys
)
}
func
TestCompositeTokenCacheInvalidator_OpenAI
(
t
*
testing
.
T
)
{
cache
:=
&
geminiTokenCacheStub
{}
invalidator
:=
NewCompositeTokenCacheInvalidator
(
cache
)
account
:=
&
Account
{
ID
:
500
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"openai-token"
,
},
}
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
[]
string
{
"openai:account:500"
},
cache
.
deletedKeys
)
}
func
TestCompositeTokenCacheInvalidator_Claude
(
t
*
testing
.
T
)
{
cache
:=
&
geminiTokenCacheStub
{}
invalidator
:=
NewCompositeTokenCacheInvalidator
(
cache
)
account
:=
&
Account
{
ID
:
600
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"claude-token"
,
},
}
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
[]
string
{
"claude:account:600"
},
cache
.
deletedKeys
)
}
func
TestCompositeTokenCacheInvalidator_SkipNonOAuth
(
t
*
testing
.
T
)
{
cache
:=
&
geminiTokenCacheStub
{}
invalidator
:=
NewCompositeTokenCacheInvalidator
(
cache
)
tests
:=
[]
struct
{
name
string
account
*
Account
}{
{
name
:
"gemini_api_key"
,
account
:
&
Account
{
ID
:
1
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeAPIKey
,
},
},
{
name
:
"openai_api_key"
,
account
:
&
Account
{
ID
:
2
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
},
},
{
name
:
"claude_api_key"
,
account
:
&
Account
{
ID
:
3
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
},
},
{
name
:
"claude_setup_token"
,
account
:
&
Account
{
ID
:
4
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeSetupToken
,
},
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
cache
.
deletedKeys
=
nil
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
tt
.
account
)
require
.
NoError
(
t
,
err
)
require
.
Empty
(
t
,
cache
.
deletedKeys
)
})
}
}
func
TestCompositeTokenCacheInvalidator_SkipUnsupportedPlatform
(
t
*
testing
.
T
)
{
cache
:=
&
geminiTokenCacheStub
{}
invalidator
:=
NewCompositeTokenCacheInvalidator
(
cache
)
account
:=
&
Account
{
ID
:
1
,
Platform
:
P
latform
Gemini
,
Type
:
AccountType
APIKey
,
ID
:
1
00
,
Platform
:
"unknown-p
latform
"
,
Type
:
AccountType
OAuth
,
}
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
account
)
...
...
@@ -95,3 +182,87 @@ func TestCompositeTokenCacheInvalidator_NilCache(t *testing.T) {
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
}
func
TestCompositeTokenCacheInvalidator_NilAccount
(
t
*
testing
.
T
)
{
cache
:=
&
geminiTokenCacheStub
{}
invalidator
:=
NewCompositeTokenCacheInvalidator
(
cache
)
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
nil
)
require
.
NoError
(
t
,
err
)
require
.
Empty
(
t
,
cache
.
deletedKeys
)
}
func
TestCompositeTokenCacheInvalidator_NilInvalidator
(
t
*
testing
.
T
)
{
var
invalidator
*
CompositeTokenCacheInvalidator
account
:=
&
Account
{
ID
:
5
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
,
}
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
}
func
TestCompositeTokenCacheInvalidator_DeleteError
(
t
*
testing
.
T
)
{
expectedErr
:=
errors
.
New
(
"redis connection failed"
)
cache
:=
&
geminiTokenCacheStub
{
deleteErr
:
expectedErr
}
invalidator
:=
NewCompositeTokenCacheInvalidator
(
cache
)
tests
:=
[]
struct
{
name
string
account
*
Account
}{
{
name
:
"openai_delete_error"
,
account
:
&
Account
{
ID
:
700
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
},
},
{
name
:
"claude_delete_error"
,
account
:
&
Account
{
ID
:
800
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
},
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
tt
.
account
)
require
.
Error
(
t
,
err
)
require
.
Equal
(
t
,
expectedErr
,
err
)
})
}
}
func
TestCompositeTokenCacheInvalidator_AllPlatformsIntegration
(
t
*
testing
.
T
)
{
// 测试所有平台的缓存键生成和删除
cache
:=
&
geminiTokenCacheStub
{}
invalidator
:=
NewCompositeTokenCacheInvalidator
(
cache
)
accounts
:=
[]
*
Account
{
{
ID
:
1
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"project_id"
:
"gemini-proj"
}},
{
ID
:
2
,
Platform
:
PlatformAntigravity
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"project_id"
:
"ag-proj"
}},
{
ID
:
3
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
},
{
ID
:
4
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
},
}
expectedKeys
:=
[]
string
{
"gemini:gemini-proj"
,
"ag:ag-proj"
,
"openai:account:3"
,
"claude:account:4"
,
}
for
_
,
acc
:=
range
accounts
{
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
acc
)
require
.
NoError
(
t
,
err
)
}
require
.
Equal
(
t
,
expectedKeys
,
cache
.
deletedKeys
)
}
backend/internal/service/token_cache_key.go
0 → 100644
View file @
3ba4d535
package
service
import
"strconv"
// OpenAITokenCacheKey 生成 OpenAI OAuth 账号的缓存键
// 格式: "openai:account:{account_id}"
func
OpenAITokenCacheKey
(
account
*
Account
)
string
{
return
"openai:account:"
+
strconv
.
FormatInt
(
account
.
ID
,
10
)
}
// ClaudeTokenCacheKey 生成 Claude (Anthropic) OAuth 账号的缓存键
// 格式: "claude:account:{account_id}"
func
ClaudeTokenCacheKey
(
account
*
Account
)
string
{
return
"claude:account:"
+
strconv
.
FormatInt
(
account
.
ID
,
10
)
}
backend/internal/service/token_cache_key_test.go
View file @
3ba4d535
...
...
@@ -22,7 +22,7 @@ func TestGeminiTokenCacheKey(t *testing.T) {
"project_id"
:
"my-project-123"
,
},
},
expected
:
"my-project-123"
,
expected
:
"
gemini:
my-project-123"
,
},
{
name
:
"project_id_with_whitespace"
,
...
...
@@ -32,7 +32,7 @@ func TestGeminiTokenCacheKey(t *testing.T) {
"project_id"
:
" project-with-spaces "
,
},
},
expected
:
"project-with-spaces"
,
expected
:
"
gemini:
project-with-spaces"
,
},
{
name
:
"empty_project_id_fallback_to_account_id"
,
...
...
@@ -42,7 +42,7 @@ func TestGeminiTokenCacheKey(t *testing.T) {
"project_id"
:
""
,
},
},
expected
:
"account:102"
,
expected
:
"
gemini:
account:102"
,
},
{
name
:
"whitespace_only_project_id_fallback_to_account_id"
,
...
...
@@ -52,7 +52,7 @@ func TestGeminiTokenCacheKey(t *testing.T) {
"project_id"
:
" "
,
},
},
expected
:
"account:103"
,
expected
:
"
gemini:
account:103"
,
},
{
name
:
"no_project_id_key_fallback_to_account_id"
,
...
...
@@ -60,7 +60,7 @@ func TestGeminiTokenCacheKey(t *testing.T) {
ID
:
104
,
Credentials
:
map
[
string
]
any
{},
},
expected
:
"account:104"
,
expected
:
"
gemini:
account:104"
,
},
{
name
:
"nil_credentials_fallback_to_account_id"
,
...
...
@@ -68,7 +68,7 @@ func TestGeminiTokenCacheKey(t *testing.T) {
ID
:
105
,
Credentials
:
nil
,
},
expected
:
"account:105"
,
expected
:
"
gemini:
account:105"
,
},
}
...
...
@@ -151,3 +151,109 @@ func TestAntigravityTokenCacheKey(t *testing.T) {
})
}
}
func
TestOpenAITokenCacheKey
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
account
*
Account
expected
string
}{
{
name
:
"basic_account"
,
account
:
&
Account
{
ID
:
300
,
},
expected
:
"openai:account:300"
,
},
{
name
:
"account_with_credentials"
,
account
:
&
Account
{
ID
:
301
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"test-token"
,
},
},
expected
:
"openai:account:301"
,
},
{
name
:
"account_id_zero"
,
account
:
&
Account
{
ID
:
0
,
},
expected
:
"openai:account:0"
,
},
{
name
:
"large_account_id"
,
account
:
&
Account
{
ID
:
9999999999
,
},
expected
:
"openai:account:9999999999"
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
result
:=
OpenAITokenCacheKey
(
tt
.
account
)
require
.
Equal
(
t
,
tt
.
expected
,
result
)
})
}
}
func
TestClaudeTokenCacheKey
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
account
*
Account
expected
string
}{
{
name
:
"basic_account"
,
account
:
&
Account
{
ID
:
400
,
},
expected
:
"claude:account:400"
,
},
{
name
:
"account_with_credentials"
,
account
:
&
Account
{
ID
:
401
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"claude-token"
,
},
},
expected
:
"claude:account:401"
,
},
{
name
:
"account_id_zero"
,
account
:
&
Account
{
ID
:
0
,
},
expected
:
"claude:account:0"
,
},
{
name
:
"large_account_id"
,
account
:
&
Account
{
ID
:
9999999999
,
},
expected
:
"claude:account:9999999999"
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
result
:=
ClaudeTokenCacheKey
(
tt
.
account
)
require
.
Equal
(
t
,
tt
.
expected
,
result
)
})
}
}
func
TestCacheKeyUniqueness
(
t
*
testing
.
T
)
{
// 确保不同平台的缓存键不会冲突
account
:=
&
Account
{
ID
:
123
}
openaiKey
:=
OpenAITokenCacheKey
(
account
)
claudeKey
:=
ClaudeTokenCacheKey
(
account
)
require
.
NotEqual
(
t
,
openaiKey
,
claudeKey
,
"OpenAI and Claude cache keys should be different"
)
require
.
Contains
(
t
,
openaiKey
,
"openai:"
)
require
.
Contains
(
t
,
claudeKey
,
"claude:"
)
}
backend/internal/service/token_refresh_service.go
View file @
3ba4d535
...
...
@@ -172,8 +172,8 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
if
err
:=
s
.
accountRepo
.
Update
(
ctx
,
account
);
err
!=
nil
{
return
fmt
.
Errorf
(
"failed to save credentials: %w"
,
err
)
}
if
s
.
cacheInvalidator
!=
nil
&&
account
.
Type
==
AccountTypeOAuth
&&
(
account
.
Pl
at
f
or
m
==
PlatformGemini
||
account
.
Platform
==
PlatformAntigravity
)
{
// 对所有 OAuth 账号调用缓存失效(InvalidateToken 内部根据平台判断是否需要处理)
if
s
.
cacheInvalid
ator
!=
nil
&&
account
.
Type
==
AccountTypeOAuth
{
if
err
:=
s
.
cacheInvalidator
.
InvalidateToken
(
ctx
,
account
);
err
!=
nil
{
log
.
Printf
(
"[TokenRefresh] Failed to invalidate token cache for account %d: %v"
,
account
.
ID
,
err
)
}
else
{
...
...
backend/internal/service/token_refresh_service_test.go
View file @
3ba4d535
...
...
@@ -197,7 +197,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount(t *testing.T) {
require
.
Equal
(
t
,
0
,
invalidator
.
calls
)
// 非 OAuth 不触发缓存失效
}
// TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth 测试
其他平台的
OAuth
账号不
触发缓存失效
// TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth 测试
所有
OAuth
平台都
触发缓存失效
func
TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth
(
t
*
testing
.
T
)
{
repo
:=
&
tokenRefreshAccountRepo
{}
invalidator
:=
&
tokenCacheInvalidatorStub
{}
...
...
@@ -210,7 +210,7 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
service
:=
NewTokenRefreshService
(
repo
,
nil
,
nil
,
nil
,
nil
,
invalidator
,
cfg
)
account
:=
&
Account
{
ID
:
10
,
Platform
:
PlatformOpenAI
,
//
其他平台
Platform
:
PlatformOpenAI
,
//
OpenAI OAuth 账户
Type
:
AccountTypeOAuth
,
}
refresher
:=
&
tokenRefresherStub
{
...
...
@@ -222,7 +222,7 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
err
:=
service
.
refreshWithRetry
(
context
.
Background
(),
account
,
refresher
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
repo
.
updateCalls
)
require
.
Equal
(
t
,
0
,
invalidator
.
calls
)
//
其他平台不
触发缓存失效
require
.
Equal
(
t
,
1
,
invalidator
.
calls
)
//
所有 OAuth 账户刷新后
触发缓存失效
}
// TestTokenRefreshService_RefreshWithRetry_UpdateFailed 测试更新失败的情况
...
...
backend/internal/service/wire.go
View file @
3ba4d535
...
...
@@ -214,10 +214,13 @@ var ProviderSet = wire.NewSet(
NewGeminiOAuthService
,
NewGeminiQuotaService
,
NewCompositeTokenCacheInvalidator
,
wire
.
Bind
(
new
(
TokenCacheInvalidator
),
new
(
*
CompositeTokenCacheInvalidator
)),
NewAntigravityOAuthService
,
NewGeminiTokenProvider
,
NewGeminiMessagesCompatService
,
NewAntigravityTokenProvider
,
NewOpenAITokenProvider
,
NewClaudeTokenProvider
,
NewAntigravityGatewayService
,
ProvideRateLimitService
,
NewAccountUsageService
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment