Unverified Commit 6bccb8a8 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge branch 'main' into feature/antigravity-user-agent-configurable

parents 1fc6ef3d 3de1e0e4
......@@ -696,6 +696,51 @@ func (a *Account) IsMixedSchedulingEnabled() bool {
return false
}
// IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用“自动透传(仅替换认证)”。
//
// 新字段:accounts.extra.openai_passthrough。
// 兼容字段:accounts.extra.openai_oauth_passthrough(历史 OAuth 开关)。
// 字段缺失或类型不正确时,按 false(关闭)处理。
func (a *Account) IsOpenAIPassthroughEnabled() bool {
if a == nil || !a.IsOpenAI() || a.Extra == nil {
return false
}
if enabled, ok := a.Extra["openai_passthrough"].(bool); ok {
return enabled
}
if enabled, ok := a.Extra["openai_oauth_passthrough"].(bool); ok {
return enabled
}
return false
}
// IsOpenAIOAuthPassthroughEnabled 兼容旧接口,等价于 OAuth 账号的 IsOpenAIPassthroughEnabled。
func (a *Account) IsOpenAIOAuthPassthroughEnabled() bool {
return a != nil && a.IsOpenAIOAuth() && a.IsOpenAIPassthroughEnabled()
}
// IsAnthropicAPIKeyPassthroughEnabled 返回 Anthropic API Key 账号是否启用“自动透传(仅替换认证)”。
// 字段:accounts.extra.anthropic_passthrough。
// 字段缺失或类型不正确时,按 false(关闭)处理。
func (a *Account) IsAnthropicAPIKeyPassthroughEnabled() bool {
if a == nil || a.Platform != PlatformAnthropic || a.Type != AccountTypeAPIKey || a.Extra == nil {
return false
}
enabled, ok := a.Extra["anthropic_passthrough"].(bool)
return ok && enabled
}
// IsCodexCLIOnlyEnabled 返回 OpenAI OAuth 账号是否启用“仅允许 Codex 官方客户端”。
// 字段:accounts.extra.codex_cli_only。
// 字段缺失或类型不正确时,按 false(关闭)处理。
func (a *Account) IsCodexCLIOnlyEnabled() bool {
if a == nil || !a.IsOpenAIOAuth() || a.Extra == nil {
return false
}
enabled, ok := a.Extra["codex_cli_only"].(bool)
return ok && enabled
}
// WindowCostSchedulability 窗口费用调度状态
type WindowCostSchedulability int
......
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestAccount_IsAnthropicAPIKeyPassthroughEnabled(t *testing.T) {
t.Run("Anthropic API Key 开启", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"anthropic_passthrough": true,
},
}
require.True(t, account.IsAnthropicAPIKeyPassthroughEnabled())
})
t.Run("Anthropic API Key 关闭", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"anthropic_passthrough": false,
},
}
require.False(t, account.IsAnthropicAPIKeyPassthroughEnabled())
})
t.Run("字段类型非法默认关闭", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"anthropic_passthrough": "true",
},
}
require.False(t, account.IsAnthropicAPIKeyPassthroughEnabled())
})
t.Run("非 Anthropic API Key 账号始终关闭", func(t *testing.T) {
oauth := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Extra: map[string]any{
"anthropic_passthrough": true,
},
}
require.False(t, oauth.IsAnthropicAPIKeyPassthroughEnabled())
openai := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"anthropic_passthrough": true,
},
}
require.False(t, openai.IsAnthropicAPIKeyPassthroughEnabled())
})
}
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestAccount_IsOpenAIPassthroughEnabled(t *testing.T) {
t.Run("新字段开启", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"openai_passthrough": true,
},
}
require.True(t, account.IsOpenAIPassthroughEnabled())
})
t.Run("兼容旧字段", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_passthrough": true,
},
}
require.True(t, account.IsOpenAIPassthroughEnabled())
})
t.Run("非OpenAI账号始终关闭", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_passthrough": true,
},
}
require.False(t, account.IsOpenAIPassthroughEnabled())
})
t.Run("空额外配置默认关闭", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
}
require.False(t, account.IsOpenAIPassthroughEnabled())
})
}
func TestAccount_IsOpenAIOAuthPassthroughEnabled(t *testing.T) {
t.Run("仅OAuth类型允许返回开启", func(t *testing.T) {
oauthAccount := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_passthrough": true,
},
}
require.True(t, oauthAccount.IsOpenAIOAuthPassthroughEnabled())
apiKeyAccount := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"openai_passthrough": true,
},
}
require.False(t, apiKeyAccount.IsOpenAIOAuthPassthroughEnabled())
})
}
func TestAccount_IsCodexCLIOnlyEnabled(t *testing.T) {
t.Run("OpenAI OAuth 开启", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"codex_cli_only": true,
},
}
require.True(t, account.IsCodexCLIOnlyEnabled())
})
t.Run("OpenAI OAuth 关闭", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"codex_cli_only": false,
},
}
require.False(t, account.IsCodexCLIOnlyEnabled())
})
t.Run("字段缺失默认关闭", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{},
}
require.False(t, account.IsCodexCLIOnlyEnabled())
})
t.Run("类型非法默认关闭", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"codex_cli_only": "true",
},
}
require.False(t, account.IsCodexCLIOnlyEnabled())
})
t.Run("非 OAuth 账号始终关闭", func(t *testing.T) {
apiKeyAccount := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"codex_cli_only": true,
},
}
require.False(t, apiKeyAccount.IsCodexCLIOnlyEnabled())
otherPlatform := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Extra: map[string]any{
"codex_cli_only": true,
},
}
require.False(t, otherPlatform.IsCodexCLIOnlyEnabled())
})
}
......@@ -25,6 +25,9 @@ type AccountRepository interface {
// GetByCRSAccountID finds an account previously synced from CRS.
// Returns (nil, nil) if not found.
GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error)
// FindByExtraField 根据 extra 字段中的键值对查找账号(限定 platform='sora')
// 用于查找通过 linked_openai_account_id 关联的 Sora 账号
FindByExtraField(ctx context.Context, key string, value any) ([]Account, error)
// ListCRSAccountIDs returns a map of crs_account_id -> local account ID
// for all accounts that have been synced from CRS.
ListCRSAccountIDs(ctx context.Context) (map[string]int64, error)
......
......@@ -54,6 +54,10 @@ func (s *accountRepoStub) GetByCRSAccountID(ctx context.Context, crsAccountID st
panic("unexpected GetByCRSAccountID call")
}
func (s *accountRepoStub) FindByExtraField(ctx context.Context, key string, value any) ([]Account, error) {
panic("unexpected FindByExtraField call")
}
func (s *accountRepoStub) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) {
panic("unexpected ListCRSAccountIDs call")
}
......
......@@ -12,13 +12,17 @@ import (
"io"
"log"
"net/http"
"net/url"
"regexp"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
......@@ -31,6 +35,11 @@ var sseDataPrefix = regexp.MustCompile(`^data:\s*`)
const (
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses"
soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接
soraBillingAPIURL = "https://sora.chatgpt.com/backend/billing/subscriptions"
soraInviteMineURL = "https://sora.chatgpt.com/backend/project_y/invite/mine"
soraBootstrapURL = "https://sora.chatgpt.com/backend/m/bootstrap"
soraRemainingURL = "https://sora.chatgpt.com/backend/nf/check"
)
// TestEvent represents a SSE event for account testing
......@@ -38,6 +47,9 @@ type TestEvent struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
Model string `json:"model,omitempty"`
Status string `json:"status,omitempty"`
Code string `json:"code,omitempty"`
Data any `json:"data,omitempty"`
Success bool `json:"success,omitempty"`
Error string `json:"error,omitempty"`
}
......@@ -49,8 +61,13 @@ type AccountTestService struct {
antigravityGatewayService *AntigravityGatewayService
httpUpstream HTTPUpstream
cfg *config.Config
soraTestGuardMu sync.Mutex
soraTestLastRun map[int64]time.Time
soraTestCooldown time.Duration
}
const defaultSoraTestCooldown = 10 * time.Second
// NewAccountTestService creates a new AccountTestService
func NewAccountTestService(
accountRepo AccountRepository,
......@@ -65,6 +82,8 @@ func NewAccountTestService(
antigravityGatewayService: antigravityGatewayService,
httpUpstream: httpUpstream,
cfg: cfg,
soraTestLastRun: make(map[int64]time.Time),
soraTestCooldown: defaultSoraTestCooldown,
}
}
......@@ -163,6 +182,10 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
return s.testAntigravityAccountConnection(c, account, modelID)
}
if account.Platform == PlatformSora {
return s.testSoraAccountConnection(c, account)
}
return s.testClaudeAccountConnection(c, account, modelID)
}
......@@ -462,6 +485,604 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
return s.processGeminiStream(c, resp.Body)
}
type soraProbeStep struct {
Name string `json:"name"`
Status string `json:"status"`
HTTPStatus int `json:"http_status,omitempty"`
ErrorCode string `json:"error_code,omitempty"`
Message string `json:"message,omitempty"`
}
type soraProbeSummary struct {
Status string `json:"status"`
Steps []soraProbeStep `json:"steps"`
}
type soraProbeRecorder struct {
steps []soraProbeStep
}
func (r *soraProbeRecorder) addStep(name, status string, httpStatus int, errorCode, message string) {
r.steps = append(r.steps, soraProbeStep{
Name: name,
Status: status,
HTTPStatus: httpStatus,
ErrorCode: strings.TrimSpace(errorCode),
Message: strings.TrimSpace(message),
})
}
func (r *soraProbeRecorder) finalize() soraProbeSummary {
meSuccess := false
partial := false
for _, step := range r.steps {
if step.Name == "me" {
meSuccess = strings.EqualFold(step.Status, "success")
continue
}
if strings.EqualFold(step.Status, "failed") {
partial = true
}
}
status := "success"
if !meSuccess {
status = "failed"
} else if partial {
status = "partial_success"
}
return soraProbeSummary{
Status: status,
Steps: append([]soraProbeStep(nil), r.steps...),
}
}
func (s *AccountTestService) emitSoraProbeSummary(c *gin.Context, rec *soraProbeRecorder) {
if rec == nil {
return
}
summary := rec.finalize()
code := ""
for _, step := range summary.Steps {
if strings.EqualFold(step.Status, "failed") && strings.TrimSpace(step.ErrorCode) != "" {
code = step.ErrorCode
break
}
}
s.sendEvent(c, TestEvent{
Type: "sora_test_result",
Status: summary.Status,
Code: code,
Data: summary,
})
}
func (s *AccountTestService) acquireSoraTestPermit(accountID int64) (time.Duration, bool) {
if accountID <= 0 {
return 0, true
}
s.soraTestGuardMu.Lock()
defer s.soraTestGuardMu.Unlock()
if s.soraTestLastRun == nil {
s.soraTestLastRun = make(map[int64]time.Time)
}
cooldown := s.soraTestCooldown
if cooldown <= 0 {
cooldown = defaultSoraTestCooldown
}
now := time.Now()
if lastRun, ok := s.soraTestLastRun[accountID]; ok {
elapsed := now.Sub(lastRun)
if elapsed < cooldown {
return cooldown - elapsed, false
}
}
s.soraTestLastRun[accountID] = now
return 0, true
}
func ceilSeconds(d time.Duration) int {
if d <= 0 {
return 1
}
sec := int(d / time.Second)
if d%time.Second != 0 {
sec++
}
if sec < 1 {
sec = 1
}
return sec
}
// testSoraAccountConnection 测试 Sora 账号的连接
// 调用 /backend/me 接口验证 access_token 有效性(不需要 Sentinel Token)
func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *Account) error {
ctx := c.Request.Context()
recorder := &soraProbeRecorder{}
authToken := account.GetCredential("access_token")
if authToken == "" {
recorder.addStep("me", "failed", http.StatusUnauthorized, "missing_access_token", "No access token available")
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, "No access token available")
}
// Set SSE headers
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Flush()
if wait, ok := s.acquireSoraTestPermit(account.ID); !ok {
msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait))
recorder.addStep("rate_limit", "failed", http.StatusTooManyRequests, "test_rate_limited", msg)
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, msg)
}
// Send test_start event
s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora"})
req, err := http.NewRequestWithContext(ctx, "GET", soraMeAPIURL, nil)
if err != nil {
recorder.addStep("me", "failed", 0, "request_build_failed", err.Error())
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, "Failed to create request")
}
// 使用 Sora 客户端标准请求头
req.Header.Set("Authorization", "Bearer "+authToken)
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
req.Header.Set("Accept", "application/json")
req.Header.Set("Accept-Language", "en-US,en;q=0.9")
req.Header.Set("Origin", "https://sora.chatgpt.com")
req.Header.Set("Referer", "https://sora.chatgpt.com/")
// Get proxy URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
enableSoraTLSFingerprint := s.shouldEnableSoraTLSFingerprint()
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint)
if err != nil {
recorder.addStep("me", "failed", 0, "network_error", err.Error())
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
if isCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) {
recorder.addStep("me", "failed", resp.StatusCode, "cf_challenge", "Cloudflare challenge detected")
s.emitSoraProbeSummary(c, recorder)
s.logSoraCloudflareChallenge(account, proxyURL, soraMeAPIURL, resp.Header, body)
return s.sendErrorAndEnd(c, formatCloudflareChallengeMessage(fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", resp.StatusCode), resp.Header, body))
}
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(body)
switch {
case resp.StatusCode == http.StatusUnauthorized && strings.EqualFold(upstreamCode, "token_invalidated"):
recorder.addStep("me", "failed", resp.StatusCode, "token_invalidated", "Sora token invalidated")
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, "Sora token 已失效(token_invalidated),请重新授权账号")
case strings.EqualFold(upstreamCode, "unsupported_country_code"):
recorder.addStep("me", "failed", resp.StatusCode, "unsupported_country_code", "Sora is unavailable in current egress region")
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, "Sora 在当前网络出口地区不可用(unsupported_country_code),请切换到支持地区后重试")
case strings.TrimSpace(upstreamMessage) != "":
recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, upstreamMessage)
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, upstreamMessage))
default:
recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, "Sora me endpoint failed")
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, truncateSoraErrorBody(body, 512)))
}
}
recorder.addStep("me", "success", resp.StatusCode, "", "me endpoint ok")
// 解析 /me 响应,提取用户信息
var meResp map[string]any
if err := json.Unmarshal(body, &meResp); err != nil {
// 能收到 200 就说明 token 有效
s.sendEvent(c, TestEvent{Type: "content", Text: "Sora connection OK (token valid)"})
} else {
// 尝试提取用户名或邮箱信息
info := "Sora connection OK"
if name, ok := meResp["name"].(string); ok && name != "" {
info = fmt.Sprintf("Sora connection OK - User: %s", name)
} else if email, ok := meResp["email"].(string); ok && email != "" {
info = fmt.Sprintf("Sora connection OK - Email: %s", email)
}
s.sendEvent(c, TestEvent{Type: "content", Text: info})
}
// 追加轻量能力检查:订阅信息查询(失败仅告警,不中断连接测试)
subReq, err := http.NewRequestWithContext(ctx, "GET", soraBillingAPIURL, nil)
if err == nil {
subReq.Header.Set("Authorization", "Bearer "+authToken)
subReq.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
subReq.Header.Set("Accept", "application/json")
subReq.Header.Set("Accept-Language", "en-US,en;q=0.9")
subReq.Header.Set("Origin", "https://sora.chatgpt.com")
subReq.Header.Set("Referer", "https://sora.chatgpt.com/")
subResp, subErr := s.httpUpstream.DoWithTLS(subReq, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint)
if subErr != nil {
recorder.addStep("subscription", "failed", 0, "network_error", subErr.Error())
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check skipped: %s", subErr.Error())})
} else {
subBody, _ := io.ReadAll(subResp.Body)
_ = subResp.Body.Close()
if subResp.StatusCode == http.StatusOK {
recorder.addStep("subscription", "success", subResp.StatusCode, "", "subscription endpoint ok")
if summary := parseSoraSubscriptionSummary(subBody); summary != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: summary})
} else {
s.sendEvent(c, TestEvent{Type: "content", Text: "Subscription check OK"})
}
} else {
if isCloudflareChallengeResponse(subResp.StatusCode, subResp.Header, subBody) {
recorder.addStep("subscription", "failed", subResp.StatusCode, "cf_challenge", "Cloudflare challenge detected")
s.logSoraCloudflareChallenge(account, proxyURL, soraBillingAPIURL, subResp.Header, subBody)
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Subscription check blocked by Cloudflare challenge (HTTP %d)", subResp.StatusCode), subResp.Header, subBody)})
} else {
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(subBody)
recorder.addStep("subscription", "failed", subResp.StatusCode, upstreamCode, upstreamMessage)
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check returned %d", subResp.StatusCode)})
}
}
}
}
// 追加 Sora2 能力探测(对齐 sora2api 的测试思路):邀请码 + 剩余额度。
s.testSora2Capabilities(c, ctx, account, authToken, proxyURL, enableSoraTLSFingerprint, recorder)
s.emitSoraProbeSummary(c, recorder)
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
func (s *AccountTestService) testSora2Capabilities(
c *gin.Context,
ctx context.Context,
account *Account,
authToken string,
proxyURL string,
enableTLSFingerprint bool,
recorder *soraProbeRecorder,
) {
inviteStatus, inviteHeader, inviteBody, err := s.fetchSoraTestEndpoint(
ctx,
account,
authToken,
soraInviteMineURL,
proxyURL,
enableTLSFingerprint,
)
if err != nil {
if recorder != nil {
recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error())
}
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check skipped: %s", err.Error())})
return
}
if inviteStatus == http.StatusUnauthorized {
bootstrapStatus, _, _, bootstrapErr := s.fetchSoraTestEndpoint(
ctx,
account,
authToken,
soraBootstrapURL,
proxyURL,
enableTLSFingerprint,
)
if bootstrapErr == nil && bootstrapStatus == http.StatusOK {
if recorder != nil {
recorder.addStep("sora2_bootstrap", "success", bootstrapStatus, "", "bootstrap endpoint ok")
}
s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 bootstrap OK, retry invite check"})
inviteStatus, inviteHeader, inviteBody, err = s.fetchSoraTestEndpoint(
ctx,
account,
authToken,
soraInviteMineURL,
proxyURL,
enableTLSFingerprint,
)
if err != nil {
if recorder != nil {
recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error())
}
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite retry failed: %s", err.Error())})
return
}
} else if recorder != nil {
code := ""
msg := ""
if bootstrapErr != nil {
code = "network_error"
msg = bootstrapErr.Error()
}
recorder.addStep("sora2_bootstrap", "failed", bootstrapStatus, code, msg)
}
}
if inviteStatus != http.StatusOK {
if isCloudflareChallengeResponse(inviteStatus, inviteHeader, inviteBody) {
if recorder != nil {
recorder.addStep("sora2_invite", "failed", inviteStatus, "cf_challenge", "Cloudflare challenge detected")
}
s.logSoraCloudflareChallenge(account, proxyURL, soraInviteMineURL, inviteHeader, inviteBody)
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 invite check blocked by Cloudflare challenge (HTTP %d)", inviteStatus), inviteHeader, inviteBody)})
return
}
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(inviteBody)
if recorder != nil {
recorder.addStep("sora2_invite", "failed", inviteStatus, upstreamCode, upstreamMessage)
}
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check returned %d", inviteStatus)})
return
}
if recorder != nil {
recorder.addStep("sora2_invite", "success", inviteStatus, "", "invite endpoint ok")
}
if summary := parseSoraInviteSummary(inviteBody); summary != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: summary})
} else {
s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 invite check OK"})
}
remainingStatus, remainingHeader, remainingBody, remainingErr := s.fetchSoraTestEndpoint(
ctx,
account,
authToken,
soraRemainingURL,
proxyURL,
enableTLSFingerprint,
)
if remainingErr != nil {
if recorder != nil {
recorder.addStep("sora2_remaining", "failed", 0, "network_error", remainingErr.Error())
}
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check skipped: %s", remainingErr.Error())})
return
}
if remainingStatus != http.StatusOK {
if isCloudflareChallengeResponse(remainingStatus, remainingHeader, remainingBody) {
if recorder != nil {
recorder.addStep("sora2_remaining", "failed", remainingStatus, "cf_challenge", "Cloudflare challenge detected")
}
s.logSoraCloudflareChallenge(account, proxyURL, soraRemainingURL, remainingHeader, remainingBody)
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 remaining check blocked by Cloudflare challenge (HTTP %d)", remainingStatus), remainingHeader, remainingBody)})
return
}
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(remainingBody)
if recorder != nil {
recorder.addStep("sora2_remaining", "failed", remainingStatus, upstreamCode, upstreamMessage)
}
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check returned %d", remainingStatus)})
return
}
if recorder != nil {
recorder.addStep("sora2_remaining", "success", remainingStatus, "", "remaining endpoint ok")
}
if summary := parseSoraRemainingSummary(remainingBody); summary != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: summary})
} else {
s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 remaining check OK"})
}
}
func (s *AccountTestService) fetchSoraTestEndpoint(
ctx context.Context,
account *Account,
authToken string,
url string,
proxyURL string,
enableTLSFingerprint bool,
) (int, http.Header, []byte, error) {
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return 0, nil, nil, err
}
req.Header.Set("Authorization", "Bearer "+authToken)
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
req.Header.Set("Accept", "application/json")
req.Header.Set("Accept-Language", "en-US,en;q=0.9")
req.Header.Set("Origin", "https://sora.chatgpt.com")
req.Header.Set("Referer", "https://sora.chatgpt.com/")
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, enableTLSFingerprint)
if err != nil {
return 0, nil, nil, err
}
defer func() { _ = resp.Body.Close() }()
body, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return resp.StatusCode, resp.Header, nil, readErr
}
return resp.StatusCode, resp.Header, body, nil
}
func parseSoraSubscriptionSummary(body []byte) string {
var subResp struct {
Data []struct {
Plan struct {
ID string `json:"id"`
Title string `json:"title"`
} `json:"plan"`
EndTS string `json:"end_ts"`
} `json:"data"`
}
if err := json.Unmarshal(body, &subResp); err != nil {
return ""
}
if len(subResp.Data) == 0 {
return ""
}
first := subResp.Data[0]
parts := make([]string, 0, 3)
if first.Plan.Title != "" {
parts = append(parts, first.Plan.Title)
}
if first.Plan.ID != "" {
parts = append(parts, first.Plan.ID)
}
if first.EndTS != "" {
parts = append(parts, "end="+first.EndTS)
}
if len(parts) == 0 {
return ""
}
return "Subscription: " + strings.Join(parts, " | ")
}
func parseSoraInviteSummary(body []byte) string {
var inviteResp struct {
InviteCode string `json:"invite_code"`
RedeemedCount int64 `json:"redeemed_count"`
TotalCount int64 `json:"total_count"`
}
if err := json.Unmarshal(body, &inviteResp); err != nil {
return ""
}
parts := []string{"Sora2: supported"}
if inviteResp.InviteCode != "" {
parts = append(parts, "invite="+inviteResp.InviteCode)
}
if inviteResp.TotalCount > 0 {
parts = append(parts, fmt.Sprintf("used=%d/%d", inviteResp.RedeemedCount, inviteResp.TotalCount))
}
return strings.Join(parts, " | ")
}
func parseSoraRemainingSummary(body []byte) string {
var remainingResp struct {
RateLimitAndCreditBalance struct {
EstimatedNumVideosRemaining int64 `json:"estimated_num_videos_remaining"`
RateLimitReached bool `json:"rate_limit_reached"`
AccessResetsInSeconds int64 `json:"access_resets_in_seconds"`
} `json:"rate_limit_and_credit_balance"`
}
if err := json.Unmarshal(body, &remainingResp); err != nil {
return ""
}
info := remainingResp.RateLimitAndCreditBalance
parts := []string{fmt.Sprintf("Sora2 remaining: %d", info.EstimatedNumVideosRemaining)}
if info.RateLimitReached {
parts = append(parts, "rate_limited=true")
}
if info.AccessResetsInSeconds > 0 {
parts = append(parts, fmt.Sprintf("reset_in=%ds", info.AccessResetsInSeconds))
}
return strings.Join(parts, " | ")
}
func (s *AccountTestService) shouldEnableSoraTLSFingerprint() bool {
if s == nil || s.cfg == nil {
return true
}
return !s.cfg.Sora.Client.DisableTLSFingerprint
}
func isCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body)
}
func formatCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
return soraerror.FormatCloudflareChallengeMessage(base, headers, body)
}
func extractCloudflareRayID(headers http.Header, body []byte) string {
return soraerror.ExtractCloudflareRayID(headers, body)
}
func extractSoraEgressIPHint(headers http.Header) string {
if headers == nil {
return "unknown"
}
candidates := []string{
"x-openai-public-ip",
"x-envoy-external-address",
"cf-connecting-ip",
"x-forwarded-for",
}
for _, key := range candidates {
if value := strings.TrimSpace(headers.Get(key)); value != "" {
return value
}
}
return "unknown"
}
func sanitizeProxyURLForLog(raw string) string {
raw = strings.TrimSpace(raw)
if raw == "" {
return ""
}
u, err := url.Parse(raw)
if err != nil {
return "<invalid_proxy_url>"
}
if u.User != nil {
u.User = nil
}
return u.String()
}
func endpointPathForLog(endpoint string) string {
parsed, err := url.Parse(strings.TrimSpace(endpoint))
if err != nil || parsed.Path == "" {
return endpoint
}
return parsed.Path
}
func (s *AccountTestService) logSoraCloudflareChallenge(account *Account, proxyURL, endpoint string, headers http.Header, body []byte) {
accountID := int64(0)
platform := ""
proxyID := "none"
if account != nil {
accountID = account.ID
platform = account.Platform
if account.ProxyID != nil {
proxyID = fmt.Sprintf("%d", *account.ProxyID)
}
}
cfRay := extractCloudflareRayID(headers, body)
if cfRay == "" {
cfRay = "unknown"
}
log.Printf(
"[SoraCFChallenge] account_id=%d platform=%s endpoint=%s path=%s proxy_id=%s proxy_url=%s cf_ray=%s egress_ip_hint=%s",
accountID,
platform,
endpoint,
endpointPathForLog(endpoint),
proxyID,
sanitizeProxyURLForLog(proxyURL),
cfRay,
extractSoraEgressIPHint(headers),
)
}
func truncateSoraErrorBody(body []byte, max int) string {
return soraerror.TruncateBody(body, max)
}
// testAntigravityAccountConnection tests an Antigravity account's connection
// 支持 Claude 和 Gemini 两种协议,使用非流式请求
func (s *AccountTestService) testAntigravityAccountConnection(c *gin.Context, account *Account, modelID string) error {
......
package service
import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type queuedHTTPUpstream struct {
responses []*http.Response
requests []*http.Request
tlsFlags []bool
}
func (u *queuedHTTPUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
return nil, fmt.Errorf("unexpected Do call")
}
func (u *queuedHTTPUpstream) DoWithTLS(req *http.Request, _ string, _ int64, _ int, enableTLSFingerprint bool) (*http.Response, error) {
u.requests = append(u.requests, req)
u.tlsFlags = append(u.tlsFlags, enableTLSFingerprint)
if len(u.responses) == 0 {
return nil, fmt.Errorf("no mocked response")
}
resp := u.responses[0]
u.responses = u.responses[1:]
return resp, nil
}
func newJSONResponse(status int, body string) *http.Response {
return &http.Response{
StatusCode: status,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(body)),
}
}
func newJSONResponseWithHeader(status int, body, key, value string) *http.Response {
resp := newJSONResponse(status, body)
resp.Header.Set(key, value)
return resp
}
func newSoraTestContext() (*gin.Context, *httptest.ResponseRecorder) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil)
return c, rec
}
func TestAccountTestService_testSoraAccountConnection_WithSubscription(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`),
newJSONResponse(http.StatusOK, `{"data":[{"plan":{"id":"chatgpt_plus","title":"ChatGPT Plus"},"end_ts":"2026-12-31T00:00:00Z"}]}`),
newJSONResponse(http.StatusOK, `{"invite_code":"inv_abc","redeemed_count":3,"total_count":50}`),
newJSONResponse(http.StatusOK, `{"rate_limit_and_credit_balance":{"estimated_num_videos_remaining":27,"rate_limit_reached":false,"access_resets_in_seconds":46833}}`),
},
}
svc := &AccountTestService{
httpUpstream: upstream,
cfg: &config.Config{
Gateway: config.GatewayConfig{
TLSFingerprint: config.TLSFingerprintConfig{
Enabled: true,
},
},
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
DisableTLSFingerprint: false,
},
},
},
}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.NoError(t, err)
require.Len(t, upstream.requests, 4)
require.Equal(t, soraMeAPIURL, upstream.requests[0].URL.String())
require.Equal(t, soraBillingAPIURL, upstream.requests[1].URL.String())
require.Equal(t, soraInviteMineURL, upstream.requests[2].URL.String())
require.Equal(t, soraRemainingURL, upstream.requests[3].URL.String())
require.Equal(t, "Bearer test_token", upstream.requests[0].Header.Get("Authorization"))
require.Equal(t, "Bearer test_token", upstream.requests[1].Header.Get("Authorization"))
require.Equal(t, []bool{true, true, true, true}, upstream.tlsFlags)
body := rec.Body.String()
require.Contains(t, body, `"type":"test_start"`)
require.Contains(t, body, "Sora connection OK - Email: demo@example.com")
require.Contains(t, body, "Subscription: ChatGPT Plus | chatgpt_plus | end=2026-12-31T00:00:00Z")
require.Contains(t, body, "Sora2: supported | invite=inv_abc | used=3/50")
require.Contains(t, body, "Sora2 remaining: 27 | reset_in=46833s")
require.Contains(t, body, `"type":"sora_test_result"`)
require.Contains(t, body, `"status":"success"`)
require.Contains(t, body, `"type":"test_complete","success":true`)
}
func TestAccountTestService_testSoraAccountConnection_SubscriptionFailedStillSuccess(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusOK, `{"name":"demo-user"}`),
newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`),
newJSONResponse(http.StatusUnauthorized, `{"error":{"message":"Unauthorized"}}`),
newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.NoError(t, err)
require.Len(t, upstream.requests, 4)
body := rec.Body.String()
require.Contains(t, body, "Sora connection OK - User: demo-user")
require.Contains(t, body, "Subscription check returned 403")
require.Contains(t, body, "Sora2 invite check returned 401")
require.Contains(t, body, `"type":"sora_test_result"`)
require.Contains(t, body, `"status":"partial_success"`)
require.Contains(t, body, `"type":"test_complete","success":true`)
}
func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponseWithHeader(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`, "cf-ray", "9cff2d62d83bb98d"),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.Error(t, err)
require.Contains(t, err.Error(), "Cloudflare challenge")
require.Contains(t, err.Error(), "cf-ray: 9cff2d62d83bb98d")
body := rec.Body.String()
require.Contains(t, body, `"type":"error"`)
require.Contains(t, body, "Cloudflare challenge")
require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
}
func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge429WithHeader(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponseWithHeader(http.StatusTooManyRequests, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body></body></html>`, "cf-mitigated", "challenge"),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.Error(t, err)
require.Contains(t, err.Error(), "Cloudflare challenge")
require.Contains(t, err.Error(), "HTTP 429")
body := rec.Body.String()
require.Contains(t, body, "Cloudflare challenge")
}
func TestAccountTestService_testSoraAccountConnection_TokenInvalidated(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusUnauthorized, `{"error":{"code":"token_invalidated","message":"Token invalid"}}`),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.Error(t, err)
require.Contains(t, err.Error(), "token_invalidated")
body := rec.Body.String()
require.Contains(t, body, `"type":"sora_test_result"`)
require.Contains(t, body, `"status":"failed"`)
require.Contains(t, body, "token_invalidated")
require.NotContains(t, body, `"type":"test_complete","success":true`)
}
func TestAccountTestService_testSoraAccountConnection_RateLimited(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`),
},
}
svc := &AccountTestService{
httpUpstream: upstream,
soraTestCooldown: time.Hour,
}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c1, _ := newSoraTestContext()
err := svc.testSoraAccountConnection(c1, account)
require.NoError(t, err)
c2, rec2 := newSoraTestContext()
err = svc.testSoraAccountConnection(c2, account)
require.Error(t, err)
require.Contains(t, err.Error(), "测试过于频繁")
body := rec2.Body.String()
require.Contains(t, body, `"type":"sora_test_result"`)
require.Contains(t, body, `"code":"test_rate_limited"`)
require.Contains(t, body, `"status":"failed"`)
require.NotContains(t, body, `"type":"test_complete","success":true`)
}
func TestAccountTestService_testSoraAccountConnection_SubscriptionCloudflareChallengeWithRay(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusOK, `{"name":"demo-user"}`),
newJSONResponse(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`),
newJSONResponse(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.NoError(t, err)
body := rec.Body.String()
require.Contains(t, body, "Subscription check blocked by Cloudflare challenge (HTTP 403)")
require.Contains(t, body, "Sora2 invite check blocked by Cloudflare challenge (HTTP 403)")
require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
require.Contains(t, body, `"type":"test_complete","success":true`)
}
func TestSanitizeProxyURLForLog(t *testing.T) {
require.Equal(t, "http://proxy.example.com:8080", sanitizeProxyURLForLog("http://user:pass@proxy.example.com:8080"))
require.Equal(t, "", sanitizeProxyURLForLog(""))
require.Equal(t, "<invalid_proxy_url>", sanitizeProxyURLForLog("://invalid"))
}
func TestExtractSoraEgressIPHint(t *testing.T) {
h := make(http.Header)
h.Set("x-openai-public-ip", "203.0.113.10")
require.Equal(t, "203.0.113.10", extractSoraEgressIPHint(h))
h2 := make(http.Header)
h2.Set("x-envoy-external-address", "198.51.100.9")
require.Equal(t, "198.51.100.9", extractSoraEgressIPHint(h2))
require.Equal(t, "unknown", extractSoraEgressIPHint(nil))
require.Equal(t, "unknown", extractSoraEgressIPHint(http.Header{}))
}
......@@ -36,8 +36,8 @@ type UsageLogRepository interface {
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error)
GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error)
GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error)
// User dashboard stats
GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error)
......
......@@ -4,11 +4,15 @@ import (
"context"
"errors"
"fmt"
"log"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
)
// AdminService interface defines admin management operations
......@@ -65,6 +69,7 @@ type AdminService interface {
GetProxyAccounts(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error)
CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error)
TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error)
CheckProxyQuality(ctx context.Context, id int64) (*ProxyQualityCheckResult, error)
// Redeem code management
ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error)
......@@ -111,11 +116,16 @@ type CreateGroupInput struct {
WeeklyLimitUSD *float64 // 周限额 (USD)
MonthlyLimitUSD *float64 // 月限额 (USD)
// 图片生成计费配置(仅 antigravity 平台使用)
ImagePrice1K *float64
ImagePrice2K *float64
ImagePrice4K *float64
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID
ImagePrice1K *float64
ImagePrice2K *float64
ImagePrice4K *float64
// Sora 按次计费配置
SoraImagePrice360 *float64
SoraImagePrice540 *float64
SoraVideoPricePerRequest *float64
SoraVideoPricePerRequestHD *float64
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
FallbackGroupIDOnInvalidRequest *int64
// 模型路由配置(仅 anthropic 平台使用)
......@@ -140,11 +150,16 @@ type UpdateGroupInput struct {
WeeklyLimitUSD *float64 // 周限额 (USD)
MonthlyLimitUSD *float64 // 月限额 (USD)
// 图片生成计费配置(仅 antigravity 平台使用)
ImagePrice1K *float64
ImagePrice2K *float64
ImagePrice4K *float64
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID
ImagePrice1K *float64
ImagePrice2K *float64
ImagePrice4K *float64
// Sora 按次计费配置
SoraImagePrice360 *float64
SoraImagePrice540 *float64
SoraVideoPricePerRequest *float64
SoraVideoPricePerRequestHD *float64
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
FallbackGroupIDOnInvalidRequest *int64
// 模型路由配置(仅 anthropic 平台使用)
......@@ -278,6 +293,32 @@ type ProxyTestResult struct {
CountryCode string `json:"country_code,omitempty"`
}
type ProxyQualityCheckResult struct {
ProxyID int64 `json:"proxy_id"`
Score int `json:"score"`
Grade string `json:"grade"`
Summary string `json:"summary"`
ExitIP string `json:"exit_ip,omitempty"`
Country string `json:"country,omitempty"`
CountryCode string `json:"country_code,omitempty"`
BaseLatencyMs int64 `json:"base_latency_ms,omitempty"`
PassedCount int `json:"passed_count"`
WarnCount int `json:"warn_count"`
FailedCount int `json:"failed_count"`
ChallengeCount int `json:"challenge_count"`
CheckedAt int64 `json:"checked_at"`
Items []ProxyQualityCheckItem `json:"items"`
}
type ProxyQualityCheckItem struct {
Target string `json:"target"`
Status string `json:"status"` // pass/warn/fail/challenge
HTTPStatus int `json:"http_status,omitempty"`
LatencyMs int64 `json:"latency_ms,omitempty"`
Message string `json:"message,omitempty"`
CFRay string `json:"cf_ray,omitempty"`
}
// ProxyExitInfo represents proxy exit information from ip-api.com
type ProxyExitInfo struct {
IP string
......@@ -292,11 +333,64 @@ type ProxyExitInfoProber interface {
ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error)
}
type proxyQualityTarget struct {
Target string
URL string
Method string
AllowedStatuses map[int]struct{}
}
var proxyQualityTargets = []proxyQualityTarget{
{
Target: "openai",
URL: "https://api.openai.com/v1/models",
Method: http.MethodGet,
AllowedStatuses: map[int]struct{}{
http.StatusUnauthorized: {},
},
},
{
Target: "anthropic",
URL: "https://api.anthropic.com/v1/messages",
Method: http.MethodGet,
AllowedStatuses: map[int]struct{}{
http.StatusUnauthorized: {},
http.StatusMethodNotAllowed: {},
http.StatusNotFound: {},
http.StatusBadRequest: {},
},
},
{
Target: "gemini",
URL: "https://generativelanguage.googleapis.com/$discovery/rest?version=v1beta",
Method: http.MethodGet,
AllowedStatuses: map[int]struct{}{
http.StatusOK: {},
},
},
{
Target: "sora",
URL: "https://sora.chatgpt.com/backend/me",
Method: http.MethodGet,
AllowedStatuses: map[int]struct{}{
http.StatusUnauthorized: {},
},
},
}
const (
proxyQualityRequestTimeout = 15 * time.Second
proxyQualityResponseHeaderTimeout = 10 * time.Second
proxyQualityMaxBodyBytes = int64(8 * 1024)
proxyQualityClientUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36"
)
// adminServiceImpl implements AdminService
type adminServiceImpl struct {
userRepo UserRepository
groupRepo GroupRepository
accountRepo AccountRepository
soraAccountRepo SoraAccountRepository // Sora 账号扩展表仓储
proxyRepo ProxyRepository
apiKeyRepo APIKeyRepository
redeemCodeRepo RedeemCodeRepository
......@@ -312,6 +406,7 @@ func NewAdminService(
userRepo UserRepository,
groupRepo GroupRepository,
accountRepo AccountRepository,
soraAccountRepo SoraAccountRepository,
proxyRepo ProxyRepository,
apiKeyRepo APIKeyRepository,
redeemCodeRepo RedeemCodeRepository,
......@@ -325,6 +420,7 @@ func NewAdminService(
userRepo: userRepo,
groupRepo: groupRepo,
accountRepo: accountRepo,
soraAccountRepo: soraAccountRepo,
proxyRepo: proxyRepo,
apiKeyRepo: apiKeyRepo,
redeemCodeRepo: redeemCodeRepo,
......@@ -348,7 +444,7 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi
for i := range users {
rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID)
if err != nil {
log.Printf("failed to load user group rates: user_id=%d err=%v", users[i].ID, err)
logger.LegacyPrintf("service.admin", "failed to load user group rates: user_id=%d err=%v", users[i].ID, err)
continue
}
users[i].GroupRates = rates
......@@ -366,7 +462,7 @@ func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error)
if s.userGroupRateRepo != nil {
rates, err := s.userGroupRateRepo.GetByUserID(ctx, id)
if err != nil {
log.Printf("failed to load user group rates: user_id=%d err=%v", id, err)
logger.LegacyPrintf("service.admin", "failed to load user group rates: user_id=%d err=%v", id, err)
} else {
user.GroupRates = rates
}
......@@ -444,7 +540,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
// 同步用户专属分组倍率
if input.GroupRates != nil && s.userGroupRateRepo != nil {
if err := s.userGroupRateRepo.SyncUserGroupRates(ctx, user.ID, input.GroupRates); err != nil {
log.Printf("failed to sync user group rates: user_id=%d err=%v", user.ID, err)
logger.LegacyPrintf("service.admin", "failed to sync user group rates: user_id=%d err=%v", user.ID, err)
}
}
......@@ -458,7 +554,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
if concurrencyDiff != 0 {
code, err := GenerateRedeemCode()
if err != nil {
log.Printf("failed to generate adjustment redeem code: %v", err)
logger.LegacyPrintf("service.admin", "failed to generate adjustment redeem code: %v", err)
return user, nil
}
adjustmentRecord := &RedeemCode{
......@@ -471,7 +567,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
now := time.Now()
adjustmentRecord.UsedAt = &now
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
log.Printf("failed to create concurrency adjustment redeem code: %v", err)
logger.LegacyPrintf("service.admin", "failed to create concurrency adjustment redeem code: %v", err)
}
}
......@@ -488,7 +584,7 @@ func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error {
return errors.New("cannot delete admin user")
}
if err := s.userRepo.Delete(ctx, id); err != nil {
log.Printf("delete user failed: user_id=%d err=%v", id, err)
logger.LegacyPrintf("service.admin", "delete user failed: user_id=%d err=%v", id, err)
return err
}
if s.authCacheInvalidator != nil {
......@@ -531,7 +627,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.billingCacheService.InvalidateUserBalance(cacheCtx, userID); err != nil {
log.Printf("invalidate user balance cache failed: user_id=%d err=%v", userID, err)
logger.LegacyPrintf("service.admin", "invalidate user balance cache failed: user_id=%d err=%v", userID, err)
}
}()
}
......@@ -539,7 +635,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
if balanceDiff != 0 {
code, err := GenerateRedeemCode()
if err != nil {
log.Printf("failed to generate adjustment redeem code: %v", err)
logger.LegacyPrintf("service.admin", "failed to generate adjustment redeem code: %v", err)
return user, nil
}
......@@ -555,7 +651,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
adjustmentRecord.UsedAt = &now
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
log.Printf("failed to create balance adjustment redeem code: %v", err)
logger.LegacyPrintf("service.admin", "failed to create balance adjustment redeem code: %v", err)
}
}
......@@ -639,6 +735,10 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
imagePrice1K := normalizePrice(input.ImagePrice1K)
imagePrice2K := normalizePrice(input.ImagePrice2K)
imagePrice4K := normalizePrice(input.ImagePrice4K)
soraImagePrice360 := normalizePrice(input.SoraImagePrice360)
soraImagePrice540 := normalizePrice(input.SoraImagePrice540)
soraVideoPrice := normalizePrice(input.SoraVideoPricePerRequest)
soraVideoPriceHD := normalizePrice(input.SoraVideoPricePerRequestHD)
// 校验降级分组
if input.FallbackGroupID != nil {
......@@ -709,6 +809,10 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
ImagePrice1K: imagePrice1K,
ImagePrice2K: imagePrice2K,
ImagePrice4K: imagePrice4K,
SoraImagePrice360: soraImagePrice360,
SoraImagePrice540: soraImagePrice540,
SoraVideoPricePerRequest: soraVideoPrice,
SoraVideoPricePerRequestHD: soraVideoPriceHD,
ClaudeCodeOnly: input.ClaudeCodeOnly,
FallbackGroupID: input.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest,
......@@ -865,6 +969,18 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.ImagePrice4K != nil {
group.ImagePrice4K = normalizePrice(input.ImagePrice4K)
}
if input.SoraImagePrice360 != nil {
group.SoraImagePrice360 = normalizePrice(input.SoraImagePrice360)
}
if input.SoraImagePrice540 != nil {
group.SoraImagePrice540 = normalizePrice(input.SoraImagePrice540)
}
if input.SoraVideoPricePerRequest != nil {
group.SoraVideoPricePerRequest = normalizePrice(input.SoraVideoPricePerRequest)
}
if input.SoraVideoPricePerRequestHD != nil {
group.SoraVideoPricePerRequestHD = normalizePrice(input.SoraVideoPricePerRequestHD)
}
// Claude Code 客户端限制
if input.ClaudeCodeOnly != nil {
......@@ -993,7 +1109,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
defer cancel()
for _, userID := range affectedUserIDs {
if err := s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID); err != nil {
log.Printf("invalidate subscription cache failed: user_id=%d group_id=%d err=%v", userID, groupID, err)
logger.LegacyPrintf("service.admin", "invalidate subscription cache failed: user_id=%d group_id=%d err=%v", userID, groupID, err)
}
}
}()
......@@ -1103,6 +1219,18 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
return nil, err
}
// 如果是 Sora 平台账号,自动创建 sora_accounts 扩展表记录
if account.Platform == PlatformSora && s.soraAccountRepo != nil {
soraUpdates := map[string]any{
"access_token": account.GetCredential("access_token"),
"refresh_token": account.GetCredential("refresh_token"),
}
if err := s.soraAccountRepo.Upsert(ctx, account.ID, soraUpdates); err != nil {
// 只记录警告日志,不阻塞账号创建
logger.LegacyPrintf("service.admin", "[AdminService] 创建 sora_accounts 记录失败: account_id=%d err=%v", account.ID, err)
}
}
// 绑定分组
if len(groupIDs) > 0 {
if err := s.accountRepo.BindGroups(ctx, account.ID, groupIDs); err != nil {
......@@ -1200,7 +1328,11 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
}
// 重新查询以确保返回完整数据(包括正确的 Proxy 关联对象)
return s.accountRepo.GetByID(ctx, id)
updated, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
return updated, nil
}
// BulkUpdateAccounts updates multiple accounts in one request.
......@@ -1216,16 +1348,21 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
return result, nil
}
// Preload account platforms for mixed channel risk checks if group bindings are requested.
needMixedChannelCheck := input.GroupIDs != nil && !input.SkipMixedChannelCheck
// 预加载账号平台信息(混合渠道检查或 Sora 同步需要)。
platformByID := map[int64]string{}
if input.GroupIDs != nil && !input.SkipMixedChannelCheck {
if needMixedChannelCheck {
accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs)
if err != nil {
return nil, err
}
for _, account := range accounts {
if account != nil {
platformByID[account.ID] = account.Platform
if needMixedChannelCheck {
return nil, err
}
} else {
for _, account := range accounts {
if account != nil {
platformByID[account.ID] = account.Platform
}
}
}
}
......@@ -1318,7 +1455,10 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
}
func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error {
return s.accountRepo.Delete(ctx, id)
if err := s.accountRepo.Delete(ctx, id); err != nil {
return err
}
return nil
}
func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error) {
......@@ -1351,7 +1491,11 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64,
if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil {
return nil, err
}
return s.accountRepo.GetByID(ctx, id)
updated, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
return updated, nil
}
// Proxy management implementations
......@@ -1629,6 +1773,270 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR
}, nil
}
func (s *adminServiceImpl) CheckProxyQuality(ctx context.Context, id int64) (*ProxyQualityCheckResult, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
result := &ProxyQualityCheckResult{
ProxyID: id,
Score: 100,
Grade: "A",
CheckedAt: time.Now().Unix(),
Items: make([]ProxyQualityCheckItem, 0, len(proxyQualityTargets)+1),
}
proxyURL := proxy.URL()
if s.proxyProber == nil {
result.Items = append(result.Items, ProxyQualityCheckItem{
Target: "base_connectivity",
Status: "fail",
Message: "代理探测服务未配置",
})
result.FailedCount++
finalizeProxyQualityResult(result)
s.saveProxyQualitySnapshot(ctx, id, result, nil)
return result, nil
}
exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL)
if err != nil {
result.Items = append(result.Items, ProxyQualityCheckItem{
Target: "base_connectivity",
Status: "fail",
LatencyMs: latencyMs,
Message: err.Error(),
})
result.FailedCount++
finalizeProxyQualityResult(result)
s.saveProxyQualitySnapshot(ctx, id, result, nil)
return result, nil
}
result.ExitIP = exitInfo.IP
result.Country = exitInfo.Country
result.CountryCode = exitInfo.CountryCode
result.BaseLatencyMs = latencyMs
result.Items = append(result.Items, ProxyQualityCheckItem{
Target: "base_connectivity",
Status: "pass",
LatencyMs: latencyMs,
Message: "代理出口连通正常",
})
result.PassedCount++
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL,
Timeout: proxyQualityRequestTimeout,
ResponseHeaderTimeout: proxyQualityResponseHeaderTimeout,
ProxyStrict: true,
})
if err != nil {
result.Items = append(result.Items, ProxyQualityCheckItem{
Target: "http_client",
Status: "fail",
Message: fmt.Sprintf("创建检测客户端失败: %v", err),
})
result.FailedCount++
finalizeProxyQualityResult(result)
s.saveProxyQualitySnapshot(ctx, id, result, exitInfo)
return result, nil
}
for _, target := range proxyQualityTargets {
item := runProxyQualityTarget(ctx, client, target)
result.Items = append(result.Items, item)
switch item.Status {
case "pass":
result.PassedCount++
case "warn":
result.WarnCount++
case "challenge":
result.ChallengeCount++
default:
result.FailedCount++
}
}
finalizeProxyQualityResult(result)
s.saveProxyQualitySnapshot(ctx, id, result, exitInfo)
return result, nil
}
func runProxyQualityTarget(ctx context.Context, client *http.Client, target proxyQualityTarget) ProxyQualityCheckItem {
item := ProxyQualityCheckItem{
Target: target.Target,
}
req, err := http.NewRequestWithContext(ctx, target.Method, target.URL, nil)
if err != nil {
item.Status = "fail"
item.Message = fmt.Sprintf("构建请求失败: %v", err)
return item
}
req.Header.Set("Accept", "application/json,text/html,*/*")
req.Header.Set("User-Agent", proxyQualityClientUserAgent)
start := time.Now()
resp, err := client.Do(req)
if err != nil {
item.Status = "fail"
item.LatencyMs = time.Since(start).Milliseconds()
item.Message = fmt.Sprintf("请求失败: %v", err)
return item
}
defer func() { _ = resp.Body.Close() }()
item.LatencyMs = time.Since(start).Milliseconds()
item.HTTPStatus = resp.StatusCode
body, readErr := io.ReadAll(io.LimitReader(resp.Body, proxyQualityMaxBodyBytes+1))
if readErr != nil {
item.Status = "fail"
item.Message = fmt.Sprintf("读取响应失败: %v", readErr)
return item
}
if int64(len(body)) > proxyQualityMaxBodyBytes {
body = body[:proxyQualityMaxBodyBytes]
}
if target.Target == "sora" && soraerror.IsCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) {
item.Status = "challenge"
item.CFRay = soraerror.ExtractCloudflareRayID(resp.Header, body)
item.Message = "Sora 命中 Cloudflare challenge"
return item
}
if _, ok := target.AllowedStatuses[resp.StatusCode]; ok {
if resp.StatusCode >= http.StatusOK && resp.StatusCode < http.StatusMultipleChoices {
item.Status = "pass"
item.Message = fmt.Sprintf("HTTP %d", resp.StatusCode)
} else {
item.Status = "warn"
item.Message = fmt.Sprintf("HTTP %d(目标可达,但鉴权或方法受限)", resp.StatusCode)
}
return item
}
if resp.StatusCode == http.StatusTooManyRequests {
item.Status = "warn"
item.Message = "目标返回 429,可能存在频控"
return item
}
item.Status = "fail"
item.Message = fmt.Sprintf("非预期状态码: %d", resp.StatusCode)
return item
}
func finalizeProxyQualityResult(result *ProxyQualityCheckResult) {
if result == nil {
return
}
score := 100 - result.WarnCount*10 - result.FailedCount*22 - result.ChallengeCount*30
if score < 0 {
score = 0
}
result.Score = score
result.Grade = proxyQualityGrade(score)
result.Summary = fmt.Sprintf(
"通过 %d 项,告警 %d 项,失败 %d 项,挑战 %d 项",
result.PassedCount,
result.WarnCount,
result.FailedCount,
result.ChallengeCount,
)
}
func proxyQualityGrade(score int) string {
switch {
case score >= 90:
return "A"
case score >= 75:
return "B"
case score >= 60:
return "C"
case score >= 40:
return "D"
default:
return "F"
}
}
func proxyQualityOverallStatus(result *ProxyQualityCheckResult) string {
if result == nil {
return ""
}
if result.ChallengeCount > 0 {
return "challenge"
}
if result.FailedCount > 0 {
return "failed"
}
if result.WarnCount > 0 {
return "warn"
}
if result.PassedCount > 0 {
return "healthy"
}
return "failed"
}
func proxyQualityFirstCFRay(result *ProxyQualityCheckResult) string {
if result == nil {
return ""
}
for _, item := range result.Items {
if item.CFRay != "" {
return item.CFRay
}
}
return ""
}
func proxyQualityBaseConnectivityPass(result *ProxyQualityCheckResult) bool {
if result == nil {
return false
}
for _, item := range result.Items {
if item.Target == "base_connectivity" {
return item.Status == "pass"
}
}
return false
}
func (s *adminServiceImpl) saveProxyQualitySnapshot(ctx context.Context, proxyID int64, result *ProxyQualityCheckResult, exitInfo *ProxyExitInfo) {
if result == nil {
return
}
score := result.Score
checkedAt := result.CheckedAt
info := &ProxyLatencyInfo{
Success: proxyQualityBaseConnectivityPass(result),
Message: result.Summary,
QualityStatus: proxyQualityOverallStatus(result),
QualityScore: &score,
QualityGrade: result.Grade,
QualitySummary: result.Summary,
QualityCheckedAt: &checkedAt,
QualityCFRay: proxyQualityFirstCFRay(result),
UpdatedAt: time.Now(),
}
if result.BaseLatencyMs > 0 {
latency := result.BaseLatencyMs
info.LatencyMs = &latency
}
if exitInfo != nil {
info.IPAddress = exitInfo.IP
info.Country = exitInfo.Country
info.CountryCode = exitInfo.CountryCode
info.Region = exitInfo.Region
info.City = exitInfo.City
}
s.saveProxyLatency(ctx, proxyID, info)
}
func (s *adminServiceImpl) probeProxyLatency(ctx context.Context, proxy *Proxy) {
if s.proxyProber == nil || proxy == nil {
return
......@@ -1718,7 +2126,7 @@ func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []Pro
latencies, err := s.proxyLatencyCache.GetProxyLatencies(ctx, ids)
if err != nil {
log.Printf("Warning: load proxy latency cache failed: %v", err)
logger.LegacyPrintf("service.admin", "Warning: load proxy latency cache failed: %v", err)
return
}
......@@ -1739,6 +2147,11 @@ func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []Pro
proxies[i].CountryCode = info.CountryCode
proxies[i].Region = info.Region
proxies[i].City = info.City
proxies[i].QualityStatus = info.QualityStatus
proxies[i].QualityScore = info.QualityScore
proxies[i].QualityGrade = info.QualityGrade
proxies[i].QualitySummary = info.QualitySummary
proxies[i].QualityChecked = info.QualityCheckedAt
}
}
......@@ -1746,8 +2159,28 @@ func (s *adminServiceImpl) saveProxyLatency(ctx context.Context, proxyID int64,
if s.proxyLatencyCache == nil || info == nil {
return
}
if err := s.proxyLatencyCache.SetProxyLatency(ctx, proxyID, info); err != nil {
log.Printf("Warning: store proxy latency cache failed: %v", err)
merged := *info
if latencies, err := s.proxyLatencyCache.GetProxyLatencies(ctx, []int64{proxyID}); err == nil {
if existing := latencies[proxyID]; existing != nil {
if merged.QualityCheckedAt == nil &&
merged.QualityScore == nil &&
merged.QualityGrade == "" &&
merged.QualityStatus == "" &&
merged.QualitySummary == "" &&
merged.QualityCFRay == "" {
merged.QualityStatus = existing.QualityStatus
merged.QualityScore = existing.QualityScore
merged.QualityGrade = existing.QualityGrade
merged.QualitySummary = existing.QualitySummary
merged.QualityCheckedAt = existing.QualityCheckedAt
merged.QualityCFRay = existing.QualityCFRay
}
}
}
if err := s.proxyLatencyCache.SetProxyLatency(ctx, proxyID, &merged); err != nil {
logger.LegacyPrintf("service.admin", "Warning: store proxy latency cache failed: %v", err)
}
}
......
......@@ -15,6 +15,13 @@ type accountRepoStubForBulkUpdate struct {
bulkUpdateErr error
bulkUpdateIDs []int64
bindGroupErrByID map[int64]error
getByIDsAccounts []*Account
getByIDsErr error
getByIDsCalled bool
getByIDsIDs []int64
getByIDAccounts map[int64]*Account
getByIDErrByID map[int64]error
getByIDCalled []int64
}
func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) {
......@@ -32,6 +39,26 @@ func (s *accountRepoStubForBulkUpdate) BindGroups(_ context.Context, accountID i
return nil
}
func (s *accountRepoStubForBulkUpdate) GetByIDs(_ context.Context, ids []int64) ([]*Account, error) {
s.getByIDsCalled = true
s.getByIDsIDs = append([]int64{}, ids...)
if s.getByIDsErr != nil {
return nil, s.getByIDsErr
}
return s.getByIDsAccounts, nil
}
func (s *accountRepoStubForBulkUpdate) GetByID(_ context.Context, id int64) (*Account, error) {
s.getByIDCalled = append(s.getByIDCalled, id)
if err, ok := s.getByIDErrByID[id]; ok {
return nil, err
}
if account, ok := s.getByIDAccounts[id]; ok {
return account, nil
}
return nil, errors.New("account not found")
}
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) {
repo := &accountRepoStubForBulkUpdate{}
......
package service
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
)
func TestFinalizeProxyQualityResult_ScoreAndGrade(t *testing.T) {
result := &ProxyQualityCheckResult{
PassedCount: 2,
WarnCount: 1,
FailedCount: 1,
ChallengeCount: 1,
}
finalizeProxyQualityResult(result)
require.Equal(t, 38, result.Score)
require.Equal(t, "F", result.Grade)
require.Contains(t, result.Summary, "通过 2 项")
require.Contains(t, result.Summary, "告警 1 项")
require.Contains(t, result.Summary, "失败 1 项")
require.Contains(t, result.Summary, "挑战 1 项")
}
func TestRunProxyQualityTarget_SoraChallenge(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/html")
w.Header().Set("cf-ray", "test-ray-123")
w.WriteHeader(http.StatusForbidden)
_, _ = w.Write([]byte("<!DOCTYPE html><title>Just a moment...</title><script>window._cf_chl_opt={};</script>"))
}))
defer server.Close()
target := proxyQualityTarget{
Target: "sora",
URL: server.URL,
Method: http.MethodGet,
AllowedStatuses: map[int]struct{}{
http.StatusUnauthorized: {},
},
}
item := runProxyQualityTarget(context.Background(), server.Client(), target)
require.Equal(t, "challenge", item.Status)
require.Equal(t, http.StatusForbidden, item.HTTPStatus)
require.Equal(t, "test-ray-123", item.CFRay)
}
func TestRunProxyQualityTarget_AllowedStatusPass(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"models":[]}`))
}))
defer server.Close()
target := proxyQualityTarget{
Target: "gemini",
URL: server.URL,
Method: http.MethodGet,
AllowedStatuses: map[int]struct{}{
http.StatusOK: {},
},
}
item := runProxyQualityTarget(context.Background(), server.Client(), target)
require.Equal(t, "pass", item.Status)
require.Equal(t, http.StatusOK, item.HTTPStatus)
}
func TestRunProxyQualityTarget_AllowedStatusWarnForUnauthorized(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte(`{"error":"unauthorized"}`))
}))
defer server.Close()
target := proxyQualityTarget{
Target: "openai",
URL: server.URL,
Method: http.MethodGet,
AllowedStatuses: map[int]struct{}{
http.StatusUnauthorized: {},
},
}
item := runProxyQualityTarget(context.Background(), server.Client(), target)
require.Equal(t, "warn", item.Status)
require.Equal(t, http.StatusUnauthorized, item.HTTPStatus)
require.Contains(t, item.Message, "目标可达")
}
......@@ -22,8 +22,10 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/tidwall/gjson"
)
const (
......@@ -184,7 +186,7 @@ type smartRetryResult struct {
func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParams, resp *http.Response, respBody []byte, baseURL string, urlIdx int, availableURLs []string) *smartRetryResult {
// "Resource has been exhausted" 是 URL 级别限流,切换 URL(仅 429)
if resp.StatusCode == http.StatusTooManyRequests && isURLLevelRateLimit(respBody) && urlIdx < len(availableURLs)-1 {
log.Printf("%s URL fallback (429): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1])
logger.LegacyPrintf("service.antigravity_gateway", "%s URL fallback (429): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1])
return &smartRetryResult{action: smartRetryActionContinueURL}
}
......@@ -204,13 +206,13 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
if rateLimitDuration <= 0 {
rateLimitDuration = antigravityDefaultRateLimitDuration
}
log.Printf("%s status=%d oauth_long_delay model=%s account=%d upstream_retry_delay=%v body=%s (model rate limit, switch account)",
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d oauth_long_delay model=%s account=%d upstream_retry_delay=%v body=%s (model rate limit, switch account)",
p.prefix, resp.StatusCode, modelName, p.account.ID, rateLimitDuration, truncateForLog(respBody, 200))
resetAt := time.Now().Add(rateLimitDuration)
if !setModelRateLimitByModelName(p.ctx, p.accountRepo, p.account.ID, modelName, p.prefix, resp.StatusCode, resetAt, false) {
p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.requestedModel, p.groupID, p.sessionHash, p.isStickySession)
log.Printf("%s status=%d rate_limited account=%d (no model mapping)", p.prefix, resp.StatusCode, p.account.ID)
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d rate_limited account=%d (no model mapping)", p.prefix, resp.StatusCode, p.account.ID)
} else {
s.updateAccountModelRateLimitInCache(p.ctx, p.account, modelName, resetAt)
}
......@@ -273,7 +275,7 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
// 智能重试:创建新请求
retryReq, err := antigravity.NewAPIRequestWithURL(p.ctx, baseURL, p.action, p.accessToken, p.body)
if err != nil {
log.Printf("%s status=smart_retry_request_build_failed error=%v", p.prefix, err)
logger.LegacyPrintf("service.antigravity_gateway", "%s status=smart_retry_request_build_failed error=%v", p.prefix, err)
p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.requestedModel, p.groupID, p.sessionHash, p.isStickySession)
return &smartRetryResult{
action: smartRetryActionBreakWithResp,
......@@ -356,7 +358,7 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
// 单账号 503 退避重试模式:智能重试耗尽后不设限流、不切换账号,
// 直接返回 503 让 Handler 层的单账号退避循环做最终处理。
if resp.StatusCode == http.StatusServiceUnavailable && isSingleAccountRetry(p.ctx) {
log.Printf("%s status=%d smart_retry_exhausted_single_account attempts=%d model=%s account=%d body=%s (return 503 directly)",
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d smart_retry_exhausted_single_account attempts=%d model=%s account=%d body=%s (return 503 directly)",
p.prefix, resp.StatusCode, antigravitySmartRetryMaxAttempts, modelName, p.account.ID, truncateForLog(retryBody, 200))
return &smartRetryResult{
action: smartRetryActionBreakWithResp,
......@@ -374,9 +376,9 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
resetAt := time.Now().Add(rateLimitDuration)
if p.accountRepo != nil && modelName != "" {
if err := p.accountRepo.SetModelRateLimit(p.ctx, p.account.ID, modelName, resetAt); err != nil {
log.Printf("%s status=%d model_rate_limit_failed model=%s error=%v", p.prefix, resp.StatusCode, modelName, err)
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d model_rate_limit_failed model=%s error=%v", p.prefix, resp.StatusCode, modelName, err)
} else {
log.Printf("%s status=%d model_rate_limited_after_smart_retry model=%s account=%d reset_in=%v",
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d model_rate_limited_after_smart_retry model=%s account=%d reset_in=%v",
p.prefix, resp.StatusCode, modelName, p.account.ID, rateLimitDuration)
s.updateAccountModelRateLimitInCache(p.ctx, p.account, modelName, resetAt)
}
......@@ -431,7 +433,7 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace(
waitDuration = antigravitySmartRetryMinWait
}
log.Printf("%s status=%d single_account_503_retry_in_place model=%s account=%d upstream_retry_delay=%v (retrying in-place instead of rate-limiting)",
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d single_account_503_retry_in_place model=%s account=%d upstream_retry_delay=%v (retrying in-place instead of rate-limiting)",
p.prefix, resp.StatusCode, modelName, p.account.ID, waitDuration)
var lastRetryResp *http.Response
......@@ -443,21 +445,21 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace(
if totalWaited+waitDuration > antigravitySingleAccountSmartRetryTotalMaxWait {
remaining := antigravitySingleAccountSmartRetryTotalMaxWait - totalWaited
if remaining <= 0 {
log.Printf("%s single_account_503_retry: total_wait_exceeded total=%v max=%v, giving up",
logger.LegacyPrintf("service.antigravity_gateway", "%s single_account_503_retry: total_wait_exceeded total=%v max=%v, giving up",
p.prefix, totalWaited, antigravitySingleAccountSmartRetryTotalMaxWait)
break
}
waitDuration = remaining
}
log.Printf("%s status=%d single_account_503_retry attempt=%d/%d delay=%v total_waited=%v model=%s account=%d",
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d single_account_503_retry attempt=%d/%d delay=%v total_waited=%v model=%s account=%d",
p.prefix, resp.StatusCode, attempt, antigravitySingleAccountSmartRetryMaxAttempts, waitDuration, totalWaited, modelName, p.account.ID)
timer := time.NewTimer(waitDuration)
select {
case <-p.ctx.Done():
timer.Stop()
log.Printf("%s status=context_canceled_during_single_account_retry", p.prefix)
logger.LegacyPrintf("service.antigravity_gateway", "%s status=context_canceled_during_single_account_retry", p.prefix)
return &smartRetryResult{action: smartRetryActionBreakWithResp, err: p.ctx.Err()}
case <-timer.C:
}
......@@ -466,13 +468,13 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace(
// 创建新请求
retryReq, err := antigravity.NewAPIRequestWithURL(p.ctx, baseURL, p.action, p.accessToken, p.body)
if err != nil {
log.Printf("%s single_account_503_retry: request_build_failed error=%v", p.prefix, err)
logger.LegacyPrintf("service.antigravity_gateway", "%s single_account_503_retry: request_build_failed error=%v", p.prefix, err)
break
}
retryResp, retryErr := p.httpUpstream.Do(retryReq, p.proxyURL, p.account.ID, p.account.Concurrency)
if retryErr == nil && retryResp != nil && retryResp.StatusCode != http.StatusTooManyRequests && retryResp.StatusCode != http.StatusServiceUnavailable {
log.Printf("%s status=%d single_account_503_retry_success attempt=%d/%d total_waited=%v",
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d single_account_503_retry_success attempt=%d/%d total_waited=%v",
p.prefix, retryResp.StatusCode, attempt, antigravitySingleAccountSmartRetryMaxAttempts, totalWaited)
// 关闭之前的响应
if lastRetryResp != nil {
......@@ -483,7 +485,7 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace(
// 网络错误时继续重试
if retryErr != nil || retryResp == nil {
log.Printf("%s single_account_503_retry: network_error attempt=%d/%d error=%v",
logger.LegacyPrintf("service.antigravity_gateway", "%s single_account_503_retry: network_error attempt=%d/%d error=%v",
p.prefix, attempt, antigravitySingleAccountSmartRetryMaxAttempts, retryErr)
continue
}
......@@ -517,7 +519,7 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace(
if retryBody == nil {
retryBody = respBody
}
log.Printf("%s status=%d single_account_503_retry_exhausted attempts=%d total_waited=%v model=%s account=%d body=%s (return 503 directly)",
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d single_account_503_retry_exhausted attempts=%d total_waited=%v model=%s account=%d body=%s (return 503 directly)",
p.prefix, resp.StatusCode, antigravitySingleAccountSmartRetryMaxAttempts, totalWaited, modelName, p.account.ID, truncateForLog(retryBody, 200))
return &smartRetryResult{
......@@ -540,10 +542,10 @@ func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopP
// 如果上游确实还不可用,handleSmartRetry → handleSingleAccountRetryInPlace
// 会在 Service 层原地等待+重试,不需要在预检查这里等。
if isSingleAccountRetry(p.ctx) {
log.Printf("%s pre_check: single_account_retry skipping rate_limit remaining=%v model=%s account=%d (will retry in-place if 503)",
logger.LegacyPrintf("service.antigravity_gateway", "%s pre_check: single_account_retry skipping rate_limit remaining=%v model=%s account=%d (will retry in-place if 503)",
p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID)
} else {
log.Printf("%s pre_check: rate_limit_switch remaining=%v model=%s account=%d",
logger.LegacyPrintf("service.antigravity_gateway", "%s pre_check: rate_limit_switch remaining=%v model=%s account=%d",
p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID)
return nil, &AntigravityAccountSwitchError{
OriginalAccountID: p.account.ID,
......@@ -580,7 +582,7 @@ urlFallbackLoop:
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
select {
case <-p.ctx.Done():
log.Printf("%s status=context_canceled error=%v", p.prefix, p.ctx.Err())
logger.LegacyPrintf("service.antigravity_gateway", "%s status=context_canceled error=%v", p.prefix, p.ctx.Err())
return nil, p.ctx.Err()
default:
}
......@@ -610,18 +612,18 @@ urlFallbackLoop:
Message: safeErr,
})
if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
log.Printf("%s URL fallback (connection error): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1])
logger.LegacyPrintf("service.antigravity_gateway", "%s URL fallback (connection error): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1])
continue urlFallbackLoop
}
if attempt < antigravityMaxRetries {
log.Printf("%s status=request_failed retry=%d/%d error=%v", p.prefix, attempt, antigravityMaxRetries, err)
logger.LegacyPrintf("service.antigravity_gateway", "%s status=request_failed retry=%d/%d error=%v", p.prefix, attempt, antigravityMaxRetries, err)
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", p.prefix)
logger.LegacyPrintf("service.antigravity_gateway", "%s status=context_canceled_during_backoff", p.prefix)
return nil, p.ctx.Err()
}
continue
}
log.Printf("%s status=request_failed retries_exhausted error=%v", p.prefix, err)
logger.LegacyPrintf("service.antigravity_gateway", "%s status=request_failed retries_exhausted error=%v", p.prefix, err)
setOpsUpstreamError(p.c, 0, safeErr, "")
return nil, fmt.Errorf("upstream request failed after retries: %w", err)
}
......@@ -678,9 +680,9 @@ urlFallbackLoop:
Message: upstreamMsg,
Detail: getUpstreamDetail(respBody),
})
log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 200))
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 200))
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", p.prefix)
logger.LegacyPrintf("service.antigravity_gateway", "%s status=context_canceled_during_backoff", p.prefix)
return nil, p.ctx.Err()
}
continue
......@@ -688,7 +690,7 @@ urlFallbackLoop:
// 重试用尽,标记账户限流
p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.requestedModel, p.groupID, p.sessionHash, p.isStickySession)
log.Printf("%s status=%d rate_limited base_url=%s body=%s", p.prefix, resp.StatusCode, baseURL, truncateForLog(respBody, 200))
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d rate_limited base_url=%s body=%s", p.prefix, resp.StatusCode, baseURL, truncateForLog(respBody, 200))
resp = &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
......@@ -712,9 +714,9 @@ urlFallbackLoop:
Message: upstreamMsg,
Detail: getUpstreamDetail(respBody),
})
log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500))
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500))
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", p.prefix)
logger.LegacyPrintf("service.antigravity_gateway", "%s status=context_canceled_during_backoff", p.prefix)
return nil, p.ctx.Err()
}
continue
......@@ -1012,14 +1014,14 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
}
// 调试日志:Test 请求信息
log.Printf("[antigravity-Test] account=%s request_size=%d url=%s", account.Name, len(requestBody), req.URL.String())
logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Test] account=%s request_size=%d url=%s", account.Name, len(requestBody), req.URL.String())
// 发送请求
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil {
lastErr = fmt.Errorf("请求失败: %w", err)
if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
log.Printf("[antigravity-Test] URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Test] URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
continue
}
return nil, lastErr
......@@ -1034,7 +1036,7 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
// 检查是否需要 URL 降级
if shouldAntigravityFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
log.Printf("[antigravity-Test] URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Test] URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
continue
}
......@@ -1243,16 +1245,12 @@ func (s *AntigravityGatewayService) wrapV1InternalRequest(projectID, model strin
}
// unwrapV1InternalResponse 解包 v1internal 响应
// 使用 gjson 零拷贝提取 response 字段,避免 Unmarshal+Marshal 双重开销
func (s *AntigravityGatewayService) unwrapV1InternalResponse(body []byte) ([]byte, error) {
var outer map[string]any
if err := json.Unmarshal(body, &outer); err != nil {
return nil, err
}
if resp, ok := outer["response"]; ok {
return json.Marshal(resp)
result := gjson.GetBytes(body, "response")
if result.Exists() {
return []byte(result.Raw), nil
}
return body, nil
}
......@@ -1420,7 +1418,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
continue
}
log.Printf("Antigravity account %d: detected signature-related 400, retrying once (%s)", account.ID, stage.name)
logger.LegacyPrintf("service.antigravity_gateway", "Antigravity account %d: detected signature-related 400, retrying once (%s)", account.ID, stage.name)
retryGeminiBody, txErr := antigravity.TransformClaudeToGeminiWithOptions(&retryClaudeReq, projectID, mappedModel, s.getClaudeTransformOptions(ctx))
if txErr != nil {
......@@ -1453,7 +1451,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
Kind: "signature_retry_request_error",
Message: sanitizeUpstreamErrorMessage(retryErr.Error()),
})
log.Printf("Antigravity account %d: signature retry request failed (%s): %v", account.ID, stage.name, retryErr)
logger.LegacyPrintf("service.antigravity_gateway", "Antigravity account %d: signature retry request failed (%s): %v", account.ID, stage.name, retryErr)
continue
}
......@@ -1472,7 +1470,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
if retryResp.Request != nil && retryResp.Request.URL != nil {
retryBaseURL = retryResp.Request.URL.Scheme + "://" + retryResp.Request.URL.Host
}
log.Printf("%s status=429 rate_limited base_url=%s retry_stage=%s body=%s", prefix, retryBaseURL, stage.name, truncateForLog(retryBody, 200))
logger.LegacyPrintf("service.antigravity_gateway", "%s status=429 rate_limited base_url=%s retry_stage=%s body=%s", prefix, retryBaseURL, stage.name, truncateForLog(retryBody, 200))
}
kind := "signature_retry"
if strings.TrimSpace(stage.name) != "" {
......@@ -1525,7 +1523,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
upstreamDetail := s.getUpstreamErrorDetail(respBody)
logBody, maxBytes := s.getLogConfig()
if logBody {
log.Printf("%s status=400 prompt_too_long=true upstream_message=%q request_id=%s body=%s", prefix, upstreamMsg, resp.Header.Get("x-request-id"), truncateForLog(respBody, maxBytes))
logger.LegacyPrintf("service.antigravity_gateway", "%s status=400 prompt_too_long=true upstream_message=%q request_id=%s body=%s", prefix, upstreamMsg, resp.Header.Get("x-request-id"), truncateForLog(respBody, maxBytes))
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
......@@ -1600,7 +1598,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
// 客户端要求流式,直接透传转换
streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel)
if err != nil {
log.Printf("%s status=stream_error error=%v", prefix, err)
logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_error error=%v", prefix, err)
return nil, err
}
usage = streamRes.usage
......@@ -1610,7 +1608,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
// 客户端要求非流式,收集流式响应后转换返回
streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel)
if err != nil {
log.Printf("%s status=stream_collect_error error=%v", prefix, err)
logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_collect_error error=%v", prefix, err)
return nil, err
}
usage = streamRes.usage
......@@ -1963,7 +1961,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
Usage: ClaudeUsage{},
Model: originalModel,
Stream: false,
Duration: time.Since(time.Now()),
Duration: time.Since(startTime),
FirstTokenMs: nil,
}, nil
default:
......@@ -2002,9 +2000,9 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
// 清理 Schema
if cleanedBody, err := cleanGeminiRequest(injectedBody); err == nil {
injectedBody = cleanedBody
log.Printf("[Antigravity] Cleaned request schema in forwarded request for account %s", account.Name)
logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] Cleaned request schema in forwarded request for account %s", account.Name)
} else {
log.Printf("[Antigravity] Failed to clean schema: %v", err)
logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] Failed to clean schema: %v", err)
}
// 包装请求
......@@ -2066,7 +2064,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
isModelNotFoundError(resp.StatusCode, respBody) {
fallbackModel := s.settingService.GetFallbackModel(ctx, PlatformAntigravity)
if fallbackModel != "" && fallbackModel != mappedModel {
log.Printf("[Antigravity] Model not found (%s), retrying with fallback model %s (account: %s)", mappedModel, fallbackModel, account.Name)
logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] Model not found (%s), retrying with fallback model %s (account: %s)", mappedModel, fallbackModel, account.Name)
fallbackWrapped, err := s.wrapV1InternalRequest(projectID, fallbackModel, injectedBody)
if err == nil {
......@@ -2149,7 +2147,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
Message: upstreamMsg,
Detail: upstreamDetail,
})
log.Printf("[antigravity-Forward] upstream error status=%d body=%s", resp.StatusCode, truncateForLog(unwrappedForOps, 500))
logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Forward] upstream error status=%d body=%s", resp.StatusCode, truncateForLog(unwrappedForOps, 500))
c.Data(resp.StatusCode, contentType, unwrappedForOps)
return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode)
}
......@@ -2168,7 +2166,7 @@ handleSuccess:
// 客户端要求流式,直接透传
streamRes, err := s.handleGeminiStreamingResponse(c, resp, startTime)
if err != nil {
log.Printf("%s status=stream_error error=%v", prefix, err)
logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_error error=%v", prefix, err)
return nil, err
}
usage = streamRes.usage
......@@ -2178,7 +2176,7 @@ handleSuccess:
// 客户端要求非流式,收集流式响应后返回
streamRes, err := s.handleGeminiStreamToNonStreaming(c, resp, startTime)
if err != nil {
log.Printf("%s status=stream_collect_error error=%v", prefix, err)
logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_collect_error error=%v", prefix, err)
return nil, err
}
usage = streamRes.usage
......@@ -2297,13 +2295,13 @@ func setModelRateLimitByModelName(ctx context.Context, repo AccountRepository, a
}
// 直接使用官方模型 ID 作为 key,不再转换为 scope
if err := repo.SetModelRateLimit(ctx, accountID, modelName, resetAt); err != nil {
log.Printf("%s status=%d model_rate_limit_failed model=%s error=%v", prefix, statusCode, modelName, err)
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d model_rate_limit_failed model=%s error=%v", prefix, statusCode, modelName, err)
return false
}
if afterSmartRetry {
log.Printf("%s status=%d model_rate_limited_after_smart_retry model=%s account=%d reset_in=%v", prefix, statusCode, modelName, accountID, time.Until(resetAt).Truncate(time.Second))
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d model_rate_limited_after_smart_retry model=%s account=%d reset_in=%v", prefix, statusCode, modelName, accountID, time.Until(resetAt).Truncate(time.Second))
} else {
log.Printf("%s status=%d model_rate_limited model=%s account=%d reset_in=%v", prefix, statusCode, modelName, accountID, time.Until(resetAt).Truncate(time.Second))
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d model_rate_limited model=%s account=%d reset_in=%v", prefix, statusCode, modelName, accountID, time.Until(resetAt).Truncate(time.Second))
}
return true
}
......@@ -2411,7 +2409,7 @@ func parseAntigravitySmartRetryInfo(body []byte) *antigravitySmartRetryInfo {
// 例如: "0.5s", "10s", "4m50s", "1h30m", "200ms" 等
dur, err := time.ParseDuration(delay)
if err != nil {
log.Printf("[Antigravity] failed to parse retryDelay: %s error=%v", delay, err)
logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] failed to parse retryDelay: %s error=%v", delay, err)
continue
}
retryDelay = dur
......@@ -2532,7 +2530,7 @@ func (s *AntigravityGatewayService) handleModelRateLimit(p *handleModelRateLimit
// RATE_LIMIT_EXCEEDED: < antigravityRateLimitThreshold: 等待后重试
if info.RetryDelay < antigravityRateLimitThreshold {
log.Printf("%s status=%d model_rate_limit_wait model=%s wait=%v",
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d model_rate_limit_wait model=%s wait=%v",
p.prefix, p.statusCode, info.ModelName, info.RetryDelay)
return &handleModelRateLimitResult{
Handled: true,
......@@ -2557,12 +2555,12 @@ func (s *AntigravityGatewayService) handleModelRateLimit(p *handleModelRateLimit
// setModelRateLimitAndClearSession 设置模型限流并清除粘性会话
func (s *AntigravityGatewayService) setModelRateLimitAndClearSession(p *handleModelRateLimitParams, info *antigravitySmartRetryInfo) {
resetAt := time.Now().Add(info.RetryDelay)
log.Printf("%s status=%d model_rate_limited model=%s account=%d reset_in=%v",
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d model_rate_limited model=%s account=%d reset_in=%v",
p.prefix, p.statusCode, info.ModelName, p.account.ID, info.RetryDelay)
// 设置模型限流状态(数据库)
if err := s.accountRepo.SetModelRateLimit(p.ctx, p.account.ID, info.ModelName, resetAt); err != nil {
log.Printf("%s model_rate_limit_failed model=%s error=%v", p.prefix, info.ModelName, err)
logger.LegacyPrintf("service.antigravity_gateway", "%s model_rate_limit_failed model=%s error=%v", p.prefix, info.ModelName, err)
}
// 立即更新 Redis 快照中账号的限流状态,避免并发请求重复选中
......@@ -2598,7 +2596,7 @@ func (s *AntigravityGatewayService) updateAccountModelRateLimitInCache(ctx conte
// 更新 Redis 快照
if err := s.schedulerSnapshot.UpdateAccountInCache(ctx, account); err != nil {
log.Printf("[antigravity-Forward] cache_update_failed account=%d model=%s err=%v", account.ID, modelKey, err)
logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Forward] cache_update_failed account=%d model=%s err=%v", account.ID, modelKey, err)
}
}
......@@ -2637,7 +2635,7 @@ func (s *AntigravityGatewayService) handleUpstreamError(
// 429:尝试解析模型级限流,解析失败时兜底为账号级限流
if statusCode == 429 {
if logBody, maxBytes := s.getLogConfig(); logBody {
log.Printf("[Antigravity-Debug] 429 response body: %s", truncateString(string(body), maxBytes))
logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity-Debug] 429 response body: %s", truncateString(string(body), maxBytes))
}
resetAt := ParseGeminiRateLimitResetTime(body)
......@@ -2648,9 +2646,9 @@ func (s *AntigravityGatewayService) handleUpstreamError(
if modelKey != "" {
ra := s.resolveResetTime(resetAt, defaultDur)
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelKey, ra); err != nil {
log.Printf("%s status=429 model_rate_limit_set_failed model=%s error=%v", prefix, modelKey, err)
logger.LegacyPrintf("service.antigravity_gateway", "%s status=429 model_rate_limit_set_failed model=%s error=%v", prefix, modelKey, err)
} else {
log.Printf("%s status=429 model_rate_limited model=%s account=%d reset_at=%v reset_in=%v",
logger.LegacyPrintf("service.antigravity_gateway", "%s status=429 model_rate_limited model=%s account=%d reset_at=%v reset_in=%v",
prefix, modelKey, account.ID, ra.Format("15:04:05"), time.Until(ra).Truncate(time.Second))
s.updateAccountModelRateLimitInCache(ctx, account, modelKey, ra)
}
......@@ -2659,10 +2657,10 @@ func (s *AntigravityGatewayService) handleUpstreamError(
// 无法解析模型 key,兜底为账号级限流
ra := s.resolveResetTime(resetAt, defaultDur)
log.Printf("%s status=429 rate_limited account=%d reset_at=%v reset_in=%v (fallback)",
logger.LegacyPrintf("service.antigravity_gateway", "%s status=429 rate_limited account=%d reset_at=%v reset_in=%v (fallback)",
prefix, account.ID, ra.Format("15:04:05"), time.Until(ra).Truncate(time.Second))
if err := s.accountRepo.SetRateLimited(ctx, account.ID, ra); err != nil {
log.Printf("%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err)
logger.LegacyPrintf("service.antigravity_gateway", "%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err)
}
return nil
}
......@@ -2672,7 +2670,7 @@ func (s *AntigravityGatewayService) handleUpstreamError(
}
shouldDisable := s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body)
if shouldDisable {
log.Printf("%s status=%d marked_error", prefix, statusCode)
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d marked_error", prefix, statusCode)
}
return nil
}
......@@ -2746,18 +2744,18 @@ func (cw *antigravityClientWriter) Disconnected() bool { return cw.disconnected
func (cw *antigravityClientWriter) markDisconnected() {
cw.disconnected = true
log.Printf("Client disconnected during streaming (%s), continuing to drain upstream for billing", cw.prefix)
logger.LegacyPrintf("service.antigravity_gateway", "Client disconnected during streaming (%s), continuing to drain upstream for billing", cw.prefix)
}
// handleStreamReadError 处理上游读取错误的通用逻辑。
// 返回 (clientDisconnect, handled):handled=true 表示错误已处理,调用方应返回已收集的 usage。
func handleStreamReadError(err error, clientDisconnected bool, prefix string) (disconnect bool, handled bool) {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
log.Printf("Context canceled during streaming (%s), returning collected usage", prefix)
logger.LegacyPrintf("service.antigravity_gateway", "Context canceled during streaming (%s), returning collected usage", prefix)
return true, true
}
if clientDisconnected {
log.Printf("Upstream read error after client disconnect (%s): %v, returning collected usage", prefix, err)
logger.LegacyPrintf("service.antigravity_gateway", "Upstream read error after client disconnect (%s): %v, returning collected usage", prefix, err)
return true, true
}
return false, false
......@@ -2786,7 +2784,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
scanBuf := getSSEScannerBuf64K()
scanner.Buffer(scanBuf[:0], maxLineSize)
usage := &ClaudeUsage{}
var firstTokenMs *int
......@@ -2807,7 +2806,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
go func(scanBuf *sseScannerBuf64K) {
defer putSSEScannerBuf64K(scanBuf)
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
......@@ -2818,7 +2818,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
}(scanBuf)
defer close(done)
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
......@@ -2860,7 +2860,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: disconnect}, nil
}
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
logger.LegacyPrintf("service.antigravity_gateway", "SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
sendErrorEvent("response_too_large")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
}
......@@ -2884,19 +2884,19 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
}
// 解析 usage
if u := extractGeminiUsage(inner); u != nil {
usage = u
}
var parsed map[string]any
if json.Unmarshal(inner, &parsed) == nil {
if u := extractGeminiUsage(parsed); u != nil {
usage = u
}
// Check for MALFORMED_FUNCTION_CALL
if candidates, ok := parsed["candidates"].([]any); ok && len(candidates) > 0 {
if cand, ok := candidates[0].(map[string]any); ok {
if fr, ok := cand["finishReason"].(string); ok && fr == "MALFORMED_FUNCTION_CALL" {
log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in forward stream")
logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] MALFORMED_FUNCTION_CALL detected in forward stream")
if content, ok := cand["content"]; ok {
if b, err := json.Marshal(content); err == nil {
log.Printf("[Antigravity] Malformed content: %s", string(b))
logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] Malformed content: %s", string(b))
}
}
}
......@@ -2921,10 +2921,10 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
continue
}
if cw.Disconnected() {
log.Printf("Upstream timeout after client disconnect (antigravity gemini), returning collected usage")
logger.LegacyPrintf("service.antigravity_gateway", "Upstream timeout after client disconnect (antigravity gemini), returning collected usage")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
log.Printf("Stream data interval timeout (antigravity)")
logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity)")
sendErrorEvent("stream_timeout")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
}
......@@ -2939,7 +2939,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
scanBuf := getSSEScannerBuf64K()
scanner.Buffer(scanBuf[:0], maxLineSize)
usage := &ClaudeUsage{}
var firstTokenMs *int
......@@ -2967,7 +2968,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
go func(scanBuf *sseScannerBuf64K) {
defer putSSEScannerBuf64K(scanBuf)
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
......@@ -2978,7 +2980,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
}(scanBuf)
defer close(done)
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
......@@ -3005,7 +3007,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
}
if ev.err != nil {
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long (antigravity non-stream): max_size=%d error=%v", maxLineSize, ev.err)
logger.LegacyPrintf("service.antigravity_gateway", "SSE line too long (antigravity non-stream): max_size=%d error=%v", maxLineSize, ev.err)
}
return nil, ev.err
}
......@@ -3042,7 +3044,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
last = parsed
// 提取 usage
if u := extractGeminiUsage(parsed); u != nil {
if u := extractGeminiUsage(inner); u != nil {
usage = u
}
......@@ -3050,10 +3052,10 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
if candidates, ok := parsed["candidates"].([]any); ok && len(candidates) > 0 {
if cand, ok := candidates[0].(map[string]any); ok {
if fr, ok := cand["finishReason"].(string); ok && fr == "MALFORMED_FUNCTION_CALL" {
log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in forward non-stream collect")
logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] MALFORMED_FUNCTION_CALL detected in forward non-stream collect")
if content, ok := cand["content"]; ok {
if b, err := json.Marshal(content); err == nil {
log.Printf("[Antigravity] Malformed content: %s", string(b))
logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] Malformed content: %s", string(b))
}
}
}
......@@ -3080,7 +3082,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
if time.Since(lastRead) < streamInterval {
continue
}
log.Printf("Stream data interval timeout (antigravity non-stream)")
logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity non-stream)")
return nil, fmt.Errorf("stream data interval timeout")
}
}
......@@ -3091,7 +3093,7 @@ returnResponse:
// 处理空响应情况 — 触发同账号重试 + failover 切换账号
if last == nil && lastWithParts == nil {
log.Printf("[antigravity-Forward] warning: empty stream response (gemini non-stream), triggering failover")
logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Forward] warning: empty stream response (gemini non-stream), triggering failover")
return nil, &UpstreamFailoverError{
StatusCode: http.StatusBadGateway,
ResponseBody: []byte(`{"error":"empty stream response from upstream"}`),
......@@ -3311,7 +3313,7 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, accou
// 记录上游错误详情便于排障(可选:由配置控制;不回显到客户端)
if logBody {
log.Printf("[antigravity-Forward] upstream_error status=%d body=%s", upstreamStatus, truncateForLog(body, maxBytes))
logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Forward] upstream_error status=%d body=%s", upstreamStatus, truncateForLog(body, maxBytes))
}
// 检查错误透传规则
......@@ -3402,7 +3404,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
scanBuf := getSSEScannerBuf64K()
scanner.Buffer(scanBuf[:0], maxLineSize)
var firstTokenMs *int
var last map[string]any
......@@ -3428,7 +3431,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
go func(scanBuf *sseScannerBuf64K) {
defer putSSEScannerBuf64K(scanBuf)
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
......@@ -3439,7 +3443,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
}(scanBuf)
defer close(done)
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
......@@ -3466,7 +3470,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
}
if ev.err != nil {
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long (antigravity claude non-stream): max_size=%d error=%v", maxLineSize, ev.err)
logger.LegacyPrintf("service.antigravity_gateway", "SSE line too long (antigravity claude non-stream): max_size=%d error=%v", maxLineSize, ev.err)
}
return nil, ev.err
}
......@@ -3515,7 +3519,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
if time.Since(lastRead) < streamInterval {
continue
}
log.Printf("Stream data interval timeout (antigravity claude non-stream)")
logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity claude non-stream)")
return nil, fmt.Errorf("stream data interval timeout")
}
}
......@@ -3526,7 +3530,7 @@ returnResponse:
// 处理空响应情况 — 触发同账号重试 + failover 切换账号
if last == nil && lastWithParts == nil {
log.Printf("[antigravity-Forward] warning: empty stream response (claude non-stream), triggering failover")
logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Forward] warning: empty stream response (claude non-stream), triggering failover")
return nil, &UpstreamFailoverError{
StatusCode: http.StatusBadGateway,
ResponseBody: []byte(`{"error":"empty stream response from upstream"}`),
......@@ -3548,7 +3552,7 @@ returnResponse:
// 转换 Gemini 响应为 Claude 格式
claudeResp, agUsage, err := antigravity.TransformGeminiToClaude(geminiBody, originalModel)
if err != nil {
log.Printf("[antigravity-Forward] transform_error error=%v body=%s", err, string(geminiBody))
logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Forward] transform_error error=%v body=%s", err, string(geminiBody))
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
}
......@@ -3586,7 +3590,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
scanBuf := getSSEScannerBuf64K()
scanner.Buffer(scanBuf[:0], maxLineSize)
// 辅助函数:转换 antigravity.ClaudeUsage 到 service.ClaudeUsage
convertUsage := func(agUsage *antigravity.ClaudeUsage) *ClaudeUsage {
......@@ -3618,7 +3623,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
go func(scanBuf *sseScannerBuf64K) {
defer putSSEScannerBuf64K(scanBuf)
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
......@@ -3629,7 +3635,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
}(scanBuf)
defer close(done)
streamInterval := time.Duration(0)
......@@ -3681,7 +3687,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
return &antigravityStreamResult{usage: finishUsage(), firstTokenMs: firstTokenMs, clientDisconnect: disconnect}, nil
}
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
logger.LegacyPrintf("service.antigravity_gateway", "SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
sendErrorEvent("response_too_large")
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, ev.err
}
......@@ -3705,10 +3711,10 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
continue
}
if cw.Disconnected() {
log.Printf("Upstream timeout after client disconnect (antigravity claude), returning collected usage")
logger.LegacyPrintf("service.antigravity_gateway", "Upstream timeout after client disconnect (antigravity claude), returning collected usage")
return &antigravityStreamResult{usage: finishUsage(), firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
log.Printf("Stream data interval timeout (antigravity)")
logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity)")
sendErrorEvent("stream_timeout")
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
}
......@@ -3908,7 +3914,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
// 发送请求
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil {
log.Printf("%s upstream request failed: %v", prefix, err)
logger.LegacyPrintf("service.antigravity_gateway", "%s upstream request failed: %v", prefix, err)
return nil, fmt.Errorf("upstream request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
......@@ -3966,7 +3972,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
// 构建计费结果
duration := time.Since(startTime)
log.Printf("%s status=success duration_ms=%d", prefix, duration.Milliseconds())
logger.LegacyPrintf("service.antigravity_gateway", "%s status=success duration_ms=%d", prefix, duration.Milliseconds())
return &ForwardResult{
Model: billingModel,
......@@ -4052,7 +4058,7 @@ func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp
if disconnect, handled := handleStreamReadError(ev.err, cw.Disconnected(), "antigravity upstream"); handled {
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: disconnect}
}
log.Printf("Stream read error (antigravity upstream): %v", ev.err)
logger.LegacyPrintf("service.antigravity_gateway", "Stream read error (antigravity upstream): %v", ev.err)
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}
}
......@@ -4076,10 +4082,10 @@ func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp
continue
}
if cw.Disconnected() {
log.Printf("Upstream timeout after client disconnect (antigravity upstream), returning collected usage")
logger.LegacyPrintf("service.antigravity_gateway", "Upstream timeout after client disconnect (antigravity upstream), returning collected usage")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}
}
log.Printf("Stream data interval timeout (antigravity upstream)")
logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity upstream)")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}
}
}
......
......@@ -9,6 +9,7 @@ import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
......@@ -417,6 +418,44 @@ func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
}
// TestStreamUpstreamResponse_UsageAndFirstToken
// 验证:usage 字段可被累积/覆盖更新,并且能记录首 token 时间
func TestStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr}
go func() {
defer func() { _ = pw.Close() }()
fmt.Fprintln(pw, `data: {"usage":{"input_tokens":1,"output_tokens":2,"cache_read_input_tokens":3,"cache_creation_input_tokens":4}}`)
fmt.Fprintln(pw, `data: {"usage":{"output_tokens":5}}`)
}()
start := time.Now().Add(-10 * time.Millisecond)
result := svc.streamUpstreamResponse(c, resp, start)
_ = pr.Close()
require.NotNil(t, result)
require.NotNil(t, result.usage)
require.Equal(t, 1, result.usage.InputTokens)
// 第二次事件覆盖 output_tokens
require.Equal(t, 5, result.usage.OutputTokens)
require.Equal(t, 3, result.usage.CacheReadInputTokens)
require.Equal(t, 4, result.usage.CacheCreationInputTokens)
require.NotNil(t, result.firstTokenMs)
// 确保有透传输出
require.Contains(t, rec.Body.String(), "data:")
}
// --- 流式 happy path 测试 ---
// TestStreamUpstreamResponse_NormalComplete
......@@ -920,3 +959,144 @@ func TestAntigravityClientWriter(t *testing.T) {
require.True(t, cw.Disconnected())
})
}
// TestUnwrapV1InternalResponse 测试 unwrapV1InternalResponse 的各种输入场景
func TestUnwrapV1InternalResponse(t *testing.T) {
svc := &AntigravityGatewayService{}
// 构造 >50KB 的大型 JSON
largePadding := strings.Repeat("x", 50*1024)
largeInput := []byte(fmt.Sprintf(`{"response":{"id":"big","pad":"%s"}}`, largePadding))
largeExpected := fmt.Sprintf(`{"id":"big","pad":"%s"}`, largePadding)
tests := []struct {
name string
input []byte
expected string
wantErr bool
}{
{
name: "正常 response 包装",
input: []byte(`{"response":{"id":"123","content":"hello"}}`),
expected: `{"id":"123","content":"hello"}`,
},
{
name: "无 response 透传",
input: []byte(`{"id":"456"}`),
expected: `{"id":"456"}`,
},
{
name: "空 JSON",
input: []byte(`{}`),
expected: `{}`,
},
{
name: "response 为 null",
input: []byte(`{"response":null}`),
expected: `null`,
},
{
name: "response 为基础类型 string",
input: []byte(`{"response":"hello"}`),
expected: `"hello"`,
},
{
name: "非法 JSON",
input: []byte(`not json`),
expected: `not json`,
},
{
name: "嵌套 response 只解一层",
input: []byte(`{"response":{"response":{"inner":true}}}`),
expected: `{"response":{"inner":true}}`,
},
{
name: "大型 JSON >50KB",
input: largeInput,
expected: largeExpected,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := svc.unwrapV1InternalResponse(tt.input)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Equal(t, tt.expected, strings.TrimSpace(string(got)))
})
}
}
// --- unwrapV1InternalResponse benchmark 对照组 ---
// unwrapV1InternalResponseOld 旧实现:Unmarshal+Marshal 双重开销(仅用于 benchmark 对照)
func unwrapV1InternalResponseOld(body []byte) ([]byte, error) {
var outer map[string]any
if err := json.Unmarshal(body, &outer); err != nil {
return nil, err
}
if resp, ok := outer["response"]; ok {
return json.Marshal(resp)
}
return body, nil
}
func BenchmarkUnwrapV1Internal_Old_Small(b *testing.B) {
body := []byte(`{"response":{"candidates":[{"content":{"parts":[{"text":"hello world"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5}}}`)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = unwrapV1InternalResponseOld(body)
}
}
func BenchmarkUnwrapV1Internal_New_Small(b *testing.B) {
body := []byte(`{"response":{"candidates":[{"content":{"parts":[{"text":"hello world"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5}}}`)
svc := &AntigravityGatewayService{}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = svc.unwrapV1InternalResponse(body)
}
}
func BenchmarkUnwrapV1Internal_Old_Large(b *testing.B) {
body := generateLargeUnwrapJSON(10 * 1024) // ~10KB
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = unwrapV1InternalResponseOld(body)
}
}
func BenchmarkUnwrapV1Internal_New_Large(b *testing.B) {
body := generateLargeUnwrapJSON(10 * 1024) // ~10KB
svc := &AntigravityGatewayService{}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = svc.unwrapV1InternalResponse(body)
}
}
// generateLargeUnwrapJSON 生成指定最小大小的包含 response 包装的 JSON
func generateLargeUnwrapJSON(minSize int) []byte {
parts := make([]map[string]string, 0)
current := 0
for current < minSize {
text := fmt.Sprintf("这是第 %d 段内容,用于填充 JSON 到目标大小。", len(parts)+1)
parts = append(parts, map[string]string{"text": text})
current += len(text) + 20 // 估算 JSON 编码开销
}
inner := map[string]any{
"candidates": []map[string]any{
{"content": map[string]any{"parts": parts}},
},
"usageMetadata": map[string]any{
"promptTokenCount": 100,
"candidatesTokenCount": 50,
},
}
outer := map[string]any{"response": inner}
b, _ := json.Marshal(outer)
return b
}
......@@ -15,6 +15,12 @@ import (
"github.com/stretchr/testify/require"
)
// 编译期接口断言
var _ HTTPUpstream = (*stubAntigravityUpstream)(nil)
var _ HTTPUpstream = (*recordingOKUpstream)(nil)
var _ AccountRepository = (*stubAntigravityAccountRepo)(nil)
var _ SchedulerCache = (*stubSchedulerCache)(nil)
type stubAntigravityUpstream struct {
firstBase string
secondBase string
......
......@@ -19,6 +19,7 @@ type APIKey struct {
Status string
IPWhitelist []string
IPBlacklist []string
LastUsedAt *time.Time
CreatedAt time.Time
UpdatedAt time.Time
User *User
......
......@@ -44,6 +44,10 @@ type APIKeyAuthGroupSnapshot struct {
ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"`
SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"`
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"`
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd,omitempty"`
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"`
......
......@@ -6,8 +6,7 @@ import (
"encoding/hex"
"errors"
"fmt"
"math/rand"
"sync"
"math/rand/v2"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
......@@ -23,12 +22,6 @@ type apiKeyAuthCacheConfig struct {
singleflight bool
}
var (
jitterRandMu sync.Mutex
// 认证缓存抖动使用独立随机源,避免全局 Seed
jitterRand = rand.New(rand.NewSource(time.Now().UnixNano()))
)
func newAPIKeyAuthCacheConfig(cfg *config.Config) apiKeyAuthCacheConfig {
if cfg == nil {
return apiKeyAuthCacheConfig{}
......@@ -56,6 +49,8 @@ func (c apiKeyAuthCacheConfig) negativeEnabled() bool {
return c.negativeTTL > 0
}
// jitterTTL 为缓存 TTL 添加抖动,避免多个请求在同一时刻同时过期触发集中回源。
// 这里直接使用 rand/v2 的顶层函数:并发安全,无需全局互斥锁。
func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration {
if ttl <= 0 {
return ttl
......@@ -68,9 +63,7 @@ func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration {
percent = 100
}
delta := float64(percent) / 100
jitterRandMu.Lock()
randVal := jitterRand.Float64()
jitterRandMu.Unlock()
randVal := rand.Float64()
factor := 1 - delta + randVal*(2*delta)
if factor <= 0 {
return ttl
......@@ -238,6 +231,10 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
ImagePrice1K: apiKey.Group.ImagePrice1K,
ImagePrice2K: apiKey.Group.ImagePrice2K,
ImagePrice4K: apiKey.Group.ImagePrice4K,
SoraImagePrice360: apiKey.Group.SoraImagePrice360,
SoraImagePrice540: apiKey.Group.SoraImagePrice540,
SoraVideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
FallbackGroupID: apiKey.Group.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: apiKey.Group.FallbackGroupIDOnInvalidRequest,
......@@ -288,6 +285,10 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
ImagePrice1K: snapshot.Group.ImagePrice1K,
ImagePrice2K: snapshot.Group.ImagePrice2K,
ImagePrice4K: snapshot.Group.ImagePrice4K,
SoraImagePrice360: snapshot.Group.SoraImagePrice360,
SoraImagePrice540: snapshot.Group.SoraImagePrice540,
SoraVideoPricePerRequest: snapshot.Group.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: snapshot.Group.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
FallbackGroupID: snapshot.Group.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: snapshot.Group.FallbackGroupIDOnInvalidRequest,
......
......@@ -5,6 +5,8 @@ import (
"crypto/rand"
"encoding/hex"
"fmt"
"strconv"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
......@@ -32,6 +34,9 @@ var (
const (
apiKeyMaxErrorsPerHour = 20
apiKeyLastUsedMinTouch = 30 * time.Second
// DB 写失败后的短退避,避免请求路径持续同步重试造成写风暴与高延迟。
apiKeyLastUsedFailBackoff = 5 * time.Second
)
type APIKeyRepository interface {
......@@ -58,6 +63,7 @@ type APIKeyRepository interface {
// Quota methods
IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error)
UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error
}
// APIKeyCache defines cache operations for API key service
......@@ -125,6 +131,8 @@ type APIKeyService struct {
authCacheL1 *ristretto.Cache
authCfg apiKeyAuthCacheConfig
authGroup singleflight.Group
lastUsedTouchL1 sync.Map // keyID -> nextAllowedAt(time.Time)
lastUsedTouchSF singleflight.Group
}
// NewAPIKeyService 创建API Key服务实例
......@@ -527,6 +535,7 @@ func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) erro
if err := s.apiKeyRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete api key: %w", err)
}
s.lastUsedTouchL1.Delete(id)
return nil
}
......@@ -558,6 +567,38 @@ func (s *APIKeyService) ValidateKey(ctx context.Context, key string) (*APIKey, *
return apiKey, user, nil
}
// TouchLastUsed 通过防抖更新 api_keys.last_used_at,减少高频写放大。
// 该操作为尽力而为,不应阻塞主请求链路。
func (s *APIKeyService) TouchLastUsed(ctx context.Context, keyID int64) error {
if keyID <= 0 {
return nil
}
now := time.Now()
if v, ok := s.lastUsedTouchL1.Load(keyID); ok {
if nextAllowedAt, ok := v.(time.Time); ok && now.Before(nextAllowedAt) {
return nil
}
}
_, err, _ := s.lastUsedTouchSF.Do(strconv.FormatInt(keyID, 10), func() (any, error) {
latest := time.Now()
if v, ok := s.lastUsedTouchL1.Load(keyID); ok {
if nextAllowedAt, ok := v.(time.Time); ok && latest.Before(nextAllowedAt) {
return nil, nil
}
}
if err := s.apiKeyRepo.UpdateLastUsed(ctx, keyID, latest); err != nil {
s.lastUsedTouchL1.Store(keyID, latest.Add(apiKeyLastUsedFailBackoff))
return nil, fmt.Errorf("touch api key last used: %w", err)
}
s.lastUsedTouchL1.Store(keyID, latest.Add(apiKeyLastUsedMinTouch))
return nil, nil
})
return err
}
// IncrementUsage 增加API Key使用次数(可选:用于统计)
func (s *APIKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
// 使用Redis计数器
......
......@@ -103,6 +103,10 @@ func (s *authRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amount
panic("unexpected IncrementQuotaUsed call")
}
func (s *authRepoStub) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
panic("unexpected UpdateLastUsed call")
}
type authCacheStub struct {
getAuthCache func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)
setAuthKeys []string
......
......@@ -24,10 +24,13 @@ import (
// - deleteErr: 模拟 Delete 返回的错误
// - deletedIDs: 记录被调用删除的 API Key ID,用于断言验证
type apiKeyRepoStub struct {
apiKey *APIKey // GetKeyAndOwnerID 的返回值
getByIDErr error // GetKeyAndOwnerID 的错误返回值
deleteErr error // Delete 的错误返回值
deletedIDs []int64 // 记录已删除的 API Key ID 列表
apiKey *APIKey // GetKeyAndOwnerID 的返回值
getByIDErr error // GetKeyAndOwnerID 的错误返回值
deleteErr error // Delete 的错误返回值
deletedIDs []int64 // 记录已删除的 API Key ID 列表
updateLastUsed func(ctx context.Context, id int64, usedAt time.Time) error
touchedIDs []int64
touchedUsedAts []time.Time
}
// 以下方法在本测试中不应被调用,使用 panic 确保测试失败时能快速定位问题
......@@ -122,6 +125,15 @@ func (s *apiKeyRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amoun
panic("unexpected IncrementQuotaUsed call")
}
func (s *apiKeyRepoStub) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
s.touchedIDs = append(s.touchedIDs, id)
s.touchedUsedAts = append(s.touchedUsedAts, usedAt)
if s.updateLastUsed != nil {
return s.updateLastUsed(ctx, id, usedAt)
}
return nil
}
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
// 用于验证删除操作时缓存清理逻辑是否被正确调用。
//
......@@ -214,12 +226,15 @@ func TestApiKeyService_Delete_Success(t *testing.T) {
}
cache := &apiKeyCacheStub{}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
svc.lastUsedTouchL1.Store(int64(42), time.Now())
err := svc.Delete(context.Background(), 42, 7) // API Key ID=42, 调用者 userID=7
require.NoError(t, err)
require.Equal(t, []int64{42}, repo.deletedIDs) // 验证正确的 API Key 被删除
require.Equal(t, []int64{7}, cache.invalidated) // 验证所有者的缓存被清除
require.Equal(t, []string{svc.authCacheKey("k")}, cache.deleteAuthKeys)
_, exists := svc.lastUsedTouchL1.Load(int64(42))
require.False(t, exists, "delete should clear touch debounce cache")
}
// TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment