Commit 987589ea authored by yangjianbo's avatar yangjianbo
Browse files

Merge branch 'test' into release

parents 372e04f6 03f69dd3
...@@ -236,6 +236,24 @@ func (h *ProxyHandler) Test(c *gin.Context) { ...@@ -236,6 +236,24 @@ func (h *ProxyHandler) Test(c *gin.Context) {
response.Success(c, result) response.Success(c, result)
} }
// CheckQuality handles checking proxy quality across common AI targets.
// POST /api/v1/admin/proxies/:id/quality-check
func (h *ProxyHandler) CheckQuality(c *gin.Context) {
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid proxy ID")
return
}
result, err := h.adminService.CheckProxyQuality(c.Request.Context(), proxyID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
// GetStats handles getting proxy statistics // GetStats handles getting proxy statistics
// GET /api/v1/admin/proxies/:id/stats // GET /api/v1/admin/proxies/:id/stats
func (h *ProxyHandler) GetStats(c *gin.Context) { func (h *ProxyHandler) GetStats(c *gin.Context) {
......
...@@ -214,6 +214,13 @@ func AccountFromServiceShallow(a *service.Account) *Account { ...@@ -214,6 +214,13 @@ func AccountFromServiceShallow(a *service.Account) *Account {
enabled := true enabled := true
out.EnableSessionIDMasking = &enabled out.EnableSessionIDMasking = &enabled
} }
// 缓存 TTL 强制替换
if a.IsCacheTTLOverrideEnabled() {
enabled := true
out.CacheTTLOverrideEnabled = &enabled
target := a.GetCacheTTLOverrideTarget()
out.CacheTTLOverrideTarget = &target
}
} }
return out return out
...@@ -296,6 +303,11 @@ func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWi ...@@ -296,6 +303,11 @@ func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWi
CountryCode: p.CountryCode, CountryCode: p.CountryCode,
Region: p.Region, Region: p.Region,
City: p.City, City: p.City,
QualityStatus: p.QualityStatus,
QualityScore: p.QualityScore,
QualityGrade: p.QualityGrade,
QualitySummary: p.QualitySummary,
QualityChecked: p.QualityChecked,
} }
} }
...@@ -402,6 +414,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog { ...@@ -402,6 +414,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
ImageSize: l.ImageSize, ImageSize: l.ImageSize,
MediaType: l.MediaType, MediaType: l.MediaType,
UserAgent: l.UserAgent, UserAgent: l.UserAgent,
CacheTTLOverridden: l.CacheTTLOverridden,
CreatedAt: l.CreatedAt, CreatedAt: l.CreatedAt,
User: UserFromServiceShallow(l.User), User: UserFromServiceShallow(l.User),
APIKey: APIKeyFromService(l.APIKey), APIKey: APIKeyFromService(l.APIKey),
......
...@@ -156,6 +156,11 @@ type Account struct { ...@@ -156,6 +156,11 @@ type Account struct {
// 从 extra 字段提取,方便前端显示和编辑 // 从 extra 字段提取,方便前端显示和编辑
EnableSessionIDMasking *bool `json:"session_id_masking_enabled,omitempty"` EnableSessionIDMasking *bool `json:"session_id_masking_enabled,omitempty"`
// 缓存 TTL 强制替换(仅 Anthropic OAuth/SetupToken 账号有效)
// 启用后将所有 cache creation tokens 归入指定的 TTL 类型计费
CacheTTLOverrideEnabled *bool `json:"cache_ttl_override_enabled,omitempty"`
CacheTTLOverrideTarget *string `json:"cache_ttl_override_target,omitempty"`
Proxy *Proxy `json:"proxy,omitempty"` Proxy *Proxy `json:"proxy,omitempty"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"` AccountGroups []AccountGroup `json:"account_groups,omitempty"`
...@@ -197,6 +202,11 @@ type ProxyWithAccountCount struct { ...@@ -197,6 +202,11 @@ type ProxyWithAccountCount struct {
CountryCode string `json:"country_code,omitempty"` CountryCode string `json:"country_code,omitempty"`
Region string `json:"region,omitempty"` Region string `json:"region,omitempty"`
City string `json:"city,omitempty"` City string `json:"city,omitempty"`
QualityStatus string `json:"quality_status,omitempty"`
QualityScore *int `json:"quality_score,omitempty"`
QualityGrade string `json:"quality_grade,omitempty"`
QualitySummary string `json:"quality_summary,omitempty"`
QualityChecked *int64 `json:"quality_checked,omitempty"`
} }
type ProxyAccountSummary struct { type ProxyAccountSummary struct {
...@@ -280,6 +290,9 @@ type UsageLog struct { ...@@ -280,6 +290,9 @@ type UsageLog struct {
// User-Agent // User-Agent
UserAgent *string `json:"user_agent"` UserAgent *string `json:"user_agent"`
// Cache TTL Override 标记
CacheTTLOverridden bool `json:"cache_ttl_overridden"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
User *User `json:"user,omitempty"` User *User `json:"user,omitempty"`
......
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"context" "context"
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io"
...@@ -20,6 +21,7 @@ import ( ...@@ -20,6 +21,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
...@@ -35,6 +37,7 @@ type SoraGatewayHandler struct { ...@@ -35,6 +37,7 @@ type SoraGatewayHandler struct {
concurrencyHelper *ConcurrencyHelper concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int maxAccountSwitches int
streamMode string streamMode string
soraTLSEnabled bool
soraMediaSigningKey string soraMediaSigningKey string
soraMediaRoot string soraMediaRoot string
} }
...@@ -50,6 +53,7 @@ func NewSoraGatewayHandler( ...@@ -50,6 +53,7 @@ func NewSoraGatewayHandler(
pingInterval := time.Duration(0) pingInterval := time.Duration(0)
maxAccountSwitches := 3 maxAccountSwitches := 3
streamMode := "force" streamMode := "force"
soraTLSEnabled := true
signKey := "" signKey := ""
mediaRoot := "/app/data/sora" mediaRoot := "/app/data/sora"
if cfg != nil { if cfg != nil {
...@@ -60,6 +64,7 @@ func NewSoraGatewayHandler( ...@@ -60,6 +64,7 @@ func NewSoraGatewayHandler(
if mode := strings.TrimSpace(cfg.Gateway.SoraStreamMode); mode != "" { if mode := strings.TrimSpace(cfg.Gateway.SoraStreamMode); mode != "" {
streamMode = mode streamMode = mode
} }
soraTLSEnabled = !cfg.Sora.Client.DisableTLSFingerprint
signKey = strings.TrimSpace(cfg.Gateway.SoraMediaSigningKey) signKey = strings.TrimSpace(cfg.Gateway.SoraMediaSigningKey)
if root := strings.TrimSpace(cfg.Sora.Storage.LocalPath); root != "" { if root := strings.TrimSpace(cfg.Sora.Storage.LocalPath); root != "" {
mediaRoot = root mediaRoot = root
...@@ -72,6 +77,7 @@ func NewSoraGatewayHandler( ...@@ -72,6 +77,7 @@ func NewSoraGatewayHandler(
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
maxAccountSwitches: maxAccountSwitches, maxAccountSwitches: maxAccountSwitches,
streamMode: strings.ToLower(streamMode), streamMode: strings.ToLower(streamMode),
soraTLSEnabled: soraTLSEnabled,
soraMediaSigningKey: signKey, soraMediaSigningKey: signKey,
soraMediaRoot: mediaRoot, soraMediaRoot: mediaRoot,
} }
...@@ -212,6 +218,8 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { ...@@ -212,6 +218,8 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
switchCount := 0 switchCount := 0
failedAccountIDs := make(map[int64]struct{}) failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0 lastFailoverStatus := 0
var lastFailoverBody []byte
var lastFailoverHeaders http.Header
for { for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "") selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "")
...@@ -224,11 +232,31 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { ...@@ -224,11 +232,31 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return return
} }
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
fields := []zap.Field{
zap.Int("last_upstream_status", lastFailoverStatus),
}
if rayID != "" {
fields = append(fields, zap.String("last_upstream_cf_ray", rayID))
}
if mitigated != "" {
fields = append(fields, zap.String("last_upstream_cf_mitigated", mitigated))
}
if contentType != "" {
fields = append(fields, zap.String("last_upstream_content_type", contentType))
}
reqLog.Warn("sora.failover_exhausted_no_available_accounts", fields...)
h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted)
return return
} }
account := selection.Account account := selection.Account
setOpsSelectedAccount(c, account.ID, account.Platform) setOpsSelectedAccount(c, account.ID, account.Platform)
proxyBound := account.ProxyID != nil
proxyID := int64(0)
if account.ProxyID != nil {
proxyID = *account.ProxyID
}
tlsFingerprintEnabled := h.soraTLSEnabled
accountReleaseFunc := selection.ReleaseFunc accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired { if !selection.Acquired {
...@@ -239,10 +267,19 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { ...@@ -239,10 +267,19 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
accountWaitCounted := false accountWaitCounted := false
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil { if err != nil {
reqLog.Warn("sora.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err)) reqLog.Warn("sora.account_wait_counter_increment_failed",
zap.Int64("account_id", account.ID),
zap.Int64("proxy_id", proxyID),
zap.Bool("proxy_bound", proxyBound),
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
zap.Error(err),
)
} else if !canWait { } else if !canWait {
reqLog.Info("sora.account_wait_queue_full", reqLog.Info("sora.account_wait_queue_full",
zap.Int64("account_id", account.ID), zap.Int64("account_id", account.ID),
zap.Int64("proxy_id", proxyID),
zap.Bool("proxy_bound", proxyBound),
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting), zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
) )
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
...@@ -266,7 +303,13 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { ...@@ -266,7 +303,13 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
&streamStarted, &streamStarted,
) )
if err != nil { if err != nil {
reqLog.Warn("sora.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) reqLog.Warn("sora.account_slot_acquire_failed",
zap.Int64("account_id", account.ID),
zap.Int64("proxy_id", proxyID),
zap.Bool("proxy_bound", proxyBound),
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
zap.Error(err),
)
h.handleConcurrencyError(c, err, "account", streamStarted) h.handleConcurrencyError(c, err, "account", streamStarted)
return return
} }
...@@ -287,20 +330,67 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { ...@@ -287,20 +330,67 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
failedAccountIDs[account.ID] = struct{}{} failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches { if switchCount >= maxAccountSwitches {
lastFailoverStatus = failoverErr.StatusCode lastFailoverStatus = failoverErr.StatusCode
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders)
lastFailoverBody = failoverErr.ResponseBody
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
fields := []zap.Field{
zap.Int64("account_id", account.ID),
zap.Int64("proxy_id", proxyID),
zap.Bool("proxy_bound", proxyBound),
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
}
if rayID != "" {
fields = append(fields, zap.String("upstream_cf_ray", rayID))
}
if mitigated != "" {
fields = append(fields, zap.String("upstream_cf_mitigated", mitigated))
}
if contentType != "" {
fields = append(fields, zap.String("upstream_content_type", contentType))
}
reqLog.Warn("sora.upstream_failover_exhausted", fields...)
h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted)
return return
} }
lastFailoverStatus = failoverErr.StatusCode lastFailoverStatus = failoverErr.StatusCode
lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders)
lastFailoverBody = failoverErr.ResponseBody
switchCount++ switchCount++
reqLog.Warn("sora.upstream_failover_switching", upstreamErrCode, upstreamErrMsg := extractUpstreamErrorCodeAndMessage(lastFailoverBody)
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
fields := []zap.Field{
zap.Int64("account_id", account.ID), zap.Int64("account_id", account.ID),
zap.Int64("proxy_id", proxyID),
zap.Bool("proxy_bound", proxyBound),
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
zap.Int("upstream_status", failoverErr.StatusCode), zap.Int("upstream_status", failoverErr.StatusCode),
zap.String("upstream_error_code", upstreamErrCode),
zap.String("upstream_error_message", upstreamErrMsg),
zap.Int("switch_count", switchCount), zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches), zap.Int("max_switches", maxAccountSwitches),
) }
if rayID != "" {
fields = append(fields, zap.String("upstream_cf_ray", rayID))
}
if mitigated != "" {
fields = append(fields, zap.String("upstream_cf_mitigated", mitigated))
}
if contentType != "" {
fields = append(fields, zap.String("upstream_content_type", contentType))
}
reqLog.Warn("sora.upstream_failover_switching", fields...)
continue continue
} }
reqLog.Error("sora.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err)) reqLog.Error("sora.forward_failed",
zap.Int64("account_id", account.ID),
zap.Int64("proxy_id", proxyID),
zap.Bool("proxy_bound", proxyBound),
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
zap.Error(err),
)
return return
} }
...@@ -331,6 +421,9 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { ...@@ -331,6 +421,9 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
}(result, account, userAgent, clientIP) }(result, account, userAgent, clientIP)
reqLog.Debug("sora.request_completed", reqLog.Debug("sora.request_completed",
zap.Int64("account_id", account.ID), zap.Int64("account_id", account.ID),
zap.Int64("proxy_id", proxyID),
zap.Bool("proxy_bound", proxyBound),
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
zap.Int("switch_count", switchCount), zap.Int("switch_count", switchCount),
) )
return return
...@@ -360,17 +453,41 @@ func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, s ...@@ -360,17 +453,41 @@ func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, s
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
} }
func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) { func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseHeaders http.Header, responseBody []byte, streamStarted bool) {
status, errType, errMsg := h.mapUpstreamError(statusCode) status, errType, errMsg := h.mapUpstreamError(statusCode, responseHeaders, responseBody)
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
} }
func (h *SoraGatewayHandler) mapUpstreamError(statusCode int) (int, string, string) { func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseHeaders http.Header, responseBody []byte) (int, string, string) {
if isSoraCloudflareChallengeResponse(statusCode, responseHeaders, responseBody) {
baseMsg := fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", statusCode)
return http.StatusBadGateway, "upstream_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody)
}
upstreamCode, upstreamMessage := extractUpstreamErrorCodeAndMessage(responseBody)
if strings.EqualFold(upstreamCode, "cf_shield_429") {
baseMsg := "Sora request blocked by Cloudflare shield (429). Please switch to a clean proxy/network and retry."
return http.StatusTooManyRequests, "rate_limit_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody)
}
if shouldPassthroughSoraUpstreamMessage(statusCode, upstreamMessage) {
switch statusCode {
case 401, 403, 404, 500, 502, 503, 504:
return http.StatusBadGateway, "upstream_error", upstreamMessage
case 429:
return http.StatusTooManyRequests, "rate_limit_error", upstreamMessage
}
}
switch statusCode { switch statusCode {
case 401: case 401:
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator" return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
case 403: case 403:
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator" return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
case 404:
if strings.EqualFold(upstreamCode, "unsupported_country_code") {
return http.StatusBadGateway, "upstream_error", "Upstream region capability unavailable for this account, please contact administrator"
}
return http.StatusBadGateway, "upstream_error", "Upstream capability unavailable for this account, please contact administrator"
case 429: case 429:
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later" return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
case 529: case 529:
...@@ -382,11 +499,67 @@ func (h *SoraGatewayHandler) mapUpstreamError(statusCode int) (int, string, stri ...@@ -382,11 +499,67 @@ func (h *SoraGatewayHandler) mapUpstreamError(statusCode int) (int, string, stri
} }
} }
func cloneHTTPHeaders(headers http.Header) http.Header {
if headers == nil {
return nil
}
return headers.Clone()
}
func extractSoraFailoverHeaderInsights(headers http.Header, body []byte) (rayID, mitigated, contentType string) {
if headers != nil {
mitigated = strings.TrimSpace(headers.Get("cf-mitigated"))
contentType = strings.TrimSpace(headers.Get("content-type"))
if contentType == "" {
contentType = strings.TrimSpace(headers.Get("Content-Type"))
}
}
rayID = soraerror.ExtractCloudflareRayID(headers, body)
return rayID, mitigated, contentType
}
func isSoraCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body)
}
func shouldPassthroughSoraUpstreamMessage(statusCode int, message string) bool {
message = strings.TrimSpace(message)
if message == "" {
return false
}
if statusCode == http.StatusForbidden || statusCode == http.StatusTooManyRequests {
lower := strings.ToLower(message)
if strings.Contains(lower, "<html") || strings.Contains(lower, "<!doctype html") || strings.Contains(lower, "window._cf_chl_opt") {
return false
}
}
return true
}
func formatSoraCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
return soraerror.FormatCloudflareChallengeMessage(base, headers, body)
}
func extractUpstreamErrorCodeAndMessage(body []byte) (string, string) {
return soraerror.ExtractUpstreamErrorCodeAndMessage(body)
}
func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) { func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
if streamStarted { if streamStarted {
flusher, ok := c.Writer.(http.Flusher) flusher, ok := c.Writer.(http.Flusher)
if ok { if ok {
errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message) errorData := map[string]any{
"error": map[string]string{
"type": errType,
"message": message,
},
}
jsonBytes, err := json.Marshal(errorData)
if err != nil {
_ = c.Error(err)
return
}
errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes))
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil { if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
_ = c.Error(err) _ = c.Error(err)
} }
......
...@@ -43,6 +43,48 @@ func (s *stubSoraClient) CreateImageTask(ctx context.Context, account *service.A ...@@ -43,6 +43,48 @@ func (s *stubSoraClient) CreateImageTask(ctx context.Context, account *service.A
func (s *stubSoraClient) CreateVideoTask(ctx context.Context, account *service.Account, req service.SoraVideoRequest) (string, error) { func (s *stubSoraClient) CreateVideoTask(ctx context.Context, account *service.Account, req service.SoraVideoRequest) (string, error) {
return "task-video", nil return "task-video", nil
} }
func (s *stubSoraClient) CreateStoryboardTask(ctx context.Context, account *service.Account, req service.SoraStoryboardRequest) (string, error) {
return "task-video", nil
}
func (s *stubSoraClient) UploadCharacterVideo(ctx context.Context, account *service.Account, data []byte) (string, error) {
return "cameo-1", nil
}
func (s *stubSoraClient) GetCameoStatus(ctx context.Context, account *service.Account, cameoID string) (*service.SoraCameoStatus, error) {
return &service.SoraCameoStatus{
Status: "finalized",
StatusMessage: "Completed",
DisplayNameHint: "Character",
UsernameHint: "user.character",
ProfileAssetURL: "https://example.com/avatar.webp",
}, nil
}
func (s *stubSoraClient) DownloadCharacterImage(ctx context.Context, account *service.Account, imageURL string) ([]byte, error) {
return []byte("avatar"), nil
}
func (s *stubSoraClient) UploadCharacterImage(ctx context.Context, account *service.Account, data []byte) (string, error) {
return "asset-pointer", nil
}
func (s *stubSoraClient) FinalizeCharacter(ctx context.Context, account *service.Account, req service.SoraCharacterFinalizeRequest) (string, error) {
return "character-1", nil
}
func (s *stubSoraClient) SetCharacterPublic(ctx context.Context, account *service.Account, cameoID string) error {
return nil
}
func (s *stubSoraClient) DeleteCharacter(ctx context.Context, account *service.Account, characterID string) error {
return nil
}
func (s *stubSoraClient) PostVideoForWatermarkFree(ctx context.Context, account *service.Account, generationID string) (string, error) {
return "s_post", nil
}
func (s *stubSoraClient) DeletePost(ctx context.Context, account *service.Account, postID string) error {
return nil
}
func (s *stubSoraClient) GetWatermarkFreeURLCustom(ctx context.Context, account *service.Account, parseURL, parseToken, postID string) (string, error) {
return "https://example.com/no-watermark.mp4", nil
}
func (s *stubSoraClient) EnhancePrompt(ctx context.Context, account *service.Account, prompt, expansionLevel string, durationS int) (string, error) {
return "enhanced prompt", nil
}
func (s *stubSoraClient) GetImageTask(ctx context.Context, account *service.Account, taskID string) (*service.SoraImageTaskStatus, error) { func (s *stubSoraClient) GetImageTask(ctx context.Context, account *service.Account, taskID string) (*service.SoraImageTaskStatus, error) {
return &service.SoraImageTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil return &service.SoraImageTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil
} }
...@@ -88,7 +130,7 @@ func (r *stubAccountRepo) Delete(ctx context.Context, id int64) error ...@@ -88,7 +130,7 @@ func (r *stubAccountRepo) Delete(ctx context.Context, id int64) error
func (r *stubAccountRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { func (r *stubAccountRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, nil return nil, nil, nil
} }
func (r *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) { func (r *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, nil return nil, nil, nil
} }
func (r *stubAccountRepo) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) { func (r *stubAccountRepo) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) {
...@@ -495,3 +537,152 @@ func TestGenerateOpenAISessionHash_WithBody(t *testing.T) { ...@@ -495,3 +537,152 @@ func TestGenerateOpenAISessionHash_WithBody(t *testing.T) {
require.NotEmpty(t, hash3) require.NotEmpty(t, hash3)
require.NotEqual(t, hash, hash3) // 不同来源应产生不同 hash require.NotEqual(t, hash, hash3) // 不同来源应产生不同 hash
} }
func TestSoraHandleStreamingAwareError_JSONEscaping(t *testing.T) {
tests := []struct {
name string
errType string
message string
}{
{
name: "包含双引号",
errType: "upstream_error",
message: `upstream returned "invalid" payload`,
},
{
name: "包含换行和制表符",
errType: "rate_limit_error",
message: "line1\nline2\ttab",
},
{
name: "包含反斜杠",
errType: "upstream_error",
message: `path C:\Users\test\file.txt not found`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
h := &SoraGatewayHandler{}
h.handleStreamingAwareError(c, http.StatusBadGateway, tt.errType, tt.message, true)
body := w.Body.String()
require.True(t, strings.HasPrefix(body, "event: error\n"), "应以 SSE error 事件开头")
require.True(t, strings.HasSuffix(body, "\n\n"), "应以 SSE 结束分隔符结尾")
lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n")
require.Len(t, lines, 2, "SSE 错误事件应包含 event 行和 data 行")
require.Equal(t, "event: error", lines[0])
require.True(t, strings.HasPrefix(lines[1], "data: "), "第二行应为 data 前缀")
jsonStr := strings.TrimPrefix(lines[1], "data: ")
var parsed map[string]any
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed), "data 行必须是合法 JSON")
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok, "JSON 中应包含 error 对象")
require.Equal(t, tt.errType, errorObj["type"])
require.Equal(t, tt.message, errorObj["message"])
})
}
}
func TestSoraHandleFailoverExhausted_StreamPassesUpstreamMessage(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
h := &SoraGatewayHandler{}
resp := []byte(`{"error":{"message":"invalid \"prompt\"\nline2","code":"bad_request"}}`)
h.handleFailoverExhausted(c, http.StatusBadGateway, nil, resp, true)
body := w.Body.String()
require.True(t, strings.HasPrefix(body, "event: error\n"))
require.True(t, strings.HasSuffix(body, "\n\n"))
lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n")
require.Len(t, lines, 2)
jsonStr := strings.TrimPrefix(lines[1], "data: ")
var parsed map[string]any
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok)
require.Equal(t, "upstream_error", errorObj["type"])
require.Equal(t, "invalid \"prompt\"\nline2", errorObj["message"])
}
func TestSoraHandleFailoverExhausted_CloudflareChallengeIncludesRay(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
headers := http.Header{}
headers.Set("cf-ray", "9d01b0e9ecc35829-SEA")
body := []byte(`<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={};</script></body></html>`)
h := &SoraGatewayHandler{}
h.handleFailoverExhausted(c, http.StatusForbidden, headers, body, true)
lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n")
require.Len(t, lines, 2)
jsonStr := strings.TrimPrefix(lines[1], "data: ")
var parsed map[string]any
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok)
require.Equal(t, "upstream_error", errorObj["type"])
msg, _ := errorObj["message"].(string)
require.Contains(t, msg, "Cloudflare challenge")
require.Contains(t, msg, "cf-ray: 9d01b0e9ecc35829-SEA")
}
func TestSoraHandleFailoverExhausted_CfShield429MappedToRateLimitError(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
headers := http.Header{}
headers.Set("cf-ray", "9d03b68c086027a1-SEA")
body := []byte(`{"error":{"code":"cf_shield_429","message":"shield blocked"}}`)
h := &SoraGatewayHandler{}
h.handleFailoverExhausted(c, http.StatusTooManyRequests, headers, body, true)
lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n")
require.Len(t, lines, 2)
jsonStr := strings.TrimPrefix(lines[1], "data: ")
var parsed map[string]any
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok)
require.Equal(t, "rate_limit_error", errorObj["type"])
msg, _ := errorObj["message"].(string)
require.Contains(t, msg, "Cloudflare shield")
require.Contains(t, msg, "cf-ray: 9d03b68c086027a1-SEA")
}
func TestExtractSoraFailoverHeaderInsights(t *testing.T) {
headers := http.Header{}
headers.Set("cf-mitigated", "challenge")
headers.Set("content-type", "text/html")
body := []byte(`<script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script>`)
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(headers, body)
require.Equal(t, "9cff2d62d83bb98d", rayID)
require.Equal(t, "challenge", mitigated)
require.Equal(t, "text/html", contentType)
}
...@@ -10,6 +10,7 @@ const ( ...@@ -10,6 +10,7 @@ const (
BetaInterleavedThinking = "interleaved-thinking-2025-05-14" BetaInterleavedThinking = "interleaved-thinking-2025-05-14"
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14" BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
BetaTokenCounting = "token-counting-2024-11-01" BetaTokenCounting = "token-counting-2024-11-01"
BetaContext1M = "context-1m-2025-08-07"
) )
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header // DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
...@@ -77,6 +78,12 @@ var DefaultModels = []Model{ ...@@ -77,6 +78,12 @@ var DefaultModels = []Model{
DisplayName: "Claude Opus 4.6", DisplayName: "Claude Opus 4.6",
CreatedAt: "2026-02-06T00:00:00Z", CreatedAt: "2026-02-06T00:00:00Z",
}, },
{
ID: "claude-sonnet-4-6",
Type: "model",
DisplayName: "Claude Sonnet 4.6",
CreatedAt: "2026-02-18T00:00:00Z",
},
{ {
ID: "claude-sonnet-4-5-20250929", ID: "claude-sonnet-4-5-20250929",
Type: "model", Type: "model",
......
...@@ -17,6 +17,8 @@ import ( ...@@ -17,6 +17,8 @@ import (
const ( const (
// OAuth Client ID for OpenAI (Codex CLI official) // OAuth Client ID for OpenAI (Codex CLI official)
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann" ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
// OAuth Client ID for Sora mobile flow (aligned with sora2api)
SoraClientID = "app_LlGpXReQgckcGGUo2JrYvtJK"
// OAuth endpoints // OAuth endpoints
AuthorizeURL = "https://auth.openai.com/oauth/authorize" AuthorizeURL = "https://auth.openai.com/oauth/authorize"
......
...@@ -435,10 +435,10 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error { ...@@ -435,10 +435,10 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error {
} }
func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "", "") return r.ListWithFilters(ctx, params, "", "", "", "", 0)
} }
func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) { func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) {
q := r.client.Account.Query() q := r.client.Account.Query()
if platform != "" { if platform != "" {
...@@ -458,6 +458,9 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati ...@@ -458,6 +458,9 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
if search != "" { if search != "" {
q = q.Where(dbaccount.NameContainsFold(search)) q = q.Where(dbaccount.NameContainsFold(search))
} }
if groupID > 0 {
q = q.Where(dbaccount.HasAccountGroupsWith(dbaccountgroup.GroupIDEQ(groupID)))
}
total, err := q.Count(ctx) total, err := q.Count(ctx)
if err != nil { if err != nil {
......
...@@ -238,7 +238,7 @@ func (s *AccountRepoSuite) TestListWithFilters() { ...@@ -238,7 +238,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
tt.setup(client) tt.setup(client)
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search) accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search, 0)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(accounts, tt.wantCount) s.Require().Len(accounts, tt.wantCount)
if tt.validate != nil { if tt.validate != nil {
...@@ -305,7 +305,7 @@ func (s *AccountRepoSuite) TestPreload_And_VirtualFields() { ...@@ -305,7 +305,7 @@ func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
s.Require().Len(got.Groups, 1, "expected Groups to be populated") s.Require().Len(got.Groups, 1, "expected Groups to be populated")
s.Require().Equal(group.ID, got.Groups[0].ID) s.Require().Equal(group.ID, got.Groups[0].ID)
accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc") accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc", 0)
s.Require().NoError(err, "ListWithFilters") s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total) s.Require().Equal(int64(1), page.Total)
s.Require().Len(accounts, 1) s.Require().Len(accounts, 1)
......
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"context" "context"
"net/http" "net/http"
"net/url" "net/url"
"strings"
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
...@@ -56,12 +57,49 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie ...@@ -56,12 +57,49 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
} }
func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) { func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "")
}
func (s *openaiOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
if strings.TrimSpace(clientID) != "" {
return s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, strings.TrimSpace(clientID))
}
clientIDs := []string{
openai.ClientID,
openai.SoraClientID,
}
seen := make(map[string]struct{}, len(clientIDs))
var lastErr error
for _, clientID := range clientIDs {
clientID = strings.TrimSpace(clientID)
if clientID == "" {
continue
}
if _, ok := seen[clientID]; ok {
continue
}
seen[clientID] = struct{}{}
tokenResp, err := s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
if err == nil {
return tokenResp, nil
}
lastErr = err
}
if lastErr != nil {
return nil, lastErr
}
return nil, infraerrors.New(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_REFRESH_FAILED", "token refresh failed")
}
func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL, clientID string) (*openai.TokenResponse, error) {
client := createOpenAIReqClient(proxyURL) client := createOpenAIReqClient(proxyURL)
formData := url.Values{} formData := url.Values{}
formData.Set("grant_type", "refresh_token") formData.Set("grant_type", "refresh_token")
formData.Set("refresh_token", refreshToken) formData.Set("refresh_token", refreshToken)
formData.Set("client_id", openai.ClientID) formData.Set("client_id", clientID)
formData.Set("scope", openai.RefreshScopes) formData.Set("scope", openai.RefreshScopes)
var tokenResp openai.TokenResponse var tokenResp openai.TokenResponse
......
...@@ -136,6 +136,60 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() { ...@@ -136,6 +136,60 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() {
require.Equal(s.T(), "rt2", resp.RefreshToken) require.Equal(s.T(), "rt2", resp.RefreshToken)
} }
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FallbackToSoraClientID() {
var seenClientIDs []string
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
clientID := r.PostForm.Get("client_id")
seenClientIDs = append(seenClientIDs, clientID)
if clientID == openai.ClientID {
w.WriteHeader(http.StatusBadRequest)
_, _ = io.WriteString(w, "invalid_grant")
return
}
if clientID == openai.SoraClientID {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at-sora","refresh_token":"rt-sora","token_type":"bearer","expires_in":3600}`)
return
}
w.WriteHeader(http.StatusBadRequest)
}))
resp, err := s.svc.RefreshToken(s.ctx, "rt", "")
require.NoError(s.T(), err, "RefreshToken")
require.Equal(s.T(), "at-sora", resp.AccessToken)
require.Equal(s.T(), "rt-sora", resp.RefreshToken)
require.Equal(s.T(), []string{openai.ClientID, openai.SoraClientID}, seenClientIDs)
}
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseProvidedClientID() {
const customClientID = "custom-client-id"
var seenClientIDs []string
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
clientID := r.PostForm.Get("client_id")
seenClientIDs = append(seenClientIDs, clientID)
if clientID != customClientID {
w.WriteHeader(http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at-custom","refresh_token":"rt-custom","token_type":"bearer","expires_in":3600}`)
}))
resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", customClientID)
require.NoError(s.T(), err, "RefreshTokenWithClientID")
require.Equal(s.T(), "at-custom", resp.AccessToken)
require.Equal(s.T(), "rt-custom", resp.RefreshToken)
require.Equal(s.T(), []string{customClientID}, seenClientIDs)
}
func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() { func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
......
...@@ -22,7 +22,7 @@ import ( ...@@ -22,7 +22,7 @@ import (
"github.com/lib/pq" "github.com/lib/pq"
) )
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, created_at" const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, cache_ttl_overridden, created_at"
// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL // dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL
var dateFormatWhitelist = map[string]string{ var dateFormatWhitelist = map[string]string{
...@@ -132,6 +132,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ...@@ -132,6 +132,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
image_size, image_size,
media_type, media_type,
reasoning_effort, reasoning_effort,
cache_ttl_overridden,
created_at created_at
) VALUES ( ) VALUES (
$1, $2, $3, $4, $5, $1, $2, $3, $4, $5,
...@@ -139,7 +140,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ...@@ -139,7 +140,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
$8, $9, $10, $11, $8, $9, $10, $11,
$12, $13, $12, $13,
$14, $15, $16, $17, $18, $19, $14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32 $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33
) )
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at RETURNING id, created_at
...@@ -192,6 +193,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ...@@ -192,6 +193,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
imageSize, imageSize,
mediaType, mediaType,
reasoningEffort, reasoningEffort,
log.CacheTTLOverridden,
createdAt, createdAt,
} }
if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil { if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil {
...@@ -2221,6 +2223,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -2221,6 +2223,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
imageSize sql.NullString imageSize sql.NullString
mediaType sql.NullString mediaType sql.NullString
reasoningEffort sql.NullString reasoningEffort sql.NullString
cacheTTLOverridden bool
createdAt time.Time createdAt time.Time
) )
...@@ -2257,6 +2260,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -2257,6 +2260,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&imageSize, &imageSize,
&mediaType, &mediaType,
&reasoningEffort, &reasoningEffort,
&cacheTTLOverridden,
&createdAt, &createdAt,
); err != nil { ); err != nil {
return nil, err return nil, err
...@@ -2285,6 +2289,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -2285,6 +2289,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
BillingType: int8(billingType), BillingType: int8(billingType),
Stream: stream, Stream: stream,
ImageCount: imageCount, ImageCount: imageCount,
CacheTTLOverridden: cacheTTLOverridden,
CreatedAt: createdAt, CreatedAt: createdAt,
} }
......
...@@ -406,6 +406,7 @@ func TestAPIContracts(t *testing.T) { ...@@ -406,6 +406,7 @@ func TestAPIContracts(t *testing.T) {
"image_count": 0, "image_count": 0,
"image_size": null, "image_size": null,
"media_type": null, "media_type": null,
"cache_ttl_overridden": false,
"created_at": "2025-01-02T03:04:05Z", "created_at": "2025-01-02T03:04:05Z",
"user_agent": null "user_agent": null
} }
...@@ -945,7 +946,7 @@ func (s *stubAccountRepo) List(ctx context.Context, params pagination.Pagination ...@@ -945,7 +946,7 @@ func (s *stubAccountRepo) List(ctx context.Context, params pagination.Pagination
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) { func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
......
...@@ -50,6 +50,19 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc { ...@@ -50,6 +50,19 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc {
} }
allowedSet[origin] = struct{}{} allowedSet[origin] = struct{}{}
} }
allowHeaders := []string{
"Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization",
"accept", "origin", "Cache-Control", "X-Requested-With", "X-API-Key",
}
// OpenAI Node SDK 会发送 x-stainless-* 请求头,需在 CORS 中显式放行。
openAIProperties := []string{
"lang", "package-version", "os", "arch", "retry-count", "runtime",
"runtime-version", "async", "helper-method", "poll-helper", "custom-poll-interval", "timeout",
}
for _, prop := range openAIProperties {
allowHeaders = append(allowHeaders, "x-stainless-"+prop)
}
allowHeadersValue := strings.Join(allowHeaders, ", ")
return func(c *gin.Context) { return func(c *gin.Context) {
origin := strings.TrimSpace(c.GetHeader("Origin")) origin := strings.TrimSpace(c.GetHeader("Origin"))
...@@ -68,12 +81,11 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc { ...@@ -68,12 +81,11 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc {
if allowCredentials { if allowCredentials {
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true") c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
} }
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key") c.Writer.Header().Set("Access-Control-Allow-Headers", allowHeadersValue)
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH") c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")
c.Writer.Header().Set("Access-Control-Expose-Headers", "ETag") c.Writer.Header().Set("Access-Control-Expose-Headers", "ETag")
c.Writer.Header().Set("Access-Control-Max-Age", "86400") c.Writer.Header().Set("Access-Control-Max-Age", "86400")
} }
// 处理预检请求 // 处理预检请求
if c.Request.Method == http.MethodOptions { if c.Request.Method == http.MethodOptions {
if originAllowed { if originAllowed {
......
...@@ -34,6 +34,8 @@ func RegisterAdminRoutes( ...@@ -34,6 +34,8 @@ func RegisterAdminRoutes(
// OpenAI OAuth // OpenAI OAuth
registerOpenAIOAuthRoutes(admin, h) registerOpenAIOAuthRoutes(admin, h)
// Sora OAuth(实现复用 OpenAI OAuth 服务,入口独立)
registerSoraOAuthRoutes(admin, h)
// Gemini OAuth // Gemini OAuth
registerGeminiOAuthRoutes(admin, h) registerGeminiOAuthRoutes(admin, h)
...@@ -276,6 +278,19 @@ func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -276,6 +278,19 @@ func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
} }
} }
func registerSoraOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
sora := admin.Group("/sora")
{
sora.POST("/generate-auth-url", h.Admin.OpenAIOAuth.GenerateAuthURL)
sora.POST("/exchange-code", h.Admin.OpenAIOAuth.ExchangeCode)
sora.POST("/refresh-token", h.Admin.OpenAIOAuth.RefreshToken)
sora.POST("/st2at", h.Admin.OpenAIOAuth.ExchangeSoraSessionToken)
sora.POST("/rt2at", h.Admin.OpenAIOAuth.RefreshToken)
sora.POST("/accounts/:id/refresh", h.Admin.OpenAIOAuth.RefreshAccountToken)
sora.POST("/create-from-oauth", h.Admin.OpenAIOAuth.CreateAccountFromOAuth)
}
}
func registerGeminiOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) { func registerGeminiOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
gemini := admin.Group("/gemini") gemini := admin.Group("/gemini")
{ {
...@@ -306,6 +321,7 @@ func registerProxyRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -306,6 +321,7 @@ func registerProxyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
proxies.PUT("/:id", h.Admin.Proxy.Update) proxies.PUT("/:id", h.Admin.Proxy.Update)
proxies.DELETE("/:id", h.Admin.Proxy.Delete) proxies.DELETE("/:id", h.Admin.Proxy.Delete)
proxies.POST("/:id/test", h.Admin.Proxy.Test) proxies.POST("/:id/test", h.Admin.Proxy.Test)
proxies.POST("/:id/quality-check", h.Admin.Proxy.CheckQuality)
proxies.GET("/:id/stats", h.Admin.Proxy.GetStats) proxies.GET("/:id/stats", h.Admin.Proxy.GetStats)
proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts) proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts)
proxies.POST("/batch-delete", h.Admin.Proxy.BatchDelete) proxies.POST("/batch-delete", h.Admin.Proxy.BatchDelete)
......
package routes package routes
import ( import (
"net/http"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/server/middleware"
...@@ -41,16 +43,15 @@ func RegisterGatewayRoutes( ...@@ -41,16 +43,15 @@ func RegisterGatewayRoutes(
gateway.GET("/usage", h.Gateway.Usage) gateway.GET("/usage", h.Gateway.Usage)
// OpenAI Responses API // OpenAI Responses API
gateway.POST("/responses", h.OpenAIGateway.Responses) gateway.POST("/responses", h.OpenAIGateway.Responses)
} // 明确阻止旧入口误用到 Sora,避免客户端把 OpenAI Chat Completions 当作 Sora 入口
gateway.POST("/chat/completions", func(c *gin.Context) {
// Sora Chat Completions c.JSON(http.StatusBadRequest, gin.H{
soraGateway := r.Group("/v1") "error": gin.H{
soraGateway.Use(soraBodyLimit) "type": "invalid_request_error",
soraGateway.Use(clientRequestID) "message": "For Sora, use /sora/v1/chat/completions. OpenAI should use /v1/responses.",
soraGateway.Use(opsErrorLogger) },
soraGateway.Use(gin.HandlerFunc(apiKeyAuth)) })
{ })
soraGateway.POST("/chat/completions", h.SoraGateway.ChatCompletions)
} }
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连) // Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
......
...@@ -786,6 +786,38 @@ func (a *Account) IsSessionIDMaskingEnabled() bool { ...@@ -786,6 +786,38 @@ func (a *Account) IsSessionIDMaskingEnabled() bool {
return false return false
} }
// IsCacheTTLOverrideEnabled 检查是否启用缓存 TTL 强制替换
// 仅适用于 Anthropic OAuth/SetupToken 类型账号
// 启用后将所有 cache creation tokens 归入指定的 TTL 类型(5m 或 1h)
func (a *Account) IsCacheTTLOverrideEnabled() bool {
if !a.IsAnthropicOAuthOrSetupToken() {
return false
}
if a.Extra == nil {
return false
}
if v, ok := a.Extra["cache_ttl_override_enabled"]; ok {
if enabled, ok := v.(bool); ok {
return enabled
}
}
return false
}
// GetCacheTTLOverrideTarget 获取缓存 TTL 强制替换的目标类型
// 返回 "5m" 或 "1h",默认 "5m"
func (a *Account) GetCacheTTLOverrideTarget() string {
if a.Extra == nil {
return "5m"
}
if v, ok := a.Extra["cache_ttl_override_target"]; ok {
if target, ok := v.(string); ok && (target == "5m" || target == "1h") {
return target
}
}
return "5m"
}
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元) // GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
// 返回 0 表示未启用 // 返回 0 表示未启用
func (a *Account) GetWindowCostLimit() float64 { func (a *Account) GetWindowCostLimit() float64 {
......
...@@ -35,7 +35,7 @@ type AccountRepository interface { ...@@ -35,7 +35,7 @@ type AccountRepository interface {
Delete(ctx context.Context, id int64) error Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error)
ListByGroup(ctx context.Context, groupID int64) ([]Account, error) ListByGroup(ctx context.Context, groupID int64) ([]Account, error)
ListActive(ctx context.Context) ([]Account, error) ListActive(ctx context.Context) ([]Account, error)
ListByPlatform(ctx context.Context, platform string) ([]Account, error) ListByPlatform(ctx context.Context, platform string) ([]Account, error)
......
...@@ -79,7 +79,7 @@ func (s *accountRepoStub) List(ctx context.Context, params pagination.Pagination ...@@ -79,7 +79,7 @@ func (s *accountRepoStub) List(ctx context.Context, params pagination.Pagination
panic("unexpected List call") panic("unexpected List call")
} }
func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) { func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call") panic("unexpected ListWithFilters call")
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment