"git@web.lueluesay.top:chenxi/sub2api.git" did not exist on "39ca192c41ea8e040aee99b015b7761bec6964d1"
Unverified Commit b36f3db9 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #300 from mt21625457/main

feat(网关): 引入 OpenAI/Claude OAuth token 缓存
parents 5f890e85 f862ddc9
...@@ -129,3 +129,4 @@ deploy/docker-compose.override.yml ...@@ -129,3 +129,4 @@ deploy/docker-compose.override.yml
.gocache/ .gocache/
vite.config.js vite.config.js
docs/* docs/*
.serena/
\ No newline at end of file
...@@ -100,8 +100,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -100,8 +100,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
tempUnschedCache := repository.NewTempUnschedCache(redisClient) tempUnschedCache := repository.NewTempUnschedCache(redisClient)
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient) timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
geminiTokenCache := repository.NewGeminiTokenCache(redisClient) geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
tokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache) compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, tokenCacheInvalidator) rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator)
claudeUsageFetcher := repository.NewClaudeUsageFetcher() claudeUsageFetcher := repository.NewClaudeUsageFetcher()
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository) antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
usageCache := service.NewUsageCache() usageCache := service.NewUsageCache()
...@@ -136,8 +136,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -136,8 +136,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
identityCache := repository.NewIdentityCache(redisClient) identityCache := repository.NewIdentityCache(redisClient)
identityService := service.NewIdentityService(identityCache) identityService := service.NewIdentityService(identityCache)
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService) claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService) gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider)
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService) opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService)
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService) settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService)
...@@ -168,7 +170,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -168,7 +170,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig) opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig) opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig) opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, tokenCacheInvalidator, configConfig) tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository) accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
application := &Application{ application := &Application{
......
...@@ -11,8 +11,8 @@ import ( ...@@ -11,8 +11,8 @@ import (
) )
const ( const (
geminiTokenKeyPrefix = "gemini:token:" oauthTokenKeyPrefix = "oauth:token:"
geminiRefreshLockKeyPrefix = "gemini:refresh_lock:" oauthRefreshLockKeyPrefix = "oauth:refresh_lock:"
) )
type geminiTokenCache struct { type geminiTokenCache struct {
...@@ -24,26 +24,26 @@ func NewGeminiTokenCache(rdb *redis.Client) service.GeminiTokenCache { ...@@ -24,26 +24,26 @@ func NewGeminiTokenCache(rdb *redis.Client) service.GeminiTokenCache {
} }
func (c *geminiTokenCache) GetAccessToken(ctx context.Context, cacheKey string) (string, error) { func (c *geminiTokenCache) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey) key := fmt.Sprintf("%s%s", oauthTokenKeyPrefix, cacheKey)
return c.rdb.Get(ctx, key).Result() return c.rdb.Get(ctx, key).Result()
} }
func (c *geminiTokenCache) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error { func (c *geminiTokenCache) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey) key := fmt.Sprintf("%s%s", oauthTokenKeyPrefix, cacheKey)
return c.rdb.Set(ctx, key, token, ttl).Err() return c.rdb.Set(ctx, key, token, ttl).Err()
} }
func (c *geminiTokenCache) DeleteAccessToken(ctx context.Context, cacheKey string) error { func (c *geminiTokenCache) DeleteAccessToken(ctx context.Context, cacheKey string) error {
key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey) key := fmt.Sprintf("%s%s", oauthTokenKeyPrefix, cacheKey)
return c.rdb.Del(ctx, key).Err() return c.rdb.Del(ctx, key).Err()
} }
func (c *geminiTokenCache) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) { func (c *geminiTokenCache) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey) key := fmt.Sprintf("%s%s", oauthRefreshLockKeyPrefix, cacheKey)
return c.rdb.SetNX(ctx, key, 1, ttl).Result() return c.rdb.SetNX(ctx, key, 1, ttl).Result()
} }
func (c *geminiTokenCache) ReleaseRefreshLock(ctx context.Context, cacheKey string) error { func (c *geminiTokenCache) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey) key := fmt.Sprintf("%s%s", oauthRefreshLockKeyPrefix, cacheKey)
return c.rdb.Del(ctx, key).Err() return c.rdb.Del(ctx, key).Err()
} }
package service
import (
"context"
"errors"
"log/slog"
"strconv"
"strings"
"time"
)
const (
claudeTokenRefreshSkew = 3 * time.Minute
claudeTokenCacheSkew = 5 * time.Minute
claudeLockWaitTime = 200 * time.Millisecond
)
// ClaudeTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
type ClaudeTokenCache = GeminiTokenCache
// ClaudeTokenProvider 管理 Claude (Anthropic) OAuth 账户的 access_token
type ClaudeTokenProvider struct {
accountRepo AccountRepository
tokenCache ClaudeTokenCache
oauthService *OAuthService
}
func NewClaudeTokenProvider(
accountRepo AccountRepository,
tokenCache ClaudeTokenCache,
oauthService *OAuthService,
) *ClaudeTokenProvider {
return &ClaudeTokenProvider{
accountRepo: accountRepo,
tokenCache: tokenCache,
oauthService: oauthService,
}
}
// GetAccessToken 获取有效的 access_token
func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth {
return "", errors.New("not an anthropic oauth account")
}
cacheKey := ClaudeTokenCacheKey(account)
// 1. 先尝试缓存
if p.tokenCache != nil {
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
slog.Debug("claude_token_cache_hit", "account_id", account.ID)
return token, nil
} else if err != nil {
slog.Warn("claude_token_cache_get_failed", "account_id", account.ID, "error", err)
}
}
slog.Debug("claude_token_cache_miss", "account_id", account.ID)
// 2. 如果即将过期则刷新
expiresAt := account.GetCredentialAsTime("expires_at")
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew
refreshFailed := false
if needsRefresh && p.tokenCache != nil {
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if lockErr == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
// 从数据库获取最新账户信息
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
if err == nil && fresh != nil {
account = fresh
}
expiresAt = account.GetCredentialAsTime("expires_at")
if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew {
if p.oauthService == nil {
slog.Warn("claude_oauth_service_not_configured", "account_id", account.ID)
refreshFailed = true // 无法刷新,标记失败
} else {
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
if err != nil {
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
slog.Warn("claude_token_refresh_failed", "account_id", account.ID, "error", err)
refreshFailed = true // 刷新失败,标记以使用短 TTL
} else {
// 构建新 credentials,保留原有字段
newCredentials := make(map[string]any)
for k, v := range account.Credentials {
newCredentials[k] = v
}
newCredentials["access_token"] = tokenInfo.AccessToken
newCredentials["token_type"] = tokenInfo.TokenType
newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
if tokenInfo.RefreshToken != "" {
newCredentials["refresh_token"] = tokenInfo.RefreshToken
}
if tokenInfo.Scope != "" {
newCredentials["scope"] = tokenInfo.Scope
}
account.Credentials = newCredentials
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
slog.Error("claude_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
}
expiresAt = account.GetCredentialAsTime("expires_at")
}
}
}
} else if lockErr != nil {
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
slog.Warn("claude_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr)
// 检查 ctx 是否已取消
if ctx.Err() != nil {
return "", ctx.Err()
}
// 从数据库获取最新账户信息
if p.accountRepo != nil {
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
if err == nil && fresh != nil {
account = fresh
}
}
expiresAt = account.GetCredentialAsTime("expires_at")
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew {
if p.oauthService == nil {
slog.Warn("claude_oauth_service_not_configured", "account_id", account.ID)
refreshFailed = true
} else {
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
if err != nil {
slog.Warn("claude_token_refresh_failed_degraded", "account_id", account.ID, "error", err)
refreshFailed = true
} else {
// 构建新 credentials,保留原有字段
newCredentials := make(map[string]any)
for k, v := range account.Credentials {
newCredentials[k] = v
}
newCredentials["access_token"] = tokenInfo.AccessToken
newCredentials["token_type"] = tokenInfo.TokenType
newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
if tokenInfo.RefreshToken != "" {
newCredentials["refresh_token"] = tokenInfo.RefreshToken
}
if tokenInfo.Scope != "" {
newCredentials["scope"] = tokenInfo.Scope
}
account.Credentials = newCredentials
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
slog.Error("claude_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
}
expiresAt = account.GetCredentialAsTime("expires_at")
}
}
}
} else {
// 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存
time.Sleep(claudeLockWaitTime)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
slog.Debug("claude_token_cache_hit_after_wait", "account_id", account.ID)
return token, nil
}
}
}
accessToken := account.GetCredential("access_token")
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found in credentials")
}
// 3. 存入缓存
if p.tokenCache != nil {
ttl := 30 * time.Minute
if refreshFailed {
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
ttl = time.Minute
slog.Debug("claude_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
} else if expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > claudeTokenCacheSkew:
ttl = until - claudeTokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
}
}
if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
slog.Warn("claude_token_cache_set_failed", "account_id", account.ID, "error", err)
}
}
return accessToken, nil
}
This diff is collapsed.
...@@ -159,6 +159,7 @@ type GatewayService struct { ...@@ -159,6 +159,7 @@ type GatewayService struct {
httpUpstream HTTPUpstream httpUpstream HTTPUpstream
deferredService *DeferredService deferredService *DeferredService
concurrencyService *ConcurrencyService concurrencyService *ConcurrencyService
claudeTokenProvider *ClaudeTokenProvider
} }
// NewGatewayService creates a new GatewayService // NewGatewayService creates a new GatewayService
...@@ -178,6 +179,7 @@ func NewGatewayService( ...@@ -178,6 +179,7 @@ func NewGatewayService(
identityService *IdentityService, identityService *IdentityService,
httpUpstream HTTPUpstream, httpUpstream HTTPUpstream,
deferredService *DeferredService, deferredService *DeferredService,
claudeTokenProvider *ClaudeTokenProvider,
) *GatewayService { ) *GatewayService {
return &GatewayService{ return &GatewayService{
accountRepo: accountRepo, accountRepo: accountRepo,
...@@ -195,6 +197,7 @@ func NewGatewayService( ...@@ -195,6 +197,7 @@ func NewGatewayService(
identityService: identityService, identityService: identityService,
httpUpstream: httpUpstream, httpUpstream: httpUpstream,
deferredService: deferredService, deferredService: deferredService,
claudeTokenProvider: claudeTokenProvider,
} }
} }
...@@ -1079,6 +1082,16 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) ( ...@@ -1079,6 +1082,16 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
} }
func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (string, string, error) { func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (string, string, error) {
// 对于 Anthropic OAuth 账号,使用 ClaudeTokenProvider 获取缓存的 token
if account.Platform == PlatformAnthropic && account.Type == AccountTypeOAuth && s.claudeTokenProvider != nil {
accessToken, err := s.claudeTokenProvider.GetAccessToken(ctx, account)
if err != nil {
return "", "", err
}
return accessToken, "oauth", nil
}
// 其他情况(Gemini 有自己的 TokenProvider,setup-token 类型等)直接从账号读取
accessToken := account.GetCredential("access_token") accessToken := account.GetCredential("access_token")
if accessToken == "" { if accessToken == "" {
return "", "", errors.New("access_token not found in credentials") return "", "", errors.New("access_token not found in credentials")
......
...@@ -154,7 +154,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou ...@@ -154,7 +154,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
func GeminiTokenCacheKey(account *Account) string { func GeminiTokenCacheKey(account *Account) string {
projectID := strings.TrimSpace(account.GetCredential("project_id")) projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID != "" { if projectID != "" {
return projectID return "gemini:" + projectID
} }
return "account:" + strconv.FormatInt(account.ID, 10) return "gemini:account:" + strconv.FormatInt(account.ID, 10)
} }
...@@ -93,6 +93,7 @@ type OpenAIGatewayService struct { ...@@ -93,6 +93,7 @@ type OpenAIGatewayService struct {
billingCacheService *BillingCacheService billingCacheService *BillingCacheService
httpUpstream HTTPUpstream httpUpstream HTTPUpstream
deferredService *DeferredService deferredService *DeferredService
openAITokenProvider *OpenAITokenProvider
} }
// NewOpenAIGatewayService creates a new OpenAIGatewayService // NewOpenAIGatewayService creates a new OpenAIGatewayService
...@@ -110,6 +111,7 @@ func NewOpenAIGatewayService( ...@@ -110,6 +111,7 @@ func NewOpenAIGatewayService(
billingCacheService *BillingCacheService, billingCacheService *BillingCacheService,
httpUpstream HTTPUpstream, httpUpstream HTTPUpstream,
deferredService *DeferredService, deferredService *DeferredService,
openAITokenProvider *OpenAITokenProvider,
) *OpenAIGatewayService { ) *OpenAIGatewayService {
return &OpenAIGatewayService{ return &OpenAIGatewayService{
accountRepo: accountRepo, accountRepo: accountRepo,
...@@ -125,6 +127,7 @@ func NewOpenAIGatewayService( ...@@ -125,6 +127,7 @@ func NewOpenAIGatewayService(
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
httpUpstream: httpUpstream, httpUpstream: httpUpstream,
deferredService: deferredService, deferredService: deferredService,
openAITokenProvider: openAITokenProvider,
} }
} }
...@@ -503,6 +506,15 @@ func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig ...@@ -503,6 +506,15 @@ func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig
func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) { func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
switch account.Type { switch account.Type {
case AccountTypeOAuth: case AccountTypeOAuth:
// 使用 TokenProvider 获取缓存的 token
if s.openAITokenProvider != nil {
accessToken, err := s.openAITokenProvider.GetAccessToken(ctx, account)
if err != nil {
return "", "", err
}
return accessToken, "oauth", nil
}
// 降级:TokenProvider 未配置时直接从账号读取
accessToken := account.GetOpenAIAccessToken() accessToken := account.GetOpenAIAccessToken()
if accessToken == "" { if accessToken == "" {
return "", "", errors.New("access_token not found in credentials") return "", "", errors.New("access_token not found in credentials")
......
package service
import (
"context"
"errors"
"log/slog"
"strings"
"time"
)
const (
openAITokenRefreshSkew = 3 * time.Minute
openAITokenCacheSkew = 5 * time.Minute
openAILockWaitTime = 200 * time.Millisecond
)
// OpenAITokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
type OpenAITokenCache = GeminiTokenCache
// OpenAITokenProvider 管理 OpenAI OAuth 账户的 access_token
type OpenAITokenProvider struct {
accountRepo AccountRepository
tokenCache OpenAITokenCache
openAIOAuthService *OpenAIOAuthService
}
func NewOpenAITokenProvider(
accountRepo AccountRepository,
tokenCache OpenAITokenCache,
openAIOAuthService *OpenAIOAuthService,
) *OpenAITokenProvider {
return &OpenAITokenProvider{
accountRepo: accountRepo,
tokenCache: tokenCache,
openAIOAuthService: openAIOAuthService,
}
}
// GetAccessToken 获取有效的 access_token
func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth {
return "", errors.New("not an openai oauth account")
}
cacheKey := OpenAITokenCacheKey(account)
// 1. 先尝试缓存
if p.tokenCache != nil {
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
slog.Debug("openai_token_cache_hit", "account_id", account.ID)
return token, nil
} else if err != nil {
slog.Warn("openai_token_cache_get_failed", "account_id", account.ID, "error", err)
}
}
slog.Debug("openai_token_cache_miss", "account_id", account.ID)
// 2. 如果即将过期则刷新
expiresAt := account.GetCredentialAsTime("expires_at")
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew
refreshFailed := false
if needsRefresh && p.tokenCache != nil {
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if lockErr == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
// 从数据库获取最新账户信息
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
if err == nil && fresh != nil {
account = fresh
}
expiresAt = account.GetCredentialAsTime("expires_at")
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
if p.openAIOAuthService == nil {
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
refreshFailed = true // 无法刷新,标记失败
} else {
tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account)
if err != nil {
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err)
refreshFailed = true // 刷新失败,标记以使用短 TTL
} else {
newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo)
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
account.Credentials = newCredentials
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
slog.Error("openai_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
}
expiresAt = account.GetCredentialAsTime("expires_at")
}
}
}
} else if lockErr != nil {
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
slog.Warn("openai_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr)
// 检查 ctx 是否已取消
if ctx.Err() != nil {
return "", ctx.Err()
}
// 从数据库获取最新账户信息
if p.accountRepo != nil {
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
if err == nil && fresh != nil {
account = fresh
}
}
expiresAt = account.GetCredentialAsTime("expires_at")
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
if p.openAIOAuthService == nil {
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
refreshFailed = true
} else {
tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account)
if err != nil {
slog.Warn("openai_token_refresh_failed_degraded", "account_id", account.ID, "error", err)
refreshFailed = true
} else {
newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo)
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
account.Credentials = newCredentials
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
slog.Error("openai_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
}
expiresAt = account.GetCredentialAsTime("expires_at")
}
}
}
} else {
// 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存
time.Sleep(openAILockWaitTime)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID)
return token, nil
}
}
}
accessToken := account.GetOpenAIAccessToken()
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found in credentials")
}
// 3. 存入缓存
if p.tokenCache != nil {
ttl := 30 * time.Minute
if refreshFailed {
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
ttl = time.Minute
slog.Debug("openai_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
} else if expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > openAITokenCacheSkew:
ttl = until - openAITokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
}
}
if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
slog.Warn("openai_token_cache_set_failed", "account_id", account.ID, "error", err)
}
}
return accessToken, nil
}
This diff is collapsed.
...@@ -85,13 +85,24 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc ...@@ -85,13 +85,24 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
switch statusCode { switch statusCode {
case 401: case 401:
if account.Type == AccountTypeOAuth && // 对所有 OAuth 账号在 401 错误时调用缓存失效并强制下次刷新
(account.Platform == PlatformAntigravity || account.Platform == PlatformGemini) { if account.Type == AccountTypeOAuth {
// 1. 失效缓存
if s.tokenCacheInvalidator != nil { if s.tokenCacheInvalidator != nil {
if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil { if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil {
slog.Warn("oauth_401_invalidate_cache_failed", "account_id", account.ID, "error", err) slog.Warn("oauth_401_invalidate_cache_failed", "account_id", account.ID, "error", err)
} }
} }
// 2. 设置 expires_at 为当前时间,强制下次请求刷新 token
if account.Credentials == nil {
account.Credentials = make(map[string]any)
}
account.Credentials["expires_at"] = time.Now().Format(time.RFC3339)
if err := s.accountRepo.Update(ctx, account); err != nil {
slog.Warn("oauth_401_force_refresh_update_failed", "account_id", account.ID, "error", err)
} else {
slog.Info("oauth_401_force_refresh_set", "account_id", account.ID, "platform", account.Platform)
}
} }
msg := "Authentication failed (401): invalid or expired credentials" msg := "Authentication failed (401): invalid or expired credentials"
if upstreamMsg != "" { if upstreamMsg != "" {
......
...@@ -7,29 +7,35 @@ type TokenCacheInvalidator interface { ...@@ -7,29 +7,35 @@ type TokenCacheInvalidator interface {
} }
type CompositeTokenCacheInvalidator struct { type CompositeTokenCacheInvalidator struct {
geminiCache GeminiTokenCache cache GeminiTokenCache // 统一使用一个缓存接口,通过缓存键前缀区分平台
} }
func NewCompositeTokenCacheInvalidator(geminiCache GeminiTokenCache) *CompositeTokenCacheInvalidator { func NewCompositeTokenCacheInvalidator(cache GeminiTokenCache) *CompositeTokenCacheInvalidator {
return &CompositeTokenCacheInvalidator{ return &CompositeTokenCacheInvalidator{
geminiCache: geminiCache, cache: cache,
} }
} }
func (c *CompositeTokenCacheInvalidator) InvalidateToken(ctx context.Context, account *Account) error { func (c *CompositeTokenCacheInvalidator) InvalidateToken(ctx context.Context, account *Account) error {
if c == nil || c.geminiCache == nil || account == nil { if c == nil || c.cache == nil || account == nil {
return nil return nil
} }
if account.Type != AccountTypeOAuth { if account.Type != AccountTypeOAuth {
return nil return nil
} }
var cacheKey string
switch account.Platform { switch account.Platform {
case PlatformGemini: case PlatformGemini:
return c.geminiCache.DeleteAccessToken(ctx, GeminiTokenCacheKey(account)) cacheKey = GeminiTokenCacheKey(account)
case PlatformAntigravity: case PlatformAntigravity:
return c.geminiCache.DeleteAccessToken(ctx, AntigravityTokenCacheKey(account)) cacheKey = AntigravityTokenCacheKey(account)
case PlatformOpenAI:
cacheKey = OpenAITokenCacheKey(account)
case PlatformAnthropic:
cacheKey = ClaudeTokenCacheKey(account)
default: default:
return nil return nil
} }
return c.cache.DeleteAccessToken(ctx, cacheKey)
} }
...@@ -4,6 +4,7 @@ package service ...@@ -4,6 +4,7 @@ package service
import ( import (
"context" "context"
"errors"
"testing" "testing"
"time" "time"
...@@ -50,7 +51,7 @@ func TestCompositeTokenCacheInvalidator_Gemini(t *testing.T) { ...@@ -50,7 +51,7 @@ func TestCompositeTokenCacheInvalidator_Gemini(t *testing.T) {
err := invalidator.InvalidateToken(context.Background(), account) err := invalidator.InvalidateToken(context.Background(), account)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, []string{"project-x"}, cache.deletedKeys) require.Equal(t, []string{"gemini:project-x"}, cache.deletedKeys)
} }
func TestCompositeTokenCacheInvalidator_Antigravity(t *testing.T) { func TestCompositeTokenCacheInvalidator_Antigravity(t *testing.T) {
...@@ -70,13 +71,99 @@ func TestCompositeTokenCacheInvalidator_Antigravity(t *testing.T) { ...@@ -70,13 +71,99 @@ func TestCompositeTokenCacheInvalidator_Antigravity(t *testing.T) {
require.Equal(t, []string{"ag:ag-project"}, cache.deletedKeys) require.Equal(t, []string{"ag:ag-project"}, cache.deletedKeys)
} }
func TestCompositeTokenCacheInvalidator_OpenAI(t *testing.T) {
cache := &geminiTokenCacheStub{}
invalidator := NewCompositeTokenCacheInvalidator(cache)
account := &Account{
ID: 500,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "openai-token",
},
}
err := invalidator.InvalidateToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, []string{"openai:account:500"}, cache.deletedKeys)
}
func TestCompositeTokenCacheInvalidator_Claude(t *testing.T) {
cache := &geminiTokenCacheStub{}
invalidator := NewCompositeTokenCacheInvalidator(cache)
account := &Account{
ID: 600,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "claude-token",
},
}
err := invalidator.InvalidateToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, []string{"claude:account:600"}, cache.deletedKeys)
}
func TestCompositeTokenCacheInvalidator_SkipNonOAuth(t *testing.T) { func TestCompositeTokenCacheInvalidator_SkipNonOAuth(t *testing.T) {
cache := &geminiTokenCacheStub{} cache := &geminiTokenCacheStub{}
invalidator := NewCompositeTokenCacheInvalidator(cache) invalidator := NewCompositeTokenCacheInvalidator(cache)
tests := []struct {
name string
account *Account
}{
{
name: "gemini_api_key",
account: &Account{
ID: 1,
Platform: PlatformGemini,
Type: AccountTypeAPIKey,
},
},
{
name: "openai_api_key",
account: &Account{
ID: 2,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
},
},
{
name: "claude_api_key",
account: &Account{
ID: 3,
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
},
},
{
name: "claude_setup_token",
account: &Account{
ID: 4,
Platform: PlatformAnthropic,
Type: AccountTypeSetupToken,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cache.deletedKeys = nil
err := invalidator.InvalidateToken(context.Background(), tt.account)
require.NoError(t, err)
require.Empty(t, cache.deletedKeys)
})
}
}
func TestCompositeTokenCacheInvalidator_SkipUnsupportedPlatform(t *testing.T) {
cache := &geminiTokenCacheStub{}
invalidator := NewCompositeTokenCacheInvalidator(cache)
account := &Account{ account := &Account{
ID: 1, ID: 100,
Platform: PlatformGemini, Platform: "unknown-platform",
Type: AccountTypeAPIKey, Type: AccountTypeOAuth,
} }
err := invalidator.InvalidateToken(context.Background(), account) err := invalidator.InvalidateToken(context.Background(), account)
...@@ -95,3 +182,87 @@ func TestCompositeTokenCacheInvalidator_NilCache(t *testing.T) { ...@@ -95,3 +182,87 @@ func TestCompositeTokenCacheInvalidator_NilCache(t *testing.T) {
err := invalidator.InvalidateToken(context.Background(), account) err := invalidator.InvalidateToken(context.Background(), account)
require.NoError(t, err) require.NoError(t, err)
} }
func TestCompositeTokenCacheInvalidator_NilAccount(t *testing.T) {
cache := &geminiTokenCacheStub{}
invalidator := NewCompositeTokenCacheInvalidator(cache)
err := invalidator.InvalidateToken(context.Background(), nil)
require.NoError(t, err)
require.Empty(t, cache.deletedKeys)
}
func TestCompositeTokenCacheInvalidator_NilInvalidator(t *testing.T) {
var invalidator *CompositeTokenCacheInvalidator
account := &Account{
ID: 5,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
err := invalidator.InvalidateToken(context.Background(), account)
require.NoError(t, err)
}
func TestCompositeTokenCacheInvalidator_DeleteError(t *testing.T) {
expectedErr := errors.New("redis connection failed")
cache := &geminiTokenCacheStub{deleteErr: expectedErr}
invalidator := NewCompositeTokenCacheInvalidator(cache)
tests := []struct {
name string
account *Account
}{
{
name: "openai_delete_error",
account: &Account{
ID: 700,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
},
},
{
name: "claude_delete_error",
account: &Account{
ID: 800,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := invalidator.InvalidateToken(context.Background(), tt.account)
require.Error(t, err)
require.Equal(t, expectedErr, err)
})
}
}
func TestCompositeTokenCacheInvalidator_AllPlatformsIntegration(t *testing.T) {
// 测试所有平台的缓存键生成和删除
cache := &geminiTokenCacheStub{}
invalidator := NewCompositeTokenCacheInvalidator(cache)
accounts := []*Account{
{ID: 1, Platform: PlatformGemini, Type: AccountTypeOAuth, Credentials: map[string]any{"project_id": "gemini-proj"}},
{ID: 2, Platform: PlatformAntigravity, Type: AccountTypeOAuth, Credentials: map[string]any{"project_id": "ag-proj"}},
{ID: 3, Platform: PlatformOpenAI, Type: AccountTypeOAuth},
{ID: 4, Platform: PlatformAnthropic, Type: AccountTypeOAuth},
}
expectedKeys := []string{
"gemini:gemini-proj",
"ag:ag-proj",
"openai:account:3",
"claude:account:4",
}
for _, acc := range accounts {
err := invalidator.InvalidateToken(context.Background(), acc)
require.NoError(t, err)
}
require.Equal(t, expectedKeys, cache.deletedKeys)
}
package service
import "strconv"
// OpenAITokenCacheKey 生成 OpenAI OAuth 账号的缓存键
// 格式: "openai:account:{account_id}"
func OpenAITokenCacheKey(account *Account) string {
return "openai:account:" + strconv.FormatInt(account.ID, 10)
}
// ClaudeTokenCacheKey 生成 Claude (Anthropic) OAuth 账号的缓存键
// 格式: "claude:account:{account_id}"
func ClaudeTokenCacheKey(account *Account) string {
return "claude:account:" + strconv.FormatInt(account.ID, 10)
}
...@@ -22,7 +22,7 @@ func TestGeminiTokenCacheKey(t *testing.T) { ...@@ -22,7 +22,7 @@ func TestGeminiTokenCacheKey(t *testing.T) {
"project_id": "my-project-123", "project_id": "my-project-123",
}, },
}, },
expected: "my-project-123", expected: "gemini:my-project-123",
}, },
{ {
name: "project_id_with_whitespace", name: "project_id_with_whitespace",
...@@ -32,7 +32,7 @@ func TestGeminiTokenCacheKey(t *testing.T) { ...@@ -32,7 +32,7 @@ func TestGeminiTokenCacheKey(t *testing.T) {
"project_id": " project-with-spaces ", "project_id": " project-with-spaces ",
}, },
}, },
expected: "project-with-spaces", expected: "gemini:project-with-spaces",
}, },
{ {
name: "empty_project_id_fallback_to_account_id", name: "empty_project_id_fallback_to_account_id",
...@@ -42,7 +42,7 @@ func TestGeminiTokenCacheKey(t *testing.T) { ...@@ -42,7 +42,7 @@ func TestGeminiTokenCacheKey(t *testing.T) {
"project_id": "", "project_id": "",
}, },
}, },
expected: "account:102", expected: "gemini:account:102",
}, },
{ {
name: "whitespace_only_project_id_fallback_to_account_id", name: "whitespace_only_project_id_fallback_to_account_id",
...@@ -52,7 +52,7 @@ func TestGeminiTokenCacheKey(t *testing.T) { ...@@ -52,7 +52,7 @@ func TestGeminiTokenCacheKey(t *testing.T) {
"project_id": " ", "project_id": " ",
}, },
}, },
expected: "account:103", expected: "gemini:account:103",
}, },
{ {
name: "no_project_id_key_fallback_to_account_id", name: "no_project_id_key_fallback_to_account_id",
...@@ -60,7 +60,7 @@ func TestGeminiTokenCacheKey(t *testing.T) { ...@@ -60,7 +60,7 @@ func TestGeminiTokenCacheKey(t *testing.T) {
ID: 104, ID: 104,
Credentials: map[string]any{}, Credentials: map[string]any{},
}, },
expected: "account:104", expected: "gemini:account:104",
}, },
{ {
name: "nil_credentials_fallback_to_account_id", name: "nil_credentials_fallback_to_account_id",
...@@ -68,7 +68,7 @@ func TestGeminiTokenCacheKey(t *testing.T) { ...@@ -68,7 +68,7 @@ func TestGeminiTokenCacheKey(t *testing.T) {
ID: 105, ID: 105,
Credentials: nil, Credentials: nil,
}, },
expected: "account:105", expected: "gemini:account:105",
}, },
} }
...@@ -151,3 +151,109 @@ func TestAntigravityTokenCacheKey(t *testing.T) { ...@@ -151,3 +151,109 @@ func TestAntigravityTokenCacheKey(t *testing.T) {
}) })
} }
} }
func TestOpenAITokenCacheKey(t *testing.T) {
tests := []struct {
name string
account *Account
expected string
}{
{
name: "basic_account",
account: &Account{
ID: 300,
},
expected: "openai:account:300",
},
{
name: "account_with_credentials",
account: &Account{
ID: 301,
Credentials: map[string]any{
"access_token": "test-token",
},
},
expected: "openai:account:301",
},
{
name: "account_id_zero",
account: &Account{
ID: 0,
},
expected: "openai:account:0",
},
{
name: "large_account_id",
account: &Account{
ID: 9999999999,
},
expected: "openai:account:9999999999",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := OpenAITokenCacheKey(tt.account)
require.Equal(t, tt.expected, result)
})
}
}
func TestClaudeTokenCacheKey(t *testing.T) {
tests := []struct {
name string
account *Account
expected string
}{
{
name: "basic_account",
account: &Account{
ID: 400,
},
expected: "claude:account:400",
},
{
name: "account_with_credentials",
account: &Account{
ID: 401,
Credentials: map[string]any{
"access_token": "claude-token",
},
},
expected: "claude:account:401",
},
{
name: "account_id_zero",
account: &Account{
ID: 0,
},
expected: "claude:account:0",
},
{
name: "large_account_id",
account: &Account{
ID: 9999999999,
},
expected: "claude:account:9999999999",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ClaudeTokenCacheKey(tt.account)
require.Equal(t, tt.expected, result)
})
}
}
func TestCacheKeyUniqueness(t *testing.T) {
// 确保不同平台的缓存键不会冲突
account := &Account{ID: 123}
openaiKey := OpenAITokenCacheKey(account)
claudeKey := ClaudeTokenCacheKey(account)
require.NotEqual(t, openaiKey, claudeKey, "OpenAI and Claude cache keys should be different")
require.Contains(t, openaiKey, "openai:")
require.Contains(t, claudeKey, "claude:")
}
...@@ -172,8 +172,8 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc ...@@ -172,8 +172,8 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
if err := s.accountRepo.Update(ctx, account); err != nil { if err := s.accountRepo.Update(ctx, account); err != nil {
return fmt.Errorf("failed to save credentials: %w", err) return fmt.Errorf("failed to save credentials: %w", err)
} }
if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth && // 对所有 OAuth 账号调用缓存失效(InvalidateToken 内部根据平台判断是否需要处理)
(account.Platform == PlatformGemini || account.Platform == PlatformAntigravity) { if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth {
if err := s.cacheInvalidator.InvalidateToken(ctx, account); err != nil { if err := s.cacheInvalidator.InvalidateToken(ctx, account); err != nil {
log.Printf("[TokenRefresh] Failed to invalidate token cache for account %d: %v", account.ID, err) log.Printf("[TokenRefresh] Failed to invalidate token cache for account %d: %v", account.ID, err)
} else { } else {
......
...@@ -197,7 +197,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount(t *testing.T) { ...@@ -197,7 +197,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount(t *testing.T) {
require.Equal(t, 0, invalidator.calls) // 非 OAuth 不触发缓存失效 require.Equal(t, 0, invalidator.calls) // 非 OAuth 不触发缓存失效
} }
// TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth 测试其他平台的 OAuth 账号不触发缓存失效 // TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth 测试所有 OAuth 平台都触发缓存失效
func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) { func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
repo := &tokenRefreshAccountRepo{} repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{} invalidator := &tokenCacheInvalidatorStub{}
...@@ -210,7 +210,7 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) { ...@@ -210,7 +210,7 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg) service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
account := &Account{ account := &Account{
ID: 10, ID: 10,
Platform: PlatformOpenAI, // 其他平台 Platform: PlatformOpenAI, // OpenAI OAuth 账户
Type: AccountTypeOAuth, Type: AccountTypeOAuth,
} }
refresher := &tokenRefresherStub{ refresher := &tokenRefresherStub{
...@@ -222,7 +222,7 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) { ...@@ -222,7 +222,7 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
err := service.refreshWithRetry(context.Background(), account, refresher) err := service.refreshWithRetry(context.Background(), account, refresher)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls) require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 0, invalidator.calls) // 其他平台不触发缓存失效 require.Equal(t, 1, invalidator.calls) // 所有 OAuth 账户刷新后触发缓存失效
} }
// TestTokenRefreshService_RefreshWithRetry_UpdateFailed 测试更新失败的情况 // TestTokenRefreshService_RefreshWithRetry_UpdateFailed 测试更新失败的情况
......
...@@ -214,10 +214,13 @@ var ProviderSet = wire.NewSet( ...@@ -214,10 +214,13 @@ var ProviderSet = wire.NewSet(
NewGeminiOAuthService, NewGeminiOAuthService,
NewGeminiQuotaService, NewGeminiQuotaService,
NewCompositeTokenCacheInvalidator, NewCompositeTokenCacheInvalidator,
wire.Bind(new(TokenCacheInvalidator), new(*CompositeTokenCacheInvalidator)),
NewAntigravityOAuthService, NewAntigravityOAuthService,
NewGeminiTokenProvider, NewGeminiTokenProvider,
NewGeminiMessagesCompatService, NewGeminiMessagesCompatService,
NewAntigravityTokenProvider, NewAntigravityTokenProvider,
NewOpenAITokenProvider,
NewClaudeTokenProvider,
NewAntigravityGatewayService, NewAntigravityGatewayService,
ProvideRateLimitService, ProvideRateLimitService,
NewAccountUsageService, NewAccountUsageService,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment