Unverified Commit 27ffc7f3 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #1828 from wx-11/main

使用codex的生图接口代替web2api
parents 0b85a8da 9e5a6351
...@@ -124,9 +124,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -124,9 +124,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository) geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
tempUnschedCache := repository.NewTempUnschedCache(redisClient) tempUnschedCache := repository.NewTempUnschedCache(redisClient)
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient) timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
openAI403CounterCache := repository.NewOpenAI403CounterCache(redisClient)
geminiTokenCache := repository.NewGeminiTokenCache(redisClient) geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache) compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator) rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, openAI403CounterCache, settingService, compositeTokenCacheInvalidator)
httpUpstream := repository.NewHTTPUpstream(configConfig) httpUpstream := repository.NewHTTPUpstream(configConfig)
claudeUsageFetcher := repository.NewClaudeUsageFetcher(httpUpstream) claudeUsageFetcher := repository.NewClaudeUsageFetcher(httpUpstream)
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository) antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
......
package repository
import (
"context"
"fmt"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const openAI403CounterPrefix = "openai_403_count:account:"
var openAI403CounterIncrScript = redis.NewScript(`
local key = KEYS[1]
local ttl = tonumber(ARGV[1])
local count = redis.call('INCR', key)
if count == 1 then
redis.call('EXPIRE', key, ttl)
end
return count
`)
type openAI403CounterCache struct {
rdb *redis.Client
}
func NewOpenAI403CounterCache(rdb *redis.Client) service.OpenAI403CounterCache {
return &openAI403CounterCache{rdb: rdb}
}
func (c *openAI403CounterCache) IncrementOpenAI403Count(ctx context.Context, accountID int64, windowMinutes int) (int64, error) {
key := fmt.Sprintf("%s%d", openAI403CounterPrefix, accountID)
ttlSeconds := windowMinutes * 60
if ttlSeconds < 60 {
ttlSeconds = 60
}
result, err := openAI403CounterIncrScript.Run(ctx, c.rdb, []string{key}, ttlSeconds).Int64()
if err != nil {
return 0, fmt.Errorf("increment openai 403 count: %w", err)
}
return result, nil
}
func (c *openAI403CounterCache) ResetOpenAI403Count(ctx context.Context, accountID int64) error {
key := fmt.Sprintf("%s%d", openAI403CounterPrefix, accountID)
return c.rdb.Del(ctx, key).Err()
}
...@@ -96,6 +96,7 @@ var ProviderSet = wire.NewSet( ...@@ -96,6 +96,7 @@ var ProviderSet = wire.NewSet(
NewAPIKeyCache, NewAPIKeyCache,
NewTempUnschedCache, NewTempUnschedCache,
NewTimeoutCounterCache, NewTimeoutCounterCache,
NewOpenAI403CounterCache,
NewInternal500CounterCache, NewInternal500CounterCache,
ProvideConcurrencyCache, ProvideConcurrencyCache,
ProvideSessionLimitCache, ProvideSessionLimitCache,
......
...@@ -930,10 +930,8 @@ func (a *Account) SupportsOpenAIImageCapability(capability OpenAIImagesCapabilit ...@@ -930,10 +930,8 @@ func (a *Account) SupportsOpenAIImageCapability(capability OpenAIImagesCapabilit
return false return false
} }
switch capability { switch capability {
case OpenAIImagesCapabilityBasic: case OpenAIImagesCapabilityBasic, OpenAIImagesCapabilityNative:
return a.Type == AccountTypeOAuth || a.Type == AccountTypeAPIKey return a.Type == AccountTypeOAuth || a.Type == AccountTypeAPIKey
case OpenAIImagesCapabilityNative:
return a.Type == AccountTypeAPIKey
default: default:
return true return true
} }
......
...@@ -5,7 +5,6 @@ import ( ...@@ -5,7 +5,6 @@ import (
"bytes" "bytes"
"context" "context"
"crypto/rand" "crypto/rand"
"encoding/base64"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"errors" "errors"
...@@ -1138,7 +1137,7 @@ func (s *AccountTestService) testOpenAIImageAPIKey(c *gin.Context, ctx context.C ...@@ -1138,7 +1137,7 @@ func (s *AccountTestService) testOpenAIImageAPIKey(c *gin.Context, ctx context.C
return nil return nil
} }
// testOpenAIImageOAuth tests OpenAI image generation using an OAuth account via ChatGPT backend API. // testOpenAIImageOAuth tests OpenAI image generation using an OAuth account via Codex /responses API.
func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Context, account *Account, modelID, prompt string) error { func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Context, account *Account, modelID, prompt string) error {
authToken := account.GetOpenAIAccessToken() authToken := account.GetOpenAIAccessToken()
if authToken == "" { if authToken == "" {
...@@ -1153,69 +1152,46 @@ func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Co ...@@ -1153,69 +1152,46 @@ func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Co
c.Writer.Flush() c.Writer.Flush()
s.sendEvent(c, TestEvent{Type: "test_start", Model: modelID}) s.sendEvent(c, TestEvent{Type: "test_start", Model: modelID})
s.sendEvent(c, TestEvent{Type: "content", Text: "Initializing ChatGPT backend...\n"}) s.sendEvent(c, TestEvent{Type: "content", Text: "Calling Codex /responses image tool...\n"})
// Build headers (replicating buildOpenAIBackendAPIHeaders logic) parsed := &OpenAIImagesRequest{
headers := buildOpenAIBackendAPIHeadersForTest(ctx, account, authToken, s.accountRepo) Endpoint: openAIImagesGenerationsEndpoint,
Model: strings.TrimSpace(modelID),
proxyURL := "" Prompt: prompt,
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
} }
applyOpenAIImagesDefaults(parsed)
client, err := newOpenAIBackendAPIClient(proxyURL) responsesBody, err := buildOpenAIImagesResponsesRequest(parsed, parsed.Model)
if err != nil { if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to create client: %s", err.Error())) return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to build image request: %s", err.Error()))
} }
// Bootstrap req, err := http.NewRequestWithContext(ctx, http.MethodPost, chatgptCodexAPIURL, bytes.NewReader(responsesBody))
if bootstrapErr := bootstrapOpenAIBackendAPI(ctx, client, headers); bootstrapErr != nil {
log.Printf("OpenAI image test bootstrap warning: %v", bootstrapErr)
}
// Fetch chat requirements
s.sendEvent(c, TestEvent{Type: "content", Text: "Fetching chat requirements...\n"})
chatReqs, err := fetchOpenAIChatRequirements(ctx, client, headers)
if err != nil { if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Chat requirements failed: %s", err.Error())) return s.sendErrorAndEnd(c, "Failed to create request")
} }
if chatReqs.Arkose.Required { req.Host = "chatgpt.com"
return s.sendErrorAndEnd(c, "Unsupported challenge: arkose required") req.Header.Set("Authorization", "Bearer "+authToken)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("OpenAI-Beta", "responses=experimental")
req.Header.Set("originator", "opencode")
if customUA := strings.TrimSpace(account.GetOpenAIUserAgent()); customUA != "" {
req.Header.Set("User-Agent", customUA)
} else {
req.Header.Set("User-Agent", codexCLIUserAgent)
} }
if chatgptAccountID := strings.TrimSpace(account.GetChatGPTAccountID()); chatgptAccountID != "" {
// Initialize and prepare conversation req.Header.Set("chatgpt-account-id", chatgptAccountID)
s.sendEvent(c, TestEvent{Type: "content", Text: "Preparing image conversation...\n"})
parentMessageID := uuid.NewString()
proofToken := generateOpenAIProofToken(chatReqs.ProofOfWork.Required, chatReqs.ProofOfWork.Seed, chatReqs.ProofOfWork.Difficulty, headers.Get("User-Agent"))
_ = initializeOpenAIImageConversation(ctx, client, headers)
conduitToken, err := prepareOpenAIImageConversation(ctx, client, headers, prompt, parentMessageID, chatReqs.Token, proofToken)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Conversation prepare failed: %s", err.Error()))
} }
// Build simplified conversation request (no file uploads) proxyURL := ""
convReq := buildOpenAIImageTestConversationRequest(prompt, parentMessageID) if account.ProxyID != nil && account.Proxy != nil {
convHeaders := cloneHTTPHeader(headers) proxyURL = account.Proxy.URL()
convHeaders.Set("Accept", "text/event-stream")
convHeaders.Set("Content-Type", "application/json")
convHeaders.Set("openai-sentinel-chat-requirements-token", chatReqs.Token)
if conduitToken != "" {
convHeaders.Set("x-conduit-token", conduitToken)
}
if proofToken != "" {
convHeaders.Set("openai-sentinel-proof-token", proofToken)
} }
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
s.sendEvent(c, TestEvent{Type: "content", Text: "Generating image...\n"})
resp, err := client.R().
SetContext(ctx).
DisableAutoReadResponse().
SetHeaders(headerToMap(convHeaders)).
SetBodyJsonMarshal(convReq).
Post(openAIChatGPTConversationURL)
if err != nil { if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Conversation request failed: %s", err.Error())) return s.sendErrorAndEnd(c, fmt.Sprintf("Responses API request failed: %s", err.Error()))
} }
defer func() { defer func() {
if resp != nil && resp.Body != nil { if resp != nil && resp.Body != nil {
...@@ -1223,49 +1199,35 @@ func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Co ...@@ -1223,49 +1199,35 @@ func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Co
} }
}() }()
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
return s.sendErrorAndEnd(c, fmt.Sprintf("Conversation API returned %d", resp.StatusCode)) body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
message := strings.TrimSpace(extractUpstreamErrorMessage(body))
if message == "" {
message = fmt.Sprintf("Responses API returned %d", resp.StatusCode)
}
return s.sendErrorAndEnd(c, message)
} }
startTime := time.Now() body, err := io.ReadAll(resp.Body)
conversationID, pointerInfos, _, _, err := readOpenAIImageConversationStream(resp, startTime)
if err != nil { if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read failed: %s", err.Error())) return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to read image response: %s", err.Error()))
} }
pointerInfos = mergeOpenAIImagePointerInfos(pointerInfos, nil) results, _, _, _, _, err := collectOpenAIImagesFromResponsesBody(body)
if conversationID != "" && !hasOpenAIFileServicePointerInfos(pointerInfos) { if err != nil {
s.sendEvent(c, TestEvent{Type: "content", Text: "Waiting for image generation to complete...\n"}) return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to parse image response: %s", err.Error()))
polledPointers, pollErr := pollOpenAIImageConversation(ctx, client, headers, conversationID)
if pollErr != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Poll failed: %s", pollErr.Error()))
}
pointerInfos = mergeOpenAIImagePointerInfos(pointerInfos, polledPointers)
} }
pointerInfos = preferOpenAIFileServicePointerInfos(pointerInfos) if len(results) == 0 {
if len(pointerInfos) == 0 { return s.sendErrorAndEnd(c, "No images returned from responses API")
return s.sendErrorAndEnd(c, "No images returned from conversation")
} }
s.sendEvent(c, TestEvent{Type: "content", Text: "Downloading generated image...\n"}) for _, item := range results {
if item.RevisedPrompt != "" {
// Download and encode each image s.sendEvent(c, TestEvent{Type: "content", Text: item.RevisedPrompt})
for _, pointer := range pointerInfos {
downloadURL, err := fetchOpenAIImageDownloadURL(ctx, client, headers, conversationID, pointer.Pointer)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Download URL fetch failed: %s", err.Error()))
}
data, err := downloadOpenAIImageBytes(ctx, client, headers, downloadURL)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Image download failed: %s", err.Error()))
}
b64 := base64.StdEncoding.EncodeToString(data)
mimeType := http.DetectContentType(data)
if pointer.Prompt != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: pointer.Prompt})
} }
mimeType := openAIImageOutputMIMEType(item.OutputFormat)
s.sendEvent(c, TestEvent{ s.sendEvent(c, TestEvent{
Type: "image", Type: "image",
ImageURL: "data:" + mimeType + ";base64," + b64, ImageURL: "data:" + mimeType + ";base64," + item.Result,
MimeType: mimeType, MimeType: mimeType,
}) })
} }
...@@ -1274,107 +1236,6 @@ func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Co ...@@ -1274,107 +1236,6 @@ func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Co
return nil return nil
} }
// buildOpenAIBackendAPIHeadersForTest builds ChatGPT backend API headers for test purposes.
// Replicates the logic from OpenAIGatewayService.buildOpenAIBackendAPIHeaders without
// requiring the full gateway service dependency.
func buildOpenAIBackendAPIHeadersForTest(ctx context.Context, account *Account, token string, repo AccountRepository) http.Header {
// Ensure device and session IDs exist
deviceID := account.GetOpenAIDeviceID()
sessionID := account.GetOpenAISessionID()
if deviceID == "" || sessionID == "" {
updates := map[string]any{}
if deviceID == "" {
deviceID = uuid.NewString()
updates["openai_device_id"] = deviceID
}
if sessionID == "" {
sessionID = uuid.NewString()
updates["openai_session_id"] = sessionID
}
if account.Extra == nil {
account.Extra = map[string]any{}
}
for key, value := range updates {
account.Extra[key] = value
}
if repo != nil {
updateCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
_ = repo.UpdateExtra(updateCtx, account.ID, updates)
}
}
headers := make(http.Header)
headers.Set("Authorization", "Bearer "+token)
headers.Set("Accept", "application/json")
headers.Set("Origin", "https://chatgpt.com")
headers.Set("Referer", "https://chatgpt.com/")
headers.Set("Sec-Fetch-Dest", "empty")
headers.Set("Sec-Fetch-Mode", "cors")
headers.Set("Sec-Fetch-Site", "same-origin")
headers.Set("User-Agent", openAIImageBackendUserAgent)
if customUA := strings.TrimSpace(account.GetOpenAIUserAgent()); customUA != "" {
headers.Set("User-Agent", customUA)
}
if chatgptAccountID := strings.TrimSpace(account.GetChatGPTAccountID()); chatgptAccountID != "" {
headers.Set("chatgpt-account-id", chatgptAccountID)
}
if deviceID != "" {
headers.Set("oai-device-id", deviceID)
headers.Set("Cookie", "oai-did="+deviceID)
}
if sessionID != "" {
headers.Set("oai-session-id", sessionID)
}
return headers
}
// buildOpenAIImageTestConversationRequest creates a simplified image generation conversation request.
func buildOpenAIImageTestConversationRequest(prompt, parentMessageID string) map[string]any {
promptText := strings.TrimSpace(prompt)
if promptText == "" {
promptText = "Generate an image."
}
metadata := map[string]any{
"developer_mode_connector_ids": []any{},
"selected_github_repos": []any{},
"selected_all_github_repos": false,
"system_hints": []string{"picture_v2"},
"serialization_metadata": map[string]any{
"custom_symbol_offsets": []any{},
},
}
message := map[string]any{
"id": uuid.NewString(),
"author": map[string]any{"role": "user"},
"content": map[string]any{
"content_type": "text",
"parts": []any{promptText},
},
"metadata": metadata,
"create_time": float64(time.Now().UnixMilli()) / 1000,
}
return map[string]any{
"action": "next",
"client_prepare_state": "sent",
"parent_message_id": parentMessageID,
"messages": []any{message},
"model": "auto",
"timezone_offset_min": openAITimezoneOffsetMinutes(),
"timezone": openAITimezoneName(),
"conversation_mode": map[string]any{"kind": "primary_assistant"},
"system_hints": []string{"picture_v2"},
"supports_buffering": true,
"supported_encodings": []string{"v1"},
"client_contextual_info": map[string]any{"app_name": "chatgpt.com"},
"force_nulligen": false,
"force_paragen": false,
"force_paragen_model_slug": "",
"force_rate_limit": false,
"websocket_request_id": uuid.NewString(),
}
}
func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) { func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) {
eventJSON, _ := json.Marshal(event) eventJSON, _ := json.Marshal(event)
if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON); err != nil { if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON); err != nil {
......
package service
import (
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestAccountTestService_OpenAIImageOAuthHandlesOutputItemDoneFallback(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"text/event-stream"},
},
Body: io.NopCloser(strings.NewReader(
"data: {\"type\":\"response.output_item.done\",\"item\":{\"id\":\"ig_123\",\"type\":\"image_generation_call\",\"result\":\"aGVsbG8=\",\"revised_prompt\":\"draw a cat\",\"output_format\":\"png\"}}\n\n" +
"data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000006,\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[]}}\n\n" +
"data: [DONE]\n\n",
)),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 53,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "token-123",
},
}
err := svc.testOpenAIImageOAuth(c, context.Background(), account, "gpt-image-2", "draw a cat")
require.NoError(t, err)
require.Contains(t, rec.Body.String(), "Calling Codex /responses image tool")
require.Contains(t, rec.Body.String(), "data:image/png;base64,aGVsbG8=")
require.Contains(t, rec.Body.String(), "\"success\":true")
}
package service
import "context"
// OpenAI403CounterCache 追踪 OpenAI 账号连续 403 失败次数。
type OpenAI403CounterCache interface {
// IncrementOpenAI403Count 原子递增 403 计数并返回当前值。
IncrementOpenAI403Count(ctx context.Context, accountID int64, windowMinutes int) (int64, error)
// ResetOpenAI403Count 成功后清零计数器。
ResetOpenAI403Count(ctx context.Context, accountID int64) error
}
package service
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
type openAI403CounterResetStub struct {
resetCalls []int64
}
func (s *openAI403CounterResetStub) IncrementOpenAI403Count(context.Context, int64, int) (int64, error) {
return 0, nil
}
func (s *openAI403CounterResetStub) ResetOpenAI403Count(_ context.Context, accountID int64) error {
s.resetCalls = append(s.resetCalls, accountID)
return nil
}
func TestOpenAIGatewayServiceRecordUsage_ResetsOpenAI403CounterBeforeZeroUsageReturn(t *testing.T) {
counter := &openAI403CounterResetStub{}
rateLimitSvc := NewRateLimitService(nil, nil, nil, nil, nil)
rateLimitSvc.SetOpenAI403CounterCache(counter)
svc := &OpenAIGatewayService{
rateLimitService: rateLimitSvc,
}
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{},
Account: &Account{ID: 777, Platform: PlatformOpenAI},
})
require.NoError(t, err)
require.Equal(t, []int64{777}, counter.resetCalls)
}
...@@ -1098,3 +1098,50 @@ func TestOpenAIGatewayServiceRecordUsage_ImageOnlyUsageStillPersists(t *testing. ...@@ -1098,3 +1098,50 @@ func TestOpenAIGatewayServiceRecordUsage_ImageOnlyUsageStillPersists(t *testing.
require.NotNil(t, usageRepo.lastLog.BillingMode) require.NotNil(t, usageRepo.lastLog.BillingMode)
require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode) require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode)
} }
func TestOpenAIGatewayServiceRecordUsage_ImageUsesPerImageBillingEvenWithUsageTokens(t *testing.T) {
imagePrice := 0.02
groupID := int64(12)
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_image_per_request",
Model: "gpt-image-2",
Usage: OpenAIUsage{
InputTokens: 1110,
OutputTokens: 1756,
ImageOutputTokens: 1756,
},
ImageCount: 2,
ImageSize: "1K",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 1008,
GroupID: i64p(groupID),
Group: &Group{
ID: groupID,
RateMultiplier: 1.0,
ImagePrice1K: &imagePrice,
},
},
User: &User{ID: 2008},
Account: &Account{ID: 3008},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.NotNil(t, usageRepo.lastLog.BillingMode)
require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode)
require.Equal(t, 2, usageRepo.lastLog.ImageCount)
require.InDelta(t, 0.04, usageRepo.lastLog.TotalCost, 1e-12)
require.InDelta(t, 0.04, usageRepo.lastLog.ActualCost, 1e-12)
require.InDelta(t, 0.0, usageRepo.lastLog.InputCost, 1e-12)
require.InDelta(t, 0.0, usageRepo.lastLog.OutputCost, 1e-12)
require.InDelta(t, 0.0, usageRepo.lastLog.ImageOutputCost, 1e-12)
}
...@@ -4425,6 +4425,9 @@ type OpenAIRecordUsageInput struct { ...@@ -4425,6 +4425,9 @@ type OpenAIRecordUsageInput struct {
// RecordUsage records usage and deducts balance // RecordUsage records usage and deducts balance
func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error { func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error {
result := input.Result result := input.Result
if s.rateLimitService != nil && input != nil && input.Account != nil && input.Account.Platform == PlatformOpenAI {
s.rateLimitService.ResetOpenAI403Counter(ctx, input.Account.ID)
}
// 跳过所有 token 均为零的用量记录——上游未返回 usage 时不应写入数据库 // 跳过所有 token 均为零的用量记录——上游未返回 usage 时不应写入数据库
if result.Usage.InputTokens == 0 && result.Usage.OutputTokens == 0 && if result.Usage.InputTokens == 0 && result.Usage.OutputTokens == 0 &&
...@@ -4622,12 +4625,6 @@ func (s *OpenAIGatewayService) calculateOpenAIRecordUsageCost( ...@@ -4622,12 +4625,6 @@ func (s *OpenAIGatewayService) calculateOpenAIRecordUsageCost(
serviceTier string, serviceTier string,
) (*CostBreakdown, error) { ) (*CostBreakdown, error) {
if result != nil && result.ImageCount > 0 { if result != nil && result.ImageCount > 0 {
if hasOpenAIImageUsageTokens(result) {
cost, err := s.calculateOpenAIImageTokenCost(ctx, apiKey, billingModel, multiplier, tokens, serviceTier, result.ImageSize)
if err == nil {
return cost, nil
}
}
return s.calculateOpenAIImageCost(ctx, billingModel, apiKey, result, multiplier), nil return s.calculateOpenAIImageCost(ctx, billingModel, apiKey, result, multiplier), nil
} }
if s.resolver != nil && apiKey.Group != nil { if s.resolver != nil && apiKey.Group != nil {
...@@ -4679,7 +4676,8 @@ func (s *OpenAIGatewayService) calculateOpenAIImageCost( ...@@ -4679,7 +4676,8 @@ func (s *OpenAIGatewayService) calculateOpenAIImageCost(
result *OpenAIForwardResult, result *OpenAIForwardResult,
multiplier float64, multiplier float64,
) *CostBreakdown { ) *CostBreakdown {
if resolved := s.resolveOpenAIChannelPricing(ctx, billingModel, apiKey); resolved != nil { if resolved := s.resolveOpenAIChannelPricing(ctx, billingModel, apiKey); resolved != nil &&
(resolved.Mode == BillingModePerRequest || resolved.Mode == BillingModeImage) {
gid := apiKey.Group.ID gid := apiKey.Group.ID
cost, err := s.billingService.CalculateCostUnified(CostInput{ cost, err := s.billingService.CalculateCostUnified(CostInput{
Ctx: ctx, Ctx: ctx,
......
...@@ -50,6 +50,7 @@ const ( ...@@ -50,6 +50,7 @@ const (
openAIImageLifecycleTimeout = 2 * time.Minute openAIImageLifecycleTimeout = 2 * time.Minute
openAIImageMaxDownloadBytes = 20 << 20 // 20MB per image download openAIImageMaxDownloadBytes = 20 << 20 // 20MB per image download
openAIImageMaxUploadPartSize = 20 << 20 // 20MB per multipart upload part openAIImageMaxUploadPartSize = 20 << 20 // 20MB per multipart upload part
openAIImagesResponsesMainModel = "gpt-5.4-mini"
) )
type OpenAIImagesCapability string type OpenAIImagesCapability string
...@@ -81,10 +82,21 @@ type OpenAIImagesRequest struct { ...@@ -81,10 +82,21 @@ type OpenAIImagesRequest struct {
ExplicitSize bool ExplicitSize bool
SizeTier string SizeTier string
ResponseFormat string ResponseFormat string
Quality string
Background string
OutputFormat string
Moderation string
InputFidelity string
Style string
OutputCompression *int
PartialImages *int
HasMask bool HasMask bool
HasNativeOptions bool HasNativeOptions bool
RequiredCapability OpenAIImagesCapability RequiredCapability OpenAIImagesCapability
InputImageURLs []string
MaskImageURL string
Uploads []OpenAIImagesUpload Uploads []OpenAIImagesUpload
MaskUpload *OpenAIImagesUpload
Body []byte Body []byte
bodyHash string bodyHash string
} }
...@@ -188,7 +200,54 @@ func parseOpenAIImagesJSONRequest(body []byte, req *OpenAIImagesRequest) error { ...@@ -188,7 +200,54 @@ func parseOpenAIImagesJSONRequest(body []byte, req *OpenAIImagesRequest) error {
req.ExplicitSize = req.Size != "" req.ExplicitSize = req.Size != ""
} }
req.ResponseFormat = strings.ToLower(strings.TrimSpace(gjson.GetBytes(body, "response_format").String())) req.ResponseFormat = strings.ToLower(strings.TrimSpace(gjson.GetBytes(body, "response_format").String()))
req.Quality = strings.TrimSpace(gjson.GetBytes(body, "quality").String())
req.Background = strings.TrimSpace(gjson.GetBytes(body, "background").String())
req.OutputFormat = strings.TrimSpace(gjson.GetBytes(body, "output_format").String())
req.Moderation = strings.TrimSpace(gjson.GetBytes(body, "moderation").String())
req.InputFidelity = strings.TrimSpace(gjson.GetBytes(body, "input_fidelity").String())
req.Style = strings.TrimSpace(gjson.GetBytes(body, "style").String())
req.HasMask = gjson.GetBytes(body, "mask").Exists() req.HasMask = gjson.GetBytes(body, "mask").Exists()
if outputCompression := gjson.GetBytes(body, "output_compression"); outputCompression.Exists() {
if outputCompression.Type != gjson.Number {
return fmt.Errorf("invalid output_compression field type")
}
v := int(outputCompression.Int())
req.OutputCompression = &v
}
if partialImages := gjson.GetBytes(body, "partial_images"); partialImages.Exists() {
if partialImages.Type != gjson.Number {
return fmt.Errorf("invalid partial_images field type")
}
v := int(partialImages.Int())
req.PartialImages = &v
}
if req.IsEdits() {
images := gjson.GetBytes(body, "images")
if images.Exists() {
if !images.IsArray() {
return fmt.Errorf("invalid images field type")
}
for _, item := range images.Array() {
if imageURL := strings.TrimSpace(item.Get("image_url").String()); imageURL != "" {
req.InputImageURLs = append(req.InputImageURLs, imageURL)
continue
}
if item.Get("file_id").Exists() {
return fmt.Errorf("images[].file_id is not supported (use images[].image_url instead)")
}
}
}
if maskImageURL := strings.TrimSpace(gjson.GetBytes(body, "mask.image_url").String()); maskImageURL != "" {
req.MaskImageURL = maskImageURL
req.HasMask = true
}
if gjson.GetBytes(body, "mask.file_id").Exists() {
return fmt.Errorf("mask.file_id is not supported (use mask.image_url instead)")
}
if len(req.InputImageURLs) == 0 {
return fmt.Errorf("images[].image_url is required")
}
}
req.HasNativeOptions = hasOpenAINativeImageOptions(func(path string) bool { req.HasNativeOptions = hasOpenAINativeImageOptions(func(path string) bool {
return gjson.GetBytes(body, path).Exists() return gjson.GetBytes(body, path).Exists()
}) })
...@@ -231,6 +290,16 @@ func parseOpenAIImagesMultipartRequest(body []byte, contentType string, req *Ope ...@@ -231,6 +290,16 @@ func parseOpenAIImagesMultipartRequest(body []byte, contentType string, req *Ope
partContentType := strings.TrimSpace(part.Header.Get("Content-Type")) partContentType := strings.TrimSpace(part.Header.Get("Content-Type"))
if name == "mask" && len(data) > 0 { if name == "mask" && len(data) > 0 {
req.HasMask = true req.HasMask = true
width, height := parseOpenAIImageDimensions(part.Header)
maskUpload := OpenAIImagesUpload{
FieldName: name,
FileName: fileName,
ContentType: partContentType,
Data: data,
Width: width,
Height: height,
}
req.MaskUpload = &maskUpload
} }
if name == "image" || strings.HasPrefix(name, "image[") { if name == "image" || strings.HasPrefix(name, "image[") {
width, height := parseOpenAIImageDimensions(part.Header) width, height := parseOpenAIImageDimensions(part.Header)
...@@ -270,6 +339,38 @@ func parseOpenAIImagesMultipartRequest(body []byte, contentType string, req *Ope ...@@ -270,6 +339,38 @@ func parseOpenAIImagesMultipartRequest(body []byte, contentType string, req *Ope
return fmt.Errorf("n must be a positive integer") return fmt.Errorf("n must be a positive integer")
} }
req.N = n req.N = n
case "quality":
req.Quality = value
req.HasNativeOptions = true
case "background":
req.Background = value
req.HasNativeOptions = true
case "output_format":
req.OutputFormat = value
req.HasNativeOptions = true
case "moderation":
req.Moderation = value
req.HasNativeOptions = true
case "input_fidelity":
req.InputFidelity = value
req.HasNativeOptions = true
case "style":
req.Style = value
req.HasNativeOptions = true
case "output_compression":
n, err := strconv.Atoi(value)
if err != nil {
return fmt.Errorf("invalid output_compression field value")
}
req.OutputCompression = &n
req.HasNativeOptions = true
case "partial_images":
n, err := strconv.Atoi(value)
if err != nil {
return fmt.Errorf("invalid partial_images field value")
}
req.PartialImages = &n
req.HasNativeOptions = true
default: default:
if isOpenAINativeImageOption(name) && value != "" { if isOpenAINativeImageOption(name) && value != "" {
req.HasNativeOptions = true req.HasNativeOptions = true
...@@ -359,6 +460,8 @@ func hasOpenAINativeImageOptions(exists func(path string) bool) bool { ...@@ -359,6 +460,8 @@ func hasOpenAINativeImageOptions(exists func(path string) bool) bool {
"output_format", "output_format",
"output_compression", "output_compression",
"moderation", "moderation",
"input_fidelity",
"partial_images",
} { } {
if exists(path) { if exists(path) {
return true return true
...@@ -369,7 +472,7 @@ func hasOpenAINativeImageOptions(exists func(path string) bool) bool { ...@@ -369,7 +472,7 @@ func hasOpenAINativeImageOptions(exists func(path string) bool) bool {
func isOpenAINativeImageOption(name string) bool { func isOpenAINativeImageOption(name string) bool {
switch strings.TrimSpace(strings.ToLower(name)) { switch strings.TrimSpace(strings.ToLower(name)) {
case "background", "quality", "style", "output_format", "output_compression", "moderation": case "background", "quality", "style", "output_format", "output_compression", "moderation", "input_fidelity", "partial_images":
return true return true
default: default:
return false return false
...@@ -782,156 +885,6 @@ func extractOpenAIImageCountFromJSONBytes(body []byte) int { ...@@ -782,156 +885,6 @@ func extractOpenAIImageCountFromJSONBytes(body []byte) int {
return 0 return 0
} }
func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
ctx context.Context,
c *gin.Context,
account *Account,
parsed *OpenAIImagesRequest,
channelMappedModel string,
) (*OpenAIForwardResult, error) {
startTime := time.Now()
requestModel := strings.TrimSpace(parsed.Model)
if mapped := strings.TrimSpace(channelMappedModel); mapped != "" {
requestModel = mapped
}
if err := validateOpenAIImagesModel(requestModel); err != nil {
return nil, err
}
logger.LegacyPrintf(
"service.openai_gateway",
"[OpenAI] Images request routing request_model=%s endpoint=%s account_type=%s uploads=%d",
requestModel,
parsed.Endpoint,
account.Type,
len(parsed.Uploads),
)
token, _, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, err
}
client, err := newOpenAIBackendAPIClient(resolveOpenAIProxyURL(account))
if err != nil {
return nil, err
}
headers, err := s.buildOpenAIBackendAPIHeaders(account, token)
if err != nil {
return nil, err
}
if bootstrapErr := bootstrapOpenAIBackendAPI(ctx, client, headers); bootstrapErr != nil {
logger.LegacyPrintf("service.openai_gateway", "OpenAI image bootstrap failed: %v", bootstrapErr)
}
chatReqs, err := fetchOpenAIChatRequirements(ctx, client, headers)
if err != nil {
return nil, s.wrapOpenAIImageBackendError(ctx, c, account, err)
}
if chatReqs.Arkose.Required {
return nil, s.wrapOpenAIImageBackendError(
ctx,
c,
account,
newOpenAIImageSyntheticStatusError(
http.StatusForbidden,
"chat-requirements requires unsupported challenge (arkose)",
openAIChatGPTChatRequirementsURL,
),
)
}
parentMessageID := uuid.NewString()
proofToken := generateOpenAIProofToken(chatReqs.ProofOfWork.Required, chatReqs.ProofOfWork.Seed, chatReqs.ProofOfWork.Difficulty, headers.Get("User-Agent"))
_ = initializeOpenAIImageConversation(ctx, client, headers)
conduitToken, err := prepareOpenAIImageConversation(ctx, client, headers, parsed.Prompt, parentMessageID, chatReqs.Token, proofToken)
if err != nil {
return nil, s.wrapOpenAIImageBackendError(ctx, c, account, err)
}
uploads, err := uploadOpenAIImageFiles(ctx, client, headers, parsed.Uploads)
if err != nil {
return nil, s.wrapOpenAIImageBackendError(ctx, c, account, err)
}
convReq := buildOpenAIImageConversationRequest(parsed, parentMessageID, uploads)
if parsedContent, err := json.Marshal(convReq); err == nil {
setOpsUpstreamRequestBody(c, parsedContent)
}
convHeaders := cloneHTTPHeader(headers)
convHeaders.Set("Accept", "text/event-stream")
convHeaders.Set("Content-Type", "application/json")
convHeaders.Set("openai-sentinel-chat-requirements-token", chatReqs.Token)
if conduitToken != "" {
convHeaders.Set("x-conduit-token", conduitToken)
}
if proofToken != "" {
convHeaders.Set("openai-sentinel-proof-token", proofToken)
}
resp, err := client.R().
SetContext(ctx).
DisableAutoReadResponse().
SetHeaders(headerToMap(convHeaders)).
SetBodyJsonMarshal(convReq).
Post(openAIChatGPTConversationURL)
if err != nil {
return nil, fmt.Errorf("openai image conversation request failed: %w", err)
}
defer func() {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
}
}()
if resp.StatusCode >= 400 {
return nil, s.wrapOpenAIImageBackendError(ctx, c, account, handleOpenAIImageBackendError(resp))
}
conversationID, pointerInfos, usage, firstTokenMs, err := readOpenAIImageConversationStream(resp, startTime)
if err != nil {
return nil, err
}
pointerInfos = mergeOpenAIImagePointerInfos(pointerInfos, nil)
logger.LegacyPrintf(
"service.openai_gateway",
"[OpenAI] Image extraction stream conversation_id=%s total_assets=%d file_service_assets=%d direct_assets=%d",
conversationID,
len(pointerInfos),
countOpenAIFileServicePointerInfos(pointerInfos),
countOpenAIDirectImageAssets(pointerInfos),
)
lifecycleCtx, releaseLifecycleCtx := detachOpenAIImageLifecycleContext(ctx, openAIImageLifecycleTimeout)
defer releaseLifecycleCtx()
if conversationID != "" && !hasOpenAIFileServicePointerInfos(pointerInfos) {
polledPointers, pollErr := pollOpenAIImageConversation(lifecycleCtx, client, headers, conversationID)
if pollErr != nil {
return nil, s.wrapOpenAIImageBackendError(ctx, c, account, pollErr)
}
pointerInfos = mergeOpenAIImagePointerInfos(pointerInfos, polledPointers)
}
pointerInfos = preferOpenAIFileServicePointerInfos(pointerInfos)
if len(pointerInfos) == 0 {
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Image extraction yielded no assets conversation_id=%s", conversationID)
return nil, fmt.Errorf("openai image conversation returned no downloadable images")
}
responseBody, imageCount, err := buildOpenAIImageResponse(lifecycleCtx, client, headers, conversationID, pointerInfos)
if err != nil {
return nil, s.wrapOpenAIImageBackendError(ctx, c, account, err)
}
c.Data(http.StatusOK, "application/json; charset=utf-8", responseBody)
return &OpenAIForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Usage: usage,
Model: requestModel,
UpstreamModel: requestModel,
Stream: false,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
ImageCount: imageCount,
ImageSize: parsed.SizeTier,
}, nil
}
func resolveOpenAIProxyURL(account *Account) string { func resolveOpenAIProxyURL(account *Account) string {
if account != nil && account.ProxyID != nil && account.Proxy != nil { if account != nil && account.ProxyID != nil && account.Proxy != nil {
return account.Proxy.URL() return account.Proxy.URL()
......
This diff is collapsed.
package service package service
import ( import (
"bytes"
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"log/slog" "log/slog"
"net/http" "net/http"
"strconv" "strconv"
...@@ -23,6 +25,7 @@ type RateLimitService struct { ...@@ -23,6 +25,7 @@ type RateLimitService struct {
geminiQuotaService *GeminiQuotaService geminiQuotaService *GeminiQuotaService
tempUnschedCache TempUnschedCache tempUnschedCache TempUnschedCache
timeoutCounterCache TimeoutCounterCache timeoutCounterCache TimeoutCounterCache
openAI403CounterCache OpenAI403CounterCache
settingService *SettingService settingService *SettingService
tokenCacheInvalidator TokenCacheInvalidator tokenCacheInvalidator TokenCacheInvalidator
usageCacheMu sync.RWMutex usageCacheMu sync.RWMutex
...@@ -52,6 +55,12 @@ type geminiUsageTotalsBatchProvider interface { ...@@ -52,6 +55,12 @@ type geminiUsageTotalsBatchProvider interface {
const geminiPrecheckCacheTTL = time.Minute const geminiPrecheckCacheTTL = time.Minute
const (
openAI403CooldownMinutesDefault = 10
openAI403DisableThreshold = 3
openAI403CounterWindowMinutes = 180
)
// NewRateLimitService 创建RateLimitService实例 // NewRateLimitService 创建RateLimitService实例
func NewRateLimitService(accountRepo AccountRepository, usageRepo UsageLogRepository, cfg *config.Config, geminiQuotaService *GeminiQuotaService, tempUnschedCache TempUnschedCache) *RateLimitService { func NewRateLimitService(accountRepo AccountRepository, usageRepo UsageLogRepository, cfg *config.Config, geminiQuotaService *GeminiQuotaService, tempUnschedCache TempUnschedCache) *RateLimitService {
return &RateLimitService{ return &RateLimitService{
...@@ -69,6 +78,11 @@ func (s *RateLimitService) SetTimeoutCounterCache(cache TimeoutCounterCache) { ...@@ -69,6 +78,11 @@ func (s *RateLimitService) SetTimeoutCounterCache(cache TimeoutCounterCache) {
s.timeoutCounterCache = cache s.timeoutCounterCache = cache
} }
// SetOpenAI403CounterCache 设置 OpenAI 403 连续失败计数器(可选依赖)
func (s *RateLimitService) SetOpenAI403CounterCache(cache OpenAI403CounterCache) {
s.openAI403CounterCache = cache
}
// SetSettingService 设置系统设置服务(可选依赖) // SetSettingService 设置系统设置服务(可选依赖)
func (s *RateLimitService) SetSettingService(settingService *SettingService) { func (s *RateLimitService) SetSettingService(settingService *SettingService) {
s.settingService = settingService s.settingService = settingService
...@@ -655,6 +669,30 @@ func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account ...@@ -655,6 +669,30 @@ func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account
slog.Warn("account_disabled_auth_error", "account_id", account.ID, "error", errorMsg) slog.Warn("account_disabled_auth_error", "account_id", account.ID, "error", errorMsg)
} }
func buildForbiddenErrorMessage(prefix string, upstreamMsg string, responseBody []byte, fallback string) string {
prefix = strings.TrimSpace(prefix)
if prefix != "" && !strings.HasSuffix(prefix, " ") {
prefix += " "
}
if msg := strings.TrimSpace(upstreamMsg); msg != "" {
return prefix + msg
}
rawBody := bytes.TrimSpace(responseBody)
if len(rawBody) > 0 {
if json.Valid(rawBody) {
var compact bytes.Buffer
if err := json.Compact(&compact, rawBody); err == nil {
return prefix + truncateForLog(compact.Bytes(), 512)
}
}
return prefix + truncateForLog(rawBody, 512)
}
return prefix + fallback
}
// handle403 处理 403 Forbidden 错误 // handle403 处理 403 Forbidden 错误
// Antigravity 平台区分 validation/violation/generic 三种类型,均 SetError 永久禁用; // Antigravity 平台区分 validation/violation/generic 三种类型,均 SetError 永久禁用;
// 其他平台保持原有 SetError 行为。 // 其他平台保持原有 SetError 行为。
...@@ -662,15 +700,64 @@ func (s *RateLimitService) handle403(ctx context.Context, account *Account, upst ...@@ -662,15 +700,64 @@ func (s *RateLimitService) handle403(ctx context.Context, account *Account, upst
if account.Platform == PlatformAntigravity { if account.Platform == PlatformAntigravity {
return s.handleAntigravity403(ctx, account, upstreamMsg, responseBody) return s.handleAntigravity403(ctx, account, upstreamMsg, responseBody)
} }
// 非 Antigravity 平台:保持原有行为 if account.Platform == PlatformOpenAI {
msg := "Access forbidden (403): account may be suspended or lack permissions" return s.handleOpenAI403(ctx, account, upstreamMsg, responseBody)
if upstreamMsg != "" {
msg = "Access forbidden (403): " + upstreamMsg
} }
// 非 Antigravity 平台:保持原有行为
msg := buildForbiddenErrorMessage(
"Access forbidden (403):",
upstreamMsg,
responseBody,
"account may be suspended or lack permissions",
)
s.handleAuthError(ctx, account, msg) s.handleAuthError(ctx, account, msg)
return true return true
} }
func (s *RateLimitService) handleOpenAI403(ctx context.Context, account *Account, upstreamMsg string, responseBody []byte) (shouldDisable bool) {
msg := buildForbiddenErrorMessage(
"Access forbidden (403):",
upstreamMsg,
responseBody,
"account may be suspended or lack permissions",
)
if s.openAI403CounterCache == nil {
s.handleAuthError(ctx, account, msg)
return true
}
count, err := s.openAI403CounterCache.IncrementOpenAI403Count(ctx, account.ID, openAI403CounterWindowMinutes)
if err != nil {
slog.Warn("openai_403_increment_failed", "account_id", account.ID, "error", err)
s.handleAuthError(ctx, account, msg)
return true
}
if count >= openAI403DisableThreshold {
msg = fmt.Sprintf("%s | consecutive_403=%d/%d", msg, count, openAI403DisableThreshold)
s.handleAuthError(ctx, account, msg)
return true
}
until := time.Now().Add(time.Duration(openAI403CooldownMinutesDefault) * time.Minute)
reason := fmt.Sprintf("OpenAI 403 temporary cooldown (%d/%d): %s", count, openAI403DisableThreshold, msg)
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
slog.Warn("openai_403_set_temp_unschedulable_failed", "account_id", account.ID, "error", err)
s.handleAuthError(ctx, account, msg)
return true
}
slog.Warn(
"openai_403_temp_unschedulable",
"account_id", account.ID,
"until", until,
"count", count,
"threshold", openAI403DisableThreshold,
)
return true
}
// handleAntigravity403 处理 Antigravity 平台的 403 错误 // handleAntigravity403 处理 Antigravity 平台的 403 错误
// validation(需要验证)→ 永久 SetError(需人工去 Google 验证后恢复) // validation(需要验证)→ 永久 SetError(需人工去 Google 验证后恢复)
// violation(违规封号)→ 永久 SetError(需人工处理) // violation(违规封号)→ 永久 SetError(需人工处理)
...@@ -681,10 +768,12 @@ func (s *RateLimitService) handleAntigravity403(ctx context.Context, account *Ac ...@@ -681,10 +768,12 @@ func (s *RateLimitService) handleAntigravity403(ctx context.Context, account *Ac
switch fbType { switch fbType {
case forbiddenTypeValidation: case forbiddenTypeValidation:
// VALIDATION_REQUIRED: 永久禁用,需人工去 Google 验证后手动恢复 // VALIDATION_REQUIRED: 永久禁用,需人工去 Google 验证后手动恢复
msg := "Validation required (403): account needs Google verification" msg := buildForbiddenErrorMessage(
if upstreamMsg != "" { "Validation required (403):",
msg = "Validation required (403): " + upstreamMsg upstreamMsg,
} responseBody,
"account needs Google verification",
)
if validationURL := extractValidationURL(string(responseBody)); validationURL != "" { if validationURL := extractValidationURL(string(responseBody)); validationURL != "" {
msg += " | validation_url: " + validationURL msg += " | validation_url: " + validationURL
} }
...@@ -693,19 +782,23 @@ func (s *RateLimitService) handleAntigravity403(ctx context.Context, account *Ac ...@@ -693,19 +782,23 @@ func (s *RateLimitService) handleAntigravity403(ctx context.Context, account *Ac
case forbiddenTypeViolation: case forbiddenTypeViolation:
// 违规封号: 永久禁用,需人工处理 // 违规封号: 永久禁用,需人工处理
msg := "Account violation (403): terms of service violation" msg := buildForbiddenErrorMessage(
if upstreamMsg != "" { "Account violation (403):",
msg = "Account violation (403): " + upstreamMsg upstreamMsg,
} responseBody,
"terms of service violation",
)
s.handleAuthError(ctx, account, msg) s.handleAuthError(ctx, account, msg)
return true return true
default: default:
// 通用 403: 保持原有行为 // 通用 403: 保持原有行为
msg := "Access forbidden (403): account may be suspended or lack permissions" msg := buildForbiddenErrorMessage(
if upstreamMsg != "" { "Access forbidden (403):",
msg = "Access forbidden (403): " + upstreamMsg upstreamMsg,
} responseBody,
"account may be suspended or lack permissions",
)
s.handleAuthError(ctx, account, msg) s.handleAuthError(ctx, account, msg)
return true return true
} }
...@@ -1221,9 +1314,19 @@ func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64) ...@@ -1221,9 +1314,19 @@ func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64)
slog.Warn("temp_unsched_cache_delete_failed", "account_id", accountID, "error", err) slog.Warn("temp_unsched_cache_delete_failed", "account_id", accountID, "error", err)
} }
} }
s.ResetOpenAI403Counter(ctx, accountID)
return nil return nil
} }
func (s *RateLimitService) ResetOpenAI403Counter(ctx context.Context, accountID int64) {
if s == nil || s.openAI403CounterCache == nil || accountID <= 0 {
return
}
if err := s.openAI403CounterCache.ResetOpenAI403Count(ctx, accountID); err != nil {
slog.Warn("openai_403_reset_failed", "account_id", accountID, "error", err)
}
}
// RecoverAccountState 按需恢复账号的可恢复运行时状态。 // RecoverAccountState 按需恢复账号的可恢复运行时状态。
func (s *RateLimitService) RecoverAccountState(ctx context.Context, accountID int64, options AccountRecoveryOptions) (*SuccessfulTestRecoveryResult, error) { func (s *RateLimitService) RecoverAccountState(ctx context.Context, accountID int64, options AccountRecoveryOptions) (*SuccessfulTestRecoveryResult, error) {
account, err := s.accountRepo.GetByID(ctx, accountID) account, err := s.accountRepo.GetByID(ctx, accountID)
...@@ -1250,6 +1353,9 @@ func (s *RateLimitService) RecoverAccountState(ctx context.Context, accountID in ...@@ -1250,6 +1353,9 @@ func (s *RateLimitService) RecoverAccountState(ctx context.Context, accountID in
} }
result.ClearedRateLimit = true result.ClearedRateLimit = true
} }
if result.ClearedError || result.ClearedRateLimit {
s.ResetOpenAI403Counter(ctx, accountID)
}
return result, nil return result, nil
} }
......
...@@ -20,6 +20,7 @@ type rateLimitAccountRepoStub struct { ...@@ -20,6 +20,7 @@ type rateLimitAccountRepoStub struct {
updateCredentialsCalls int updateCredentialsCalls int
lastCredentials map[string]any lastCredentials map[string]any
lastErrorMsg string lastErrorMsg string
lastTempReason string
} }
func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error { func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
...@@ -30,6 +31,7 @@ func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, error ...@@ -30,6 +31,7 @@ func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, error
func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error { func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
r.tempCalls++ r.tempCalls++
r.lastTempReason = reason
return nil return nil
} }
...@@ -44,6 +46,29 @@ type tokenCacheInvalidatorRecorder struct { ...@@ -44,6 +46,29 @@ type tokenCacheInvalidatorRecorder struct {
err error err error
} }
type openAI403CounterCacheStub struct {
counts []int64
resetCalls []int64
err error
}
func (s *openAI403CounterCacheStub) IncrementOpenAI403Count(_ context.Context, _ int64, _ int) (int64, error) {
if s.err != nil {
return 0, s.err
}
if len(s.counts) == 0 {
return 1, nil
}
count := s.counts[0]
s.counts = s.counts[1:]
return count, nil
}
func (s *openAI403CounterCacheStub) ResetOpenAI403Count(_ context.Context, accountID int64) error {
s.resetCalls = append(s.resetCalls, accountID)
return nil
}
func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, account *Account) error { func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, account *Account) error {
r.accounts = append(r.accounts, account) r.accounts = append(r.accounts, account)
return r.err return r.err
......
//go:build unit
package service
import (
"context"
"net/http"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestRateLimitService_HandleUpstreamError_OpenAI403FirstHitTempUnschedulable(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
counter := &openAI403CounterCacheStub{counts: []int64{1}}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetOpenAI403CounterCache(counter)
account := &Account{
ID: 301,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
}
shouldDisable := service.HandleUpstreamError(
context.Background(),
account,
http.StatusForbidden,
http.Header{},
[]byte(`{"error":{"message":"temporary edge rejection"}}`),
)
require.True(t, shouldDisable)
require.Equal(t, 0, repo.setErrorCalls)
require.Equal(t, 1, repo.tempCalls)
require.Contains(t, repo.lastTempReason, "temporary edge rejection")
require.Contains(t, repo.lastTempReason, "(1/3)")
}
func TestRateLimitService_HandleUpstreamError_OpenAI403ThresholdDisables(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
counter := &openAI403CounterCacheStub{counts: []int64{3}}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetOpenAI403CounterCache(counter)
account := &Account{
ID: 302,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
}
shouldDisable := service.HandleUpstreamError(
context.Background(),
account,
http.StatusForbidden,
http.Header{},
[]byte(`{"error":{"message":"workspace forbidden by policy"}}`),
)
require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrorCalls)
require.Equal(t, 0, repo.tempCalls)
require.Contains(t, repo.lastErrorMsg, "workspace forbidden by policy")
require.Contains(t, repo.lastErrorMsg, "consecutive_403=3/3")
}
...@@ -7,6 +7,9 @@ import ( ...@@ -7,6 +7,9 @@ import (
"net/http" "net/http"
"testing" "testing"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
) )
func TestCalculateOpenAI429ResetTime_7dExhausted(t *testing.T) { func TestCalculateOpenAI429ResetTime_7dExhausted(t *testing.T) {
...@@ -259,6 +262,53 @@ func TestNormalizedCodexLimits_OnlyPrimaryData(t *testing.T) { ...@@ -259,6 +262,53 @@ func TestNormalizedCodexLimits_OnlyPrimaryData(t *testing.T) {
} }
} }
func TestRateLimitService_HandleUpstreamError_403PreservesOriginalUpstreamMessage(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 201,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
}
shouldDisable := service.HandleUpstreamError(
context.Background(),
account,
403,
http.Header{},
[]byte(`{"error":{"message":"workspace forbidden by policy","type":"invalid_request_error"}}`),
)
require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrorCalls)
require.Contains(t, repo.lastErrorMsg, "workspace forbidden by policy")
require.NotContains(t, repo.lastErrorMsg, "account may be suspended or lack permissions")
}
func TestRateLimitService_HandleUpstreamError_403FallsBackToRawBody(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 202,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
}
shouldDisable := service.HandleUpstreamError(
context.Background(),
account,
403,
http.Header{},
[]byte(`{"error":{"type":"access_denied","details":{"reason":"ip_blocked"}}}`),
)
require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrorCalls)
require.Contains(t, repo.lastErrorMsg, `"access_denied"`)
require.Contains(t, repo.lastErrorMsg, `"ip_blocked"`)
require.NotContains(t, repo.lastErrorMsg, "account may be suspended or lack permissions")
}
func TestNormalizedCodexLimits_OnlySecondaryData(t *testing.T) { func TestNormalizedCodexLimits_OnlySecondaryData(t *testing.T) {
// Test when only secondary has data, no window_minutes // Test when only secondary has data, no window_minutes
sUsed := 60.0 sUsed := 60.0
......
...@@ -210,11 +210,13 @@ func ProvideRateLimitService( ...@@ -210,11 +210,13 @@ func ProvideRateLimitService(
geminiQuotaService *GeminiQuotaService, geminiQuotaService *GeminiQuotaService,
tempUnschedCache TempUnschedCache, tempUnschedCache TempUnschedCache,
timeoutCounterCache TimeoutCounterCache, timeoutCounterCache TimeoutCounterCache,
openAI403CounterCache OpenAI403CounterCache,
settingService *SettingService, settingService *SettingService,
tokenCacheInvalidator TokenCacheInvalidator, tokenCacheInvalidator TokenCacheInvalidator,
) *RateLimitService { ) *RateLimitService {
svc := NewRateLimitService(accountRepo, usageRepo, cfg, geminiQuotaService, tempUnschedCache) svc := NewRateLimitService(accountRepo, usageRepo, cfg, geminiQuotaService, tempUnschedCache)
svc.SetTimeoutCounterCache(timeoutCounterCache) svc.SetTimeoutCounterCache(timeoutCounterCache)
svc.SetOpenAI403CounterCache(openAI403CounterCache)
svc.SetSettingService(settingService) svc.SetSettingService(settingService)
svc.SetTokenCacheInvalidator(tokenCacheInvalidator) svc.SetTokenCacheInvalidator(tokenCacheInvalidator)
return svc return svc
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment