Commit 987589ea authored by yangjianbo's avatar yangjianbo
Browse files

Merge branch 'test' into release

parents 372e04f6 03f69dd3
package service
import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type queuedHTTPUpstream struct {
responses []*http.Response
requests []*http.Request
tlsFlags []bool
}
func (u *queuedHTTPUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
return nil, fmt.Errorf("unexpected Do call")
}
func (u *queuedHTTPUpstream) DoWithTLS(req *http.Request, _ string, _ int64, _ int, enableTLSFingerprint bool) (*http.Response, error) {
u.requests = append(u.requests, req)
u.tlsFlags = append(u.tlsFlags, enableTLSFingerprint)
if len(u.responses) == 0 {
return nil, fmt.Errorf("no mocked response")
}
resp := u.responses[0]
u.responses = u.responses[1:]
return resp, nil
}
func newJSONResponse(status int, body string) *http.Response {
return &http.Response{
StatusCode: status,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(body)),
}
}
func newJSONResponseWithHeader(status int, body, key, value string) *http.Response {
resp := newJSONResponse(status, body)
resp.Header.Set(key, value)
return resp
}
func newSoraTestContext() (*gin.Context, *httptest.ResponseRecorder) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil)
return c, rec
}
func TestAccountTestService_testSoraAccountConnection_WithSubscription(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`),
newJSONResponse(http.StatusOK, `{"data":[{"plan":{"id":"chatgpt_plus","title":"ChatGPT Plus"},"end_ts":"2026-12-31T00:00:00Z"}]}`),
newJSONResponse(http.StatusOK, `{"invite_code":"inv_abc","redeemed_count":3,"total_count":50}`),
newJSONResponse(http.StatusOK, `{"rate_limit_and_credit_balance":{"estimated_num_videos_remaining":27,"rate_limit_reached":false,"access_resets_in_seconds":46833}}`),
},
}
svc := &AccountTestService{
httpUpstream: upstream,
cfg: &config.Config{
Gateway: config.GatewayConfig{
TLSFingerprint: config.TLSFingerprintConfig{
Enabled: true,
},
},
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
DisableTLSFingerprint: false,
},
},
},
}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.NoError(t, err)
require.Len(t, upstream.requests, 4)
require.Equal(t, soraMeAPIURL, upstream.requests[0].URL.String())
require.Equal(t, soraBillingAPIURL, upstream.requests[1].URL.String())
require.Equal(t, soraInviteMineURL, upstream.requests[2].URL.String())
require.Equal(t, soraRemainingURL, upstream.requests[3].URL.String())
require.Equal(t, "Bearer test_token", upstream.requests[0].Header.Get("Authorization"))
require.Equal(t, "Bearer test_token", upstream.requests[1].Header.Get("Authorization"))
require.Equal(t, []bool{true, true, true, true}, upstream.tlsFlags)
body := rec.Body.String()
require.Contains(t, body, `"type":"test_start"`)
require.Contains(t, body, "Sora connection OK - Email: demo@example.com")
require.Contains(t, body, "Subscription: ChatGPT Plus | chatgpt_plus | end=2026-12-31T00:00:00Z")
require.Contains(t, body, "Sora2: supported | invite=inv_abc | used=3/50")
require.Contains(t, body, "Sora2 remaining: 27 | reset_in=46833s")
require.Contains(t, body, `"type":"sora_test_result"`)
require.Contains(t, body, `"status":"success"`)
require.Contains(t, body, `"type":"test_complete","success":true`)
}
func TestAccountTestService_testSoraAccountConnection_SubscriptionFailedStillSuccess(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusOK, `{"name":"demo-user"}`),
newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`),
newJSONResponse(http.StatusUnauthorized, `{"error":{"message":"Unauthorized"}}`),
newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.NoError(t, err)
require.Len(t, upstream.requests, 4)
body := rec.Body.String()
require.Contains(t, body, "Sora connection OK - User: demo-user")
require.Contains(t, body, "Subscription check returned 403")
require.Contains(t, body, "Sora2 invite check returned 401")
require.Contains(t, body, `"type":"sora_test_result"`)
require.Contains(t, body, `"status":"partial_success"`)
require.Contains(t, body, `"type":"test_complete","success":true`)
}
func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponseWithHeader(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`, "cf-ray", "9cff2d62d83bb98d"),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.Error(t, err)
require.Contains(t, err.Error(), "Cloudflare challenge")
require.Contains(t, err.Error(), "cf-ray: 9cff2d62d83bb98d")
body := rec.Body.String()
require.Contains(t, body, `"type":"error"`)
require.Contains(t, body, "Cloudflare challenge")
require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
}
func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge429WithHeader(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponseWithHeader(http.StatusTooManyRequests, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body></body></html>`, "cf-mitigated", "challenge"),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.Error(t, err)
require.Contains(t, err.Error(), "Cloudflare challenge")
require.Contains(t, err.Error(), "HTTP 429")
body := rec.Body.String()
require.Contains(t, body, "Cloudflare challenge")
}
func TestAccountTestService_testSoraAccountConnection_TokenInvalidated(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusUnauthorized, `{"error":{"code":"token_invalidated","message":"Token invalid"}}`),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.Error(t, err)
require.Contains(t, err.Error(), "token_invalidated")
body := rec.Body.String()
require.Contains(t, body, `"type":"sora_test_result"`)
require.Contains(t, body, `"status":"failed"`)
require.Contains(t, body, "token_invalidated")
require.NotContains(t, body, `"type":"test_complete","success":true`)
}
func TestAccountTestService_testSoraAccountConnection_RateLimited(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`),
},
}
svc := &AccountTestService{
httpUpstream: upstream,
soraTestCooldown: time.Hour,
}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c1, _ := newSoraTestContext()
err := svc.testSoraAccountConnection(c1, account)
require.NoError(t, err)
c2, rec2 := newSoraTestContext()
err = svc.testSoraAccountConnection(c2, account)
require.Error(t, err)
require.Contains(t, err.Error(), "测试过于频繁")
body := rec2.Body.String()
require.Contains(t, body, `"type":"sora_test_result"`)
require.Contains(t, body, `"code":"test_rate_limited"`)
require.Contains(t, body, `"status":"failed"`)
require.NotContains(t, body, `"type":"test_complete","success":true`)
}
func TestAccountTestService_testSoraAccountConnection_SubscriptionCloudflareChallengeWithRay(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusOK, `{"name":"demo-user"}`),
newJSONResponse(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`),
newJSONResponse(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.NoError(t, err)
body := rec.Body.String()
require.Contains(t, body, "Subscription check blocked by Cloudflare challenge (HTTP 403)")
require.Contains(t, body, "Sora2 invite check blocked by Cloudflare challenge (HTTP 403)")
require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
require.Contains(t, body, `"type":"test_complete","success":true`)
}
func TestSanitizeProxyURLForLog(t *testing.T) {
require.Equal(t, "http://proxy.example.com:8080", sanitizeProxyURLForLog("http://user:pass@proxy.example.com:8080"))
require.Equal(t, "", sanitizeProxyURLForLog(""))
require.Equal(t, "<invalid_proxy_url>", sanitizeProxyURLForLog("://invalid"))
}
func TestExtractSoraEgressIPHint(t *testing.T) {
h := make(http.Header)
h.Set("x-openai-public-ip", "203.0.113.10")
require.Equal(t, "203.0.113.10", extractSoraEgressIPHint(h))
h2 := make(http.Header)
h2.Set("x-envoy-external-address", "198.51.100.9")
require.Equal(t, "198.51.100.9", extractSoraEgressIPHint(h2))
require.Equal(t, "unknown", extractSoraEgressIPHint(nil))
require.Equal(t, "unknown", extractSoraEgressIPHint(http.Header{}))
}
...@@ -4,11 +4,15 @@ import ( ...@@ -4,11 +4,15 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"io"
"net/http"
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
) )
// AdminService interface defines admin management operations // AdminService interface defines admin management operations
...@@ -39,7 +43,7 @@ type AdminService interface { ...@@ -39,7 +43,7 @@ type AdminService interface {
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
// Account management // Account management
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error)
GetAccount(ctx context.Context, id int64) (*Account, error) GetAccount(ctx context.Context, id int64) (*Account, error)
GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error)
CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error)
...@@ -65,6 +69,7 @@ type AdminService interface { ...@@ -65,6 +69,7 @@ type AdminService interface {
GetProxyAccounts(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) GetProxyAccounts(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error)
CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error)
TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error) TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error)
CheckProxyQuality(ctx context.Context, id int64) (*ProxyQualityCheckResult, error)
// Redeem code management // Redeem code management
ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error)
...@@ -288,6 +293,32 @@ type ProxyTestResult struct { ...@@ -288,6 +293,32 @@ type ProxyTestResult struct {
CountryCode string `json:"country_code,omitempty"` CountryCode string `json:"country_code,omitempty"`
} }
type ProxyQualityCheckResult struct {
ProxyID int64 `json:"proxy_id"`
Score int `json:"score"`
Grade string `json:"grade"`
Summary string `json:"summary"`
ExitIP string `json:"exit_ip,omitempty"`
Country string `json:"country,omitempty"`
CountryCode string `json:"country_code,omitempty"`
BaseLatencyMs int64 `json:"base_latency_ms,omitempty"`
PassedCount int `json:"passed_count"`
WarnCount int `json:"warn_count"`
FailedCount int `json:"failed_count"`
ChallengeCount int `json:"challenge_count"`
CheckedAt int64 `json:"checked_at"`
Items []ProxyQualityCheckItem `json:"items"`
}
type ProxyQualityCheckItem struct {
Target string `json:"target"`
Status string `json:"status"` // pass/warn/fail/challenge
HTTPStatus int `json:"http_status,omitempty"`
LatencyMs int64 `json:"latency_ms,omitempty"`
Message string `json:"message,omitempty"`
CFRay string `json:"cf_ray,omitempty"`
}
// ProxyExitInfo represents proxy exit information from ip-api.com // ProxyExitInfo represents proxy exit information from ip-api.com
type ProxyExitInfo struct { type ProxyExitInfo struct {
IP string IP string
...@@ -302,6 +333,58 @@ type ProxyExitInfoProber interface { ...@@ -302,6 +333,58 @@ type ProxyExitInfoProber interface {
ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error) ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error)
} }
type proxyQualityTarget struct {
Target string
URL string
Method string
AllowedStatuses map[int]struct{}
}
var proxyQualityTargets = []proxyQualityTarget{
{
Target: "openai",
URL: "https://api.openai.com/v1/models",
Method: http.MethodGet,
AllowedStatuses: map[int]struct{}{
http.StatusUnauthorized: {},
},
},
{
Target: "anthropic",
URL: "https://api.anthropic.com/v1/messages",
Method: http.MethodGet,
AllowedStatuses: map[int]struct{}{
http.StatusUnauthorized: {},
http.StatusMethodNotAllowed: {},
http.StatusNotFound: {},
http.StatusBadRequest: {},
},
},
{
Target: "gemini",
URL: "https://generativelanguage.googleapis.com/$discovery/rest?version=v1beta",
Method: http.MethodGet,
AllowedStatuses: map[int]struct{}{
http.StatusOK: {},
},
},
{
Target: "sora",
URL: "https://sora.chatgpt.com/backend/me",
Method: http.MethodGet,
AllowedStatuses: map[int]struct{}{
http.StatusUnauthorized: {},
},
},
}
const (
proxyQualityRequestTimeout = 15 * time.Second
proxyQualityResponseHeaderTimeout = 10 * time.Second
proxyQualityMaxBodyBytes = int64(8 * 1024)
proxyQualityClientUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36"
)
// adminServiceImpl implements AdminService // adminServiceImpl implements AdminService
type adminServiceImpl struct { type adminServiceImpl struct {
userRepo UserRepository userRepo UserRepository
...@@ -1054,9 +1137,9 @@ func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates [] ...@@ -1054,9 +1137,9 @@ func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []
} }
// Account management implementations // Account management implementations
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) { func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search) accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
...@@ -1690,6 +1773,270 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR ...@@ -1690,6 +1773,270 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR
}, nil }, nil
} }
func (s *adminServiceImpl) CheckProxyQuality(ctx context.Context, id int64) (*ProxyQualityCheckResult, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
result := &ProxyQualityCheckResult{
ProxyID: id,
Score: 100,
Grade: "A",
CheckedAt: time.Now().Unix(),
Items: make([]ProxyQualityCheckItem, 0, len(proxyQualityTargets)+1),
}
proxyURL := proxy.URL()
if s.proxyProber == nil {
result.Items = append(result.Items, ProxyQualityCheckItem{
Target: "base_connectivity",
Status: "fail",
Message: "代理探测服务未配置",
})
result.FailedCount++
finalizeProxyQualityResult(result)
s.saveProxyQualitySnapshot(ctx, id, result, nil)
return result, nil
}
exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL)
if err != nil {
result.Items = append(result.Items, ProxyQualityCheckItem{
Target: "base_connectivity",
Status: "fail",
LatencyMs: latencyMs,
Message: err.Error(),
})
result.FailedCount++
finalizeProxyQualityResult(result)
s.saveProxyQualitySnapshot(ctx, id, result, nil)
return result, nil
}
result.ExitIP = exitInfo.IP
result.Country = exitInfo.Country
result.CountryCode = exitInfo.CountryCode
result.BaseLatencyMs = latencyMs
result.Items = append(result.Items, ProxyQualityCheckItem{
Target: "base_connectivity",
Status: "pass",
LatencyMs: latencyMs,
Message: "代理出口连通正常",
})
result.PassedCount++
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL,
Timeout: proxyQualityRequestTimeout,
ResponseHeaderTimeout: proxyQualityResponseHeaderTimeout,
ProxyStrict: true,
})
if err != nil {
result.Items = append(result.Items, ProxyQualityCheckItem{
Target: "http_client",
Status: "fail",
Message: fmt.Sprintf("创建检测客户端失败: %v", err),
})
result.FailedCount++
finalizeProxyQualityResult(result)
s.saveProxyQualitySnapshot(ctx, id, result, exitInfo)
return result, nil
}
for _, target := range proxyQualityTargets {
item := runProxyQualityTarget(ctx, client, target)
result.Items = append(result.Items, item)
switch item.Status {
case "pass":
result.PassedCount++
case "warn":
result.WarnCount++
case "challenge":
result.ChallengeCount++
default:
result.FailedCount++
}
}
finalizeProxyQualityResult(result)
s.saveProxyQualitySnapshot(ctx, id, result, exitInfo)
return result, nil
}
func runProxyQualityTarget(ctx context.Context, client *http.Client, target proxyQualityTarget) ProxyQualityCheckItem {
item := ProxyQualityCheckItem{
Target: target.Target,
}
req, err := http.NewRequestWithContext(ctx, target.Method, target.URL, nil)
if err != nil {
item.Status = "fail"
item.Message = fmt.Sprintf("构建请求失败: %v", err)
return item
}
req.Header.Set("Accept", "application/json,text/html,*/*")
req.Header.Set("User-Agent", proxyQualityClientUserAgent)
start := time.Now()
resp, err := client.Do(req)
if err != nil {
item.Status = "fail"
item.LatencyMs = time.Since(start).Milliseconds()
item.Message = fmt.Sprintf("请求失败: %v", err)
return item
}
defer func() { _ = resp.Body.Close() }()
item.LatencyMs = time.Since(start).Milliseconds()
item.HTTPStatus = resp.StatusCode
body, readErr := io.ReadAll(io.LimitReader(resp.Body, proxyQualityMaxBodyBytes+1))
if readErr != nil {
item.Status = "fail"
item.Message = fmt.Sprintf("读取响应失败: %v", readErr)
return item
}
if int64(len(body)) > proxyQualityMaxBodyBytes {
body = body[:proxyQualityMaxBodyBytes]
}
if target.Target == "sora" && soraerror.IsCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) {
item.Status = "challenge"
item.CFRay = soraerror.ExtractCloudflareRayID(resp.Header, body)
item.Message = "Sora 命中 Cloudflare challenge"
return item
}
if _, ok := target.AllowedStatuses[resp.StatusCode]; ok {
if resp.StatusCode >= http.StatusOK && resp.StatusCode < http.StatusMultipleChoices {
item.Status = "pass"
item.Message = fmt.Sprintf("HTTP %d", resp.StatusCode)
} else {
item.Status = "warn"
item.Message = fmt.Sprintf("HTTP %d(目标可达,但鉴权或方法受限)", resp.StatusCode)
}
return item
}
if resp.StatusCode == http.StatusTooManyRequests {
item.Status = "warn"
item.Message = "目标返回 429,可能存在频控"
return item
}
item.Status = "fail"
item.Message = fmt.Sprintf("非预期状态码: %d", resp.StatusCode)
return item
}
func finalizeProxyQualityResult(result *ProxyQualityCheckResult) {
if result == nil {
return
}
score := 100 - result.WarnCount*10 - result.FailedCount*22 - result.ChallengeCount*30
if score < 0 {
score = 0
}
result.Score = score
result.Grade = proxyQualityGrade(score)
result.Summary = fmt.Sprintf(
"通过 %d 项,告警 %d 项,失败 %d 项,挑战 %d 项",
result.PassedCount,
result.WarnCount,
result.FailedCount,
result.ChallengeCount,
)
}
func proxyQualityGrade(score int) string {
switch {
case score >= 90:
return "A"
case score >= 75:
return "B"
case score >= 60:
return "C"
case score >= 40:
return "D"
default:
return "F"
}
}
func proxyQualityOverallStatus(result *ProxyQualityCheckResult) string {
if result == nil {
return ""
}
if result.ChallengeCount > 0 {
return "challenge"
}
if result.FailedCount > 0 {
return "failed"
}
if result.WarnCount > 0 {
return "warn"
}
if result.PassedCount > 0 {
return "healthy"
}
return "failed"
}
func proxyQualityFirstCFRay(result *ProxyQualityCheckResult) string {
if result == nil {
return ""
}
for _, item := range result.Items {
if item.CFRay != "" {
return item.CFRay
}
}
return ""
}
func proxyQualityBaseConnectivityPass(result *ProxyQualityCheckResult) bool {
if result == nil {
return false
}
for _, item := range result.Items {
if item.Target == "base_connectivity" {
return item.Status == "pass"
}
}
return false
}
func (s *adminServiceImpl) saveProxyQualitySnapshot(ctx context.Context, proxyID int64, result *ProxyQualityCheckResult, exitInfo *ProxyExitInfo) {
if result == nil {
return
}
score := result.Score
checkedAt := result.CheckedAt
info := &ProxyLatencyInfo{
Success: proxyQualityBaseConnectivityPass(result),
Message: result.Summary,
QualityStatus: proxyQualityOverallStatus(result),
QualityScore: &score,
QualityGrade: result.Grade,
QualitySummary: result.Summary,
QualityCheckedAt: &checkedAt,
QualityCFRay: proxyQualityFirstCFRay(result),
UpdatedAt: time.Now(),
}
if result.BaseLatencyMs > 0 {
latency := result.BaseLatencyMs
info.LatencyMs = &latency
}
if exitInfo != nil {
info.IPAddress = exitInfo.IP
info.Country = exitInfo.Country
info.CountryCode = exitInfo.CountryCode
info.Region = exitInfo.Region
info.City = exitInfo.City
}
s.saveProxyLatency(ctx, proxyID, info)
}
func (s *adminServiceImpl) probeProxyLatency(ctx context.Context, proxy *Proxy) { func (s *adminServiceImpl) probeProxyLatency(ctx context.Context, proxy *Proxy) {
if s.proxyProber == nil || proxy == nil { if s.proxyProber == nil || proxy == nil {
return return
...@@ -1800,6 +2147,11 @@ func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []Pro ...@@ -1800,6 +2147,11 @@ func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []Pro
proxies[i].CountryCode = info.CountryCode proxies[i].CountryCode = info.CountryCode
proxies[i].Region = info.Region proxies[i].Region = info.Region
proxies[i].City = info.City proxies[i].City = info.City
proxies[i].QualityStatus = info.QualityStatus
proxies[i].QualityScore = info.QualityScore
proxies[i].QualityGrade = info.QualityGrade
proxies[i].QualitySummary = info.QualitySummary
proxies[i].QualityChecked = info.QualityCheckedAt
} }
} }
...@@ -1807,7 +2159,27 @@ func (s *adminServiceImpl) saveProxyLatency(ctx context.Context, proxyID int64, ...@@ -1807,7 +2159,27 @@ func (s *adminServiceImpl) saveProxyLatency(ctx context.Context, proxyID int64,
if s.proxyLatencyCache == nil || info == nil { if s.proxyLatencyCache == nil || info == nil {
return return
} }
if err := s.proxyLatencyCache.SetProxyLatency(ctx, proxyID, info); err != nil {
merged := *info
if latencies, err := s.proxyLatencyCache.GetProxyLatencies(ctx, []int64{proxyID}); err == nil {
if existing := latencies[proxyID]; existing != nil {
if merged.QualityCheckedAt == nil &&
merged.QualityScore == nil &&
merged.QualityGrade == "" &&
merged.QualityStatus == "" &&
merged.QualitySummary == "" &&
merged.QualityCFRay == "" {
merged.QualityStatus = existing.QualityStatus
merged.QualityScore = existing.QualityScore
merged.QualityGrade = existing.QualityGrade
merged.QualitySummary = existing.QualitySummary
merged.QualityCheckedAt = existing.QualityCheckedAt
merged.QualityCFRay = existing.QualityCFRay
}
}
}
if err := s.proxyLatencyCache.SetProxyLatency(ctx, proxyID, &merged); err != nil {
logger.LegacyPrintf("service.admin", "Warning: store proxy latency cache failed: %v", err) logger.LegacyPrintf("service.admin", "Warning: store proxy latency cache failed: %v", err)
} }
} }
......
package service
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
)
func TestFinalizeProxyQualityResult_ScoreAndGrade(t *testing.T) {
result := &ProxyQualityCheckResult{
PassedCount: 2,
WarnCount: 1,
FailedCount: 1,
ChallengeCount: 1,
}
finalizeProxyQualityResult(result)
require.Equal(t, 38, result.Score)
require.Equal(t, "F", result.Grade)
require.Contains(t, result.Summary, "通过 2 项")
require.Contains(t, result.Summary, "告警 1 项")
require.Contains(t, result.Summary, "失败 1 项")
require.Contains(t, result.Summary, "挑战 1 项")
}
func TestRunProxyQualityTarget_SoraChallenge(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/html")
w.Header().Set("cf-ray", "test-ray-123")
w.WriteHeader(http.StatusForbidden)
_, _ = w.Write([]byte("<!DOCTYPE html><title>Just a moment...</title><script>window._cf_chl_opt={};</script>"))
}))
defer server.Close()
target := proxyQualityTarget{
Target: "sora",
URL: server.URL,
Method: http.MethodGet,
AllowedStatuses: map[int]struct{}{
http.StatusUnauthorized: {},
},
}
item := runProxyQualityTarget(context.Background(), server.Client(), target)
require.Equal(t, "challenge", item.Status)
require.Equal(t, http.StatusForbidden, item.HTTPStatus)
require.Equal(t, "test-ray-123", item.CFRay)
}
func TestRunProxyQualityTarget_AllowedStatusPass(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"models":[]}`))
}))
defer server.Close()
target := proxyQualityTarget{
Target: "gemini",
URL: server.URL,
Method: http.MethodGet,
AllowedStatuses: map[int]struct{}{
http.StatusOK: {},
},
}
item := runProxyQualityTarget(context.Background(), server.Client(), target)
require.Equal(t, "pass", item.Status)
require.Equal(t, http.StatusOK, item.HTTPStatus)
}
func TestRunProxyQualityTarget_AllowedStatusWarnForUnauthorized(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte(`{"error":"unauthorized"}`))
}))
defer server.Close()
target := proxyQualityTarget{
Target: "openai",
URL: server.URL,
Method: http.MethodGet,
AllowedStatuses: map[int]struct{}{
http.StatusUnauthorized: {},
},
}
item := runProxyQualityTarget(context.Background(), server.Client(), target)
require.Equal(t, "warn", item.Status)
require.Equal(t, http.StatusUnauthorized, item.HTTPStatus)
require.Contains(t, item.Message, "目标可达")
}
...@@ -24,7 +24,7 @@ type accountRepoStubForAdminList struct { ...@@ -24,7 +24,7 @@ type accountRepoStubForAdminList struct {
listWithFiltersErr error listWithFiltersErr error
} }
func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) { func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
s.listWithFiltersCalls++ s.listWithFiltersCalls++
s.listWithFiltersParams = params s.listWithFiltersParams = params
s.listWithFiltersPlatform = platform s.listWithFiltersPlatform = platform
...@@ -168,7 +168,7 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) { ...@@ -168,7 +168,7 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
} }
svc := &adminServiceImpl{accountRepo: repo} svc := &adminServiceImpl{accountRepo: repo}
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc") accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(10), total) require.Equal(t, int64(10), total)
require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts) require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts)
......
...@@ -4117,6 +4117,15 @@ func (s *AntigravityGatewayService) extractSSEUsage(line string, usage *ClaudeUs ...@@ -4117,6 +4117,15 @@ func (s *AntigravityGatewayService) extractSSEUsage(line string, usage *ClaudeUs
if v, ok := u["cache_creation_input_tokens"].(float64); ok && int(v) > 0 { if v, ok := u["cache_creation_input_tokens"].(float64); ok && int(v) > 0 {
usage.CacheCreationInputTokens = int(v) usage.CacheCreationInputTokens = int(v)
} }
// 解析嵌套的 cache_creation 对象中的 5m/1h 明细
if cc, ok := u["cache_creation"].(map[string]any); ok {
if v, ok := cc["ephemeral_5m_input_tokens"].(float64); ok {
usage.CacheCreation5mTokens = int(v)
}
if v, ok := cc["ephemeral_1h_input_tokens"].(float64); ok {
usage.CacheCreation1hTokens = int(v)
}
}
} }
// extractClaudeUsage 从非流式 Claude 响应提取 usage // extractClaudeUsage 从非流式 Claude 响应提取 usage
...@@ -4139,6 +4148,15 @@ func (s *AntigravityGatewayService) extractClaudeUsage(body []byte) *ClaudeUsage ...@@ -4139,6 +4148,15 @@ func (s *AntigravityGatewayService) extractClaudeUsage(body []byte) *ClaudeUsage
if v, ok := u["cache_creation_input_tokens"].(float64); ok { if v, ok := u["cache_creation_input_tokens"].(float64); ok {
usage.CacheCreationInputTokens = int(v) usage.CacheCreationInputTokens = int(v)
} }
// 解析嵌套的 cache_creation 对象中的 5m/1h 明细
if cc, ok := u["cache_creation"].(map[string]any); ok {
if v, ok := cc["ephemeral_5m_input_tokens"].(float64); ok {
usage.CacheCreation5mTokens = int(v)
}
if v, ok := cc["ephemeral_1h_input_tokens"].(float64); ok {
usage.CacheCreation1hTokens = int(v)
}
}
} }
return usage return usage
} }
...@@ -31,8 +31,8 @@ type ModelPricing struct { ...@@ -31,8 +31,8 @@ type ModelPricing struct {
OutputPricePerToken float64 // 每token输出价格 (USD) OutputPricePerToken float64 // 每token输出价格 (USD)
CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD) CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD)
CacheReadPricePerToken float64 // 缓存读取每token价格 (USD) CacheReadPricePerToken float64 // 缓存读取每token价格 (USD)
CacheCreation5mPrice float64 // 5分钟缓存创建价格(每百万token)- 仅用于硬编码回退 CacheCreation5mPrice float64 // 5分钟缓存创建每token价格 (USD)
CacheCreation1hPrice float64 // 1小时缓存创建价格(每百万token)- 仅用于硬编码回退 CacheCreation1hPrice float64 // 1小时缓存创建每token价格 (USD)
SupportsCacheBreakdown bool // 是否支持详细的缓存分类 SupportsCacheBreakdown bool // 是否支持详细的缓存分类
} }
...@@ -172,12 +172,20 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) { ...@@ -172,12 +172,20 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
if s.pricingService != nil { if s.pricingService != nil {
litellmPricing := s.pricingService.GetModelPricing(model) litellmPricing := s.pricingService.GetModelPricing(model)
if litellmPricing != nil { if litellmPricing != nil {
// 启用 5m/1h 分类计费的条件:
// 1. 存在 1h 价格
// 2. 1h 价格 > 5m 价格(防止 LiteLLM 数据错误导致少收费)
price5m := litellmPricing.CacheCreationInputTokenCost
price1h := litellmPricing.CacheCreationInputTokenCostAbove1hr
enableBreakdown := price1h > 0 && price1h > price5m
return &ModelPricing{ return &ModelPricing{
InputPricePerToken: litellmPricing.InputCostPerToken, InputPricePerToken: litellmPricing.InputCostPerToken,
OutputPricePerToken: litellmPricing.OutputCostPerToken, OutputPricePerToken: litellmPricing.OutputCostPerToken,
CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost, CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost,
CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost, CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost,
SupportsCacheBreakdown: false, CacheCreation5mPrice: price5m,
CacheCreation1hPrice: price1h,
SupportsCacheBreakdown: enableBreakdown,
}, nil }, nil
} }
} }
...@@ -209,9 +217,14 @@ func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMul ...@@ -209,9 +217,14 @@ func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMul
// 计算缓存费用 // 计算缓存费用
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) { if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) {
// 支持详细缓存分类的模型(5分钟/1小时缓存) // 支持详细缓存分类的模型(5分钟/1小时缓存,价格为 per-token)
breakdown.CacheCreationCost = float64(tokens.CacheCreation5mTokens)/1_000_000*pricing.CacheCreation5mPrice + if tokens.CacheCreation5mTokens == 0 && tokens.CacheCreation1hTokens == 0 && tokens.CacheCreationTokens > 0 {
float64(tokens.CacheCreation1hTokens)/1_000_000*pricing.CacheCreation1hPrice // API 未返回 ephemeral 明细,回退到全部按 5m 单价计费
breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreation5mPrice
} else {
breakdown.CacheCreationCost = float64(tokens.CacheCreation5mTokens)*pricing.CacheCreation5mPrice +
float64(tokens.CacheCreation1hTokens)*pricing.CacheCreation1hPrice
}
} else { } else {
// 标准缓存创建价格(per-token) // 标准缓存创建价格(per-token)
breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken
...@@ -280,10 +293,12 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage ...@@ -280,10 +293,12 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage
// 范围内部分:正常计费 // 范围内部分:正常计费
inRangeTokens := UsageTokens{ inRangeTokens := UsageTokens{
InputTokens: inRangeInputTokens, InputTokens: inRangeInputTokens,
OutputTokens: tokens.OutputTokens, // 输出只算一次 OutputTokens: tokens.OutputTokens, // 输出只算一次
CacheCreationTokens: tokens.CacheCreationTokens, CacheCreationTokens: tokens.CacheCreationTokens,
CacheReadTokens: inRangeCacheTokens, CacheReadTokens: inRangeCacheTokens,
CacheCreation5mTokens: tokens.CacheCreation5mTokens,
CacheCreation1hTokens: tokens.CacheCreation1hTokens,
} }
inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier) inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier)
if err != nil { if err != nil {
......
...@@ -399,8 +399,8 @@ func TestCalculateCost_SupportsCacheBreakdown(t *testing.T) { ...@@ -399,8 +399,8 @@ func TestCalculateCost_SupportsCacheBreakdown(t *testing.T) {
InputPricePerToken: 3e-6, InputPricePerToken: 3e-6,
OutputPricePerToken: 15e-6, OutputPricePerToken: 15e-6,
SupportsCacheBreakdown: true, SupportsCacheBreakdown: true,
CacheCreation5mPrice: 4.0, // per million tokens CacheCreation5mPrice: 4e-6, // per token
CacheCreation1hPrice: 5.0, // per million tokens CacheCreation1hPrice: 5e-6, // per token
}, },
}, },
} }
...@@ -414,8 +414,8 @@ func TestCalculateCost_SupportsCacheBreakdown(t *testing.T) { ...@@ -414,8 +414,8 @@ func TestCalculateCost_SupportsCacheBreakdown(t *testing.T) {
cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
require.NoError(t, err) require.NoError(t, err)
expected5m := float64(100000) / 1_000_000 * 4.0 expected5m := float64(tokens.CacheCreation5mTokens) * 4e-6
expected1h := float64(50000) / 1_000_000 * 5.0 expected1h := float64(tokens.CacheCreation1hTokens) * 5e-6
require.InDelta(t, expected5m+expected1h, cost.CacheCreationCost, 1e-10) require.InDelta(t, expected5m+expected1h, cost.CacheCreationCost, 1e-10)
} }
......
...@@ -21,3 +21,72 @@ func TestMergeAnthropicBeta_EmptyIncoming(t *testing.T) { ...@@ -21,3 +21,72 @@ func TestMergeAnthropicBeta_EmptyIncoming(t *testing.T) {
) )
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14", got) require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14", got)
} }
func TestStripBetaToken(t *testing.T) {
tests := []struct {
name string
header string
token string
want string
}{
{
name: "token in middle",
header: "oauth-2025-04-20,context-1m-2025-08-07,interleaved-thinking-2025-05-14",
token: "context-1m-2025-08-07",
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
},
{
name: "token at start",
header: "context-1m-2025-08-07,oauth-2025-04-20,interleaved-thinking-2025-05-14",
token: "context-1m-2025-08-07",
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
},
{
name: "token at end",
header: "oauth-2025-04-20,interleaved-thinking-2025-05-14,context-1m-2025-08-07",
token: "context-1m-2025-08-07",
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
},
{
name: "token not present",
header: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
token: "context-1m-2025-08-07",
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
},
{
name: "empty header",
header: "",
token: "context-1m-2025-08-07",
want: "",
},
{
name: "with spaces",
header: "oauth-2025-04-20, context-1m-2025-08-07 , interleaved-thinking-2025-05-14",
token: "context-1m-2025-08-07",
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
},
{
name: "only token",
header: "context-1m-2025-08-07",
token: "context-1m-2025-08-07",
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := stripBetaToken(tt.header, tt.token)
require.Equal(t, tt.want, got)
})
}
}
func TestMergeAnthropicBetaDropping_Context1M(t *testing.T) {
required := []string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"}
incoming := "context-1m-2025-08-07,foo-beta,oauth-2025-04-20"
drop := map[string]struct{}{"context-1m-2025-08-07": {}}
got := mergeAnthropicBetaDropping(required, incoming, drop)
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,foo-beta", got)
require.NotContains(t, got, "context-1m-2025-08-07")
}
...@@ -92,7 +92,7 @@ func (m *mockAccountRepoForPlatform) Delete(ctx context.Context, id int64) error ...@@ -92,7 +92,7 @@ func (m *mockAccountRepoForPlatform) Delete(ctx context.Context, id int64) error
func (m *mockAccountRepoForPlatform) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) { func (m *mockAccountRepoForPlatform) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
return nil, nil, nil return nil, nil, nil
} }
func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) { func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
return nil, nil, nil return nil, nil, nil
} }
func (m *mockAccountRepoForPlatform) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) { func (m *mockAccountRepoForPlatform) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
......
...@@ -349,6 +349,8 @@ type ClaudeUsage struct { ...@@ -349,6 +349,8 @@ type ClaudeUsage struct {
OutputTokens int `json:"output_tokens"` OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens"` CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
CacheReadInputTokens int `json:"cache_read_input_tokens"` CacheReadInputTokens int `json:"cache_read_input_tokens"`
CacheCreation5mTokens int // 5分钟缓存创建token(来自嵌套 cache_creation 对象)
CacheCreation1hTokens int // 1小时缓存创建token(来自嵌套 cache_creation 对象)
} }
// ForwardResult 转发结果 // ForwardResult 转发结果
...@@ -373,9 +375,10 @@ type ForwardResult struct { ...@@ -373,9 +375,10 @@ type ForwardResult struct {
// UpstreamFailoverError indicates an upstream error that should trigger account failover. // UpstreamFailoverError indicates an upstream error that should trigger account failover.
type UpstreamFailoverError struct { type UpstreamFailoverError struct {
StatusCode int StatusCode int
ResponseBody []byte // 上游响应体,用于错误透传规则匹配 ResponseBody []byte // 上游响应体,用于错误透传规则匹配
ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true ResponseHeaders http.Header // 上游响应头,用于透传 cf-ray/cf-mitigated/content-type 等诊断信息
RetryableOnSameAccount bool // 临时性错误(如 Google 间歇性 400、空响应),应在同一账号上重试 N 次再切换 ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true
RetryableOnSameAccount bool // 临时性错误(如 Google 间歇性 400、空响应),应在同一账号上重试 N 次再切换
} }
func (e *UpstreamFailoverError) Error() string { func (e *UpstreamFailoverError) Error() string {
...@@ -3580,12 +3583,12 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex ...@@ -3580,12 +3583,12 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// messages requests typically use only oauth + interleaved-thinking. // messages requests typically use only oauth + interleaved-thinking.
// Also drop claude-code beta if a downstream client added it. // Also drop claude-code beta if a downstream client added it.
requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking} requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking}
drop := map[string]struct{}{claude.BetaClaudeCode: {}} drop := map[string]struct{}{claude.BetaClaudeCode: {}, claude.BetaContext1M: {}}
req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, drop)) req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, drop))
} else { } else {
// Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta // Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta
clientBetaHeader := req.Header.Get("anthropic-beta") clientBetaHeader := req.Header.Get("anthropic-beta")
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, clientBetaHeader)) req.Header.Set("anthropic-beta", stripBetaToken(s.getBetaHeader(modelID, clientBetaHeader), claude.BetaContext1M))
} }
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" { } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭) // API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
...@@ -3739,6 +3742,23 @@ func mergeAnthropicBetaDropping(required []string, incoming string, drop map[str ...@@ -3739,6 +3742,23 @@ func mergeAnthropicBetaDropping(required []string, incoming string, drop map[str
return strings.Join(out, ",") return strings.Join(out, ",")
} }
// stripBetaToken removes a single beta token from a comma-separated header value.
// It short-circuits when the token is not present to avoid unnecessary allocations.
func stripBetaToken(header, token string) string {
if !strings.Contains(header, token) {
return header
}
out := make([]string, 0, 8)
for _, p := range strings.Split(header, ",") {
p = strings.TrimSpace(p)
if p == "" || p == token {
continue
}
out = append(out, p)
}
return strings.Join(out, ",")
}
// applyClaudeCodeMimicHeaders forces "Claude Code-like" request headers. // applyClaudeCodeMimicHeaders forces "Claude Code-like" request headers.
// This mirrors opencode-anthropic-auth behavior: do not trust downstream // This mirrors opencode-anthropic-auth behavior: do not trust downstream
// headers when using Claude Code-scoped OAuth credentials. // headers when using Claude Code-scoped OAuth credentials.
...@@ -4305,6 +4325,23 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http ...@@ -4305,6 +4325,23 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
} }
} }
// Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类
if account.IsCacheTTLOverrideEnabled() {
overrideTarget := account.GetCacheTTLOverrideTarget()
if eventType == "message_start" {
if msg, ok := event["message"].(map[string]any); ok {
if u, ok := msg["usage"].(map[string]any); ok {
rewriteCacheCreationJSON(u, overrideTarget)
}
}
}
if eventType == "message_delta" {
if u, ok := event["usage"].(map[string]any); ok {
rewriteCacheCreationJSON(u, overrideTarget)
}
}
}
if needModelReplace { if needModelReplace {
if msg, ok := event["message"].(map[string]any); ok { if msg, ok := event["message"].(map[string]any); ok {
if model, ok := msg["model"].(string); ok && model == mappedModel { if model, ok := msg["model"].(string); ok && model == mappedModel {
...@@ -4432,6 +4469,14 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) { ...@@ -4432,6 +4469,14 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
usage.InputTokens = msgStart.Message.Usage.InputTokens usage.InputTokens = msgStart.Message.Usage.InputTokens
usage.CacheCreationInputTokens = msgStart.Message.Usage.CacheCreationInputTokens usage.CacheCreationInputTokens = msgStart.Message.Usage.CacheCreationInputTokens
usage.CacheReadInputTokens = msgStart.Message.Usage.CacheReadInputTokens usage.CacheReadInputTokens = msgStart.Message.Usage.CacheReadInputTokens
// 解析嵌套的 cache_creation 对象中的 5m/1h 明细
cc5m := gjson.Get(data, "message.usage.cache_creation.ephemeral_5m_input_tokens")
cc1h := gjson.Get(data, "message.usage.cache_creation.ephemeral_1h_input_tokens")
if cc5m.Exists() || cc1h.Exists() {
usage.CacheCreation5mTokens = int(cc5m.Int())
usage.CacheCreation1hTokens = int(cc1h.Int())
}
} }
// 解析message_delta获取tokens(兼容GLM等把所有usage放在delta中的API) // 解析message_delta获取tokens(兼容GLM等把所有usage放在delta中的API)
...@@ -4460,6 +4505,68 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) { ...@@ -4460,6 +4505,68 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
if msgDelta.Usage.CacheReadInputTokens > 0 { if msgDelta.Usage.CacheReadInputTokens > 0 {
usage.CacheReadInputTokens = msgDelta.Usage.CacheReadInputTokens usage.CacheReadInputTokens = msgDelta.Usage.CacheReadInputTokens
} }
// 解析嵌套的 cache_creation 对象中的 5m/1h 明细
cc5m := gjson.Get(data, "usage.cache_creation.ephemeral_5m_input_tokens")
cc1h := gjson.Get(data, "usage.cache_creation.ephemeral_1h_input_tokens")
if cc5m.Exists() && cc5m.Int() > 0 {
usage.CacheCreation5mTokens = int(cc5m.Int())
}
if cc1h.Exists() && cc1h.Int() > 0 {
usage.CacheCreation1hTokens = int(cc1h.Int())
}
}
}
// applyCacheTTLOverride 将所有 cache creation tokens 归入指定的 TTL 类型。
// target 为 "5m" 或 "1h"。返回 true 表示发生了变更。
func applyCacheTTLOverride(usage *ClaudeUsage, target string) bool {
// Fallback: 如果只有聚合字段但无 5m/1h 明细,将聚合字段归入 5m 默认类别
if usage.CacheCreation5mTokens == 0 && usage.CacheCreation1hTokens == 0 && usage.CacheCreationInputTokens > 0 {
usage.CacheCreation5mTokens = usage.CacheCreationInputTokens
}
total := usage.CacheCreation5mTokens + usage.CacheCreation1hTokens
if total == 0 {
return false
}
switch target {
case "1h":
if usage.CacheCreation1hTokens == total {
return false // 已经全是 1h
}
usage.CacheCreation1hTokens = total
usage.CacheCreation5mTokens = 0
default: // "5m"
if usage.CacheCreation5mTokens == total {
return false // 已经全是 5m
}
usage.CacheCreation5mTokens = total
usage.CacheCreation1hTokens = 0
}
return true
}
// rewriteCacheCreationJSON 在 JSON usage 对象中重写 cache_creation 嵌套对象的 TTL 分类。
// usageObj 是 usage JSON 对象(map[string]any)。
func rewriteCacheCreationJSON(usageObj map[string]any, target string) {
ccObj, ok := usageObj["cache_creation"].(map[string]any)
if !ok {
return
}
v5m, _ := ccObj["ephemeral_5m_input_tokens"].(float64)
v1h, _ := ccObj["ephemeral_1h_input_tokens"].(float64)
total := v5m + v1h
if total == 0 {
return
}
switch target {
case "1h":
ccObj["ephemeral_1h_input_tokens"] = total
ccObj["ephemeral_5m_input_tokens"] = float64(0)
default: // "5m"
ccObj["ephemeral_5m_input_tokens"] = total
ccObj["ephemeral_1h_input_tokens"] = float64(0)
} }
} }
...@@ -4491,6 +4598,14 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h ...@@ -4491,6 +4598,14 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
return nil, fmt.Errorf("parse response: %w", err) return nil, fmt.Errorf("parse response: %w", err)
} }
// 解析嵌套的 cache_creation 对象中的 5m/1h 明细
cc5m := gjson.GetBytes(body, "usage.cache_creation.ephemeral_5m_input_tokens")
cc1h := gjson.GetBytes(body, "usage.cache_creation.ephemeral_1h_input_tokens")
if cc5m.Exists() || cc1h.Exists() {
response.Usage.CacheCreation5mTokens = int(cc5m.Int())
response.Usage.CacheCreation1hTokens = int(cc1h.Int())
}
// 兼容 Kimi cached_tokens → cache_read_input_tokens // 兼容 Kimi cached_tokens → cache_read_input_tokens
if response.Usage.CacheReadInputTokens == 0 { if response.Usage.CacheReadInputTokens == 0 {
cachedTokens := gjson.GetBytes(body, "usage.cached_tokens").Int() cachedTokens := gjson.GetBytes(body, "usage.cached_tokens").Int()
...@@ -4502,6 +4617,20 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h ...@@ -4502,6 +4617,20 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
} }
} }
// Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类
if account.IsCacheTTLOverrideEnabled() {
overrideTarget := account.GetCacheTTLOverrideTarget()
if applyCacheTTLOverride(&response.Usage, overrideTarget) {
// 同步更新 body JSON 中的嵌套 cache_creation 对象
if newBody, err := sjson.SetBytes(body, "usage.cache_creation.ephemeral_5m_input_tokens", response.Usage.CacheCreation5mTokens); err == nil {
body = newBody
}
if newBody, err := sjson.SetBytes(body, "usage.cache_creation.ephemeral_1h_input_tokens", response.Usage.CacheCreation1hTokens); err == nil {
body = newBody
}
}
}
// 如果有模型映射,替换响应中的model字段 // 如果有模型映射,替换响应中的model字段
if originalModel != mappedModel { if originalModel != mappedModel {
body = s.replaceModelInResponseBody(body, mappedModel, originalModel) body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
...@@ -4570,6 +4699,13 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -4570,6 +4699,13 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
result.Usage.InputTokens = 0 result.Usage.InputTokens = 0
} }
// Cache TTL Override: 确保计费时 token 分类与账号设置一致
cacheTTLOverridden := false
if account.IsCacheTTLOverrideEnabled() {
applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget())
cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0
}
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认) // 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
multiplier := s.cfg.Default.RateMultiplier multiplier := s.cfg.Default.RateMultiplier
if apiKey.GroupID != nil && apiKey.Group != nil { if apiKey.GroupID != nil && apiKey.Group != nil {
...@@ -4617,10 +4753,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -4617,10 +4753,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
} else { } else {
// Token 计费 // Token 计费
tokens := UsageTokens{ tokens := UsageTokens{
InputTokens: result.Usage.InputTokens, InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens, OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens, CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens,
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
} }
var err error var err error
cost, err = s.billingService.CalculateCost(result.Model, tokens, multiplier) cost, err = s.billingService.CalculateCost(result.Model, tokens, multiplier)
...@@ -4658,6 +4796,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -4658,6 +4796,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
OutputTokens: result.Usage.OutputTokens, OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens, CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens,
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
InputCost: cost.InputCost, InputCost: cost.InputCost,
OutputCost: cost.OutputCost, OutputCost: cost.OutputCost,
CacheCreationCost: cost.CacheCreationCost, CacheCreationCost: cost.CacheCreationCost,
...@@ -4673,6 +4813,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -4673,6 +4813,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
ImageCount: result.ImageCount, ImageCount: result.ImageCount,
ImageSize: imageSize, ImageSize: imageSize,
MediaType: mediaType, MediaType: mediaType,
CacheTTLOverridden: cacheTTLOverridden,
CreatedAt: time.Now(), CreatedAt: time.Now(),
} }
...@@ -4773,6 +4914,13 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * ...@@ -4773,6 +4914,13 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
result.Usage.InputTokens = 0 result.Usage.InputTokens = 0
} }
// Cache TTL Override: 确保计费时 token 分类与账号设置一致
cacheTTLOverridden := false
if account.IsCacheTTLOverrideEnabled() {
applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget())
cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0
}
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认) // 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
multiplier := s.cfg.Default.RateMultiplier multiplier := s.cfg.Default.RateMultiplier
if apiKey.GroupID != nil && apiKey.Group != nil { if apiKey.GroupID != nil && apiKey.Group != nil {
...@@ -4803,10 +4951,12 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * ...@@ -4803,10 +4951,12 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
} else { } else {
// Token 计费(使用长上下文计费方法) // Token 计费(使用长上下文计费方法)
tokens := UsageTokens{ tokens := UsageTokens{
InputTokens: result.Usage.InputTokens, InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens, OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens, CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens,
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
} }
var err error var err error
cost, err = s.billingService.CalculateCostWithLongContext(result.Model, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier) cost, err = s.billingService.CalculateCostWithLongContext(result.Model, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier)
...@@ -4840,6 +4990,8 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * ...@@ -4840,6 +4990,8 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
OutputTokens: result.Usage.OutputTokens, OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens, CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens,
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
InputCost: cost.InputCost, InputCost: cost.InputCost,
OutputCost: cost.OutputCost, OutputCost: cost.OutputCost,
CacheCreationCost: cost.CacheCreationCost, CacheCreationCost: cost.CacheCreationCost,
...@@ -4854,6 +5006,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * ...@@ -4854,6 +5006,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
FirstTokenMs: result.FirstTokenMs, FirstTokenMs: result.FirstTokenMs,
ImageCount: result.ImageCount, ImageCount: result.ImageCount,
ImageSize: imageSize, ImageSize: imageSize,
CacheTTLOverridden: cacheTTLOverridden,
CreatedAt: time.Now(), CreatedAt: time.Now(),
} }
...@@ -5170,7 +5323,8 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con ...@@ -5170,7 +5323,8 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
incomingBeta := req.Header.Get("anthropic-beta") incomingBeta := req.Header.Get("anthropic-beta")
requiredBetas := []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking, claude.BetaTokenCounting} requiredBetas := []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking, claude.BetaTokenCounting}
req.Header.Set("anthropic-beta", mergeAnthropicBeta(requiredBetas, incomingBeta)) drop := map[string]struct{}{claude.BetaContext1M: {}}
req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, drop))
} else { } else {
clientBetaHeader := req.Header.Get("anthropic-beta") clientBetaHeader := req.Header.Get("anthropic-beta")
if clientBetaHeader == "" { if clientBetaHeader == "" {
...@@ -5180,7 +5334,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con ...@@ -5180,7 +5334,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
if !strings.Contains(beta, claude.BetaTokenCounting) { if !strings.Contains(beta, claude.BetaTokenCounting) {
beta = beta + "," + claude.BetaTokenCounting beta = beta + "," + claude.BetaTokenCounting
} }
req.Header.Set("anthropic-beta", beta) req.Header.Set("anthropic-beta", stripBetaToken(beta, claude.BetaContext1M))
} }
} }
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" { } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
......
...@@ -79,6 +79,22 @@ func TestParseSSEUsage_DeltaOverwritesWithNonZero(t *testing.T) { ...@@ -79,6 +79,22 @@ func TestParseSSEUsage_DeltaOverwritesWithNonZero(t *testing.T) {
require.Equal(t, 60, usage.CacheReadInputTokens) require.Equal(t, 60, usage.CacheReadInputTokens)
} }
func TestParseSSEUsage_DeltaDoesNotResetCacheCreationBreakdown(t *testing.T) {
svc := newMinimalGatewayService()
usage := &ClaudeUsage{}
// 先在 message_start 中写入非零 5m/1h 明细
svc.parseSSEUsage(`{"type":"message_start","message":{"usage":{"input_tokens":100,"cache_creation":{"ephemeral_5m_input_tokens":30,"ephemeral_1h_input_tokens":70}}}}`, usage)
require.Equal(t, 30, usage.CacheCreation5mTokens)
require.Equal(t, 70, usage.CacheCreation1hTokens)
// 后续 delta 带默认 0,不应覆盖已有非零值
svc.parseSSEUsage(`{"type":"message_delta","usage":{"output_tokens":12,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0}}}`, usage)
require.Equal(t, 30, usage.CacheCreation5mTokens, "delta 的 0 值不应重置 5m 明细")
require.Equal(t, 70, usage.CacheCreation1hTokens, "delta 的 0 值不应重置 1h 明细")
require.Equal(t, 12, usage.OutputTokens)
}
func TestParseSSEUsage_InvalidJSON(t *testing.T) { func TestParseSSEUsage_InvalidJSON(t *testing.T) {
svc := newMinimalGatewayService() svc := newMinimalGatewayService()
usage := &ClaudeUsage{} usage := &ClaudeUsage{}
......
...@@ -79,7 +79,7 @@ func (m *mockAccountRepoForGemini) Delete(ctx context.Context, id int64) error ...@@ -79,7 +79,7 @@ func (m *mockAccountRepoForGemini) Delete(ctx context.Context, id int64) error
func (m *mockAccountRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) { func (m *mockAccountRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
return nil, nil, nil return nil, nil, nil
} }
func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) { func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
return nil, nil, nil return nil, nil, nil
} }
func (m *mockAccountRepoForGemini) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) { func (m *mockAccountRepoForGemini) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
......
...@@ -14,6 +14,7 @@ import ( ...@@ -14,6 +14,7 @@ import (
type OpenAIOAuthClient interface { type OpenAIOAuthClient interface {
ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error)
RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error)
RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error)
} }
// ClaudeOAuthClient handles HTTP requests for Claude OAuth flows // ClaudeOAuthClient handles HTTP requests for Claude OAuth flows
......
...@@ -99,13 +99,19 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool) codexTran ...@@ -99,13 +99,19 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool) codexTran
result.Modified = true result.Modified = true
} }
if _, ok := reqBody["max_output_tokens"]; ok { // Strip parameters unsupported by codex models via the Responses API.
delete(reqBody, "max_output_tokens") for _, key := range []string{
result.Modified = true "max_output_tokens",
} "max_completion_tokens",
if _, ok := reqBody["max_completion_tokens"]; ok { "temperature",
delete(reqBody, "max_completion_tokens") "top_p",
result.Modified = true "frequency_penalty",
"presence_penalty",
} {
if _, ok := reqBody[key]; ok {
delete(reqBody, key)
result.Modified = true
}
} }
if normalizeCodexTools(reqBody) { if normalizeCodexTools(reqBody) {
......
...@@ -2,13 +2,20 @@ package service ...@@ -2,13 +2,20 @@ package service
import ( import (
"context" "context"
"crypto/subtle"
"encoding/json"
"io"
"net/http" "net/http"
"net/url"
"strings"
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
) )
var openAISoraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session"
// OpenAIOAuthService handles OpenAI OAuth authentication flows // OpenAIOAuthService handles OpenAI OAuth authentication flows
type OpenAIOAuthService struct { type OpenAIOAuthService struct {
sessionStore *openai.SessionStore sessionStore *openai.SessionStore
...@@ -92,6 +99,7 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64 ...@@ -92,6 +99,7 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
type OpenAIExchangeCodeInput struct { type OpenAIExchangeCodeInput struct {
SessionID string SessionID string
Code string Code string
State string
RedirectURI string RedirectURI string
ProxyID *int64 ProxyID *int64
} }
...@@ -116,6 +124,12 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch ...@@ -116,6 +124,12 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
if !ok { if !ok {
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_SESSION_NOT_FOUND", "session not found or expired") return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_SESSION_NOT_FOUND", "session not found or expired")
} }
if input.State == "" {
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_STATE_REQUIRED", "oauth state is required")
}
if subtle.ConstantTimeCompare([]byte(input.State), []byte(session.State)) != 1 {
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_STATE", "invalid oauth state")
}
// Get proxy URL: prefer input.ProxyID, fallback to session.ProxyURL // Get proxy URL: prefer input.ProxyID, fallback to session.ProxyURL
proxyURL := session.ProxyURL proxyURL := session.ProxyURL
...@@ -173,7 +187,12 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch ...@@ -173,7 +187,12 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
// RefreshToken refreshes an OpenAI OAuth token // RefreshToken refreshes an OpenAI OAuth token
func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken string, proxyURL string) (*OpenAITokenInfo, error) { func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken string, proxyURL string) (*OpenAITokenInfo, error) {
tokenResp, err := s.oauthClient.RefreshToken(ctx, refreshToken, proxyURL) return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "")
}
// RefreshTokenWithClientID refreshes an OpenAI/Sora OAuth token with optional client_id.
func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken string, proxyURL string, clientID string) (*OpenAITokenInfo, error) {
tokenResp, err := s.oauthClient.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -205,13 +224,83 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri ...@@ -205,13 +224,83 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri
return tokenInfo, nil return tokenInfo, nil
} }
// RefreshAccountToken refreshes token for an OpenAI account // ExchangeSoraSessionToken exchanges Sora session_token to access_token.
func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessionToken string, proxyID *int64) (*OpenAITokenInfo, error) {
if strings.TrimSpace(sessionToken) == "" {
return nil, infraerrors.New(http.StatusBadRequest, "SORA_SESSION_TOKEN_REQUIRED", "session_token is required")
}
proxyURL, err := s.resolveProxyURL(ctx, proxyID)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, openAISoraSessionAuthURL, nil)
if err != nil {
return nil, infraerrors.Newf(http.StatusInternalServerError, "SORA_SESSION_REQUEST_BUILD_FAILED", "failed to build request: %v", err)
}
req.Header.Set("Cookie", "__Secure-next-auth.session-token="+strings.TrimSpace(sessionToken))
req.Header.Set("Accept", "application/json")
req.Header.Set("Origin", "https://sora.chatgpt.com")
req.Header.Set("Referer", "https://sora.chatgpt.com/")
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
client := newOpenAIOAuthHTTPClient(proxyURL)
resp, err := client.Do(req)
if err != nil {
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_REQUEST_FAILED", "request failed: %v", err)
}
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
if resp.StatusCode != http.StatusOK {
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_EXCHANGE_FAILED", "status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var sessionResp struct {
AccessToken string `json:"accessToken"`
Expires string `json:"expires"`
User struct {
Email string `json:"email"`
Name string `json:"name"`
} `json:"user"`
}
if err := json.Unmarshal(body, &sessionResp); err != nil {
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_PARSE_FAILED", "failed to parse response: %v", err)
}
if strings.TrimSpace(sessionResp.AccessToken) == "" {
return nil, infraerrors.New(http.StatusBadGateway, "SORA_SESSION_ACCESS_TOKEN_MISSING", "session exchange response missing access token")
}
expiresAt := time.Now().Add(time.Hour).Unix()
if strings.TrimSpace(sessionResp.Expires) != "" {
if parsed, parseErr := time.Parse(time.RFC3339, sessionResp.Expires); parseErr == nil {
expiresAt = parsed.Unix()
}
}
expiresIn := expiresAt - time.Now().Unix()
if expiresIn < 0 {
expiresIn = 0
}
return &OpenAITokenInfo{
AccessToken: strings.TrimSpace(sessionResp.AccessToken),
ExpiresIn: expiresIn,
ExpiresAt: expiresAt,
Email: strings.TrimSpace(sessionResp.User.Email),
}, nil
}
// RefreshAccountToken refreshes token for an OpenAI/Sora OAuth account
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) { func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
if !account.IsOpenAI() { if account.Platform != PlatformOpenAI && account.Platform != PlatformSora {
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI account") return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI/Sora account")
}
if account.Type != AccountTypeOAuth {
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT_TYPE", "account is not an OAuth account")
} }
refreshToken := account.GetOpenAIRefreshToken() refreshToken := account.GetCredential("refresh_token")
if refreshToken == "" { if refreshToken == "" {
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available") return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available")
} }
...@@ -224,7 +313,8 @@ func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *A ...@@ -224,7 +313,8 @@ func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *A
} }
} }
return s.RefreshToken(ctx, refreshToken, proxyURL) clientID := account.GetCredential("client_id")
return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
} }
// BuildAccountCredentials builds credentials map from token info // BuildAccountCredentials builds credentials map from token info
...@@ -260,3 +350,30 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo) ...@@ -260,3 +350,30 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo)
func (s *OpenAIOAuthService) Stop() { func (s *OpenAIOAuthService) Stop() {
s.sessionStore.Stop() s.sessionStore.Stop()
} }
func (s *OpenAIOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64) (string, error) {
if proxyID == nil {
return "", nil
}
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
if err != nil {
return "", infraerrors.Newf(http.StatusBadRequest, "OPENAI_OAUTH_PROXY_NOT_FOUND", "proxy not found: %v", err)
}
if proxy == nil {
return "", nil
}
return proxy.URL(), nil
}
func newOpenAIOAuthHTTPClient(proxyURL string) *http.Client {
transport := &http.Transport{}
if strings.TrimSpace(proxyURL) != "" {
if parsed, err := url.Parse(proxyURL); err == nil && parsed.Host != "" {
transport.Proxy = http.ProxyURL(parsed)
}
}
return &http.Client{
Timeout: 120 * time.Second,
Transport: transport,
}
}
package service
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/stretchr/testify/require"
)
type openaiOAuthClientNoopStub struct{}
func (s *openaiOAuthClientNoopStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
return nil, errors.New("not implemented")
}
func (s *openaiOAuthClientNoopStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
return nil, errors.New("not implemented")
}
func (s *openaiOAuthClientNoopStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
return nil, errors.New("not implemented")
}
func TestOpenAIOAuthService_ExchangeSoraSessionToken_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=st-token")
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
}))
defer server.Close()
origin := openAISoraSessionAuthURL
openAISoraSessionAuthURL = server.URL
defer func() { openAISoraSessionAuthURL = origin }()
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
defer svc.Stop()
info, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil)
require.NoError(t, err)
require.NotNil(t, info)
require.Equal(t, "at-token", info.AccessToken)
require.Equal(t, "demo@example.com", info.Email)
require.Greater(t, info.ExpiresAt, int64(0))
}
func TestOpenAIOAuthService_ExchangeSoraSessionToken_MissingAccessToken(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"expires":"2099-01-01T00:00:00Z"}`))
}))
defer server.Close()
origin := openAISoraSessionAuthURL
openAISoraSessionAuthURL = server.URL
defer func() { openAISoraSessionAuthURL = origin }()
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
defer svc.Stop()
_, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "missing access token")
}
package service
import (
"context"
"errors"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/stretchr/testify/require"
)
type openaiOAuthClientStateStub struct {
exchangeCalled int32
}
func (s *openaiOAuthClientStateStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
atomic.AddInt32(&s.exchangeCalled, 1)
return &openai.TokenResponse{
AccessToken: "at",
RefreshToken: "rt",
ExpiresIn: 3600,
}, nil
}
func (s *openaiOAuthClientStateStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
return nil, errors.New("not implemented")
}
func (s *openaiOAuthClientStateStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
return s.RefreshToken(ctx, refreshToken, proxyURL)
}
func TestOpenAIOAuthService_ExchangeCode_StateRequired(t *testing.T) {
client := &openaiOAuthClientStateStub{}
svc := NewOpenAIOAuthService(nil, client)
defer svc.Stop()
svc.sessionStore.Set("sid", &openai.OAuthSession{
State: "expected-state",
CodeVerifier: "verifier",
RedirectURI: openai.DefaultRedirectURI,
CreatedAt: time.Now(),
})
_, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{
SessionID: "sid",
Code: "auth-code",
})
require.Error(t, err)
require.Contains(t, err.Error(), "oauth state is required")
require.Equal(t, int32(0), atomic.LoadInt32(&client.exchangeCalled))
}
func TestOpenAIOAuthService_ExchangeCode_StateMismatch(t *testing.T) {
client := &openaiOAuthClientStateStub{}
svc := NewOpenAIOAuthService(nil, client)
defer svc.Stop()
svc.sessionStore.Set("sid", &openai.OAuthSession{
State: "expected-state",
CodeVerifier: "verifier",
RedirectURI: openai.DefaultRedirectURI,
CreatedAt: time.Now(),
})
_, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{
SessionID: "sid",
Code: "auth-code",
State: "wrong-state",
})
require.Error(t, err)
require.Contains(t, err.Error(), "invalid oauth state")
require.Equal(t, int32(0), atomic.LoadInt32(&client.exchangeCalled))
}
func TestOpenAIOAuthService_ExchangeCode_StateMatch(t *testing.T) {
client := &openaiOAuthClientStateStub{}
svc := NewOpenAIOAuthService(nil, client)
defer svc.Stop()
svc.sessionStore.Set("sid", &openai.OAuthSession{
State: "expected-state",
CodeVerifier: "verifier",
RedirectURI: openai.DefaultRedirectURI,
CreatedAt: time.Now(),
})
info, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{
SessionID: "sid",
Code: "auth-code",
State: "expected-state",
})
require.NoError(t, err)
require.NotNil(t, info)
require.Equal(t, "at", info.AccessToken)
require.Equal(t, int32(1), atomic.LoadInt32(&client.exchangeCalled))
_, ok := svc.sessionStore.Get("sid")
require.False(t, ok)
}
...@@ -157,7 +157,11 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou ...@@ -157,7 +157,11 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
} }
expiresAt = account.GetCredentialAsTime("expires_at") expiresAt = account.GetCredentialAsTime("expires_at")
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew { if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
if p.openAIOAuthService == nil { if account.Platform == PlatformSora {
slog.Debug("openai_token_refresh_skipped_for_sora", "account_id", account.ID)
// Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。
refreshFailed = true
} else if p.openAIOAuthService == nil {
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID) slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
p.metrics.refreshFailure.Add(1) p.metrics.refreshFailure.Add(1)
refreshFailed = true // 无法刷新,标记失败 refreshFailed = true // 无法刷新,标记失败
...@@ -206,7 +210,11 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou ...@@ -206,7 +210,11 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新 // 仅在 expires_at 已过期/接近过期时才执行无锁刷新
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew { if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
if p.openAIOAuthService == nil { if account.Platform == PlatformSora {
slog.Debug("openai_token_refresh_skipped_for_sora_degraded", "account_id", account.ID)
// Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。
refreshFailed = true
} else if p.openAIOAuthService == nil {
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID) slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
p.metrics.refreshFailure.Add(1) p.metrics.refreshFailure.Add(1)
refreshFailed = true refreshFailed = true
......
...@@ -24,7 +24,7 @@ func (s *OpsService) listAllAccountsForOps(ctx context.Context, platformFilter s ...@@ -24,7 +24,7 @@ func (s *OpsService) listAllAccountsForOps(ctx context.Context, platformFilter s
accounts, pageInfo, err := s.accountRepo.ListWithFilters(ctx, pagination.PaginationParams{ accounts, pageInfo, err := s.accountRepo.ListWithFilters(ctx, pagination.PaginationParams{
Page: page, Page: page,
PageSize: opsAccountsPageSize, PageSize: opsAccountsPageSize,
}, platformFilter, "", "", "") }, platformFilter, "", "", "", 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }
......
...@@ -28,14 +28,15 @@ var ( ...@@ -28,14 +28,15 @@ var (
// LiteLLMModelPricing LiteLLM价格数据结构 // LiteLLMModelPricing LiteLLM价格数据结构
// 只保留我们需要的字段,使用指针来处理可能缺失的值 // 只保留我们需要的字段,使用指针来处理可能缺失的值
type LiteLLMModelPricing struct { type LiteLLMModelPricing struct {
InputCostPerToken float64 `json:"input_cost_per_token"` InputCostPerToken float64 `json:"input_cost_per_token"`
OutputCostPerToken float64 `json:"output_cost_per_token"` OutputCostPerToken float64 `json:"output_cost_per_token"`
CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"` CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"`
CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"` CacheCreationInputTokenCostAbove1hr float64 `json:"cache_creation_input_token_cost_above_1hr"`
LiteLLMProvider string `json:"litellm_provider"` CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"`
Mode string `json:"mode"` LiteLLMProvider string `json:"litellm_provider"`
SupportsPromptCaching bool `json:"supports_prompt_caching"` Mode string `json:"mode"`
OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格 SupportsPromptCaching bool `json:"supports_prompt_caching"`
OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格
} }
// PricingRemoteClient 远程价格数据获取接口 // PricingRemoteClient 远程价格数据获取接口
...@@ -46,14 +47,15 @@ type PricingRemoteClient interface { ...@@ -46,14 +47,15 @@ type PricingRemoteClient interface {
// LiteLLMRawEntry 用于解析原始JSON数据 // LiteLLMRawEntry 用于解析原始JSON数据
type LiteLLMRawEntry struct { type LiteLLMRawEntry struct {
InputCostPerToken *float64 `json:"input_cost_per_token"` InputCostPerToken *float64 `json:"input_cost_per_token"`
OutputCostPerToken *float64 `json:"output_cost_per_token"` OutputCostPerToken *float64 `json:"output_cost_per_token"`
CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost"` CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost"`
CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost"` CacheCreationInputTokenCostAbove1hr *float64 `json:"cache_creation_input_token_cost_above_1hr"`
LiteLLMProvider string `json:"litellm_provider"` CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost"`
Mode string `json:"mode"` LiteLLMProvider string `json:"litellm_provider"`
SupportsPromptCaching bool `json:"supports_prompt_caching"` Mode string `json:"mode"`
OutputCostPerImage *float64 `json:"output_cost_per_image"` SupportsPromptCaching bool `json:"supports_prompt_caching"`
OutputCostPerImage *float64 `json:"output_cost_per_image"`
} }
// PricingService 动态价格服务 // PricingService 动态价格服务
...@@ -319,6 +321,9 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel ...@@ -319,6 +321,9 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel
if entry.CacheCreationInputTokenCost != nil { if entry.CacheCreationInputTokenCost != nil {
pricing.CacheCreationInputTokenCost = *entry.CacheCreationInputTokenCost pricing.CacheCreationInputTokenCost = *entry.CacheCreationInputTokenCost
} }
if entry.CacheCreationInputTokenCostAbove1hr != nil {
pricing.CacheCreationInputTokenCostAbove1hr = *entry.CacheCreationInputTokenCostAbove1hr
}
if entry.CacheReadInputTokenCost != nil { if entry.CacheReadInputTokenCost != nil {
pricing.CacheReadInputTokenCost = *entry.CacheReadInputTokenCost pricing.CacheReadInputTokenCost = *entry.CacheReadInputTokenCost
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment