Commit 900cce20 authored by yangjianbo's avatar yangjianbo
Browse files

feat(sora): 对齐 Sora OAuth 流程并隔离网关请求路径



- 新增并接通 Sora 专用 OAuth 接口与 ST/RT 换取能力
- 完成前端 Sora 授权、RT/ST 手动导入与账号创建流程
- 强化 Sora token 恢复、转发日志与网关路由隔离行为
- 补充后端服务层与路由层相关测试覆盖
Co-Authored-By: default avatarClaude Opus 4.6 <noreply@anthropic.com>
parent 36bb3270
...@@ -162,6 +162,8 @@ type TokenRefreshConfig struct { ...@@ -162,6 +162,8 @@ type TokenRefreshConfig struct {
MaxRetries int `mapstructure:"max_retries"` MaxRetries int `mapstructure:"max_retries"`
// 重试退避基础时间(秒) // 重试退避基础时间(秒)
RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"` RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"`
// 是否允许 OpenAI 刷新器同步覆盖关联的 Sora 账号 token(默认关闭)
SyncLinkedSoraAccounts bool `mapstructure:"sync_linked_sora_accounts"`
} }
type PricingConfig struct { type PricingConfig struct {
...@@ -269,17 +271,18 @@ type SoraConfig struct { ...@@ -269,17 +271,18 @@ type SoraConfig struct {
// SoraClientConfig 直连 Sora 客户端配置 // SoraClientConfig 直连 Sora 客户端配置
type SoraClientConfig struct { type SoraClientConfig struct {
BaseURL string `mapstructure:"base_url"` BaseURL string `mapstructure:"base_url"`
TimeoutSeconds int `mapstructure:"timeout_seconds"` TimeoutSeconds int `mapstructure:"timeout_seconds"`
MaxRetries int `mapstructure:"max_retries"` MaxRetries int `mapstructure:"max_retries"`
PollIntervalSeconds int `mapstructure:"poll_interval_seconds"` PollIntervalSeconds int `mapstructure:"poll_interval_seconds"`
MaxPollAttempts int `mapstructure:"max_poll_attempts"` MaxPollAttempts int `mapstructure:"max_poll_attempts"`
RecentTaskLimit int `mapstructure:"recent_task_limit"` RecentTaskLimit int `mapstructure:"recent_task_limit"`
RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"` RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"`
Debug bool `mapstructure:"debug"` Debug bool `mapstructure:"debug"`
Headers map[string]string `mapstructure:"headers"` UseOpenAITokenProvider bool `mapstructure:"use_openai_token_provider"`
UserAgent string `mapstructure:"user_agent"` Headers map[string]string `mapstructure:"headers"`
DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"` UserAgent string `mapstructure:"user_agent"`
DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"`
} }
// SoraStorageConfig 媒体存储配置 // SoraStorageConfig 媒体存储配置
...@@ -1116,6 +1119,7 @@ func setDefaults() { ...@@ -1116,6 +1119,7 @@ func setDefaults() {
viper.SetDefault("sora.client.recent_task_limit", 50) viper.SetDefault("sora.client.recent_task_limit", 50)
viper.SetDefault("sora.client.recent_task_limit_max", 200) viper.SetDefault("sora.client.recent_task_limit_max", 200)
viper.SetDefault("sora.client.debug", false) viper.SetDefault("sora.client.debug", false)
viper.SetDefault("sora.client.use_openai_token_provider", false)
viper.SetDefault("sora.client.headers", map[string]string{}) viper.SetDefault("sora.client.headers", map[string]string{})
viper.SetDefault("sora.client.user_agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") viper.SetDefault("sora.client.user_agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
viper.SetDefault("sora.client.disable_tls_fingerprint", false) viper.SetDefault("sora.client.disable_tls_fingerprint", false)
...@@ -1137,6 +1141,7 @@ func setDefaults() { ...@@ -1137,6 +1141,7 @@ func setDefaults() {
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 0.5) // 提前30分钟刷新(适配Google 1小时token) viper.SetDefault("token_refresh.refresh_before_expiry_hours", 0.5) // 提前30分钟刷新(适配Google 1小时token)
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次 viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒 viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
viper.SetDefault("token_refresh.sync_linked_sora_accounts", false) // 默认不跨平台覆盖 Sora token
// Gemini OAuth - configure via environment variables or config file // Gemini OAuth - configure via environment variables or config file
// GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET // GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET
......
...@@ -1333,6 +1333,12 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) { ...@@ -1333,6 +1333,12 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
return return
} }
// Handle Sora accounts
if account.Platform == service.PlatformSora {
response.Success(c, service.DefaultSoraModels(nil))
return
}
// Handle Claude/Anthropic accounts // Handle Claude/Anthropic accounts
// For OAuth and Setup-Token accounts: return default models // For OAuth and Setup-Token accounts: return default models
if account.IsOAuth() { if account.IsOAuth() {
......
...@@ -2,6 +2,7 @@ package admin ...@@ -2,6 +2,7 @@ package admin
import ( import (
"strconv" "strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
...@@ -16,6 +17,13 @@ type OpenAIOAuthHandler struct { ...@@ -16,6 +17,13 @@ type OpenAIOAuthHandler struct {
adminService service.AdminService adminService service.AdminService
} }
func oauthPlatformFromPath(c *gin.Context) string {
if strings.Contains(c.FullPath(), "/admin/sora/") {
return service.PlatformSora
}
return service.PlatformOpenAI
}
// NewOpenAIOAuthHandler creates a new OpenAI OAuth handler // NewOpenAIOAuthHandler creates a new OpenAI OAuth handler
func NewOpenAIOAuthHandler(openaiOAuthService *service.OpenAIOAuthService, adminService service.AdminService) *OpenAIOAuthHandler { func NewOpenAIOAuthHandler(openaiOAuthService *service.OpenAIOAuthService, adminService service.AdminService) *OpenAIOAuthHandler {
return &OpenAIOAuthHandler{ return &OpenAIOAuthHandler{
...@@ -52,6 +60,7 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) { ...@@ -52,6 +60,7 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) {
type OpenAIExchangeCodeRequest struct { type OpenAIExchangeCodeRequest struct {
SessionID string `json:"session_id" binding:"required"` SessionID string `json:"session_id" binding:"required"`
Code string `json:"code" binding:"required"` Code string `json:"code" binding:"required"`
State string `json:"state" binding:"required"`
RedirectURI string `json:"redirect_uri"` RedirectURI string `json:"redirect_uri"`
ProxyID *int64 `json:"proxy_id"` ProxyID *int64 `json:"proxy_id"`
} }
...@@ -68,6 +77,7 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) { ...@@ -68,6 +77,7 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{ tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
SessionID: req.SessionID, SessionID: req.SessionID,
Code: req.Code, Code: req.Code,
State: req.State,
RedirectURI: req.RedirectURI, RedirectURI: req.RedirectURI,
ProxyID: req.ProxyID, ProxyID: req.ProxyID,
}) })
...@@ -81,18 +91,29 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) { ...@@ -81,18 +91,29 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
// OpenAIRefreshTokenRequest represents the request for refreshing OpenAI token // OpenAIRefreshTokenRequest represents the request for refreshing OpenAI token
type OpenAIRefreshTokenRequest struct { type OpenAIRefreshTokenRequest struct {
RefreshToken string `json:"refresh_token" binding:"required"` RefreshToken string `json:"refresh_token"`
RT string `json:"rt"`
ClientID string `json:"client_id"`
ProxyID *int64 `json:"proxy_id"` ProxyID *int64 `json:"proxy_id"`
} }
// RefreshToken refreshes an OpenAI OAuth token // RefreshToken refreshes an OpenAI OAuth token
// POST /api/v1/admin/openai/refresh-token // POST /api/v1/admin/openai/refresh-token
// POST /api/v1/admin/sora/rt2at
func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) { func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
var req OpenAIRefreshTokenRequest var req OpenAIRefreshTokenRequest
if err := c.ShouldBindJSON(&req); err != nil { if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error()) response.BadRequest(c, "Invalid request: "+err.Error())
return return
} }
refreshToken := strings.TrimSpace(req.RefreshToken)
if refreshToken == "" {
refreshToken = strings.TrimSpace(req.RT)
}
if refreshToken == "" {
response.BadRequest(c, "refresh_token is required")
return
}
var proxyURL string var proxyURL string
if req.ProxyID != nil { if req.ProxyID != nil {
...@@ -102,7 +123,7 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) { ...@@ -102,7 +123,7 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
} }
} }
tokenInfo, err := h.openaiOAuthService.RefreshToken(c.Request.Context(), req.RefreshToken, proxyURL) tokenInfo, err := h.openaiOAuthService.RefreshTokenWithClientID(c.Request.Context(), refreshToken, proxyURL, strings.TrimSpace(req.ClientID))
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
...@@ -111,8 +132,39 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) { ...@@ -111,8 +132,39 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
response.Success(c, tokenInfo) response.Success(c, tokenInfo)
} }
// RefreshAccountToken refreshes token for a specific OpenAI account // ExchangeSoraSessionToken exchanges Sora session token to access token
// POST /api/v1/admin/sora/st2at
func (h *OpenAIOAuthHandler) ExchangeSoraSessionToken(c *gin.Context) {
var req struct {
SessionToken string `json:"session_token"`
ST string `json:"st"`
ProxyID *int64 `json:"proxy_id"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
sessionToken := strings.TrimSpace(req.SessionToken)
if sessionToken == "" {
sessionToken = strings.TrimSpace(req.ST)
}
if sessionToken == "" {
response.BadRequest(c, "session_token is required")
return
}
tokenInfo, err := h.openaiOAuthService.ExchangeSoraSessionToken(c.Request.Context(), sessionToken, req.ProxyID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, tokenInfo)
}
// RefreshAccountToken refreshes token for a specific OpenAI/Sora account
// POST /api/v1/admin/openai/accounts/:id/refresh // POST /api/v1/admin/openai/accounts/:id/refresh
// POST /api/v1/admin/sora/accounts/:id/refresh
func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) { func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64) accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil { if err != nil {
...@@ -127,9 +179,9 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) { ...@@ -127,9 +179,9 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
return return
} }
// Ensure account is OpenAI platform platform := oauthPlatformFromPath(c)
if !account.IsOpenAI() { if account.Platform != platform {
response.BadRequest(c, "Account is not an OpenAI account") response.BadRequest(c, "Account platform does not match OAuth endpoint")
return return
} }
...@@ -167,12 +219,14 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) { ...@@ -167,12 +219,14 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
response.Success(c, dto.AccountFromService(updatedAccount)) response.Success(c, dto.AccountFromService(updatedAccount))
} }
// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info // CreateAccountFromOAuth creates a new OpenAI/Sora OAuth account from token info
// POST /api/v1/admin/openai/create-from-oauth // POST /api/v1/admin/openai/create-from-oauth
// POST /api/v1/admin/sora/create-from-oauth
func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) { func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
var req struct { var req struct {
SessionID string `json:"session_id" binding:"required"` SessionID string `json:"session_id" binding:"required"`
Code string `json:"code" binding:"required"` Code string `json:"code" binding:"required"`
State string `json:"state" binding:"required"`
RedirectURI string `json:"redirect_uri"` RedirectURI string `json:"redirect_uri"`
ProxyID *int64 `json:"proxy_id"` ProxyID *int64 `json:"proxy_id"`
Name string `json:"name"` Name string `json:"name"`
...@@ -189,6 +243,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) { ...@@ -189,6 +243,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{ tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
SessionID: req.SessionID, SessionID: req.SessionID,
Code: req.Code, Code: req.Code,
State: req.State,
RedirectURI: req.RedirectURI, RedirectURI: req.RedirectURI,
ProxyID: req.ProxyID, ProxyID: req.ProxyID,
}) })
...@@ -200,19 +255,25 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) { ...@@ -200,19 +255,25 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
// Build credentials from token info // Build credentials from token info
credentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo) credentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
platform := oauthPlatformFromPath(c)
// Use email as default name if not provided // Use email as default name if not provided
name := req.Name name := req.Name
if name == "" && tokenInfo.Email != "" { if name == "" && tokenInfo.Email != "" {
name = tokenInfo.Email name = tokenInfo.Email
} }
if name == "" { if name == "" {
name = "OpenAI OAuth Account" if platform == service.PlatformSora {
name = "Sora OAuth Account"
} else {
name = "OpenAI OAuth Account"
}
} }
// Create account // Create account
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{ account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
Name: name, Name: name,
Platform: "openai", Platform: platform,
Type: "oauth", Type: "oauth",
Credentials: credentials, Credentials: credentials,
ProxyID: req.ProxyID, ProxyID: req.ProxyID,
......
...@@ -212,6 +212,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { ...@@ -212,6 +212,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
switchCount := 0 switchCount := 0
failedAccountIDs := make(map[int64]struct{}) failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0 lastFailoverStatus := 0
var lastFailoverBody []byte
for { for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "") selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "")
...@@ -224,7 +225,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { ...@@ -224,7 +225,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return return
} }
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverBody, streamStarted)
return return
} }
account := selection.Account account := selection.Account
...@@ -287,14 +288,19 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { ...@@ -287,14 +288,19 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
failedAccountIDs[account.ID] = struct{}{} failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches { if switchCount >= maxAccountSwitches {
lastFailoverStatus = failoverErr.StatusCode lastFailoverStatus = failoverErr.StatusCode
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) lastFailoverBody = failoverErr.ResponseBody
h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverBody, streamStarted)
return return
} }
lastFailoverStatus = failoverErr.StatusCode lastFailoverStatus = failoverErr.StatusCode
lastFailoverBody = failoverErr.ResponseBody
switchCount++ switchCount++
upstreamErrCode, upstreamErrMsg := extractUpstreamErrorCodeAndMessage(lastFailoverBody)
reqLog.Warn("sora.upstream_failover_switching", reqLog.Warn("sora.upstream_failover_switching",
zap.Int64("account_id", account.ID), zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode), zap.Int("upstream_status", failoverErr.StatusCode),
zap.String("upstream_error_code", upstreamErrCode),
zap.String("upstream_error_message", upstreamErrMsg),
zap.Int("switch_count", switchCount), zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches), zap.Int("max_switches", maxAccountSwitches),
) )
...@@ -360,17 +366,32 @@ func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, s ...@@ -360,17 +366,32 @@ func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, s
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
} }
func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) { func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseBody []byte, streamStarted bool) {
status, errType, errMsg := h.mapUpstreamError(statusCode) status, errType, errMsg := h.mapUpstreamError(statusCode, responseBody)
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
} }
func (h *SoraGatewayHandler) mapUpstreamError(statusCode int) (int, string, string) { func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseBody []byte) (int, string, string) {
upstreamCode, upstreamMessage := extractUpstreamErrorCodeAndMessage(responseBody)
if upstreamMessage != "" {
switch statusCode {
case 401, 403, 404, 500, 502, 503, 504:
return http.StatusBadGateway, "upstream_error", upstreamMessage
case 429:
return http.StatusTooManyRequests, "rate_limit_error", upstreamMessage
}
}
switch statusCode { switch statusCode {
case 401: case 401:
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator" return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
case 403: case 403:
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator" return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
case 404:
if strings.EqualFold(upstreamCode, "unsupported_country_code") {
return http.StatusBadGateway, "upstream_error", "Upstream region capability unavailable for this account, please contact administrator"
}
return http.StatusBadGateway, "upstream_error", "Upstream capability unavailable for this account, please contact administrator"
case 429: case 429:
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later" return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
case 529: case 529:
...@@ -382,6 +403,41 @@ func (h *SoraGatewayHandler) mapUpstreamError(statusCode int) (int, string, stri ...@@ -382,6 +403,41 @@ func (h *SoraGatewayHandler) mapUpstreamError(statusCode int) (int, string, stri
} }
} }
func extractUpstreamErrorCodeAndMessage(body []byte) (string, string) {
trimmed := strings.TrimSpace(string(body))
if trimmed == "" {
return "", ""
}
if !gjson.Valid(trimmed) {
return "", truncateSoraErrorMessage(trimmed, 256)
}
code := strings.TrimSpace(gjson.Get(trimmed, "error.code").String())
if code == "" {
code = strings.TrimSpace(gjson.Get(trimmed, "code").String())
}
message := strings.TrimSpace(gjson.Get(trimmed, "error.message").String())
if message == "" {
message = strings.TrimSpace(gjson.Get(trimmed, "message").String())
}
if message == "" {
message = strings.TrimSpace(gjson.Get(trimmed, "error.detail").String())
}
if message == "" {
message = strings.TrimSpace(gjson.Get(trimmed, "detail").String())
}
return code, truncateSoraErrorMessage(message, 512)
}
func truncateSoraErrorMessage(s string, maxLen int) string {
if maxLen <= 0 {
return ""
}
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "...(truncated)"
}
func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) { func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
if streamStarted { if streamStarted {
flusher, ok := c.Writer.(http.Flusher) flusher, ok := c.Writer.(http.Flusher)
......
...@@ -43,6 +43,9 @@ func (s *stubSoraClient) CreateImageTask(ctx context.Context, account *service.A ...@@ -43,6 +43,9 @@ func (s *stubSoraClient) CreateImageTask(ctx context.Context, account *service.A
func (s *stubSoraClient) CreateVideoTask(ctx context.Context, account *service.Account, req service.SoraVideoRequest) (string, error) { func (s *stubSoraClient) CreateVideoTask(ctx context.Context, account *service.Account, req service.SoraVideoRequest) (string, error) {
return "task-video", nil return "task-video", nil
} }
func (s *stubSoraClient) EnhancePrompt(ctx context.Context, account *service.Account, prompt, expansionLevel string, durationS int) (string, error) {
return "enhanced prompt", nil
}
func (s *stubSoraClient) GetImageTask(ctx context.Context, account *service.Account, taskID string) (*service.SoraImageTaskStatus, error) { func (s *stubSoraClient) GetImageTask(ctx context.Context, account *service.Account, taskID string) (*service.SoraImageTaskStatus, error) {
return &service.SoraImageTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil return &service.SoraImageTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil
} }
......
...@@ -17,6 +17,8 @@ import ( ...@@ -17,6 +17,8 @@ import (
const ( const (
// OAuth Client ID for OpenAI (Codex CLI official) // OAuth Client ID for OpenAI (Codex CLI official)
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann" ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
// OAuth Client ID for Sora mobile flow (aligned with sora2api)
SoraClientID = "app_LlGpXReQgckcGGUo2JrYvtJK"
// OAuth endpoints // OAuth endpoints
AuthorizeURL = "https://auth.openai.com/oauth/authorize" AuthorizeURL = "https://auth.openai.com/oauth/authorize"
......
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"context" "context"
"net/http" "net/http"
"net/url" "net/url"
"strings"
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
...@@ -56,12 +57,49 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie ...@@ -56,12 +57,49 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
} }
func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) { func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "")
}
func (s *openaiOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
if strings.TrimSpace(clientID) != "" {
return s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, strings.TrimSpace(clientID))
}
clientIDs := []string{
openai.ClientID,
openai.SoraClientID,
}
seen := make(map[string]struct{}, len(clientIDs))
var lastErr error
for _, clientID := range clientIDs {
clientID = strings.TrimSpace(clientID)
if clientID == "" {
continue
}
if _, ok := seen[clientID]; ok {
continue
}
seen[clientID] = struct{}{}
tokenResp, err := s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
if err == nil {
return tokenResp, nil
}
lastErr = err
}
if lastErr != nil {
return nil, lastErr
}
return nil, infraerrors.New(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_REFRESH_FAILED", "token refresh failed")
}
func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL, clientID string) (*openai.TokenResponse, error) {
client := createOpenAIReqClient(proxyURL) client := createOpenAIReqClient(proxyURL)
formData := url.Values{} formData := url.Values{}
formData.Set("grant_type", "refresh_token") formData.Set("grant_type", "refresh_token")
formData.Set("refresh_token", refreshToken) formData.Set("refresh_token", refreshToken)
formData.Set("client_id", openai.ClientID) formData.Set("client_id", clientID)
formData.Set("scope", openai.RefreshScopes) formData.Set("scope", openai.RefreshScopes)
var tokenResp openai.TokenResponse var tokenResp openai.TokenResponse
......
...@@ -136,6 +136,60 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() { ...@@ -136,6 +136,60 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() {
require.Equal(s.T(), "rt2", resp.RefreshToken) require.Equal(s.T(), "rt2", resp.RefreshToken)
} }
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FallbackToSoraClientID() {
var seenClientIDs []string
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
clientID := r.PostForm.Get("client_id")
seenClientIDs = append(seenClientIDs, clientID)
if clientID == openai.ClientID {
w.WriteHeader(http.StatusBadRequest)
_, _ = io.WriteString(w, "invalid_grant")
return
}
if clientID == openai.SoraClientID {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at-sora","refresh_token":"rt-sora","token_type":"bearer","expires_in":3600}`)
return
}
w.WriteHeader(http.StatusBadRequest)
}))
resp, err := s.svc.RefreshToken(s.ctx, "rt", "")
require.NoError(s.T(), err, "RefreshToken")
require.Equal(s.T(), "at-sora", resp.AccessToken)
require.Equal(s.T(), "rt-sora", resp.RefreshToken)
require.Equal(s.T(), []string{openai.ClientID, openai.SoraClientID}, seenClientIDs)
}
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseProvidedClientID() {
const customClientID = "custom-client-id"
var seenClientIDs []string
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
clientID := r.PostForm.Get("client_id")
seenClientIDs = append(seenClientIDs, clientID)
if clientID != customClientID {
w.WriteHeader(http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at-custom","refresh_token":"rt-custom","token_type":"bearer","expires_in":3600}`)
}))
resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", customClientID)
require.NoError(s.T(), err, "RefreshTokenWithClientID")
require.Equal(s.T(), "at-custom", resp.AccessToken)
require.Equal(s.T(), "rt-custom", resp.RefreshToken)
require.Equal(s.T(), []string{customClientID}, seenClientIDs)
}
func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() { func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
......
...@@ -34,6 +34,8 @@ func RegisterAdminRoutes( ...@@ -34,6 +34,8 @@ func RegisterAdminRoutes(
// OpenAI OAuth // OpenAI OAuth
registerOpenAIOAuthRoutes(admin, h) registerOpenAIOAuthRoutes(admin, h)
// Sora OAuth(实现复用 OpenAI OAuth 服务,入口独立)
registerSoraOAuthRoutes(admin, h)
// Gemini OAuth // Gemini OAuth
registerGeminiOAuthRoutes(admin, h) registerGeminiOAuthRoutes(admin, h)
...@@ -276,6 +278,19 @@ func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -276,6 +278,19 @@ func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
} }
} }
func registerSoraOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
sora := admin.Group("/sora")
{
sora.POST("/generate-auth-url", h.Admin.OpenAIOAuth.GenerateAuthURL)
sora.POST("/exchange-code", h.Admin.OpenAIOAuth.ExchangeCode)
sora.POST("/refresh-token", h.Admin.OpenAIOAuth.RefreshToken)
sora.POST("/st2at", h.Admin.OpenAIOAuth.ExchangeSoraSessionToken)
sora.POST("/rt2at", h.Admin.OpenAIOAuth.RefreshToken)
sora.POST("/accounts/:id/refresh", h.Admin.OpenAIOAuth.RefreshAccountToken)
sora.POST("/create-from-oauth", h.Admin.OpenAIOAuth.CreateAccountFromOAuth)
}
}
func registerGeminiOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) { func registerGeminiOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
gemini := admin.Group("/gemini") gemini := admin.Group("/gemini")
{ {
......
package routes package routes
import ( import (
"net/http"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/server/middleware"
...@@ -41,16 +43,15 @@ func RegisterGatewayRoutes( ...@@ -41,16 +43,15 @@ func RegisterGatewayRoutes(
gateway.GET("/usage", h.Gateway.Usage) gateway.GET("/usage", h.Gateway.Usage)
// OpenAI Responses API // OpenAI Responses API
gateway.POST("/responses", h.OpenAIGateway.Responses) gateway.POST("/responses", h.OpenAIGateway.Responses)
} // 明确阻止旧入口误用到 Sora,避免客户端把 OpenAI Chat Completions 当作 Sora 入口
gateway.POST("/chat/completions", func(c *gin.Context) {
// Sora Chat Completions c.JSON(http.StatusBadRequest, gin.H{
soraGateway := r.Group("/v1") "error": gin.H{
soraGateway.Use(soraBodyLimit) "type": "invalid_request_error",
soraGateway.Use(clientRequestID) "message": "For Sora, use /sora/v1/chat/completions. OpenAI should use /v1/responses.",
soraGateway.Use(opsErrorLogger) },
soraGateway.Use(gin.HandlerFunc(apiKeyAuth)) })
{ })
soraGateway.POST("/chat/completions", h.SoraGateway.ChatCompletions)
} }
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连) // Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
......
...@@ -27,11 +27,13 @@ import ( ...@@ -27,11 +27,13 @@ import (
// sseDataPrefix matches SSE data lines with optional whitespace after colon. // sseDataPrefix matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: "). // Some upstream APIs return non-standard "data:" without space (should be "data: ").
var sseDataPrefix = regexp.MustCompile(`^data:\s*`) var sseDataPrefix = regexp.MustCompile(`^data:\s*`)
var cloudflareRayPattern = regexp.MustCompile(`(?i)cRay:\s*'([a-z0-9-]+)'`)
const ( const (
testClaudeAPIURL = "https://api.anthropic.com/v1/messages" testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses" chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses"
soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接 soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接
soraBillingAPIURL = "https://sora.chatgpt.com/backend/billing/subscriptions"
) )
// TestEvent represents a SSE event for account testing // TestEvent represents a SSE event for account testing
...@@ -502,8 +504,9 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account * ...@@ -502,8 +504,9 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
if account.ProxyID != nil && account.Proxy != nil { if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL() proxyURL = account.Proxy.URL()
} }
enableSoraTLSFingerprint := s.shouldEnableSoraTLSFingerprint()
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint)
if err != nil { if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
} }
...@@ -512,7 +515,10 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account * ...@@ -512,7 +515,10 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, string(body))) if isCloudflareChallengeResponse(resp.StatusCode, body) {
return s.sendErrorAndEnd(c, formatCloudflareChallengeMessage("Sora request blocked by Cloudflare challenge (HTTP 403). Please switch to a clean proxy/network and retry.", resp.Header, body))
}
return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, truncateSoraErrorBody(body, 512)))
} }
// 解析 /me 响应,提取用户信息 // 解析 /me 响应,提取用户信息
...@@ -531,10 +537,129 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account * ...@@ -531,10 +537,129 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
s.sendEvent(c, TestEvent{Type: "content", Text: info}) s.sendEvent(c, TestEvent{Type: "content", Text: info})
} }
// 追加轻量能力检查:订阅信息查询(失败仅告警,不中断连接测试)
subReq, err := http.NewRequestWithContext(ctx, "GET", soraBillingAPIURL, nil)
if err == nil {
subReq.Header.Set("Authorization", "Bearer "+authToken)
subReq.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
subReq.Header.Set("Accept", "application/json")
subResp, subErr := s.httpUpstream.DoWithTLS(subReq, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint)
if subErr != nil {
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check skipped: %s", subErr.Error())})
} else {
subBody, _ := io.ReadAll(subResp.Body)
_ = subResp.Body.Close()
if subResp.StatusCode == http.StatusOK {
if summary := parseSoraSubscriptionSummary(subBody); summary != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: summary})
} else {
s.sendEvent(c, TestEvent{Type: "content", Text: "Subscription check OK"})
}
} else {
if isCloudflareChallengeResponse(subResp.StatusCode, subBody) {
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage("Subscription check blocked by Cloudflare challenge (HTTP 403)", subResp.Header, subBody)})
} else {
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check returned %d", subResp.StatusCode)})
}
}
}
}
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil return nil
} }
func parseSoraSubscriptionSummary(body []byte) string {
var subResp struct {
Data []struct {
Plan struct {
ID string `json:"id"`
Title string `json:"title"`
} `json:"plan"`
EndTS string `json:"end_ts"`
} `json:"data"`
}
if err := json.Unmarshal(body, &subResp); err != nil {
return ""
}
if len(subResp.Data) == 0 {
return ""
}
first := subResp.Data[0]
parts := make([]string, 0, 3)
if first.Plan.Title != "" {
parts = append(parts, first.Plan.Title)
}
if first.Plan.ID != "" {
parts = append(parts, first.Plan.ID)
}
if first.EndTS != "" {
parts = append(parts, "end="+first.EndTS)
}
if len(parts) == 0 {
return ""
}
return "Subscription: " + strings.Join(parts, " | ")
}
func (s *AccountTestService) shouldEnableSoraTLSFingerprint() bool {
if s == nil || s.cfg == nil {
return false
}
return s.cfg.Gateway.TLSFingerprint.Enabled && !s.cfg.Sora.Client.DisableTLSFingerprint
}
func isCloudflareChallengeResponse(statusCode int, body []byte) bool {
if statusCode != http.StatusForbidden {
return false
}
preview := strings.ToLower(truncateSoraErrorBody(body, 4096))
return strings.Contains(preview, "window._cf_chl_opt") ||
strings.Contains(preview, "just a moment") ||
strings.Contains(preview, "enable javascript and cookies to continue")
}
func formatCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
rayID := extractCloudflareRayID(headers, body)
if rayID == "" {
return base
}
return fmt.Sprintf("%s (cf-ray: %s)", base, rayID)
}
func extractCloudflareRayID(headers http.Header, body []byte) string {
if headers != nil {
rayID := strings.TrimSpace(headers.Get("cf-ray"))
if rayID != "" {
return rayID
}
rayID = strings.TrimSpace(headers.Get("Cf-Ray"))
if rayID != "" {
return rayID
}
}
preview := truncateSoraErrorBody(body, 8192)
matches := cloudflareRayPattern.FindStringSubmatch(preview)
if len(matches) >= 2 {
return strings.TrimSpace(matches[1])
}
return ""
}
func truncateSoraErrorBody(body []byte, max int) string {
if max <= 0 {
max = 512
}
raw := strings.TrimSpace(string(body))
if len(raw) <= max {
return raw
}
return raw[:max] + "...(truncated)"
}
// testAntigravityAccountConnection tests an Antigravity account's connection // testAntigravityAccountConnection tests an Antigravity account's connection
// 支持 Claude 和 Gemini 两种协议,使用非流式请求 // 支持 Claude 和 Gemini 两种协议,使用非流式请求
func (s *AccountTestService) testAntigravityAccountConnection(c *gin.Context, account *Account, modelID string) error { func (s *AccountTestService) testAntigravityAccountConnection(c *gin.Context, account *Account, modelID string) error {
......
package service
import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type queuedHTTPUpstream struct {
responses []*http.Response
requests []*http.Request
tlsFlags []bool
}
func (u *queuedHTTPUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
return nil, fmt.Errorf("unexpected Do call")
}
func (u *queuedHTTPUpstream) DoWithTLS(req *http.Request, _ string, _ int64, _ int, enableTLSFingerprint bool) (*http.Response, error) {
u.requests = append(u.requests, req)
u.tlsFlags = append(u.tlsFlags, enableTLSFingerprint)
if len(u.responses) == 0 {
return nil, fmt.Errorf("no mocked response")
}
resp := u.responses[0]
u.responses = u.responses[1:]
return resp, nil
}
func newJSONResponse(status int, body string) *http.Response {
return &http.Response{
StatusCode: status,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(body)),
}
}
func newJSONResponseWithHeader(status int, body, key, value string) *http.Response {
resp := newJSONResponse(status, body)
resp.Header.Set(key, value)
return resp
}
func newSoraTestContext() (*gin.Context, *httptest.ResponseRecorder) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil)
return c, rec
}
func TestAccountTestService_testSoraAccountConnection_WithSubscription(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`),
newJSONResponse(http.StatusOK, `{"data":[{"plan":{"id":"chatgpt_plus","title":"ChatGPT Plus"},"end_ts":"2026-12-31T00:00:00Z"}]}`),
},
}
svc := &AccountTestService{
httpUpstream: upstream,
cfg: &config.Config{
Gateway: config.GatewayConfig{
TLSFingerprint: config.TLSFingerprintConfig{
Enabled: true,
},
},
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
DisableTLSFingerprint: false,
},
},
},
}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.NoError(t, err)
require.Len(t, upstream.requests, 2)
require.Equal(t, soraMeAPIURL, upstream.requests[0].URL.String())
require.Equal(t, soraBillingAPIURL, upstream.requests[1].URL.String())
require.Equal(t, "Bearer test_token", upstream.requests[0].Header.Get("Authorization"))
require.Equal(t, "Bearer test_token", upstream.requests[1].Header.Get("Authorization"))
require.Equal(t, []bool{true, true}, upstream.tlsFlags)
body := rec.Body.String()
require.Contains(t, body, `"type":"test_start"`)
require.Contains(t, body, "Sora connection OK - Email: demo@example.com")
require.Contains(t, body, "Subscription: ChatGPT Plus | chatgpt_plus | end=2026-12-31T00:00:00Z")
require.Contains(t, body, `"type":"test_complete","success":true`)
}
func TestAccountTestService_testSoraAccountConnection_SubscriptionFailedStillSuccess(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusOK, `{"name":"demo-user"}`),
newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.NoError(t, err)
require.Len(t, upstream.requests, 2)
body := rec.Body.String()
require.Contains(t, body, "Sora connection OK - User: demo-user")
require.Contains(t, body, "Subscription check returned 403")
require.Contains(t, body, `"type":"test_complete","success":true`)
}
func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponseWithHeader(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`, "cf-ray", "9cff2d62d83bb98d"),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.Error(t, err)
require.Contains(t, err.Error(), "Cloudflare challenge")
require.Contains(t, err.Error(), "cf-ray: 9cff2d62d83bb98d")
body := rec.Body.String()
require.Contains(t, body, `"type":"error"`)
require.Contains(t, body, "Cloudflare challenge")
require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
}
func TestAccountTestService_testSoraAccountConnection_SubscriptionCloudflareChallengeWithRay(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusOK, `{"name":"demo-user"}`),
newJSONResponse(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.NoError(t, err)
body := rec.Body.String()
require.Contains(t, body, "Subscription check blocked by Cloudflare challenge (HTTP 403)")
require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
require.Contains(t, body, `"type":"test_complete","success":true`)
}
...@@ -14,6 +14,7 @@ import ( ...@@ -14,6 +14,7 @@ import (
type OpenAIOAuthClient interface { type OpenAIOAuthClient interface {
ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error)
RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error)
RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error)
} }
// ClaudeOAuthClient handles HTTP requests for Claude OAuth flows // ClaudeOAuthClient handles HTTP requests for Claude OAuth flows
......
...@@ -2,13 +2,20 @@ package service ...@@ -2,13 +2,20 @@ package service
import ( import (
"context" "context"
"crypto/subtle"
"encoding/json"
"io"
"net/http" "net/http"
"net/url"
"strings"
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
) )
var openAISoraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session"
// OpenAIOAuthService handles OpenAI OAuth authentication flows // OpenAIOAuthService handles OpenAI OAuth authentication flows
type OpenAIOAuthService struct { type OpenAIOAuthService struct {
sessionStore *openai.SessionStore sessionStore *openai.SessionStore
...@@ -92,6 +99,7 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64 ...@@ -92,6 +99,7 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
type OpenAIExchangeCodeInput struct { type OpenAIExchangeCodeInput struct {
SessionID string SessionID string
Code string Code string
State string
RedirectURI string RedirectURI string
ProxyID *int64 ProxyID *int64
} }
...@@ -116,6 +124,12 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch ...@@ -116,6 +124,12 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
if !ok { if !ok {
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_SESSION_NOT_FOUND", "session not found or expired") return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_SESSION_NOT_FOUND", "session not found or expired")
} }
if input.State == "" {
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_STATE_REQUIRED", "oauth state is required")
}
if subtle.ConstantTimeCompare([]byte(input.State), []byte(session.State)) != 1 {
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_STATE", "invalid oauth state")
}
// Get proxy URL: prefer input.ProxyID, fallback to session.ProxyURL // Get proxy URL: prefer input.ProxyID, fallback to session.ProxyURL
proxyURL := session.ProxyURL proxyURL := session.ProxyURL
...@@ -173,7 +187,12 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch ...@@ -173,7 +187,12 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
// RefreshToken refreshes an OpenAI OAuth token // RefreshToken refreshes an OpenAI OAuth token
func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken string, proxyURL string) (*OpenAITokenInfo, error) { func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken string, proxyURL string) (*OpenAITokenInfo, error) {
tokenResp, err := s.oauthClient.RefreshToken(ctx, refreshToken, proxyURL) return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "")
}
// RefreshTokenWithClientID refreshes an OpenAI/Sora OAuth token with optional client_id.
func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken string, proxyURL string, clientID string) (*OpenAITokenInfo, error) {
tokenResp, err := s.oauthClient.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -205,13 +224,83 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri ...@@ -205,13 +224,83 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri
return tokenInfo, nil return tokenInfo, nil
} }
// RefreshAccountToken refreshes token for an OpenAI account // ExchangeSoraSessionToken exchanges Sora session_token to access_token.
func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessionToken string, proxyID *int64) (*OpenAITokenInfo, error) {
if strings.TrimSpace(sessionToken) == "" {
return nil, infraerrors.New(http.StatusBadRequest, "SORA_SESSION_TOKEN_REQUIRED", "session_token is required")
}
proxyURL, err := s.resolveProxyURL(ctx, proxyID)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, openAISoraSessionAuthURL, nil)
if err != nil {
return nil, infraerrors.Newf(http.StatusInternalServerError, "SORA_SESSION_REQUEST_BUILD_FAILED", "failed to build request: %v", err)
}
req.Header.Set("Cookie", "__Secure-next-auth.session-token="+strings.TrimSpace(sessionToken))
req.Header.Set("Accept", "application/json")
req.Header.Set("Origin", "https://sora.chatgpt.com")
req.Header.Set("Referer", "https://sora.chatgpt.com/")
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
client := newOpenAIOAuthHTTPClient(proxyURL)
resp, err := client.Do(req)
if err != nil {
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_REQUEST_FAILED", "request failed: %v", err)
}
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
if resp.StatusCode != http.StatusOK {
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_EXCHANGE_FAILED", "status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var sessionResp struct {
AccessToken string `json:"accessToken"`
Expires string `json:"expires"`
User struct {
Email string `json:"email"`
Name string `json:"name"`
} `json:"user"`
}
if err := json.Unmarshal(body, &sessionResp); err != nil {
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_PARSE_FAILED", "failed to parse response: %v", err)
}
if strings.TrimSpace(sessionResp.AccessToken) == "" {
return nil, infraerrors.New(http.StatusBadGateway, "SORA_SESSION_ACCESS_TOKEN_MISSING", "session exchange response missing access token")
}
expiresAt := time.Now().Add(time.Hour).Unix()
if strings.TrimSpace(sessionResp.Expires) != "" {
if parsed, parseErr := time.Parse(time.RFC3339, sessionResp.Expires); parseErr == nil {
expiresAt = parsed.Unix()
}
}
expiresIn := expiresAt - time.Now().Unix()
if expiresIn < 0 {
expiresIn = 0
}
return &OpenAITokenInfo{
AccessToken: strings.TrimSpace(sessionResp.AccessToken),
ExpiresIn: expiresIn,
ExpiresAt: expiresAt,
Email: strings.TrimSpace(sessionResp.User.Email),
}, nil
}
// RefreshAccountToken refreshes token for an OpenAI/Sora OAuth account
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) { func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
if !account.IsOpenAI() { if account.Platform != PlatformOpenAI && account.Platform != PlatformSora {
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI account") return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI/Sora account")
}
if account.Type != AccountTypeOAuth {
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT_TYPE", "account is not an OAuth account")
} }
refreshToken := account.GetOpenAIRefreshToken() refreshToken := account.GetCredential("refresh_token")
if refreshToken == "" { if refreshToken == "" {
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available") return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available")
} }
...@@ -224,7 +313,8 @@ func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *A ...@@ -224,7 +313,8 @@ func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *A
} }
} }
return s.RefreshToken(ctx, refreshToken, proxyURL) clientID := account.GetCredential("client_id")
return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
} }
// BuildAccountCredentials builds credentials map from token info // BuildAccountCredentials builds credentials map from token info
...@@ -260,3 +350,30 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo) ...@@ -260,3 +350,30 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo)
func (s *OpenAIOAuthService) Stop() { func (s *OpenAIOAuthService) Stop() {
s.sessionStore.Stop() s.sessionStore.Stop()
} }
func (s *OpenAIOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64) (string, error) {
if proxyID == nil {
return "", nil
}
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
if err != nil {
return "", infraerrors.Newf(http.StatusBadRequest, "OPENAI_OAUTH_PROXY_NOT_FOUND", "proxy not found: %v", err)
}
if proxy == nil {
return "", nil
}
return proxy.URL(), nil
}
func newOpenAIOAuthHTTPClient(proxyURL string) *http.Client {
transport := &http.Transport{}
if strings.TrimSpace(proxyURL) != "" {
if parsed, err := url.Parse(proxyURL); err == nil && parsed.Host != "" {
transport.Proxy = http.ProxyURL(parsed)
}
}
return &http.Client{
Timeout: 120 * time.Second,
Transport: transport,
}
}
package service
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/stretchr/testify/require"
)
type openaiOAuthClientNoopStub struct{}
func (s *openaiOAuthClientNoopStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
return nil, errors.New("not implemented")
}
func (s *openaiOAuthClientNoopStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
return nil, errors.New("not implemented")
}
func (s *openaiOAuthClientNoopStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
return nil, errors.New("not implemented")
}
func TestOpenAIOAuthService_ExchangeSoraSessionToken_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=st-token")
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
}))
defer server.Close()
origin := openAISoraSessionAuthURL
openAISoraSessionAuthURL = server.URL
defer func() { openAISoraSessionAuthURL = origin }()
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
defer svc.Stop()
info, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil)
require.NoError(t, err)
require.NotNil(t, info)
require.Equal(t, "at-token", info.AccessToken)
require.Equal(t, "demo@example.com", info.Email)
require.Greater(t, info.ExpiresAt, int64(0))
}
func TestOpenAIOAuthService_ExchangeSoraSessionToken_MissingAccessToken(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"expires":"2099-01-01T00:00:00Z"}`))
}))
defer server.Close()
origin := openAISoraSessionAuthURL
openAISoraSessionAuthURL = server.URL
defer func() { openAISoraSessionAuthURL = origin }()
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
defer svc.Stop()
_, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "missing access token")
}
package service
import (
"context"
"errors"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/stretchr/testify/require"
)
type openaiOAuthClientStateStub struct {
exchangeCalled int32
}
func (s *openaiOAuthClientStateStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
atomic.AddInt32(&s.exchangeCalled, 1)
return &openai.TokenResponse{
AccessToken: "at",
RefreshToken: "rt",
ExpiresIn: 3600,
}, nil
}
func (s *openaiOAuthClientStateStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
return nil, errors.New("not implemented")
}
func (s *openaiOAuthClientStateStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
return s.RefreshToken(ctx, refreshToken, proxyURL)
}
func TestOpenAIOAuthService_ExchangeCode_StateRequired(t *testing.T) {
client := &openaiOAuthClientStateStub{}
svc := NewOpenAIOAuthService(nil, client)
defer svc.Stop()
svc.sessionStore.Set("sid", &openai.OAuthSession{
State: "expected-state",
CodeVerifier: "verifier",
RedirectURI: openai.DefaultRedirectURI,
CreatedAt: time.Now(),
})
_, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{
SessionID: "sid",
Code: "auth-code",
})
require.Error(t, err)
require.Contains(t, err.Error(), "oauth state is required")
require.Equal(t, int32(0), atomic.LoadInt32(&client.exchangeCalled))
}
func TestOpenAIOAuthService_ExchangeCode_StateMismatch(t *testing.T) {
client := &openaiOAuthClientStateStub{}
svc := NewOpenAIOAuthService(nil, client)
defer svc.Stop()
svc.sessionStore.Set("sid", &openai.OAuthSession{
State: "expected-state",
CodeVerifier: "verifier",
RedirectURI: openai.DefaultRedirectURI,
CreatedAt: time.Now(),
})
_, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{
SessionID: "sid",
Code: "auth-code",
State: "wrong-state",
})
require.Error(t, err)
require.Contains(t, err.Error(), "invalid oauth state")
require.Equal(t, int32(0), atomic.LoadInt32(&client.exchangeCalled))
}
func TestOpenAIOAuthService_ExchangeCode_StateMatch(t *testing.T) {
client := &openaiOAuthClientStateStub{}
svc := NewOpenAIOAuthService(nil, client)
defer svc.Stop()
svc.sessionStore.Set("sid", &openai.OAuthSession{
State: "expected-state",
CodeVerifier: "verifier",
RedirectURI: openai.DefaultRedirectURI,
CreatedAt: time.Now(),
})
info, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{
SessionID: "sid",
Code: "auth-code",
State: "expected-state",
})
require.NoError(t, err)
require.NotNil(t, info)
require.Equal(t, "at", info.AccessToken)
require.Equal(t, int32(1), atomic.LoadInt32(&client.exchangeCalled))
_, ok := svc.sessionStore.Get("sid")
require.False(t, ok)
}
...@@ -157,7 +157,11 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou ...@@ -157,7 +157,11 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
} }
expiresAt = account.GetCredentialAsTime("expires_at") expiresAt = account.GetCredentialAsTime("expires_at")
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew { if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
if p.openAIOAuthService == nil { if account.Platform == PlatformSora {
slog.Debug("openai_token_refresh_skipped_for_sora", "account_id", account.ID)
// Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。
refreshFailed = true
} else if p.openAIOAuthService == nil {
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID) slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
p.metrics.refreshFailure.Add(1) p.metrics.refreshFailure.Add(1)
refreshFailed = true // 无法刷新,标记失败 refreshFailed = true // 无法刷新,标记失败
...@@ -206,7 +210,11 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou ...@@ -206,7 +210,11 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新 // 仅在 expires_at 已过期/接近过期时才执行无锁刷新
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew { if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
if p.openAIOAuthService == nil { if account.Platform == PlatformSora {
slog.Debug("openai_token_refresh_skipped_for_sora_degraded", "account_id", account.ID)
// Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。
refreshFailed = true
} else if p.openAIOAuthService == nil {
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID) slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
p.metrics.refreshFailure.Add(1) p.metrics.refreshFailure.Add(1)
refreshFailed = true refreshFailed = true
......
This diff is collapsed.
...@@ -4,9 +4,13 @@ package service ...@@ -4,9 +4,13 @@ package service
import ( import (
"context" "context"
"encoding/json"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings"
"sync/atomic"
"testing" "testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
...@@ -85,3 +89,273 @@ func TestSoraDirectClient_GetImageTaskFallbackLimit(t *testing.T) { ...@@ -85,3 +89,273 @@ func TestSoraDirectClient_GetImageTaskFallbackLimit(t *testing.T) {
require.Equal(t, "completed", status.Status) require.Equal(t, "completed", status.Status)
require.Equal(t, []string{"https://example.com/a.png"}, status.URLs) require.Equal(t, []string{"https://example.com/a.png"}, status.URLs)
} }
func TestNormalizeSoraBaseURL(t *testing.T) {
t.Parallel()
tests := []struct {
name string
raw string
want string
}{
{
name: "empty",
raw: "",
want: "",
},
{
name: "append_backend_for_sora_host",
raw: "https://sora.chatgpt.com",
want: "https://sora.chatgpt.com/backend",
},
{
name: "convert_backend_api_to_backend",
raw: "https://sora.chatgpt.com/backend-api",
want: "https://sora.chatgpt.com/backend",
},
{
name: "keep_backend",
raw: "https://sora.chatgpt.com/backend",
want: "https://sora.chatgpt.com/backend",
},
{
name: "keep_custom_host",
raw: "https://example.com/custom-path",
want: "https://example.com/custom-path",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := normalizeSoraBaseURL(tt.raw)
require.Equal(t, tt.want, got)
})
}
}
func TestSoraDirectClient_BuildURL_UsesNormalizedBaseURL(t *testing.T) {
t.Parallel()
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com",
},
},
}
client := NewSoraDirectClient(cfg, nil, nil)
require.Equal(t, "https://sora.chatgpt.com/backend/video_gen", client.buildURL("/video_gen"))
}
func TestSoraDirectClient_BuildUpstreamError_NotFoundHint(t *testing.T) {
t.Parallel()
client := NewSoraDirectClient(&config.Config{}, nil, nil)
err := client.buildUpstreamError(http.StatusNotFound, http.Header{}, []byte(`{"error":{"message":"Not found"}}`), "https://sora.chatgpt.com/video_gen")
var upstreamErr *SoraUpstreamError
require.ErrorAs(t, err, &upstreamErr)
require.Contains(t, upstreamErr.Message, "请检查 sora.client.base_url")
errNoHint := client.buildUpstreamError(http.StatusNotFound, http.Header{}, []byte(`{"error":{"message":"Not found"}}`), "https://sora.chatgpt.com/backend/video_gen")
require.ErrorAs(t, errNoHint, &upstreamErr)
require.NotContains(t, upstreamErr.Message, "请检查 sora.client.base_url")
}
func TestFormatSoraHeaders_RedactsSensitive(t *testing.T) {
t.Parallel()
headers := http.Header{}
headers.Set("Authorization", "Bearer secret-token")
headers.Set("openai-sentinel-token", "sentinel-secret")
headers.Set("X-Test", "ok")
out := formatSoraHeaders(headers)
require.Contains(t, out, `"Authorization":"***"`)
require.Contains(t, out, `Sentinel-Token":"***"`)
require.Contains(t, out, `"X-Test":"ok"`)
require.NotContains(t, out, "secret-token")
require.NotContains(t, out, "sentinel-secret")
}
func TestSummarizeSoraResponseBody_RedactsJSON(t *testing.T) {
t.Parallel()
body := []byte(`{"error":{"message":"bad"},"access_token":"abc123"}`)
out := summarizeSoraResponseBody(body, 512)
require.Contains(t, out, `"access_token":"***"`)
require.NotContains(t, out, "abc123")
}
func TestSummarizeSoraResponseBody_Truncates(t *testing.T) {
t.Parallel()
body := []byte(strings.Repeat("x", 100))
out := summarizeSoraResponseBody(body, 10)
require.Contains(t, out, "(truncated)")
}
func TestSoraDirectClient_GetAccessToken_SoraDefaultUseCredentials(t *testing.T) {
t.Parallel()
cache := newOpenAITokenCacheStub()
provider := NewOpenAITokenProvider(nil, cache, nil)
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com/backend",
},
},
}
client := NewSoraDirectClient(cfg, nil, provider)
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "sora-credential-token",
},
}
token, err := client.getAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "sora-credential-token", token)
require.Equal(t, int32(0), atomic.LoadInt32(&cache.getCalled))
}
func TestSoraDirectClient_GetAccessToken_SoraCanEnableProvider(t *testing.T) {
t.Parallel()
cache := newOpenAITokenCacheStub()
account := &Account{
ID: 2,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "sora-credential-token",
},
}
cache.tokens[OpenAITokenCacheKey(account)] = "provider-token"
provider := NewOpenAITokenProvider(nil, cache, nil)
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com/backend",
UseOpenAITokenProvider: true,
},
},
}
client := NewSoraDirectClient(cfg, nil, provider)
token, err := client.getAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "provider-token", token)
require.Greater(t, atomic.LoadInt32(&cache.getCalled), int32(0))
}
func TestSoraDirectClient_GetAccessToken_FromSessionToken(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=session-token")
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"accessToken": "session-access-token",
"expires": "2099-01-01T00:00:00Z",
})
}))
defer server.Close()
origin := soraSessionAuthURL
soraSessionAuthURL = server.URL
defer func() { soraSessionAuthURL = origin }()
client := NewSoraDirectClient(&config.Config{}, nil, nil)
account := &Account{
ID: 10,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"session_token": "session-token",
},
}
token, err := client.getAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "session-access-token", token)
require.Equal(t, "session-access-token", account.GetCredential("access_token"))
}
func TestSoraDirectClient_GetAccessToken_FromRefreshToken(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodPost, r.Method)
require.Equal(t, "/oauth/token", r.URL.Path)
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"access_token": "refresh-access-token",
"refresh_token": "refresh-token-new",
"expires_in": 3600,
})
}))
defer server.Close()
origin := soraOAuthTokenURL
soraOAuthTokenURL = server.URL + "/oauth/token"
defer func() { soraOAuthTokenURL = origin }()
client := NewSoraDirectClient(&config.Config{}, nil, nil)
account := &Account{
ID: 11,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"refresh_token": "refresh-token-old",
},
}
token, err := client.getAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "refresh-access-token", token)
require.Equal(t, "refresh-token-new", account.GetCredential("refresh_token"))
require.NotNil(t, account.GetCredentialAsTime("expires_at"))
}
func TestSoraDirectClient_PreflightCheck_VideoQuotaExceeded(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
require.Equal(t, "/nf/check", r.URL.Path)
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"rate_limit_and_credit_balance": map[string]any{
"estimated_num_videos_remaining": 0,
"rate_limit_reached": true,
},
})
}))
defer server.Close()
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: server.URL,
},
},
}
client := NewSoraDirectClient(cfg, nil, nil)
account := &Account{
ID: 12,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "ok",
"expires_at": time.Now().Add(2 * time.Hour).Format(time.RFC3339),
},
}
err := client.PreflightCheck(context.Background(), account, "sora2-landscape-10s", SoraModelConfig{Type: "video"})
require.Error(t, err)
var upstreamErr *SoraUpstreamError
require.ErrorAs(t, err, &upstreamErr)
require.Equal(t, http.StatusTooManyRequests, upstreamErr.StatusCode)
}
func TestShouldAttemptSoraTokenRecover(t *testing.T) {
t.Parallel()
require.True(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://sora.chatgpt.com/backend/video_gen"))
require.True(t, shouldAttemptSoraTokenRecover(http.StatusForbidden, "https://chatgpt.com/backend/video_gen"))
require.False(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://sora.chatgpt.com/api/auth/session"))
require.False(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://auth.openai.com/oauth/token"))
require.False(t, shouldAttemptSoraTokenRecover(http.StatusTooManyRequests, "https://sora.chatgpt.com/backend/video_gen"))
}
...@@ -61,6 +61,10 @@ type SoraGatewayService struct { ...@@ -61,6 +61,10 @@ type SoraGatewayService struct {
cfg *config.Config cfg *config.Config
} }
type soraPreflightChecker interface {
PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error
}
func NewSoraGatewayService( func NewSoraGatewayService(
soraClient SoraClient, soraClient SoraClient,
mediaStorage *SoraMediaStorage, mediaStorage *SoraMediaStorage,
...@@ -112,11 +116,6 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun ...@@ -112,11 +116,6 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Unsupported Sora model", clientStream) s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Unsupported Sora model", clientStream)
return nil, fmt.Errorf("unsupported model: %s", reqModel) return nil, fmt.Errorf("unsupported model: %s", reqModel)
} }
if modelCfg.Type == "prompt_enhance" {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Prompt-enhance 模型暂未支持", clientStream)
return nil, fmt.Errorf("prompt-enhance not supported")
}
prompt, imageInput, videoInput, remixTargetID := extractSoraInput(reqBody) prompt, imageInput, videoInput, remixTargetID := extractSoraInput(reqBody)
if strings.TrimSpace(prompt) == "" { if strings.TrimSpace(prompt) == "" {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream) s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream)
...@@ -131,6 +130,41 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun ...@@ -131,6 +130,41 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
if cancel != nil { if cancel != nil {
defer cancel() defer cancel()
} }
if checker, ok := s.soraClient.(soraPreflightChecker); ok {
if err := checker.PreflightCheck(reqCtx, account, reqModel, modelCfg); err != nil {
return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
}
}
if modelCfg.Type == "prompt_enhance" {
enhancedPrompt, err := s.soraClient.EnhancePrompt(reqCtx, account, prompt, modelCfg.ExpansionLevel, modelCfg.DurationS)
if err != nil {
return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
}
content := strings.TrimSpace(enhancedPrompt)
if content == "" {
content = prompt
}
var firstTokenMs *int
if clientStream {
ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime)
if streamErr != nil {
return nil, streamErr
}
firstTokenMs = ms
} else if c != nil {
c.JSON(http.StatusOK, buildSoraNonStreamResponse(content, reqModel))
}
return &ForwardResult{
RequestID: "",
Model: reqModel,
Stream: clientStream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
Usage: ClaudeUsage{},
MediaType: "prompt",
}, nil
}
var imageData []byte var imageData []byte
imageFilename := "" imageFilename := ""
...@@ -267,7 +301,7 @@ func (s *SoraGatewayService) withSoraTimeout(ctx context.Context, stream bool) ( ...@@ -267,7 +301,7 @@ func (s *SoraGatewayService) withSoraTimeout(ctx context.Context, stream bool) (
func (s *SoraGatewayService) shouldFailoverUpstreamError(statusCode int) bool { func (s *SoraGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
switch statusCode { switch statusCode {
case 401, 402, 403, 429, 529: case 401, 402, 403, 404, 429, 529:
return true return true
default: default:
return statusCode >= 500 return statusCode >= 500
...@@ -460,7 +494,7 @@ func (s *SoraGatewayService) handleSoraRequestError(ctx context.Context, account ...@@ -460,7 +494,7 @@ func (s *SoraGatewayService) handleSoraRequestError(ctx context.Context, account
s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body) s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body)
} }
if s.shouldFailoverUpstreamError(upstreamErr.StatusCode) { if s.shouldFailoverUpstreamError(upstreamErr.StatusCode) {
return &UpstreamFailoverError{StatusCode: upstreamErr.StatusCode} return &UpstreamFailoverError{StatusCode: upstreamErr.StatusCode, ResponseBody: upstreamErr.Body}
} }
msg := upstreamErr.Message msg := upstreamErr.Message
if override := soraProErrorMessage(model, msg); override != "" { if override := soraProErrorMessage(model, msg); override != "" {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment