Commit 3f6fa1e3 authored by weak-fox's avatar weak-fox
Browse files

fix: avoid temp unsched when refresh token is missing

parent fdd8499f
......@@ -502,6 +502,25 @@ func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *A
refreshToken := account.GetCredential("refresh_token")
if refreshToken == "" {
accessToken := account.GetCredential("access_token")
if accessToken != "" {
tokenInfo := &OpenAITokenInfo{
AccessToken: accessToken,
RefreshToken: "",
IDToken: account.GetCredential("id_token"),
ClientID: account.GetCredential("client_id"),
Email: account.GetCredential("email"),
ChatGPTAccountID: account.GetCredential("chatgpt_account_id"),
ChatGPTUserID: account.GetCredential("chatgpt_user_id"),
OrganizationID: account.GetCredential("organization_id"),
PlanType: account.GetCredential("plan_type"),
}
if expiresAt := account.GetCredentialAsTime("expires_at"); expiresAt != nil {
tokenInfo.ExpiresAt = expiresAt.Unix()
tokenInfo.ExpiresIn = int64(time.Until(*expiresAt).Seconds())
}
return tokenInfo, nil
}
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available")
}
......
package service
import (
"context"
"errors"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/stretchr/testify/require"
)
type openaiOAuthClientRefreshStub struct {
refreshCalls int32
}
func (s *openaiOAuthClientRefreshStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) {
return nil, errors.New("not implemented")
}
func (s *openaiOAuthClientRefreshStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
atomic.AddInt32(&s.refreshCalls, 1)
return nil, errors.New("not implemented")
}
func (s *openaiOAuthClientRefreshStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
atomic.AddInt32(&s.refreshCalls, 1)
return nil, errors.New("not implemented")
}
func TestOpenAIOAuthService_RefreshAccountToken_NoRefreshTokenUsesExistingAccessToken(t *testing.T) {
client := &openaiOAuthClientRefreshStub{}
svc := NewOpenAIOAuthService(nil, client)
expiresAt := time.Now().Add(30 * time.Minute).UTC().Format(time.RFC3339)
account := &Account{
ID: 77,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "existing-access-token",
"expires_at": expiresAt,
"client_id": "client-id-1",
},
}
info, err := svc.RefreshAccountToken(context.Background(), account)
require.NoError(t, err)
require.NotNil(t, info)
require.Equal(t, "existing-access-token", info.AccessToken)
require.Equal(t, "client-id-1", info.ClientID)
require.Zero(t, atomic.LoadInt32(&client.refreshCalls), "existing access token should be reused without calling refresh")
}
......@@ -430,6 +430,7 @@ func isNonRetryableRefreshError(err error) bool {
"unauthorized_client", // 客户端未授权
"access_denied", // 访问被拒绝
"missing_project_id", // 缺少 project_id
"no refresh token available",
}
for _, needle := range nonRetryable {
if strings.Contains(msg, needle) {
......
......@@ -19,6 +19,7 @@ type tokenRefreshAccountRepo struct {
updateCredentialsCalls int
setErrorCalls int
clearTempCalls int
setTempUnschedCalls int
lastAccount *Account
updateErr error
}
......@@ -58,6 +59,11 @@ func (r *tokenRefreshAccountRepo) ClearTempUnschedulable(ctx context.Context, id
return nil
}
func (r *tokenRefreshAccountRepo) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
r.setTempUnschedCalls++
return nil
}
type tokenCacheInvalidatorStub struct {
calls int
err error
......@@ -490,6 +496,31 @@ func TestTokenRefreshService_RefreshWithRetry_NonRetryableErrorAllPlatforms(t *t
}
}
func TestTokenRefreshService_RefreshWithRetry_NoRefreshTokenDoesNotTempUnschedule(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 2,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil)
account := &Account{
ID: 18,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
err: errors.New("no refresh token available"),
}
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
require.Error(t, err)
require.Equal(t, 0, repo.updateCalls)
require.Equal(t, 0, repo.setTempUnschedCalls, "missing refresh token should not mark the account temp unschedulable")
require.Equal(t, 1, repo.setErrorCalls, "missing refresh token should be treated as a non-retryable credential state")
}
// TestIsNonRetryableRefreshError 测试不可重试错误判断
func TestIsNonRetryableRefreshError(t *testing.T) {
tests := []struct {
......@@ -503,6 +534,7 @@ func TestIsNonRetryableRefreshError(t *testing.T) {
{name: "invalid_client", err: errors.New("invalid_client"), expected: true},
{name: "unauthorized_client", err: errors.New("unauthorized_client"), expected: true},
{name: "access_denied", err: errors.New("access_denied"), expected: true},
{name: "no_refresh_token", err: errors.New("no refresh token available"), expected: true},
{name: "invalid_grant_with_desc", err: errors.New("Error: invalid_grant - token revoked"), expected: true},
{name: "case_insensitive", err: errors.New("INVALID_GRANT"), expected: true},
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment