Commit 987589ea authored by yangjianbo's avatar yangjianbo
Browse files

Merge branch 'test' into release

parents 372e04f6 03f69dd3
......@@ -40,6 +40,11 @@ type ProxyWithAccountCount struct {
CountryCode string
Region string
City string
QualityStatus string
QualityScore *int
QualityGrade string
QualitySummary string
QualityChecked *int64
}
type ProxyAccountSummary struct {
......
......@@ -6,15 +6,21 @@ import (
)
type ProxyLatencyInfo struct {
Success bool `json:"success"`
LatencyMs *int64 `json:"latency_ms,omitempty"`
Message string `json:"message,omitempty"`
IPAddress string `json:"ip_address,omitempty"`
Country string `json:"country,omitempty"`
CountryCode string `json:"country_code,omitempty"`
Region string `json:"region,omitempty"`
City string `json:"city,omitempty"`
UpdatedAt time.Time `json:"updated_at"`
Success bool `json:"success"`
LatencyMs *int64 `json:"latency_ms,omitempty"`
Message string `json:"message,omitempty"`
IPAddress string `json:"ip_address,omitempty"`
Country string `json:"country,omitempty"`
CountryCode string `json:"country_code,omitempty"`
Region string `json:"region,omitempty"`
City string `json:"city,omitempty"`
QualityStatus string `json:"quality_status,omitempty"`
QualityScore *int `json:"quality_score,omitempty"`
QualityGrade string `json:"quality_grade,omitempty"`
QualitySummary string `json:"quality_summary,omitempty"`
QualityCheckedAt *int64 `json:"quality_checked_at,omitempty"`
QualityCFRay string `json:"quality_cf_ray,omitempty"`
UpdatedAt time.Time `json:"updated_at"`
}
type ProxyLatencyCache interface {
......
......@@ -381,10 +381,31 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
}
}
// 2. 尝试从响应头解析重置时间(Anthropic)
// 2. Anthropic 平台:尝试解析 per-window 头(5h / 7d),选择实际触发的窗口
if result := calculateAnthropic429ResetTime(headers); result != nil {
if err := s.accountRepo.SetRateLimited(ctx, account.ID, result.resetAt); err != nil {
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
return
}
// 更新 session window:优先使用 5h-reset 头精确计算,否则从 resetAt 反推
windowEnd := result.resetAt
if result.fiveHourReset != nil {
windowEnd = *result.fiveHourReset
}
windowStart := windowEnd.Add(-5 * time.Hour)
if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil {
slog.Warn("rate_limit_update_session_window_failed", "account_id", account.ID, "error", err)
}
slog.Info("anthropic_account_rate_limited", "account_id", account.ID, "reset_at", result.resetAt, "reset_in", time.Until(result.resetAt).Truncate(time.Second))
return
}
// 3. 尝试从响应头解析重置时间(Anthropic 聚合头,向后兼容)
resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset")
// 3. 如果响应头没有,尝试从响应体解析(OpenAI usage_limit_reached, Gemini)
// 4. 如果响应头没有,尝试从响应体解析(OpenAI usage_limit_reached, Gemini)
if resetTimestamp == "" {
switch account.Platform {
case PlatformOpenAI:
......@@ -497,6 +518,112 @@ func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *tim
return nil
}
// anthropic429Result holds the parsed Anthropic 429 rate-limit information.
type anthropic429Result struct {
resetAt time.Time // The correct reset time to use for SetRateLimited
fiveHourReset *time.Time // 5h window reset timestamp (for session window calculation), nil if not available
}
// calculateAnthropic429ResetTime parses Anthropic's per-window rate-limit headers
// to determine which window (5h or 7d) actually triggered the 429.
//
// Headers used:
// - anthropic-ratelimit-unified-5h-utilization / anthropic-ratelimit-unified-5h-surpassed-threshold
// - anthropic-ratelimit-unified-5h-reset
// - anthropic-ratelimit-unified-7d-utilization / anthropic-ratelimit-unified-7d-surpassed-threshold
// - anthropic-ratelimit-unified-7d-reset
//
// Returns nil when the per-window headers are absent (caller should fall back to
// the aggregated anthropic-ratelimit-unified-reset header).
func calculateAnthropic429ResetTime(headers http.Header) *anthropic429Result {
reset5hStr := headers.Get("anthropic-ratelimit-unified-5h-reset")
reset7dStr := headers.Get("anthropic-ratelimit-unified-7d-reset")
if reset5hStr == "" && reset7dStr == "" {
return nil
}
var reset5h, reset7d *time.Time
if ts, err := strconv.ParseInt(reset5hStr, 10, 64); err == nil {
t := time.Unix(ts, 0)
reset5h = &t
}
if ts, err := strconv.ParseInt(reset7dStr, 10, 64); err == nil {
t := time.Unix(ts, 0)
reset7d = &t
}
is5hExceeded := isAnthropicWindowExceeded(headers, "5h")
is7dExceeded := isAnthropicWindowExceeded(headers, "7d")
slog.Info("anthropic_429_window_analysis",
"is_5h_exceeded", is5hExceeded,
"is_7d_exceeded", is7dExceeded,
"reset_5h", reset5hStr,
"reset_7d", reset7dStr,
)
// Select the correct reset time based on which window(s) are exceeded.
var chosen *time.Time
switch {
case is5hExceeded && is7dExceeded:
// Both exceeded → prefer 7d (longer cooldown), fall back to 5h
chosen = reset7d
if chosen == nil {
chosen = reset5h
}
case is5hExceeded:
chosen = reset5h
case is7dExceeded:
chosen = reset7d
default:
// Neither flag clearly exceeded — pick the sooner reset as best guess
chosen = pickSooner(reset5h, reset7d)
}
if chosen == nil {
return nil
}
return &anthropic429Result{resetAt: *chosen, fiveHourReset: reset5h}
}
// isAnthropicWindowExceeded checks whether a given Anthropic rate-limit window
// (e.g. "5h" or "7d") has been exceeded, using utilization and surpassed-threshold headers.
func isAnthropicWindowExceeded(headers http.Header, window string) bool {
prefix := "anthropic-ratelimit-unified-" + window + "-"
// Check surpassed-threshold first (most explicit signal)
if st := headers.Get(prefix + "surpassed-threshold"); strings.EqualFold(st, "true") {
return true
}
// Fall back to utilization >= 1.0
if utilStr := headers.Get(prefix + "utilization"); utilStr != "" {
if util, err := strconv.ParseFloat(utilStr, 64); err == nil && util >= 1.0-1e-9 {
// Use a small epsilon to handle floating point: treat 0.9999999... as >= 1.0
return true
}
}
return false
}
// pickSooner returns whichever of the two time pointers is earlier.
// If only one is non-nil, it is returned. If both are nil, returns nil.
func pickSooner(a, b *time.Time) *time.Time {
switch {
case a != nil && b != nil:
if a.Before(*b) {
return a
}
return b
case a != nil:
return a
default:
return b
}
}
// parseOpenAIRateLimitResetTime 解析 OpenAI 格式的 429 响应,返回重置时间的 Unix 时间戳
// OpenAI 的 usage_limit_reached 错误格式:
//
......
package service
import (
"net/http"
"testing"
"time"
)
func TestCalculateAnthropic429ResetTime_Only5hExceeded(t *testing.T) {
headers := http.Header{}
headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.02")
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
headers.Set("anthropic-ratelimit-unified-7d-utilization", "0.32")
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
result := calculateAnthropic429ResetTime(headers)
assertAnthropicResult(t, result, 1770998400)
if result.fiveHourReset == nil || !result.fiveHourReset.Equal(time.Unix(1770998400, 0)) {
t.Errorf("expected fiveHourReset=1770998400, got %v", result.fiveHourReset)
}
}
func TestCalculateAnthropic429ResetTime_Only7dExceeded(t *testing.T) {
headers := http.Header{}
headers.Set("anthropic-ratelimit-unified-5h-utilization", "0.50")
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
headers.Set("anthropic-ratelimit-unified-7d-utilization", "1.05")
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
result := calculateAnthropic429ResetTime(headers)
assertAnthropicResult(t, result, 1771549200)
// fiveHourReset should still be populated for session window calculation
if result.fiveHourReset == nil || !result.fiveHourReset.Equal(time.Unix(1770998400, 0)) {
t.Errorf("expected fiveHourReset=1770998400, got %v", result.fiveHourReset)
}
}
func TestCalculateAnthropic429ResetTime_BothExceeded(t *testing.T) {
headers := http.Header{}
headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.10")
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
headers.Set("anthropic-ratelimit-unified-7d-utilization", "1.02")
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
result := calculateAnthropic429ResetTime(headers)
assertAnthropicResult(t, result, 1771549200)
}
func TestCalculateAnthropic429ResetTime_NoPerWindowHeaders(t *testing.T) {
headers := http.Header{}
headers.Set("anthropic-ratelimit-unified-reset", "1771549200")
result := calculateAnthropic429ResetTime(headers)
if result != nil {
t.Errorf("expected nil result when no per-window headers, got resetAt=%v", result.resetAt)
}
}
func TestCalculateAnthropic429ResetTime_NoHeaders(t *testing.T) {
result := calculateAnthropic429ResetTime(http.Header{})
if result != nil {
t.Errorf("expected nil result for empty headers, got resetAt=%v", result.resetAt)
}
}
func TestCalculateAnthropic429ResetTime_SurpassedThreshold(t *testing.T) {
headers := http.Header{}
headers.Set("anthropic-ratelimit-unified-5h-surpassed-threshold", "true")
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
headers.Set("anthropic-ratelimit-unified-7d-surpassed-threshold", "false")
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
result := calculateAnthropic429ResetTime(headers)
assertAnthropicResult(t, result, 1770998400)
}
func TestCalculateAnthropic429ResetTime_UtilizationExactlyOne(t *testing.T) {
headers := http.Header{}
headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.0")
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
headers.Set("anthropic-ratelimit-unified-7d-utilization", "0.5")
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
result := calculateAnthropic429ResetTime(headers)
assertAnthropicResult(t, result, 1770998400)
}
func TestCalculateAnthropic429ResetTime_NeitherExceeded_UsesShorter(t *testing.T) {
headers := http.Header{}
headers.Set("anthropic-ratelimit-unified-5h-utilization", "0.95")
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400") // sooner
headers.Set("anthropic-ratelimit-unified-7d-utilization", "0.80")
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200") // later
result := calculateAnthropic429ResetTime(headers)
assertAnthropicResult(t, result, 1770998400)
}
func TestCalculateAnthropic429ResetTime_Only5hResetHeader(t *testing.T) {
headers := http.Header{}
headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.05")
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
result := calculateAnthropic429ResetTime(headers)
assertAnthropicResult(t, result, 1770998400)
}
func TestCalculateAnthropic429ResetTime_Only7dResetHeader(t *testing.T) {
headers := http.Header{}
headers.Set("anthropic-ratelimit-unified-7d-utilization", "1.03")
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
result := calculateAnthropic429ResetTime(headers)
assertAnthropicResult(t, result, 1771549200)
if result.fiveHourReset != nil {
t.Errorf("expected fiveHourReset=nil when no 5h headers, got %v", result.fiveHourReset)
}
}
func TestIsAnthropicWindowExceeded(t *testing.T) {
tests := []struct {
name string
headers http.Header
window string
expected bool
}{
{
name: "utilization above 1.0",
headers: makeHeader("anthropic-ratelimit-unified-5h-utilization", "1.02"),
window: "5h",
expected: true,
},
{
name: "utilization exactly 1.0",
headers: makeHeader("anthropic-ratelimit-unified-5h-utilization", "1.0"),
window: "5h",
expected: true,
},
{
name: "utilization below 1.0",
headers: makeHeader("anthropic-ratelimit-unified-5h-utilization", "0.99"),
window: "5h",
expected: false,
},
{
name: "surpassed-threshold true",
headers: makeHeader("anthropic-ratelimit-unified-7d-surpassed-threshold", "true"),
window: "7d",
expected: true,
},
{
name: "surpassed-threshold True (case insensitive)",
headers: makeHeader("anthropic-ratelimit-unified-7d-surpassed-threshold", "True"),
window: "7d",
expected: true,
},
{
name: "surpassed-threshold false",
headers: makeHeader("anthropic-ratelimit-unified-7d-surpassed-threshold", "false"),
window: "7d",
expected: false,
},
{
name: "no headers",
headers: http.Header{},
window: "5h",
expected: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := isAnthropicWindowExceeded(tc.headers, tc.window)
if got != tc.expected {
t.Errorf("expected %v, got %v", tc.expected, got)
}
})
}
}
// assertAnthropicResult is a test helper that verifies the result is non-nil and
// has the expected resetAt unix timestamp.
func assertAnthropicResult(t *testing.T, result *anthropic429Result, wantUnix int64) {
t.Helper()
if result == nil {
t.Fatal("expected non-nil result")
return // unreachable, but satisfies staticcheck SA5011
}
want := time.Unix(wantUnix, 0)
if !result.resetAt.Equal(want) {
t.Errorf("expected resetAt=%v, got %v", want, result.resetAt)
}
}
func makeHeader(key, value string) http.Header {
h := http.Header{}
h.Set(key, value)
return h
}
......@@ -8,6 +8,7 @@ import (
"encoding/json"
"errors"
"fmt"
"hash/fnv"
"io"
"log"
"math/rand"
......@@ -17,12 +18,16 @@ import (
"net/textproto"
"net/url"
"path"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
openaioauth "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
"github.com/google/uuid"
"github.com/tidwall/gjson"
"golang.org/x/crypto/sha3"
......@@ -34,6 +39,11 @@ const (
soraDefaultUserAgent = "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)"
)
var (
soraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session"
soraOAuthTokenURL = "https://auth.openai.com/oauth/token"
)
const (
soraPowMaxIteration = 500000
)
......@@ -86,9 +96,20 @@ var soraDesktopUserAgents = []string{
"Mozilla/5.0 (Windows NT 11.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
}
var soraMobileUserAgents = []string{
"Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)",
"Sora/1.2026.007 (Android 14; SM-G998B; build 2600700)",
"Sora/1.2026.007 (Android 15; Pixel 8 Pro; build 2600700)",
"Sora/1.2026.007 (Android 14; Pixel 7; build 2600700)",
"Sora/1.2026.007 (Android 15; 2211133C; build 2600700)",
"Sora/1.2026.007 (Android 14; SM-S918B; build 2600700)",
"Sora/1.2026.007 (Android 15; OnePlus 12; build 2600700)",
}
var soraRand = rand.New(rand.NewSource(time.Now().UnixNano()))
var soraRandMu sync.Mutex
var soraPerfStart = time.Now()
var soraPowTokenGenerator = soraGetPowToken
// SoraClient 定义直连 Sora 的任务操作接口。
type SoraClient interface {
......@@ -96,6 +117,18 @@ type SoraClient interface {
UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error)
CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error)
CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error)
CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error)
UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error)
GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error)
DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error)
UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error)
FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error)
SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error
DeleteCharacter(ctx context.Context, account *Account, characterID string) error
PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error)
DeletePost(ctx context.Context, account *Account, postID string) error
GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error)
EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error)
GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error)
GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error)
}
......@@ -117,6 +150,17 @@ type SoraVideoRequest struct {
Size string
MediaID string
RemixTargetID string
CameoIDs []string
}
// SoraStoryboardRequest 分镜视频生成请求参数
type SoraStoryboardRequest struct {
Prompt string
Orientation string
Frames int
Model string
Size string
MediaID string
}
// SoraImageTaskStatus 图片任务状态
......@@ -130,11 +174,32 @@ type SoraImageTaskStatus struct {
// SoraVideoTaskStatus 视频任务状态
type SoraVideoTaskStatus struct {
ID string
Status string
ProgressPct int
URLs []string
ErrorMsg string
ID string
Status string
ProgressPct int
URLs []string
GenerationID string
ErrorMsg string
}
// SoraCameoStatus 角色处理中间态
type SoraCameoStatus struct {
Status string
StatusMessage string
DisplayNameHint string
UsernameHint string
ProfileAssetURL string
InstructionSetHint any
InstructionSet any
}
// SoraCharacterFinalizeRequest 角色定稿请求参数
type SoraCharacterFinalizeRequest struct {
CameoID string
Username string
DisplayName string
ProfileAssetPointer string
InstructionSet any
}
// SoraUpstreamError 上游错误
......@@ -157,26 +222,110 @@ func (e *SoraUpstreamError) Error() string {
// SoraDirectClient 直连 Sora 实现
type SoraDirectClient struct {
cfg *config.Config
httpUpstream HTTPUpstream
tokenProvider *OpenAITokenProvider
cfg *config.Config
httpUpstream HTTPUpstream
tokenProvider *OpenAITokenProvider
accountRepo AccountRepository
soraAccountRepo SoraAccountRepository
baseURL string
challengeCooldownMu sync.RWMutex
challengeCooldowns map[string]soraChallengeCooldownEntry
sidecarSessionMu sync.RWMutex
sidecarSessions map[string]soraSidecarSessionEntry
}
type soraRequestTraceContextKey struct{}
type soraRequestTrace struct {
ID string
ProxyKey string
UAHash string
}
// NewSoraDirectClient 创建 Sora 直连客户端
func NewSoraDirectClient(cfg *config.Config, httpUpstream HTTPUpstream, tokenProvider *OpenAITokenProvider) *SoraDirectClient {
baseURL := ""
if cfg != nil {
rawBaseURL := strings.TrimRight(strings.TrimSpace(cfg.Sora.Client.BaseURL), "/")
baseURL = normalizeSoraBaseURL(rawBaseURL)
if rawBaseURL != "" && baseURL != rawBaseURL {
log.Printf("[SoraClient] normalized base_url from %s to %s", sanitizeSoraLogURL(rawBaseURL), sanitizeSoraLogURL(baseURL))
}
}
return &SoraDirectClient{
cfg: cfg,
httpUpstream: httpUpstream,
tokenProvider: tokenProvider,
cfg: cfg,
httpUpstream: httpUpstream,
tokenProvider: tokenProvider,
baseURL: baseURL,
challengeCooldowns: make(map[string]soraChallengeCooldownEntry),
sidecarSessions: make(map[string]soraSidecarSessionEntry),
}
}
func (c *SoraDirectClient) SetAccountRepositories(accountRepo AccountRepository, soraAccountRepo SoraAccountRepository) {
if c == nil {
return
}
c.accountRepo = accountRepo
c.soraAccountRepo = soraAccountRepo
}
// Enabled 判断是否启用 Sora 直连
func (c *SoraDirectClient) Enabled() bool {
if c == nil || c.cfg == nil {
if c == nil {
return false
}
return strings.TrimSpace(c.cfg.Sora.Client.BaseURL) != ""
if strings.TrimSpace(c.baseURL) != "" {
return true
}
if c.cfg == nil {
return false
}
return strings.TrimSpace(normalizeSoraBaseURL(c.cfg.Sora.Client.BaseURL)) != ""
}
// PreflightCheck 在创建任务前执行账号能力预检。
// 当前仅对视频模型执行 /nf/check 预检,用于提前识别额度耗尽或能力缺失。
func (c *SoraDirectClient) PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error {
if modelCfg.Type != "video" {
return nil
}
token, err := c.getAccessToken(ctx, account)
if err != nil {
return err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Accept", "application/json")
body, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodGet, c.buildURL("/nf/check"), headers, nil, false)
if err != nil {
var upstreamErr *SoraUpstreamError
if errors.As(err, &upstreamErr) && upstreamErr.StatusCode == http.StatusNotFound {
return &SoraUpstreamError{
StatusCode: http.StatusForbidden,
Message: "当前账号未开通 Sora2 能力或无可用配额",
Headers: upstreamErr.Headers,
Body: upstreamErr.Body,
}
}
return err
}
rateLimitReached := gjson.GetBytes(body, "rate_limit_and_credit_balance.rate_limit_reached").Bool()
remaining := gjson.GetBytes(body, "rate_limit_and_credit_balance.estimated_num_videos_remaining")
if rateLimitReached || (remaining.Exists() && remaining.Int() <= 0) {
msg := "当前账号 Sora2 可用配额不足"
if requestedModel != "" {
msg = fmt.Sprintf("当前账号 %s 可用配额不足", requestedModel)
}
return &SoraUpstreamError{
StatusCode: http.StatusTooManyRequests,
Message: msg,
Headers: http.Header{},
}
}
return nil
}
func (c *SoraDirectClient) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) {
......@@ -187,6 +336,8 @@ func (c *SoraDirectClient) UploadImage(ctx context.Context, account *Account, da
if err != nil {
return "", err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
if filename == "" {
filename = "image.png"
}
......@@ -213,10 +364,10 @@ func (c *SoraDirectClient) UploadImage(ctx context.Context, account *Account, da
return "", err
}
headers := c.buildBaseHeaders(token, c.defaultUserAgent())
headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Content-Type", writer.FormDataContentType())
respBody, _, err := c.doRequest(ctx, account, http.MethodPost, c.buildURL("/uploads"), headers, &body, false)
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/uploads"), headers, &body, false)
if err != nil {
return "", err
}
......@@ -232,6 +383,9 @@ func (c *SoraDirectClient) CreateImageTask(ctx context.Context, account *Account
if err != nil {
return "", err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
ctx = c.withRequestTrace(ctx, account, proxyURL, userAgent)
operation := "simple_compose"
inpaintItems := []map[string]any{}
if strings.TrimSpace(req.MediaID) != "" {
......@@ -252,7 +406,7 @@ func (c *SoraDirectClient) CreateImageTask(ctx context.Context, account *Account
"n_frames": 1,
"inpaint_items": inpaintItems,
}
headers := c.buildBaseHeaders(token, c.defaultUserAgent())
headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Content-Type", "application/json")
headers.Set("Origin", "https://sora.chatgpt.com")
headers.Set("Referer", "https://sora.chatgpt.com/")
......@@ -261,13 +415,13 @@ func (c *SoraDirectClient) CreateImageTask(ctx context.Context, account *Account
if err != nil {
return "", err
}
sentinel, err := c.generateSentinelToken(ctx, account, token)
sentinel, err := c.generateSentinelToken(ctx, account, token, userAgent, proxyURL)
if err != nil {
return "", err
}
headers.Set("openai-sentinel-token", sentinel)
respBody, _, err := c.doRequest(ctx, account, http.MethodPost, c.buildURL("/video_gen"), headers, bytes.NewReader(body), true)
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/video_gen"), headers, bytes.NewReader(body), true)
if err != nil {
return "", err
}
......@@ -283,6 +437,9 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account
if err != nil {
return "", err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
ctx = c.withRequestTrace(ctx, account, proxyURL, userAgent)
orientation := req.Orientation
if orientation == "" {
orientation = "landscape"
......@@ -320,9 +477,12 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account
payload["remix_target_id"] = req.RemixTargetID
payload["cameo_ids"] = []string{}
payload["cameo_replacements"] = map[string]any{}
} else if len(req.CameoIDs) > 0 {
payload["cameo_ids"] = req.CameoIDs
payload["cameo_replacements"] = map[string]any{}
}
headers := c.buildBaseHeaders(token, c.defaultUserAgent())
headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Content-Type", "application/json")
headers.Set("Origin", "https://sora.chatgpt.com")
headers.Set("Referer", "https://sora.chatgpt.com/")
......@@ -330,13 +490,13 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account
if err != nil {
return "", err
}
sentinel, err := c.generateSentinelToken(ctx, account, token)
sentinel, err := c.generateSentinelToken(ctx, account, token, userAgent, proxyURL)
if err != nil {
return "", err
}
headers.Set("openai-sentinel-token", sentinel)
respBody, _, err := c.doRequest(ctx, account, http.MethodPost, c.buildURL("/nf/create"), headers, bytes.NewReader(body), true)
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/nf/create"), headers, bytes.NewReader(body), true)
if err != nil {
return "", err
}
......@@ -347,6 +507,469 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account
return taskID, nil
}
func (c *SoraDirectClient) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
ctx = c.withRequestTrace(ctx, account, proxyURL, userAgent)
orientation := req.Orientation
if orientation == "" {
orientation = "landscape"
}
nFrames := req.Frames
if nFrames <= 0 {
nFrames = 450
}
model := req.Model
if model == "" {
model = "sy_8"
}
size := req.Size
if size == "" {
size = "small"
}
inpaintItems := []map[string]any{}
if strings.TrimSpace(req.MediaID) != "" {
inpaintItems = append(inpaintItems, map[string]any{
"kind": "upload",
"upload_id": req.MediaID,
})
}
payload := map[string]any{
"kind": "video",
"prompt": req.Prompt,
"title": "Draft your video",
"orientation": orientation,
"size": size,
"n_frames": nFrames,
"storyboard_id": nil,
"inpaint_items": inpaintItems,
"remix_target_id": nil,
"model": model,
"metadata": nil,
"style_id": nil,
"cameo_ids": nil,
"cameo_replacements": nil,
"audio_caption": nil,
"audio_transcript": nil,
"video_caption": nil,
}
headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Content-Type", "application/json")
headers.Set("Origin", "https://sora.chatgpt.com")
headers.Set("Referer", "https://sora.chatgpt.com/")
body, err := json.Marshal(payload)
if err != nil {
return "", err
}
sentinel, err := c.generateSentinelToken(ctx, account, token, userAgent, proxyURL)
if err != nil {
return "", err
}
headers.Set("openai-sentinel-token", sentinel)
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/nf/create/storyboard"), headers, bytes.NewReader(body), true)
if err != nil {
return "", err
}
taskID := strings.TrimSpace(gjson.GetBytes(respBody, "id").String())
if taskID == "" {
return "", errors.New("storyboard task response missing id")
}
return taskID, nil
}
func (c *SoraDirectClient) UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) {
if len(data) == 0 {
return "", errors.New("empty video data")
}
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
var body bytes.Buffer
writer := multipart.NewWriter(&body)
partHeader := make(textproto.MIMEHeader)
partHeader.Set("Content-Disposition", `form-data; name="file"; filename="video.mp4"`)
partHeader.Set("Content-Type", "video/mp4")
part, err := writer.CreatePart(partHeader)
if err != nil {
return "", err
}
if _, err := part.Write(data); err != nil {
return "", err
}
if err := writer.WriteField("timestamps", "0,3"); err != nil {
return "", err
}
if err := writer.Close(); err != nil {
return "", err
}
headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Content-Type", writer.FormDataContentType())
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/characters/upload"), headers, &body, false)
if err != nil {
return "", err
}
cameoID := strings.TrimSpace(gjson.GetBytes(respBody, "id").String())
if cameoID == "" {
return "", errors.New("character upload response missing id")
}
return cameoID, nil
}
func (c *SoraDirectClient) GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return nil, err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
headers := c.buildBaseHeaders(token, userAgent)
respBody, _, err := c.doRequestWithProxy(
ctx,
account,
proxyURL,
http.MethodGet,
c.buildURL("/project_y/cameos/in_progress/"+strings.TrimSpace(cameoID)),
headers,
nil,
false,
)
if err != nil {
return nil, err
}
return &SoraCameoStatus{
Status: strings.TrimSpace(gjson.GetBytes(respBody, "status").String()),
StatusMessage: strings.TrimSpace(gjson.GetBytes(respBody, "status_message").String()),
DisplayNameHint: strings.TrimSpace(gjson.GetBytes(respBody, "display_name_hint").String()),
UsernameHint: strings.TrimSpace(gjson.GetBytes(respBody, "username_hint").String()),
ProfileAssetURL: strings.TrimSpace(gjson.GetBytes(respBody, "profile_asset_url").String()),
InstructionSetHint: gjson.GetBytes(respBody, "instruction_set_hint").Value(),
InstructionSet: gjson.GetBytes(respBody, "instruction_set").Value(),
}, nil
}
func (c *SoraDirectClient) DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return nil, err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Accept", "image/*,*/*;q=0.8")
respBody, _, err := c.doRequestWithProxy(
ctx,
account,
proxyURL,
http.MethodGet,
strings.TrimSpace(imageURL),
headers,
nil,
false,
)
if err != nil {
return nil, err
}
return respBody, nil
}
func (c *SoraDirectClient) UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) {
if len(data) == 0 {
return "", errors.New("empty character image")
}
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
var body bytes.Buffer
writer := multipart.NewWriter(&body)
partHeader := make(textproto.MIMEHeader)
partHeader.Set("Content-Disposition", `form-data; name="file"; filename="profile.webp"`)
partHeader.Set("Content-Type", "image/webp")
part, err := writer.CreatePart(partHeader)
if err != nil {
return "", err
}
if _, err := part.Write(data); err != nil {
return "", err
}
if err := writer.WriteField("use_case", "profile"); err != nil {
return "", err
}
if err := writer.Close(); err != nil {
return "", err
}
headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Content-Type", writer.FormDataContentType())
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/project_y/file/upload"), headers, &body, false)
if err != nil {
return "", err
}
assetPointer := strings.TrimSpace(gjson.GetBytes(respBody, "asset_pointer").String())
if assetPointer == "" {
return "", errors.New("character image upload response missing asset_pointer")
}
return assetPointer, nil
}
func (c *SoraDirectClient) FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
ctx = c.withRequestTrace(ctx, account, proxyURL, userAgent)
payload := map[string]any{
"cameo_id": req.CameoID,
"username": req.Username,
"display_name": req.DisplayName,
"profile_asset_pointer": req.ProfileAssetPointer,
"instruction_set": nil,
"safety_instruction_set": nil,
}
body, err := json.Marshal(payload)
if err != nil {
return "", err
}
headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Content-Type", "application/json")
headers.Set("Origin", "https://sora.chatgpt.com")
headers.Set("Referer", "https://sora.chatgpt.com/")
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/characters/finalize"), headers, bytes.NewReader(body), false)
if err != nil {
return "", err
}
characterID := strings.TrimSpace(gjson.GetBytes(respBody, "character.character_id").String())
if characterID == "" {
return "", errors.New("character finalize response missing character_id")
}
return characterID, nil
}
func (c *SoraDirectClient) SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
payload := map[string]any{"visibility": "public"}
body, err := json.Marshal(payload)
if err != nil {
return err
}
headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Content-Type", "application/json")
headers.Set("Origin", "https://sora.chatgpt.com")
headers.Set("Referer", "https://sora.chatgpt.com/")
_, _, err = c.doRequestWithProxy(
ctx,
account,
proxyURL,
http.MethodPost,
c.buildURL("/project_y/cameos/by_id/"+strings.TrimSpace(cameoID)+"/update_v2"),
headers,
bytes.NewReader(body),
false,
)
return err
}
func (c *SoraDirectClient) DeleteCharacter(ctx context.Context, account *Account, characterID string) error {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
headers := c.buildBaseHeaders(token, userAgent)
_, _, err = c.doRequestWithProxy(
ctx,
account,
proxyURL,
http.MethodDelete,
c.buildURL("/project_y/characters/"+strings.TrimSpace(characterID)),
headers,
nil,
false,
)
return err
}
func (c *SoraDirectClient) PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
ctx = c.withRequestTrace(ctx, account, proxyURL, userAgent)
payload := map[string]any{
"attachments_to_create": []map[string]any{
{
"generation_id": generationID,
"kind": "sora",
},
},
"post_text": "",
}
body, err := json.Marshal(payload)
if err != nil {
return "", err
}
headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Content-Type", "application/json")
headers.Set("Origin", "https://sora.chatgpt.com")
headers.Set("Referer", "https://sora.chatgpt.com/")
sentinel, err := c.generateSentinelToken(ctx, account, token, userAgent, proxyURL)
if err != nil {
return "", err
}
headers.Set("openai-sentinel-token", sentinel)
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/project_y/post"), headers, bytes.NewReader(body), true)
if err != nil {
return "", err
}
postID := strings.TrimSpace(gjson.GetBytes(respBody, "post.id").String())
if postID == "" {
return "", errors.New("watermark-free publish response missing post.id")
}
return postID, nil
}
func (c *SoraDirectClient) DeletePost(ctx context.Context, account *Account, postID string) error {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
headers := c.buildBaseHeaders(token, userAgent)
_, _, err = c.doRequestWithProxy(
ctx,
account,
proxyURL,
http.MethodDelete,
c.buildURL("/project_y/post/"+strings.TrimSpace(postID)),
headers,
nil,
false,
)
return err
}
func (c *SoraDirectClient) GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) {
parseURL = strings.TrimRight(strings.TrimSpace(parseURL), "/")
if parseURL == "" {
return "", errors.New("custom parse url is required")
}
if strings.TrimSpace(parseToken) == "" {
return "", errors.New("custom parse token is required")
}
shareURL := "https://sora.chatgpt.com/p/" + strings.TrimSpace(postID)
payload := map[string]any{
"url": shareURL,
"token": strings.TrimSpace(parseToken),
}
body, err := json.Marshal(payload)
if err != nil {
return "", err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, parseURL+"/get-sora-link", bytes.NewReader(body))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
proxyURL := c.resolveProxyURL(account)
accountID := int64(0)
accountConcurrency := 0
if account != nil {
accountID = account.ID
accountConcurrency = account.Concurrency
}
var resp *http.Response
if c.httpUpstream != nil {
resp, err = c.httpUpstream.Do(req, proxyURL, accountID, accountConcurrency)
} else {
resp, err = http.DefaultClient.Do(req)
}
if err != nil {
return "", err
}
defer func() { _ = resp.Body.Close() }()
raw, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20))
if err != nil {
return "", err
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("custom parse failed: %d %s", resp.StatusCode, truncateForLog(raw, 256))
}
downloadLink := strings.TrimSpace(gjson.GetBytes(raw, "download_link").String())
if downloadLink == "" {
return "", errors.New("custom parse response missing download_link")
}
return downloadLink, nil
}
func (c *SoraDirectClient) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
if strings.TrimSpace(expansionLevel) == "" {
expansionLevel = "medium"
}
if durationS <= 0 {
durationS = 10
}
payload := map[string]any{
"prompt": prompt,
"expansion_level": expansionLevel,
"duration_s": durationS,
}
body, err := json.Marshal(payload)
if err != nil {
return "", err
}
headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Content-Type", "application/json")
headers.Set("Accept", "application/json")
headers.Set("Origin", "https://sora.chatgpt.com")
headers.Set("Referer", "https://sora.chatgpt.com/")
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/editor/enhance_prompt"), headers, bytes.NewReader(body), false)
if err != nil {
return "", err
}
enhancedPrompt := strings.TrimSpace(gjson.GetBytes(respBody, "enhanced_prompt").String())
if enhancedPrompt == "" {
return "", errors.New("enhance_prompt response missing enhanced_prompt")
}
return enhancedPrompt, nil
}
func (c *SoraDirectClient) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) {
status, found, err := c.fetchRecentImageTask(ctx, account, taskID, c.recentTaskLimit())
if err != nil {
......@@ -373,12 +996,14 @@ func (c *SoraDirectClient) fetchRecentImageTask(ctx context.Context, account *Ac
if err != nil {
return nil, false, err
}
headers := c.buildBaseHeaders(token, c.defaultUserAgent())
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
headers := c.buildBaseHeaders(token, userAgent)
if limit <= 0 {
limit = 20
}
endpoint := fmt.Sprintf("/v2/recent_tasks?limit=%d", limit)
respBody, _, err := c.doRequest(ctx, account, http.MethodGet, c.buildURL(endpoint), headers, nil, false)
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodGet, c.buildURL(endpoint), headers, nil, false)
if err != nil {
return nil, false, err
}
......@@ -435,9 +1060,11 @@ func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, t
if err != nil {
return nil, err
}
headers := c.buildBaseHeaders(token, c.defaultUserAgent())
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
headers := c.buildBaseHeaders(token, userAgent)
respBody, _, err := c.doRequest(ctx, account, http.MethodGet, c.buildURL("/nf/pending/v2"), headers, nil, false)
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodGet, c.buildURL("/nf/pending/v2"), headers, nil, false)
if err != nil {
return nil, err
}
......@@ -466,7 +1093,7 @@ func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, t
}
}
respBody, _, err = c.doRequest(ctx, account, http.MethodGet, c.buildURL("/project_y/profile/drafts?limit=15"), headers, nil, false)
respBody, _, err = c.doRequestWithProxy(ctx, account, proxyURL, http.MethodGet, c.buildURL("/project_y/profile/drafts?limit=15"), headers, nil, false)
if err != nil {
return nil, err
}
......@@ -475,6 +1102,7 @@ func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, t
if draft.Get("task_id").String() != taskID {
return true
}
generationID := strings.TrimSpace(draft.Get("id").String())
kind := strings.TrimSpace(draft.Get("kind").String())
reason := strings.TrimSpace(draft.Get("reason_str").String())
if reason == "" {
......@@ -491,15 +1119,17 @@ func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, t
msg = "Content violates guardrails"
}
draftFound = &SoraVideoTaskStatus{
ID: taskID,
Status: "failed",
ErrorMsg: msg,
ID: taskID,
Status: "failed",
GenerationID: generationID,
ErrorMsg: msg,
}
} else {
draftFound = &SoraVideoTaskStatus{
ID: taskID,
Status: "completed",
URLs: []string{urlStr},
ID: taskID,
Status: "completed",
GenerationID: generationID,
URLs: []string{urlStr},
}
}
return false
......@@ -512,9 +1142,10 @@ func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, t
}
func (c *SoraDirectClient) buildURL(endpoint string) string {
base := ""
if c != nil && c.cfg != nil {
base = strings.TrimRight(strings.TrimSpace(c.cfg.Sora.Client.BaseURL), "/")
base := strings.TrimRight(strings.TrimSpace(c.baseURL), "/")
if base == "" && c != nil && c.cfg != nil {
base = normalizeSoraBaseURL(c.cfg.Sora.Client.BaseURL)
c.baseURL = base
}
if base == "" {
return endpoint
......@@ -536,18 +1167,278 @@ func (c *SoraDirectClient) defaultUserAgent() string {
return ua
}
func (c *SoraDirectClient) taskUserAgent() string {
if c != nil && c.cfg != nil {
if ua := strings.TrimSpace(c.cfg.Sora.Client.UserAgent); ua != "" {
return ua
}
}
if len(soraMobileUserAgents) > 0 {
return soraMobileUserAgents[soraRandInt(len(soraMobileUserAgents))]
}
if len(soraDesktopUserAgents) > 0 {
return soraDesktopUserAgents[soraRandInt(len(soraDesktopUserAgents))]
}
return soraDefaultUserAgent
}
func (c *SoraDirectClient) resolveProxyURL(account *Account) string {
if account == nil || account.ProxyID == nil || account.Proxy == nil {
return ""
}
return strings.TrimSpace(account.Proxy.URL())
}
func (c *SoraDirectClient) getAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
if c.tokenProvider != nil {
return c.tokenProvider.GetAccessToken(ctx, account)
allowProvider := c.allowOpenAITokenProvider(account)
var providerErr error
if allowProvider && c.tokenProvider != nil {
token, err := c.tokenProvider.GetAccessToken(ctx, account)
if err == nil && strings.TrimSpace(token) != "" {
c.logTokenSource(account, "openai_token_provider")
return token, nil
}
providerErr = err
if err != nil && c.debugEnabled() {
c.debugLogf(
"token_provider_failed account_id=%d platform=%s err=%s",
account.ID,
account.Platform,
logredact.RedactText(err.Error()),
)
}
}
token := strings.TrimSpace(account.GetCredential("access_token"))
if token == "" {
return "", errors.New("access_token not found")
if token != "" {
expiresAt := account.GetCredentialAsTime("expires_at")
if expiresAt != nil && time.Until(*expiresAt) <= 2*time.Minute {
refreshed, refreshErr := c.recoverAccessToken(ctx, account, "access_token_expiring")
if refreshErr == nil && strings.TrimSpace(refreshed) != "" {
c.logTokenSource(account, "refresh_token_recovered")
return refreshed, nil
}
if refreshErr != nil && c.debugEnabled() {
c.debugLogf("token_refresh_before_use_failed account_id=%d err=%s", account.ID, logredact.RedactText(refreshErr.Error()))
}
}
c.logTokenSource(account, "account_credentials")
return token, nil
}
recovered, recoverErr := c.recoverAccessToken(ctx, account, "access_token_missing")
if recoverErr == nil && strings.TrimSpace(recovered) != "" {
c.logTokenSource(account, "session_or_refresh_recovered")
return recovered, nil
}
if recoverErr != nil && c.debugEnabled() {
c.debugLogf("token_recover_failed account_id=%d platform=%s err=%s", account.ID, account.Platform, logredact.RedactText(recoverErr.Error()))
}
if providerErr != nil {
return "", providerErr
}
if c.tokenProvider != nil && !allowProvider {
c.logTokenSource(account, "account_credentials(provider_disabled)")
}
return "", errors.New("access_token not found")
}
func (c *SoraDirectClient) recoverAccessToken(ctx context.Context, account *Account, reason string) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
if sessionToken := strings.TrimSpace(account.GetCredential("session_token")); sessionToken != "" {
accessToken, expiresAt, err := c.exchangeSessionToken(ctx, account, sessionToken)
if err == nil && strings.TrimSpace(accessToken) != "" {
c.applyRecoveredToken(ctx, account, accessToken, "", expiresAt, sessionToken)
c.logTokenRecover(account, "session_token", reason, true, nil)
return accessToken, nil
}
c.logTokenRecover(account, "session_token", reason, false, err)
}
refreshToken := strings.TrimSpace(account.GetCredential("refresh_token"))
if refreshToken == "" {
return "", errors.New("session_token/refresh_token not found")
}
return token, nil
accessToken, newRefreshToken, expiresAt, err := c.exchangeRefreshToken(ctx, account, refreshToken)
if err != nil {
c.logTokenRecover(account, "refresh_token", reason, false, err)
return "", err
}
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("refreshed access_token is empty")
}
c.applyRecoveredToken(ctx, account, accessToken, newRefreshToken, expiresAt, "")
c.logTokenRecover(account, "refresh_token", reason, true, nil)
return accessToken, nil
}
func (c *SoraDirectClient) exchangeSessionToken(ctx context.Context, account *Account, sessionToken string) (string, string, error) {
headers := http.Header{}
headers.Set("Cookie", "__Secure-next-auth.session-token="+sessionToken)
headers.Set("Accept", "application/json")
headers.Set("Origin", "https://sora.chatgpt.com")
headers.Set("Referer", "https://sora.chatgpt.com/")
headers.Set("User-Agent", c.defaultUserAgent())
body, _, err := c.doRequest(ctx, account, http.MethodGet, soraSessionAuthURL, headers, nil, false)
if err != nil {
return "", "", err
}
accessToken := strings.TrimSpace(gjson.GetBytes(body, "accessToken").String())
if accessToken == "" {
return "", "", errors.New("session exchange missing accessToken")
}
expiresAt := strings.TrimSpace(gjson.GetBytes(body, "expires").String())
return accessToken, expiresAt, nil
}
func (c *SoraDirectClient) exchangeRefreshToken(ctx context.Context, account *Account, refreshToken string) (string, string, string, error) {
clientIDs := []string{
strings.TrimSpace(account.GetCredential("client_id")),
openaioauth.SoraClientID,
openaioauth.ClientID,
}
tried := make(map[string]struct{}, len(clientIDs))
var lastErr error
for _, clientID := range clientIDs {
if clientID == "" {
continue
}
if _, ok := tried[clientID]; ok {
continue
}
tried[clientID] = struct{}{}
formData := url.Values{}
formData.Set("client_id", clientID)
formData.Set("grant_type", "refresh_token")
formData.Set("refresh_token", refreshToken)
formData.Set("redirect_uri", "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback")
headers := http.Header{}
headers.Set("Accept", "application/json")
headers.Set("Content-Type", "application/x-www-form-urlencoded")
headers.Set("User-Agent", c.defaultUserAgent())
respBody, _, err := c.doRequest(ctx, account, http.MethodPost, soraOAuthTokenURL, headers, strings.NewReader(formData.Encode()), false)
if err != nil {
lastErr = err
if c.debugEnabled() {
c.debugLogf("refresh_token_exchange_failed account_id=%d client_id=%s err=%s", account.ID, clientID, logredact.RedactText(err.Error()))
}
continue
}
accessToken := strings.TrimSpace(gjson.GetBytes(respBody, "access_token").String())
if accessToken == "" {
lastErr = errors.New("oauth refresh response missing access_token")
continue
}
newRefreshToken := strings.TrimSpace(gjson.GetBytes(respBody, "refresh_token").String())
expiresIn := gjson.GetBytes(respBody, "expires_in").Int()
expiresAt := ""
if expiresIn > 0 {
expiresAt = time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339)
}
return accessToken, newRefreshToken, expiresAt, nil
}
if lastErr != nil {
return "", "", "", lastErr
}
return "", "", "", errors.New("no available client_id for refresh_token exchange")
}
func (c *SoraDirectClient) applyRecoveredToken(ctx context.Context, account *Account, accessToken, refreshToken, expiresAt, sessionToken string) {
if account == nil {
return
}
if account.Credentials == nil {
account.Credentials = make(map[string]any)
}
if strings.TrimSpace(accessToken) != "" {
account.Credentials["access_token"] = accessToken
}
if strings.TrimSpace(refreshToken) != "" {
account.Credentials["refresh_token"] = refreshToken
}
if strings.TrimSpace(expiresAt) != "" {
account.Credentials["expires_at"] = expiresAt
}
if strings.TrimSpace(sessionToken) != "" {
account.Credentials["session_token"] = sessionToken
}
if c.accountRepo != nil {
if err := c.accountRepo.Update(ctx, account); err != nil {
if c.debugEnabled() {
c.debugLogf("persist_recovered_token_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error()))
}
}
}
c.updateSoraAccountExtension(ctx, account, accessToken, refreshToken, sessionToken)
}
func (c *SoraDirectClient) updateSoraAccountExtension(ctx context.Context, account *Account, accessToken, refreshToken, sessionToken string) {
if c == nil || c.soraAccountRepo == nil || account == nil || account.ID <= 0 {
return
}
updates := make(map[string]any)
if strings.TrimSpace(accessToken) != "" && strings.TrimSpace(refreshToken) != "" {
updates["access_token"] = accessToken
updates["refresh_token"] = refreshToken
}
if strings.TrimSpace(sessionToken) != "" {
updates["session_token"] = sessionToken
}
if len(updates) == 0 {
return
}
if err := c.soraAccountRepo.Upsert(ctx, account.ID, updates); err != nil && c.debugEnabled() {
c.debugLogf("persist_sora_extension_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error()))
}
}
func (c *SoraDirectClient) logTokenRecover(account *Account, source, reason string, success bool, err error) {
if !c.debugEnabled() || account == nil {
return
}
if success {
c.debugLogf("token_recover_success account_id=%d platform=%s source=%s reason=%s", account.ID, account.Platform, source, reason)
return
}
if err == nil {
c.debugLogf("token_recover_failed account_id=%d platform=%s source=%s reason=%s", account.ID, account.Platform, source, reason)
return
}
c.debugLogf("token_recover_failed account_id=%d platform=%s source=%s reason=%s err=%s", account.ID, account.Platform, source, reason, logredact.RedactText(err.Error()))
}
func (c *SoraDirectClient) allowOpenAITokenProvider(account *Account) bool {
if c == nil || c.tokenProvider == nil {
return false
}
if account != nil && account.Platform == PlatformSora {
return c.cfg != nil && c.cfg.Sora.Client.UseOpenAITokenProvider
}
return true
}
func (c *SoraDirectClient) logTokenSource(account *Account, source string) {
if !c.debugEnabled() || account == nil {
return
}
c.debugLogf(
"token_selected account_id=%d platform=%s account_type=%s source=%s",
account.ID,
account.Platform,
account.Type,
source,
)
}
func (c *SoraDirectClient) buildBaseHeaders(token, userAgent string) http.Header {
......@@ -570,9 +1461,30 @@ func (c *SoraDirectClient) buildBaseHeaders(token, userAgent string) http.Header
}
func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, method, urlStr string, headers http.Header, body io.Reader, allowRetry bool) ([]byte, http.Header, error) {
return c.doRequestWithProxy(ctx, account, c.resolveProxyURL(account), method, urlStr, headers, body, allowRetry)
}
func (c *SoraDirectClient) doRequestWithProxy(
ctx context.Context,
account *Account,
proxyURL string,
method,
urlStr string,
headers http.Header,
body io.Reader,
allowRetry bool,
) ([]byte, http.Header, error) {
if strings.TrimSpace(urlStr) == "" {
return nil, nil, errors.New("empty upstream url")
}
proxyURL = strings.TrimSpace(proxyURL)
if proxyURL == "" {
proxyURL = c.resolveProxyURL(account)
}
if cooldownErr := c.checkCloudflareChallengeCooldown(account, proxyURL); cooldownErr != nil {
return nil, nil, cooldownErr
}
traceID, traceProxyKey, traceUAHash := c.requestTraceFields(ctx, proxyURL, headers.Get("User-Agent"))
timeout := 0
if c != nil && c.cfg != nil {
timeout = c.cfg.Sora.Client.TimeoutSeconds
......@@ -600,7 +1512,29 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth
}
attempts := maxRetries + 1
authRecovered := false
authRecoverExtraAttemptGranted := false
challengeRetried := false
sawCFChallenge := false
var lastErr error
for attempt := 1; attempt <= attempts; attempt++ {
if c.debugEnabled() {
c.debugLogf(
"request_start trace_id=%s method=%s url=%s attempt=%d/%d timeout_s=%d body_bytes=%d proxy_bound=%t proxy_key=%s ua_hash=%s headers=%s",
traceID,
method,
sanitizeSoraLogURL(urlStr),
attempt,
attempts,
timeout,
len(bodyBytes),
proxyURL != "",
traceProxyKey,
traceUAHash,
formatSoraHeaders(headers),
)
}
var reader io.Reader
if bodyBytes != nil {
reader = bytes.NewReader(bodyBytes)
......@@ -612,13 +1546,24 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth
req.Header = headers.Clone()
start := time.Now()
proxyURL := ""
if account != nil && account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := c.doHTTP(req, proxyURL, account)
if err != nil {
lastErr = err
if c.debugEnabled() {
c.debugLogf(
"request_transport_error trace_id=%s method=%s url=%s attempt=%d/%d err=%s",
traceID,
method,
sanitizeSoraLogURL(urlStr),
attempt,
attempts,
logredact.RedactText(err.Error()),
)
}
if attempt < attempts && allowRetry {
if c.debugEnabled() {
c.debugLogf("request_retry_scheduled trace_id=%s method=%s url=%s reason=transport_error next_attempt=%d/%d", traceID, method, sanitizeSoraLogURL(urlStr), attempt+1, attempts)
}
c.sleepRetry(attempt)
continue
}
......@@ -632,24 +1577,119 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth
}
if c.cfg != nil && c.cfg.Sora.Client.Debug {
log.Printf("[SoraClient] %s %s status=%d cost=%s", method, sanitizeSoraLogURL(urlStr), resp.StatusCode, time.Since(start))
c.debugLogf(
"response_received trace_id=%s method=%s url=%s attempt=%d/%d status=%d cost=%s resp_bytes=%d resp_headers=%s",
traceID,
method,
sanitizeSoraLogURL(urlStr),
attempt,
attempts,
resp.StatusCode,
time.Since(start),
len(respBody),
formatSoraHeaders(resp.Header),
)
}
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
upstreamErr := c.buildUpstreamError(resp.StatusCode, resp.Header, respBody)
isCFChallenge := soraerror.IsCloudflareChallengeResponse(resp.StatusCode, resp.Header, respBody)
if isCFChallenge {
sawCFChallenge = true
c.recordCloudflareChallengeCooldown(account, proxyURL, resp.StatusCode, resp.Header, respBody)
if allowRetry && attempt < attempts && !challengeRetried {
challengeRetried = true
if c.debugEnabled() {
c.debugLogf("request_retry_scheduled trace_id=%s method=%s url=%s reason=cloudflare_challenge status=%d next_attempt=%d/%d", traceID, method, sanitizeSoraLogURL(urlStr), resp.StatusCode, attempt+1, attempts)
}
c.sleepRetry(attempt)
continue
}
}
if !isCFChallenge && !authRecovered && shouldAttemptSoraTokenRecover(resp.StatusCode, urlStr) && account != nil {
if recovered, recoverErr := c.recoverAccessToken(ctx, account, fmt.Sprintf("upstream_status_%d", resp.StatusCode)); recoverErr == nil && strings.TrimSpace(recovered) != "" {
headers.Set("Authorization", "Bearer "+recovered)
authRecovered = true
if attempt == attempts && !authRecoverExtraAttemptGranted {
attempts++
authRecoverExtraAttemptGranted = true
}
if c.debugEnabled() {
c.debugLogf("request_retry_with_recovered_token trace_id=%s method=%s url=%s status=%d", traceID, method, sanitizeSoraLogURL(urlStr), resp.StatusCode)
}
continue
} else if recoverErr != nil && c.debugEnabled() {
c.debugLogf("request_recover_token_failed trace_id=%s method=%s url=%s status=%d err=%s", traceID, method, sanitizeSoraLogURL(urlStr), resp.StatusCode, logredact.RedactText(recoverErr.Error()))
}
}
if c.debugEnabled() {
c.debugLogf(
"response_non_success trace_id=%s method=%s url=%s attempt=%d/%d status=%d body=%s",
traceID,
method,
sanitizeSoraLogURL(urlStr),
attempt,
attempts,
resp.StatusCode,
summarizeSoraResponseBody(respBody, 512),
)
}
upstreamErr := c.buildUpstreamError(resp.StatusCode, resp.Header, respBody, urlStr)
lastErr = upstreamErr
if isCFChallenge {
return nil, resp.Header, upstreamErr
}
if allowRetry && attempt < attempts && (resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode >= 500) {
if c.debugEnabled() {
c.debugLogf("request_retry_scheduled trace_id=%s method=%s url=%s reason=status_%d next_attempt=%d/%d", traceID, method, sanitizeSoraLogURL(urlStr), resp.StatusCode, attempt+1, attempts)
}
c.sleepRetry(attempt)
continue
}
return nil, resp.Header, upstreamErr
}
if sawCFChallenge {
c.clearCloudflareChallengeCooldown(account, proxyURL)
}
return respBody, resp.Header, nil
}
if lastErr != nil {
return nil, nil, lastErr
}
return nil, nil, errors.New("upstream retries exhausted")
}
func shouldAttemptSoraTokenRecover(statusCode int, rawURL string) bool {
switch statusCode {
case http.StatusUnauthorized, http.StatusForbidden:
parsed, err := url.Parse(strings.TrimSpace(rawURL))
if err != nil {
return false
}
host := strings.ToLower(parsed.Hostname())
if host != "sora.chatgpt.com" && host != "chatgpt.com" {
return false
}
// 避免在 ST->AT 转换接口上递归触发 token 恢复导致死循环。
path := strings.ToLower(strings.TrimSpace(parsed.Path))
if path == "/api/auth/session" {
return false
}
return true
default:
return false
}
}
func (c *SoraDirectClient) doHTTP(req *http.Request, proxyURL string, account *Account) (*http.Response, error) {
enableTLS := c != nil && c.cfg != nil && c.cfg.Gateway.TLSFingerprint.Enabled && !c.cfg.Sora.Client.DisableTLSFingerprint
if c != nil && c.cfg != nil && c.cfg.Sora.Client.CurlCFFISidecar.Enabled {
resp, err := c.doHTTPViaCurlCFFISidecar(req, proxyURL, account)
if err != nil {
return nil, err
}
return resp, nil
}
enableTLS := c == nil || c.cfg == nil || !c.cfg.Sora.Client.DisableTLSFingerprint
if c.httpUpstream != nil {
accountID := int64(0)
accountConcurrency := 0
......@@ -670,9 +1710,14 @@ func (c *SoraDirectClient) sleepRetry(attempt int) {
time.Sleep(backoff)
}
func (c *SoraDirectClient) buildUpstreamError(status int, headers http.Header, body []byte) error {
func (c *SoraDirectClient) buildUpstreamError(status int, headers http.Header, body []byte, requestURL string) error {
msg := strings.TrimSpace(extractUpstreamErrorMessage(body))
msg = sanitizeUpstreamErrorMessage(msg)
if status == http.StatusNotFound && strings.Contains(strings.ToLower(msg), "not found") {
if hint := soraBaseURLNotFoundHint(requestURL); hint != "" {
msg = strings.TrimSpace(msg + " " + hint)
}
}
if msg == "" {
msg = truncateForLog(body, 256)
}
......@@ -684,10 +1729,52 @@ func (c *SoraDirectClient) buildUpstreamError(status int, headers http.Header, b
}
}
func (c *SoraDirectClient) generateSentinelToken(ctx context.Context, account *Account, accessToken string) (string, error) {
func normalizeSoraBaseURL(raw string) string {
trimmed := strings.TrimRight(strings.TrimSpace(raw), "/")
if trimmed == "" {
return ""
}
parsed, err := url.Parse(trimmed)
if err != nil || parsed.Scheme == "" || parsed.Host == "" {
return trimmed
}
host := strings.ToLower(parsed.Hostname())
if host != "sora.chatgpt.com" && host != "chatgpt.com" {
return trimmed
}
pathVal := strings.TrimRight(strings.TrimSpace(parsed.Path), "/")
switch pathVal {
case "", "/":
parsed.Path = "/backend"
case "/backend-api":
parsed.Path = "/backend"
}
return strings.TrimRight(parsed.String(), "/")
}
func soraBaseURLNotFoundHint(requestURL string) string {
parsed, err := url.Parse(strings.TrimSpace(requestURL))
if err != nil || parsed.Host == "" {
return ""
}
host := strings.ToLower(parsed.Hostname())
if host != "sora.chatgpt.com" && host != "chatgpt.com" {
return ""
}
pathVal := strings.TrimSpace(parsed.Path)
if strings.HasPrefix(pathVal, "/backend/") || pathVal == "/backend" {
return ""
}
return "(请检查 sora.client.base_url,建议配置为 https://sora.chatgpt.com/backend)"
}
func (c *SoraDirectClient) generateSentinelToken(ctx context.Context, account *Account, accessToken, userAgent, proxyURL string) (string, error) {
reqID := uuid.NewString()
userAgent := soraRandChoice(soraDesktopUserAgents)
powToken := soraGetPowToken(userAgent)
userAgent = strings.TrimSpace(userAgent)
if userAgent == "" {
userAgent = c.taskUserAgent()
}
powToken := soraPowTokenGenerator(userAgent)
payload := map[string]any{
"p": powToken,
"flow": soraSentinelFlow,
......@@ -708,7 +1795,7 @@ func (c *SoraDirectClient) generateSentinelToken(ctx context.Context, account *A
}
urlStr := soraChatGPTBaseURL + "/backend-api/sentinel/req"
respBody, _, err := c.doRequest(ctx, account, http.MethodPost, urlStr, headers, bytes.NewReader(body), true)
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, urlStr, headers, bytes.NewReader(body), true)
if err != nil {
return "", err
}
......@@ -724,16 +1811,6 @@ func (c *SoraDirectClient) generateSentinelToken(ctx context.Context, account *A
return sentinel, nil
}
func soraRandChoice(items []string) string {
if len(items) == 0 {
return ""
}
soraRandMu.Lock()
idx := soraRand.Intn(len(items))
soraRandMu.Unlock()
return items[idx]
}
func soraGetPowToken(userAgent string) string {
configList := soraBuildPowConfig(userAgent)
seed := strconv.FormatFloat(soraRandFloat(), 'f', -1, 64)
......@@ -748,14 +1825,26 @@ func soraRandFloat() float64 {
return soraRand.Float64()
}
func soraRandInt(max int) int {
if max <= 1 {
return 0
}
soraRandMu.Lock()
defer soraRandMu.Unlock()
return soraRand.Intn(max)
}
func soraBuildPowConfig(userAgent string) []any {
screen := soraRandChoice([]string{
strconv.Itoa(1920 + 1080),
strconv.Itoa(2560 + 1440),
strconv.Itoa(1920 + 1200),
strconv.Itoa(2560 + 1600),
})
screenVal, _ := strconv.Atoi(screen)
userAgent = strings.TrimSpace(userAgent)
if userAgent == "" && len(soraDesktopUserAgents) > 0 {
userAgent = soraDesktopUserAgents[0]
}
screenVal := soraStableChoiceInt([]int{
1920 + 1080,
2560 + 1440,
1920 + 1200,
2560 + 1600,
}, userAgent+"|screen")
perfMs := float64(time.Since(soraPerfStart).Milliseconds())
wallMs := float64(time.Now().UnixNano()) / 1e6
diff := wallMs - perfMs
......@@ -765,32 +1854,47 @@ func soraBuildPowConfig(userAgent string) []any {
4294705152,
0,
userAgent,
soraRandChoice(soraPowScripts),
soraRandChoice(soraPowDPL),
soraStableChoice(soraPowScripts, userAgent+"|script"),
soraStableChoice(soraPowDPL, userAgent+"|dpl"),
"en-US",
"en-US,es-US,en,es",
0,
soraRandChoice(soraPowNavigatorKeys),
soraRandChoice(soraPowDocumentKeys),
soraRandChoice(soraPowWindowKeys),
soraStableChoice(soraPowNavigatorKeys, userAgent+"|navigator"),
soraStableChoice(soraPowDocumentKeys, userAgent+"|document"),
soraStableChoice(soraPowWindowKeys, userAgent+"|window"),
perfMs,
uuid.NewString(),
"",
soraRandChoiceInt(soraPowCores),
soraStableChoiceInt(soraPowCores, userAgent+"|cores"),
diff,
}
}
func soraRandChoiceInt(items []int) int {
func soraStableChoice(items []string, seed string) string {
if len(items) == 0 {
return ""
}
idx := soraStableIndex(seed, len(items))
return items[idx]
}
func soraStableChoiceInt(items []int, seed string) int {
if len(items) == 0 {
return 0
}
soraRandMu.Lock()
idx := soraRand.Intn(len(items))
soraRandMu.Unlock()
idx := soraStableIndex(seed, len(items))
return items[idx]
}
func soraStableIndex(seed string, size int) int {
if size <= 0 {
return 0
}
h := fnv.New32a()
_, _ = h.Write([]byte(seed))
return int(h.Sum32() % uint32(size))
}
func soraPowParseTime() string {
loc := time.FixedZone("EST", -5*3600)
return time.Now().In(loc).Format("Mon Jan 02 2006 15:04:05 GMT-0700 (Eastern Standard Time)")
......@@ -890,6 +1994,55 @@ func hexDecodeString(s string) ([]byte, error) {
return dst, err
}
func (c *SoraDirectClient) withRequestTrace(ctx context.Context, account *Account, proxyURL, userAgent string) context.Context {
if ctx == nil {
ctx = context.Background()
}
if existing, ok := ctx.Value(soraRequestTraceContextKey{}).(*soraRequestTrace); ok && existing != nil && existing.ID != "" {
return ctx
}
accountID := int64(0)
if account != nil {
accountID = account.ID
}
seed := fmt.Sprintf("%d|%s|%s|%d", accountID, normalizeSoraProxyKey(proxyURL), strings.TrimSpace(userAgent), time.Now().UnixNano())
trace := &soraRequestTrace{
ID: "sora-" + soraHashForLog(seed),
ProxyKey: normalizeSoraProxyKey(proxyURL),
UAHash: soraHashForLog(strings.TrimSpace(userAgent)),
}
return context.WithValue(ctx, soraRequestTraceContextKey{}, trace)
}
func (c *SoraDirectClient) requestTraceFields(ctx context.Context, proxyURL, userAgent string) (string, string, string) {
proxyKey := normalizeSoraProxyKey(proxyURL)
uaHash := soraHashForLog(strings.TrimSpace(userAgent))
traceID := ""
if ctx != nil {
if trace, ok := ctx.Value(soraRequestTraceContextKey{}).(*soraRequestTrace); ok && trace != nil {
if strings.TrimSpace(trace.ID) != "" {
traceID = strings.TrimSpace(trace.ID)
}
if strings.TrimSpace(trace.ProxyKey) != "" {
proxyKey = strings.TrimSpace(trace.ProxyKey)
}
if strings.TrimSpace(trace.UAHash) != "" {
uaHash = strings.TrimSpace(trace.UAHash)
}
}
}
if traceID == "" {
traceID = "sora-" + soraHashForLog(fmt.Sprintf("%s|%d", proxyKey, time.Now().UnixNano()))
}
return traceID, proxyKey, uaHash
}
func soraHashForLog(raw string) string {
h := fnv.New32a()
_, _ = h.Write([]byte(raw))
return fmt.Sprintf("%08x", h.Sum32())
}
func sanitizeSoraLogURL(raw string) string {
parsed, err := url.Parse(raw)
if err != nil {
......@@ -901,3 +2054,70 @@ func sanitizeSoraLogURL(raw string) string {
parsed.RawQuery = q.Encode()
return parsed.String()
}
func (c *SoraDirectClient) debugEnabled() bool {
return c != nil && c.cfg != nil && c.cfg.Sora.Client.Debug
}
func (c *SoraDirectClient) debugLogf(format string, args ...any) {
if !c.debugEnabled() {
return
}
log.Printf("[SoraClient] "+format, args...)
}
func formatSoraHeaders(headers http.Header) string {
if len(headers) == 0 {
return "{}"
}
keys := make([]string, 0, len(headers))
for key := range headers {
keys = append(keys, key)
}
sort.Strings(keys)
out := make(map[string]string, len(keys))
for _, key := range keys {
values := headers.Values(key)
if len(values) == 0 {
continue
}
val := strings.Join(values, ",")
if isSensitiveHeader(key) {
out[key] = "***"
continue
}
out[key] = truncateForLog([]byte(logredact.RedactText(val)), 160)
}
encoded, err := json.Marshal(out)
if err != nil {
return "{}"
}
return string(encoded)
}
func isSensitiveHeader(key string) bool {
k := strings.ToLower(strings.TrimSpace(key))
switch k {
case "authorization", "openai-sentinel-token", "cookie", "set-cookie", "x-api-key":
return true
default:
return false
}
}
func summarizeSoraResponseBody(body []byte, maxLen int) string {
if len(body) == 0 {
return ""
}
var text string
if json.Valid(body) {
text = logredact.RedactJSON(body)
} else {
text = logredact.RedactText(string(body))
}
text = strings.TrimSpace(text)
if maxLen <= 0 || len(text) <= maxLen {
return text
}
return text[:maxLen] + "...(truncated)"
}
......@@ -4,9 +4,16 @@ package service
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
......@@ -85,3 +92,984 @@ func TestSoraDirectClient_GetImageTaskFallbackLimit(t *testing.T) {
require.Equal(t, "completed", status.Status)
require.Equal(t, []string{"https://example.com/a.png"}, status.URLs)
}
func TestNormalizeSoraBaseURL(t *testing.T) {
t.Parallel()
tests := []struct {
name string
raw string
want string
}{
{
name: "empty",
raw: "",
want: "",
},
{
name: "append_backend_for_sora_host",
raw: "https://sora.chatgpt.com",
want: "https://sora.chatgpt.com/backend",
},
{
name: "convert_backend_api_to_backend",
raw: "https://sora.chatgpt.com/backend-api",
want: "https://sora.chatgpt.com/backend",
},
{
name: "keep_backend",
raw: "https://sora.chatgpt.com/backend",
want: "https://sora.chatgpt.com/backend",
},
{
name: "keep_custom_host",
raw: "https://example.com/custom-path",
want: "https://example.com/custom-path",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := normalizeSoraBaseURL(tt.raw)
require.Equal(t, tt.want, got)
})
}
}
func TestSoraDirectClient_BuildURL_UsesNormalizedBaseURL(t *testing.T) {
t.Parallel()
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com",
},
},
}
client := NewSoraDirectClient(cfg, nil, nil)
require.Equal(t, "https://sora.chatgpt.com/backend/video_gen", client.buildURL("/video_gen"))
}
func TestSoraDirectClient_BuildUpstreamError_NotFoundHint(t *testing.T) {
t.Parallel()
client := NewSoraDirectClient(&config.Config{}, nil, nil)
err := client.buildUpstreamError(http.StatusNotFound, http.Header{}, []byte(`{"error":{"message":"Not found"}}`), "https://sora.chatgpt.com/video_gen")
var upstreamErr *SoraUpstreamError
require.ErrorAs(t, err, &upstreamErr)
require.Contains(t, upstreamErr.Message, "请检查 sora.client.base_url")
errNoHint := client.buildUpstreamError(http.StatusNotFound, http.Header{}, []byte(`{"error":{"message":"Not found"}}`), "https://sora.chatgpt.com/backend/video_gen")
require.ErrorAs(t, errNoHint, &upstreamErr)
require.NotContains(t, upstreamErr.Message, "请检查 sora.client.base_url")
}
func TestFormatSoraHeaders_RedactsSensitive(t *testing.T) {
t.Parallel()
headers := http.Header{}
headers.Set("Authorization", "Bearer secret-token")
headers.Set("openai-sentinel-token", "sentinel-secret")
headers.Set("X-Test", "ok")
out := formatSoraHeaders(headers)
require.Contains(t, out, `"Authorization":"***"`)
require.Contains(t, out, `Sentinel-Token":"***"`)
require.Contains(t, out, `"X-Test":"ok"`)
require.NotContains(t, out, "secret-token")
require.NotContains(t, out, "sentinel-secret")
}
func TestSummarizeSoraResponseBody_RedactsJSON(t *testing.T) {
t.Parallel()
body := []byte(`{"error":{"message":"bad"},"access_token":"abc123"}`)
out := summarizeSoraResponseBody(body, 512)
require.Contains(t, out, `"access_token":"***"`)
require.NotContains(t, out, "abc123")
}
func TestSummarizeSoraResponseBody_Truncates(t *testing.T) {
t.Parallel()
body := []byte(strings.Repeat("x", 100))
out := summarizeSoraResponseBody(body, 10)
require.Contains(t, out, "(truncated)")
}
func TestSoraDirectClient_GetAccessToken_SoraDefaultUseCredentials(t *testing.T) {
t.Parallel()
cache := newOpenAITokenCacheStub()
provider := NewOpenAITokenProvider(nil, cache, nil)
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com/backend",
},
},
}
client := NewSoraDirectClient(cfg, nil, provider)
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "sora-credential-token",
},
}
token, err := client.getAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "sora-credential-token", token)
require.Equal(t, int32(0), atomic.LoadInt32(&cache.getCalled))
}
func TestSoraDirectClient_GetAccessToken_SoraCanEnableProvider(t *testing.T) {
t.Parallel()
cache := newOpenAITokenCacheStub()
account := &Account{
ID: 2,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "sora-credential-token",
},
}
cache.tokens[OpenAITokenCacheKey(account)] = "provider-token"
provider := NewOpenAITokenProvider(nil, cache, nil)
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com/backend",
UseOpenAITokenProvider: true,
},
},
}
client := NewSoraDirectClient(cfg, nil, provider)
token, err := client.getAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "provider-token", token)
require.Greater(t, atomic.LoadInt32(&cache.getCalled), int32(0))
}
func TestSoraDirectClient_GetAccessToken_FromSessionToken(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=session-token")
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"accessToken": "session-access-token",
"expires": "2099-01-01T00:00:00Z",
})
}))
defer server.Close()
origin := soraSessionAuthURL
soraSessionAuthURL = server.URL
defer func() { soraSessionAuthURL = origin }()
client := NewSoraDirectClient(&config.Config{}, nil, nil)
account := &Account{
ID: 10,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"session_token": "session-token",
},
}
token, err := client.getAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "session-access-token", token)
require.Equal(t, "session-access-token", account.GetCredential("access_token"))
}
func TestSoraDirectClient_GetAccessToken_FromRefreshToken(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodPost, r.Method)
require.Equal(t, "/oauth/token", r.URL.Path)
require.Equal(t, "application/x-www-form-urlencoded", r.Header.Get("Content-Type"))
require.NoError(t, r.ParseForm())
require.Equal(t, "refresh_token", r.FormValue("grant_type"))
require.Equal(t, "refresh-token-old", r.FormValue("refresh_token"))
require.NotEmpty(t, r.FormValue("client_id"))
require.Equal(t, "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback", r.FormValue("redirect_uri"))
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"access_token": "refresh-access-token",
"refresh_token": "refresh-token-new",
"expires_in": 3600,
})
}))
defer server.Close()
origin := soraOAuthTokenURL
soraOAuthTokenURL = server.URL + "/oauth/token"
defer func() { soraOAuthTokenURL = origin }()
client := NewSoraDirectClient(&config.Config{}, nil, nil)
account := &Account{
ID: 11,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"refresh_token": "refresh-token-old",
},
}
token, err := client.getAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "refresh-access-token", token)
require.Equal(t, "refresh-token-new", account.GetCredential("refresh_token"))
require.NotNil(t, account.GetCredentialAsTime("expires_at"))
}
func TestSoraDirectClient_PreflightCheck_VideoQuotaExceeded(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
require.Equal(t, "/nf/check", r.URL.Path)
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"rate_limit_and_credit_balance": map[string]any{
"estimated_num_videos_remaining": 0,
"rate_limit_reached": true,
},
})
}))
defer server.Close()
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: server.URL,
},
},
}
client := NewSoraDirectClient(cfg, nil, nil)
account := &Account{
ID: 12,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "ok",
"expires_at": time.Now().Add(2 * time.Hour).Format(time.RFC3339),
},
}
err := client.PreflightCheck(context.Background(), account, "sora2-landscape-10s", SoraModelConfig{Type: "video"})
require.Error(t, err)
var upstreamErr *SoraUpstreamError
require.ErrorAs(t, err, &upstreamErr)
require.Equal(t, http.StatusTooManyRequests, upstreamErr.StatusCode)
}
func TestShouldAttemptSoraTokenRecover(t *testing.T) {
t.Parallel()
require.True(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://sora.chatgpt.com/backend/video_gen"))
require.True(t, shouldAttemptSoraTokenRecover(http.StatusForbidden, "https://chatgpt.com/backend/video_gen"))
require.False(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://sora.chatgpt.com/api/auth/session"))
require.False(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://auth.openai.com/oauth/token"))
require.False(t, shouldAttemptSoraTokenRecover(http.StatusTooManyRequests, "https://sora.chatgpt.com/backend/video_gen"))
}
type soraClientRequestCall struct {
Path string
UserAgent string
ProxyURL string
}
type soraClientRecordingUpstream struct {
calls []soraClientRequestCall
}
func (u *soraClientRecordingUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
return nil, errors.New("unexpected Do call")
}
func (u *soraClientRecordingUpstream) DoWithTLS(req *http.Request, proxyURL string, _ int64, _ int, _ bool) (*http.Response, error) {
u.calls = append(u.calls, soraClientRequestCall{
Path: req.URL.Path,
UserAgent: req.Header.Get("User-Agent"),
ProxyURL: proxyURL,
})
switch req.URL.Path {
case "/backend-api/sentinel/req":
return newSoraClientMockResponse(http.StatusOK, `{"token":"sentinel-token","turnstile":{"dx":"ok"}}`), nil
case "/backend/nf/create":
return newSoraClientMockResponse(http.StatusOK, `{"id":"task-123"}`), nil
case "/backend/nf/create/storyboard":
return newSoraClientMockResponse(http.StatusOK, `{"id":"storyboard-123"}`), nil
case "/backend/uploads":
return newSoraClientMockResponse(http.StatusOK, `{"id":"upload-123"}`), nil
case "/backend/nf/check":
return newSoraClientMockResponse(http.StatusOK, `{"rate_limit_and_credit_balance":{"estimated_num_videos_remaining":1,"rate_limit_reached":false}}`), nil
case "/backend/characters/upload":
return newSoraClientMockResponse(http.StatusOK, `{"id":"cameo-123"}`), nil
case "/backend/project_y/cameos/in_progress/cameo-123":
return newSoraClientMockResponse(http.StatusOK, `{"status":"finalized","status_message":"Completed","username_hint":"foo.bar","display_name_hint":"Bar","profile_asset_url":"https://example.com/avatar.webp"}`), nil
case "/backend/project_y/file/upload":
return newSoraClientMockResponse(http.StatusOK, `{"asset_pointer":"asset-123"}`), nil
case "/backend/characters/finalize":
return newSoraClientMockResponse(http.StatusOK, `{"character":{"character_id":"character-123"}}`), nil
case "/backend/project_y/post":
return newSoraClientMockResponse(http.StatusOK, `{"post":{"id":"s_post"}}`), nil
default:
return newSoraClientMockResponse(http.StatusOK, `{"ok":true}`), nil
}
}
func newSoraClientMockResponse(statusCode int, body string) *http.Response {
return &http.Response{
StatusCode: statusCode,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(body)),
}
}
func TestSoraDirectClient_TaskUserAgent_DefaultMobileFallback(t *testing.T) {
client := NewSoraDirectClient(&config.Config{}, nil, nil)
ua := client.taskUserAgent()
require.NotEmpty(t, ua)
allowed := append([]string{}, soraMobileUserAgents...)
allowed = append(allowed, soraDesktopUserAgents...)
require.Contains(t, allowed, ua)
}
func TestSoraDirectClient_CreateVideoTask_UsesSameUserAgentAndProxyForSentinelAndCreate(t *testing.T) {
originPowTokenGenerator := soraPowTokenGenerator
soraPowTokenGenerator = func(_ string) string { return "gAAAAACmock" }
defer func() {
soraPowTokenGenerator = originPowTokenGenerator
}()
upstream := &soraClientRecordingUpstream{}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com/backend",
},
},
}
client := NewSoraDirectClient(cfg, upstream, nil)
proxyID := int64(9)
account := &Account{
ID: 21,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
ProxyID: &proxyID,
Proxy: &Proxy{
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
},
Credentials: map[string]any{
"access_token": "access-token",
"expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
},
}
taskID, err := client.CreateVideoTask(context.Background(), account, SoraVideoRequest{Prompt: "test"})
require.NoError(t, err)
require.Equal(t, "task-123", taskID)
require.Len(t, upstream.calls, 2)
sentinelCall := upstream.calls[0]
createCall := upstream.calls[1]
require.Equal(t, "/backend-api/sentinel/req", sentinelCall.Path)
require.Equal(t, "/backend/nf/create", createCall.Path)
require.Equal(t, "http://127.0.0.1:8080", sentinelCall.ProxyURL)
require.Equal(t, sentinelCall.ProxyURL, createCall.ProxyURL)
require.NotEmpty(t, sentinelCall.UserAgent)
require.Equal(t, sentinelCall.UserAgent, createCall.UserAgent)
}
func TestSoraDirectClient_UploadImage_UsesTaskUserAgentAndProxy(t *testing.T) {
upstream := &soraClientRecordingUpstream{}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com/backend",
},
},
}
client := NewSoraDirectClient(cfg, upstream, nil)
proxyID := int64(3)
account := &Account{
ID: 31,
ProxyID: &proxyID,
Proxy: &Proxy{
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
},
Credentials: map[string]any{
"access_token": "access-token",
"expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
},
}
uploadID, err := client.UploadImage(context.Background(), account, []byte("mock-image"), "a.png")
require.NoError(t, err)
require.Equal(t, "upload-123", uploadID)
require.Len(t, upstream.calls, 1)
require.Equal(t, "/backend/uploads", upstream.calls[0].Path)
require.Equal(t, "http://127.0.0.1:8080", upstream.calls[0].ProxyURL)
require.NotEmpty(t, upstream.calls[0].UserAgent)
}
func TestSoraDirectClient_PreflightCheck_UsesTaskUserAgentAndProxy(t *testing.T) {
upstream := &soraClientRecordingUpstream{}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com/backend",
},
},
}
client := NewSoraDirectClient(cfg, upstream, nil)
proxyID := int64(7)
account := &Account{
ID: 41,
ProxyID: &proxyID,
Proxy: &Proxy{
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
},
Credentials: map[string]any{
"access_token": "access-token",
"expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
},
}
err := client.PreflightCheck(context.Background(), account, "sora2", SoraModelConfig{Type: "video"})
require.NoError(t, err)
require.Len(t, upstream.calls, 1)
require.Equal(t, "/backend/nf/check", upstream.calls[0].Path)
require.Equal(t, "http://127.0.0.1:8080", upstream.calls[0].ProxyURL)
require.NotEmpty(t, upstream.calls[0].UserAgent)
}
func TestSoraDirectClient_CreateStoryboardTask(t *testing.T) {
originPowTokenGenerator := soraPowTokenGenerator
soraPowTokenGenerator = func(_ string) string { return "gAAAAACmock" }
defer func() { soraPowTokenGenerator = originPowTokenGenerator }()
upstream := &soraClientRecordingUpstream{}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com/backend",
},
},
}
client := NewSoraDirectClient(cfg, upstream, nil)
account := &Account{
ID: 51,
Credentials: map[string]any{
"access_token": "access-token",
"expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
},
}
taskID, err := client.CreateStoryboardTask(context.Background(), account, SoraStoryboardRequest{
Prompt: "Shot 1:\nduration: 5sec\nScene: cat",
})
require.NoError(t, err)
require.Equal(t, "storyboard-123", taskID)
require.Len(t, upstream.calls, 2)
require.Equal(t, "/backend-api/sentinel/req", upstream.calls[0].Path)
require.Equal(t, "/backend/nf/create/storyboard", upstream.calls[1].Path)
}
func TestSoraDirectClient_GetVideoTask_ReturnsGenerationID(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch r.URL.Path {
case "/nf/pending/v2":
_, _ = w.Write([]byte(`[]`))
case "/project_y/profile/drafts":
_, _ = w.Write([]byte(`{"items":[{"id":"gen_1","task_id":"task-1","kind":"video","downloadable_url":"https://example.com/v.mp4"}]}`))
default:
http.NotFound(w, r)
}
}))
defer server.Close()
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: server.URL,
},
},
}
client := NewSoraDirectClient(cfg, nil, nil)
account := &Account{Credentials: map[string]any{"access_token": "token"}}
status, err := client.GetVideoTask(context.Background(), account, "task-1")
require.NoError(t, err)
require.Equal(t, "completed", status.Status)
require.Equal(t, "gen_1", status.GenerationID)
require.Equal(t, []string{"https://example.com/v.mp4"}, status.URLs)
}
func TestSoraDirectClient_PostVideoForWatermarkFree(t *testing.T) {
originPowTokenGenerator := soraPowTokenGenerator
soraPowTokenGenerator = func(_ string) string { return "gAAAAACmock" }
defer func() { soraPowTokenGenerator = originPowTokenGenerator }()
upstream := &soraClientRecordingUpstream{}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com/backend",
},
},
}
client := NewSoraDirectClient(cfg, upstream, nil)
account := &Account{
ID: 52,
Credentials: map[string]any{
"access_token": "access-token",
"expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
},
}
postID, err := client.PostVideoForWatermarkFree(context.Background(), account, "gen_1")
require.NoError(t, err)
require.Equal(t, "s_post", postID)
require.Len(t, upstream.calls, 2)
require.Equal(t, "/backend-api/sentinel/req", upstream.calls[0].Path)
require.Equal(t, "/backend/project_y/post", upstream.calls[1].Path)
}
type soraClientFallbackUpstream struct {
doWithTLSCalls int32
respBody string
respStatusCode int
err error
}
func (u *soraClientFallbackUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
return nil, errors.New("unexpected Do call")
}
func (u *soraClientFallbackUpstream) DoWithTLS(_ *http.Request, _ string, _ int64, _ int, _ bool) (*http.Response, error) {
atomic.AddInt32(&u.doWithTLSCalls, 1)
if u.err != nil {
return nil, u.err
}
statusCode := u.respStatusCode
if statusCode <= 0 {
statusCode = http.StatusOK
}
body := u.respBody
if body == "" {
body = `{"ok":true}`
}
return newSoraClientMockResponse(statusCode, body), nil
}
func TestSoraDirectClient_DoHTTP_UsesCurlCFFISidecarWhenEnabled(t *testing.T) {
var captured soraCurlCFFISidecarRequest
sidecar := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodPost, r.Method)
require.Equal(t, "/request", r.URL.Path)
raw, err := io.ReadAll(r.Body)
require.NoError(t, err)
require.NoError(t, json.Unmarshal(raw, &captured))
_ = json.NewEncoder(w).Encode(map[string]any{
"status_code": http.StatusOK,
"headers": map[string]any{
"Content-Type": "application/json",
"X-Sidecar": []string{"yes"},
},
"body_base64": base64.StdEncoding.EncodeToString([]byte(`{"ok":true}`)),
})
}))
defer sidecar.Close()
upstream := &soraClientFallbackUpstream{}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com/backend",
CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{
Enabled: true,
BaseURL: sidecar.URL,
Impersonate: "chrome131",
TimeoutSeconds: 15,
SessionReuseEnabled: true,
},
},
},
}
client := NewSoraDirectClient(cfg, upstream, nil)
req, err := http.NewRequest(http.MethodPost, "https://sora.chatgpt.com/backend/me", strings.NewReader("hello-sidecar"))
require.NoError(t, err)
req.Header.Set("User-Agent", "test-ua")
resp, err := client.doHTTP(req, "http://127.0.0.1:18080", &Account{ID: 1})
require.NoError(t, err)
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.JSONEq(t, `{"ok":true}`, string(body))
require.Equal(t, int32(0), atomic.LoadInt32(&upstream.doWithTLSCalls))
require.Equal(t, "http://127.0.0.1:18080", captured.ProxyURL)
require.NotEmpty(t, captured.SessionKey)
require.Equal(t, "chrome131", captured.Impersonate)
require.Equal(t, "https://sora.chatgpt.com/backend/me", captured.URL)
decodedReqBody, err := base64.StdEncoding.DecodeString(captured.BodyBase64)
require.NoError(t, err)
require.Equal(t, "hello-sidecar", string(decodedReqBody))
}
func TestSoraDirectClient_DoHTTP_CurlCFFISidecarFailureReturnsError(t *testing.T) {
sidecar := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadGateway)
_, _ = w.Write([]byte(`{"error":"boom"}`))
}))
defer sidecar.Close()
upstream := &soraClientFallbackUpstream{respBody: `{"fallback":true}`}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com/backend",
CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{
Enabled: true,
BaseURL: sidecar.URL,
},
},
},
}
client := NewSoraDirectClient(cfg, upstream, nil)
req, err := http.NewRequest(http.MethodGet, "https://sora.chatgpt.com/backend/me", nil)
require.NoError(t, err)
_, err = client.doHTTP(req, "", &Account{ID: 2})
require.Error(t, err)
require.Contains(t, err.Error(), "sora curl_cffi sidecar")
require.Equal(t, int32(0), atomic.LoadInt32(&upstream.doWithTLSCalls))
}
func TestSoraDirectClient_DoHTTP_CurlCFFISidecarDisabledUsesLegacyStack(t *testing.T) {
upstream := &soraClientFallbackUpstream{respBody: `{"legacy":true}`}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com/backend",
CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{
Enabled: false,
BaseURL: "http://127.0.0.1:18080",
},
},
},
}
client := NewSoraDirectClient(cfg, upstream, nil)
req, err := http.NewRequest(http.MethodGet, "https://sora.chatgpt.com/backend/me", nil)
require.NoError(t, err)
resp, err := client.doHTTP(req, "", &Account{ID: 3})
require.NoError(t, err)
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.JSONEq(t, `{"legacy":true}`, string(body))
require.Equal(t, int32(1), atomic.LoadInt32(&upstream.doWithTLSCalls))
}
func TestConvertSidecarHeaderValue_NilAndSlice(t *testing.T) {
require.Nil(t, convertSidecarHeaderValue(nil))
require.Equal(t, []string{"a", "b"}, convertSidecarHeaderValue([]any{"a", " ", "b"}))
}
func TestSoraDirectClient_DoHTTP_SidecarSessionKeyStableForSameAccountProxy(t *testing.T) {
var captured []soraCurlCFFISidecarRequest
sidecar := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
raw, err := io.ReadAll(r.Body)
require.NoError(t, err)
var reqPayload soraCurlCFFISidecarRequest
require.NoError(t, json.Unmarshal(raw, &reqPayload))
captured = append(captured, reqPayload)
_ = json.NewEncoder(w).Encode(map[string]any{
"status_code": http.StatusOK,
"headers": map[string]any{
"Content-Type": "application/json",
},
"body": `{"ok":true}`,
})
}))
defer sidecar.Close()
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com/backend",
CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{
Enabled: true,
BaseURL: sidecar.URL,
SessionReuseEnabled: true,
SessionTTLSeconds: 3600,
},
},
},
}
client := NewSoraDirectClient(cfg, nil, nil)
account := &Account{ID: 1001}
req1, err := http.NewRequest(http.MethodGet, "https://sora.chatgpt.com/backend/me", nil)
require.NoError(t, err)
_, err = client.doHTTP(req1, "http://127.0.0.1:18080", account)
require.NoError(t, err)
req2, err := http.NewRequest(http.MethodGet, "https://sora.chatgpt.com/backend/me", nil)
require.NoError(t, err)
_, err = client.doHTTP(req2, "http://127.0.0.1:18080", account)
require.NoError(t, err)
require.Len(t, captured, 2)
require.NotEmpty(t, captured[0].SessionKey)
require.Equal(t, captured[0].SessionKey, captured[1].SessionKey)
}
func TestSoraDirectClient_DoRequestWithProxy_CloudflareChallengeSetsCooldownAfterSingleRetry(t *testing.T) {
var sidecarCalls int32
sidecar := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&sidecarCalls, 1)
_ = json.NewEncoder(w).Encode(map[string]any{
"status_code": http.StatusForbidden,
"headers": map[string]any{
"cf-ray": "9d05d73dec4d8c8e-GRU",
"content-type": "text/html",
},
"body": `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={};</script></body></html>`,
})
}))
defer sidecar.Close()
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com/backend",
MaxRetries: 3,
CloudflareChallengeCooldownSeconds: 60,
CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{
Enabled: true,
BaseURL: sidecar.URL,
Impersonate: "chrome131",
},
},
},
}
client := NewSoraDirectClient(cfg, nil, nil)
headers := http.Header{}
_, _, err := client.doRequestWithProxy(
context.Background(),
&Account{ID: 99},
"http://127.0.0.1:18080",
http.MethodGet,
"https://sora.chatgpt.com/backend/me",
headers,
nil,
true,
)
require.Error(t, err)
var upstreamErr *SoraUpstreamError
require.ErrorAs(t, err, &upstreamErr)
require.Equal(t, http.StatusForbidden, upstreamErr.StatusCode)
require.Equal(t, int32(2), atomic.LoadInt32(&sidecarCalls), "challenge should trigger exactly one same-proxy retry")
_, _, err = client.doRequestWithProxy(
context.Background(),
&Account{ID: 99},
"http://127.0.0.1:18080",
http.MethodGet,
"https://sora.chatgpt.com/backend/me",
headers,
nil,
true,
)
require.Error(t, err)
require.ErrorAs(t, err, &upstreamErr)
require.Equal(t, http.StatusTooManyRequests, upstreamErr.StatusCode)
require.Contains(t, upstreamErr.Message, "cooling down")
require.Contains(t, upstreamErr.Message, "cf-ray")
require.Equal(t, int32(2), atomic.LoadInt32(&sidecarCalls), "cooldown should block outbound request")
}
func TestSoraDirectClient_DoRequestWithProxy_CloudflareRetrySuccessClearsCooldown(t *testing.T) {
var sidecarCalls int32
sidecar := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
call := atomic.AddInt32(&sidecarCalls, 1)
if call == 1 {
_ = json.NewEncoder(w).Encode(map[string]any{
"status_code": http.StatusForbidden,
"headers": map[string]any{
"cf-ray": "9d05d73dec4d8c8e-GRU",
"content-type": "text/html",
},
"body": `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={};</script></body></html>`,
})
return
}
_ = json.NewEncoder(w).Encode(map[string]any{
"status_code": http.StatusOK,
"headers": map[string]any{
"content-type": "application/json",
},
"body": `{"ok":true}`,
})
}))
defer sidecar.Close()
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com/backend",
MaxRetries: 3,
CloudflareChallengeCooldownSeconds: 60,
CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{
Enabled: true,
BaseURL: sidecar.URL,
Impersonate: "chrome131",
},
},
},
}
client := NewSoraDirectClient(cfg, nil, nil)
headers := http.Header{}
account := &Account{ID: 109}
proxyURL := "http://127.0.0.1:18080"
body, _, err := client.doRequestWithProxy(
context.Background(),
account,
proxyURL,
http.MethodGet,
"https://sora.chatgpt.com/backend/me",
headers,
nil,
true,
)
require.NoError(t, err)
require.Contains(t, string(body), `"ok":true`)
require.Equal(t, int32(2), atomic.LoadInt32(&sidecarCalls))
_, _, err = client.doRequestWithProxy(
context.Background(),
account,
proxyURL,
http.MethodGet,
"https://sora.chatgpt.com/backend/me",
headers,
nil,
true,
)
require.NoError(t, err)
require.Equal(t, int32(3), atomic.LoadInt32(&sidecarCalls), "cooldown should be cleared after retry succeeds")
}
func TestSoraComputeChallengeCooldownSeconds(t *testing.T) {
require.Equal(t, 0, soraComputeChallengeCooldownSeconds(0, 3))
require.Equal(t, 10, soraComputeChallengeCooldownSeconds(10, 1))
require.Equal(t, 20, soraComputeChallengeCooldownSeconds(10, 2))
require.Equal(t, 40, soraComputeChallengeCooldownSeconds(10, 4))
require.Equal(t, 40, soraComputeChallengeCooldownSeconds(10, 9), "streak should cap at x4")
require.Equal(t, 3600, soraComputeChallengeCooldownSeconds(1200, 9), "cooldown should cap at 3600s")
}
func TestSoraDirectClient_RecordCloudflareChallengeCooldown_EscalatesStreak(t *testing.T) {
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
CloudflareChallengeCooldownSeconds: 10,
},
},
}
client := NewSoraDirectClient(cfg, nil, nil)
account := &Account{ID: 201}
proxyURL := "http://127.0.0.1:18080"
client.recordCloudflareChallengeCooldown(account, proxyURL, http.StatusForbidden, http.Header{"Cf-Ray": []string{"9d05d73dec4d8c8e-GRU"}}, nil)
client.recordCloudflareChallengeCooldown(account, proxyURL, http.StatusForbidden, http.Header{"Cf-Ray": []string{"9d05d73dec4d8c8f-GRU"}}, nil)
key := soraAccountProxyKey(account, proxyURL)
entry, ok := client.challengeCooldowns[key]
require.True(t, ok)
require.Equal(t, 2, entry.ConsecutiveChallenges)
require.Equal(t, "9d05d73dec4d8c8f-GRU", entry.CFRay)
remain := int(entry.Until.Sub(entry.LastChallengeAt).Seconds())
require.GreaterOrEqual(t, remain, 19)
}
func TestSoraDirectClient_SidecarSessionKey_SkipsWhenAccountMissing(t *testing.T) {
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{
Enabled: true,
SessionReuseEnabled: true,
SessionTTLSeconds: 3600,
},
},
},
}
client := NewSoraDirectClient(cfg, nil, nil)
require.Equal(t, "", client.sidecarSessionKey(nil, "http://127.0.0.1:18080"))
require.Empty(t, client.sidecarSessions)
}
func TestSoraDirectClient_SidecarSessionKey_PrunesExpiredAndRecreates(t *testing.T) {
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{
Enabled: true,
SessionReuseEnabled: true,
SessionTTLSeconds: 3600,
},
},
},
}
client := NewSoraDirectClient(cfg, nil, nil)
account := &Account{ID: 123}
key := soraAccountProxyKey(account, "http://127.0.0.1:18080")
client.sidecarSessions[key] = soraSidecarSessionEntry{
SessionKey: "sora-expired",
ExpiresAt: time.Now().Add(-time.Minute),
LastUsedAt: time.Now().Add(-2 * time.Minute),
}
sessionKey := client.sidecarSessionKey(account, "http://127.0.0.1:18080")
require.NotEmpty(t, sessionKey)
require.NotEqual(t, "sora-expired", sessionKey)
require.Len(t, client.sidecarSessions, 1)
}
func TestSoraDirectClient_SidecarSessionKey_TTLZeroKeepsLongLivedSession(t *testing.T) {
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{
Enabled: true,
SessionReuseEnabled: true,
SessionTTLSeconds: 0,
},
},
},
}
client := NewSoraDirectClient(cfg, nil, nil)
account := &Account{ID: 456}
first := client.sidecarSessionKey(account, "http://127.0.0.1:18080")
second := client.sidecarSessionKey(account, "http://127.0.0.1:18080")
require.NotEmpty(t, first)
require.Equal(t, first, second)
key := soraAccountProxyKey(account, "http://127.0.0.1:18080")
entry, ok := client.sidecarSessions[key]
require.True(t, ok)
require.True(t, entry.ExpiresAt.After(time.Now().Add(300*24*time.Hour)))
}
package service
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
)
const soraCurlCFFISidecarDefaultTimeoutSeconds = 60
type soraCurlCFFISidecarRequest struct {
Method string `json:"method"`
URL string `json:"url"`
Headers map[string][]string `json:"headers,omitempty"`
BodyBase64 string `json:"body_base64,omitempty"`
ProxyURL string `json:"proxy_url,omitempty"`
SessionKey string `json:"session_key,omitempty"`
Impersonate string `json:"impersonate,omitempty"`
TimeoutSeconds int `json:"timeout_seconds,omitempty"`
}
type soraCurlCFFISidecarResponse struct {
StatusCode int `json:"status_code"`
Status int `json:"status"`
Headers map[string]any `json:"headers"`
BodyBase64 string `json:"body_base64"`
Body string `json:"body"`
Error string `json:"error"`
}
func (c *SoraDirectClient) doHTTPViaCurlCFFISidecar(req *http.Request, proxyURL string, account *Account) (*http.Response, error) {
if req == nil || req.URL == nil {
return nil, errors.New("request url is nil")
}
if c == nil || c.cfg == nil {
return nil, errors.New("sora curl_cffi sidecar config is nil")
}
if !c.cfg.Sora.Client.CurlCFFISidecar.Enabled {
return nil, errors.New("sora curl_cffi sidecar is disabled")
}
endpoint := c.curlCFFISidecarEndpoint()
if endpoint == "" {
return nil, errors.New("sora curl_cffi sidecar base_url is empty")
}
bodyBytes, err := readAndRestoreRequestBody(req)
if err != nil {
return nil, fmt.Errorf("sora curl_cffi sidecar read request body failed: %w", err)
}
headers := make(map[string][]string, len(req.Header)+1)
for key, vals := range req.Header {
copied := make([]string, len(vals))
copy(copied, vals)
headers[key] = copied
}
if strings.TrimSpace(req.Host) != "" {
if _, ok := headers["Host"]; !ok {
headers["Host"] = []string{req.Host}
}
}
payload := soraCurlCFFISidecarRequest{
Method: req.Method,
URL: req.URL.String(),
Headers: headers,
ProxyURL: strings.TrimSpace(proxyURL),
SessionKey: c.sidecarSessionKey(account, proxyURL),
Impersonate: c.curlCFFIImpersonate(),
TimeoutSeconds: c.curlCFFISidecarTimeoutSeconds(),
}
if len(bodyBytes) > 0 {
payload.BodyBase64 = base64.StdEncoding.EncodeToString(bodyBytes)
}
encoded, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("sora curl_cffi sidecar marshal request failed: %w", err)
}
sidecarReq, err := http.NewRequestWithContext(req.Context(), http.MethodPost, endpoint, bytes.NewReader(encoded))
if err != nil {
return nil, fmt.Errorf("sora curl_cffi sidecar build request failed: %w", err)
}
sidecarReq.Header.Set("Content-Type", "application/json")
sidecarReq.Header.Set("Accept", "application/json")
httpClient := &http.Client{Timeout: time.Duration(payload.TimeoutSeconds) * time.Second}
sidecarResp, err := httpClient.Do(sidecarReq)
if err != nil {
return nil, fmt.Errorf("sora curl_cffi sidecar request failed: %w", err)
}
defer func() {
_ = sidecarResp.Body.Close()
}()
sidecarRespBody, err := io.ReadAll(io.LimitReader(sidecarResp.Body, 8<<20))
if err != nil {
return nil, fmt.Errorf("sora curl_cffi sidecar read response failed: %w", err)
}
if sidecarResp.StatusCode != http.StatusOK {
redacted := truncateForLog([]byte(logredact.RedactText(string(sidecarRespBody))), 512)
return nil, fmt.Errorf("sora curl_cffi sidecar http status=%d body=%s", sidecarResp.StatusCode, redacted)
}
var payloadResp soraCurlCFFISidecarResponse
if err := json.Unmarshal(sidecarRespBody, &payloadResp); err != nil {
return nil, fmt.Errorf("sora curl_cffi sidecar parse response failed: %w", err)
}
if msg := strings.TrimSpace(payloadResp.Error); msg != "" {
return nil, fmt.Errorf("sora curl_cffi sidecar upstream error: %s", msg)
}
statusCode := payloadResp.StatusCode
if statusCode <= 0 {
statusCode = payloadResp.Status
}
if statusCode <= 0 {
return nil, errors.New("sora curl_cffi sidecar response missing status code")
}
responseBody := []byte(payloadResp.Body)
if strings.TrimSpace(payloadResp.BodyBase64) != "" {
decoded, err := base64.StdEncoding.DecodeString(payloadResp.BodyBase64)
if err != nil {
return nil, fmt.Errorf("sora curl_cffi sidecar decode body failed: %w", err)
}
responseBody = decoded
}
respHeaders := make(http.Header)
for key, rawVal := range payloadResp.Headers {
for _, v := range convertSidecarHeaderValue(rawVal) {
respHeaders.Add(key, v)
}
}
return &http.Response{
StatusCode: statusCode,
Header: respHeaders,
Body: io.NopCloser(bytes.NewReader(responseBody)),
ContentLength: int64(len(responseBody)),
Request: req,
}, nil
}
func readAndRestoreRequestBody(req *http.Request) ([]byte, error) {
if req == nil || req.Body == nil {
return nil, nil
}
bodyBytes, err := io.ReadAll(req.Body)
if err != nil {
return nil, err
}
_ = req.Body.Close()
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.ContentLength = int64(len(bodyBytes))
return bodyBytes, nil
}
func (c *SoraDirectClient) curlCFFISidecarEndpoint() string {
if c == nil || c.cfg == nil {
return ""
}
raw := strings.TrimSpace(c.cfg.Sora.Client.CurlCFFISidecar.BaseURL)
if raw == "" {
return ""
}
parsed, err := url.Parse(raw)
if err != nil || strings.TrimSpace(parsed.Scheme) == "" || strings.TrimSpace(parsed.Host) == "" {
return raw
}
if path := strings.TrimSpace(parsed.Path); path == "" || path == "/" {
parsed.Path = "/request"
}
return parsed.String()
}
func (c *SoraDirectClient) curlCFFISidecarTimeoutSeconds() int {
if c == nil || c.cfg == nil {
return soraCurlCFFISidecarDefaultTimeoutSeconds
}
timeoutSeconds := c.cfg.Sora.Client.CurlCFFISidecar.TimeoutSeconds
if timeoutSeconds <= 0 {
return soraCurlCFFISidecarDefaultTimeoutSeconds
}
return timeoutSeconds
}
func (c *SoraDirectClient) curlCFFIImpersonate() string {
if c == nil || c.cfg == nil {
return "chrome131"
}
impersonate := strings.TrimSpace(c.cfg.Sora.Client.CurlCFFISidecar.Impersonate)
if impersonate == "" {
return "chrome131"
}
return impersonate
}
func (c *SoraDirectClient) sidecarSessionReuseEnabled() bool {
if c == nil || c.cfg == nil {
return true
}
return c.cfg.Sora.Client.CurlCFFISidecar.SessionReuseEnabled
}
func (c *SoraDirectClient) sidecarSessionTTLSeconds() int {
if c == nil || c.cfg == nil {
return 3600
}
ttl := c.cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds
if ttl < 0 {
return 3600
}
return ttl
}
func convertSidecarHeaderValue(raw any) []string {
switch val := raw.(type) {
case nil:
return nil
case string:
if strings.TrimSpace(val) == "" {
return nil
}
return []string{val}
case []any:
out := make([]string, 0, len(val))
for _, item := range val {
s := strings.TrimSpace(fmt.Sprint(item))
if s != "" {
out = append(out, s)
}
}
return out
case []string:
out := make([]string, 0, len(val))
for _, item := range val {
if strings.TrimSpace(item) != "" {
out = append(out, item)
}
}
return out
default:
s := strings.TrimSpace(fmt.Sprint(val))
if s == "" {
return nil
}
return []string{s}
}
}
......@@ -8,10 +8,12 @@ import (
"fmt"
"io"
"log"
"math"
"mime"
"net"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"time"
......@@ -23,6 +25,9 @@ import (
const soraImageInputMaxBytes = 20 << 20
const soraImageInputMaxRedirects = 3
const soraImageInputTimeout = 20 * time.Second
const soraVideoInputMaxBytes = 200 << 20
const soraVideoInputMaxRedirects = 3
const soraVideoInputTimeout = 60 * time.Second
var soraImageSizeMap = map[string]string{
"gpt-image": "360",
......@@ -61,6 +66,36 @@ type SoraGatewayService struct {
cfg *config.Config
}
type soraWatermarkOptions struct {
Enabled bool
ParseMethod string
ParseURL string
ParseToken string
FallbackOnFailure bool
DeletePost bool
}
type soraCharacterOptions struct {
SetPublic bool
DeleteAfterGenerate bool
}
type soraCharacterFlowResult struct {
CameoID string
CharacterID string
Username string
DisplayName string
}
var soraStoryboardPattern = regexp.MustCompile(`\[\d+(?:\.\d+)?s\]`)
var soraStoryboardShotPattern = regexp.MustCompile(`\[(\d+(?:\.\d+)?)s\]\s*([^\[]+)`)
var soraRemixTargetPattern = regexp.MustCompile(`s_[a-f0-9]{32}`)
var soraRemixTargetInURLPattern = regexp.MustCompile(`https://sora\.chatgpt\.com/p/s_[a-f0-9]{32}`)
type soraPreflightChecker interface {
PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error
}
func NewSoraGatewayService(
soraClient SoraClient,
mediaStorage *SoraMediaStorage,
......@@ -112,29 +147,133 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Unsupported Sora model", clientStream)
return nil, fmt.Errorf("unsupported model: %s", reqModel)
}
if modelCfg.Type == "prompt_enhance" {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Prompt-enhance 模型暂未支持", clientStream)
return nil, fmt.Errorf("prompt-enhance not supported")
}
prompt, imageInput, videoInput, remixTargetID := extractSoraInput(reqBody)
if strings.TrimSpace(prompt) == "" {
prompt = strings.TrimSpace(prompt)
imageInput = strings.TrimSpace(imageInput)
videoInput = strings.TrimSpace(videoInput)
remixTargetID = strings.TrimSpace(remixTargetID)
if videoInput != "" && modelCfg.Type != "video" {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "video input only supports video models", clientStream)
return nil, errors.New("video input only supports video models")
}
if videoInput != "" && imageInput != "" {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "image input and video input cannot be used together", clientStream)
return nil, errors.New("image input and video input cannot be used together")
}
characterOnly := videoInput != "" && prompt == ""
if modelCfg.Type == "prompt_enhance" && prompt == "" {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream)
return nil, errors.New("prompt is required")
}
if strings.TrimSpace(videoInput) != "" {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Video input is not supported yet", clientStream)
return nil, errors.New("video input not supported")
if modelCfg.Type != "prompt_enhance" && prompt == "" && !characterOnly {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream)
return nil, errors.New("prompt is required")
}
reqCtx, cancel := s.withSoraTimeout(ctx, reqStream)
if cancel != nil {
defer cancel()
}
if checker, ok := s.soraClient.(soraPreflightChecker); ok && !characterOnly {
if err := checker.PreflightCheck(reqCtx, account, reqModel, modelCfg); err != nil {
return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
}
}
if modelCfg.Type == "prompt_enhance" {
enhancedPrompt, err := s.soraClient.EnhancePrompt(reqCtx, account, prompt, modelCfg.ExpansionLevel, modelCfg.DurationS)
if err != nil {
return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
}
content := strings.TrimSpace(enhancedPrompt)
if content == "" {
content = prompt
}
var firstTokenMs *int
if clientStream {
ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime)
if streamErr != nil {
return nil, streamErr
}
firstTokenMs = ms
} else if c != nil {
c.JSON(http.StatusOK, buildSoraNonStreamResponse(content, reqModel))
}
return &ForwardResult{
RequestID: "",
Model: reqModel,
Stream: clientStream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
Usage: ClaudeUsage{},
MediaType: "prompt",
}, nil
}
characterOpts := parseSoraCharacterOptions(reqBody)
watermarkOpts := parseSoraWatermarkOptions(reqBody)
var characterResult *soraCharacterFlowResult
if videoInput != "" {
videoData, videoErr := decodeSoraVideoInput(reqCtx, videoInput)
if videoErr != nil {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", videoErr.Error(), clientStream)
return nil, videoErr
}
characterResult, videoErr = s.createCharacterFromVideo(reqCtx, account, videoData, characterOpts)
if videoErr != nil {
return nil, s.handleSoraRequestError(ctx, account, videoErr, reqModel, c, clientStream)
}
if characterResult != nil && characterOpts.DeleteAfterGenerate && strings.TrimSpace(characterResult.CharacterID) != "" && !characterOnly {
characterID := strings.TrimSpace(characterResult.CharacterID)
defer func() {
cleanupCtx, cancelCleanup := context.WithTimeout(context.Background(), 15*time.Second)
defer cancelCleanup()
if err := s.soraClient.DeleteCharacter(cleanupCtx, account, characterID); err != nil {
log.Printf("[Sora] cleanup character failed, character_id=%s err=%v", characterID, err)
}
}()
}
if characterOnly {
content := "角色创建成功"
if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" {
content = fmt.Sprintf("角色创建成功,角色名@%s", strings.TrimSpace(characterResult.Username))
}
var firstTokenMs *int
if clientStream {
ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime)
if streamErr != nil {
return nil, streamErr
}
firstTokenMs = ms
} else if c != nil {
resp := buildSoraNonStreamResponse(content, reqModel)
if characterResult != nil {
resp["character_id"] = characterResult.CharacterID
resp["cameo_id"] = characterResult.CameoID
resp["character_username"] = characterResult.Username
resp["character_display_name"] = characterResult.DisplayName
}
c.JSON(http.StatusOK, resp)
}
return &ForwardResult{
RequestID: "",
Model: reqModel,
Stream: clientStream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
Usage: ClaudeUsage{},
MediaType: "prompt",
}, nil
}
if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" {
prompt = fmt.Sprintf("@%s %s", characterResult.Username, prompt)
}
}
var imageData []byte
imageFilename := ""
if strings.TrimSpace(imageInput) != "" {
if imageInput != "" {
decoded, filename, err := decodeSoraImageInput(reqCtx, imageInput)
if err != nil {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", err.Error(), clientStream)
......@@ -164,15 +303,27 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
MediaID: mediaID,
})
case "video":
taskID, err = s.soraClient.CreateVideoTask(reqCtx, account, SoraVideoRequest{
Prompt: prompt,
Orientation: modelCfg.Orientation,
Frames: modelCfg.Frames,
Model: modelCfg.Model,
Size: modelCfg.Size,
MediaID: mediaID,
RemixTargetID: remixTargetID,
})
if remixTargetID == "" && isSoraStoryboardPrompt(prompt) {
taskID, err = s.soraClient.CreateStoryboardTask(reqCtx, account, SoraStoryboardRequest{
Prompt: formatSoraStoryboardPrompt(prompt),
Orientation: modelCfg.Orientation,
Frames: modelCfg.Frames,
Model: modelCfg.Model,
Size: modelCfg.Size,
MediaID: mediaID,
})
} else {
taskID, err = s.soraClient.CreateVideoTask(reqCtx, account, SoraVideoRequest{
Prompt: prompt,
Orientation: modelCfg.Orientation,
Frames: modelCfg.Frames,
Model: modelCfg.Model,
Size: modelCfg.Size,
MediaID: mediaID,
RemixTargetID: remixTargetID,
CameoIDs: extractSoraCameoIDs(reqBody),
})
}
default:
err = fmt.Errorf("unsupported model type: %s", modelCfg.Type)
}
......@@ -185,6 +336,7 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
}
var mediaURLs []string
videoGenerationID := ""
mediaType := modelCfg.Type
imageCount := 0
imageSize := ""
......@@ -198,15 +350,32 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
imageCount = len(urls)
imageSize = soraImageSizeFromModel(reqModel)
case "video":
urls, pollErr := s.pollVideoTask(reqCtx, c, account, taskID, clientStream)
videoStatus, pollErr := s.pollVideoTaskDetailed(reqCtx, c, account, taskID, clientStream)
if pollErr != nil {
return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream)
}
mediaURLs = urls
if videoStatus != nil {
mediaURLs = videoStatus.URLs
videoGenerationID = strings.TrimSpace(videoStatus.GenerationID)
}
default:
mediaType = "prompt"
}
watermarkPostID := ""
if modelCfg.Type == "video" && watermarkOpts.Enabled {
watermarkURL, postID, watermarkErr := s.resolveWatermarkFreeURL(reqCtx, account, videoGenerationID, watermarkOpts)
if watermarkErr != nil {
if !watermarkOpts.FallbackOnFailure {
return nil, s.handleSoraRequestError(ctx, account, watermarkErr, reqModel, c, clientStream)
}
log.Printf("[Sora] watermark-free fallback to original URL, task_id=%s err=%v", taskID, watermarkErr)
} else if strings.TrimSpace(watermarkURL) != "" {
mediaURLs = []string{strings.TrimSpace(watermarkURL)}
watermarkPostID = strings.TrimSpace(postID)
}
}
finalURLs := s.normalizeSoraMediaURLs(mediaURLs)
if len(mediaURLs) > 0 && s.mediaStorage != nil && s.mediaStorage.Enabled() {
stored, storeErr := s.mediaStorage.StoreFromURLs(reqCtx, mediaType, mediaURLs)
......@@ -217,6 +386,11 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
finalURLs = s.normalizeSoraMediaURLs(stored)
}
}
if watermarkPostID != "" && watermarkOpts.DeletePost {
if deleteErr := s.soraClient.DeletePost(reqCtx, account, watermarkPostID); deleteErr != nil {
log.Printf("[Sora] delete post failed, post_id=%s err=%v", watermarkPostID, deleteErr)
}
}
content := buildSoraContent(mediaType, finalURLs)
var firstTokenMs *int
......@@ -265,9 +439,270 @@ func (s *SoraGatewayService) withSoraTimeout(ctx context.Context, stream bool) (
return context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second)
}
func parseSoraWatermarkOptions(body map[string]any) soraWatermarkOptions {
opts := soraWatermarkOptions{
Enabled: parseBoolWithDefault(body, "watermark_free", false),
ParseMethod: strings.ToLower(strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_method", "third_party"))),
ParseURL: strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_url", "")),
ParseToken: strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_token", "")),
FallbackOnFailure: parseBoolWithDefault(body, "watermark_fallback_on_failure", true),
DeletePost: parseBoolWithDefault(body, "watermark_delete_post", false),
}
if opts.ParseMethod == "" {
opts.ParseMethod = "third_party"
}
return opts
}
func parseSoraCharacterOptions(body map[string]any) soraCharacterOptions {
return soraCharacterOptions{
SetPublic: parseBoolWithDefault(body, "character_set_public", true),
DeleteAfterGenerate: parseBoolWithDefault(body, "character_delete_after_generate", true),
}
}
func parseBoolWithDefault(body map[string]any, key string, def bool) bool {
if body == nil {
return def
}
val, ok := body[key]
if !ok {
return def
}
switch typed := val.(type) {
case bool:
return typed
case int:
return typed != 0
case int32:
return typed != 0
case int64:
return typed != 0
case float64:
return typed != 0
case string:
typed = strings.ToLower(strings.TrimSpace(typed))
if typed == "true" || typed == "1" || typed == "yes" {
return true
}
if typed == "false" || typed == "0" || typed == "no" {
return false
}
}
return def
}
func parseStringWithDefault(body map[string]any, key, def string) string {
if body == nil {
return def
}
val, ok := body[key]
if !ok {
return def
}
if str, ok := val.(string); ok {
return str
}
return def
}
func extractSoraCameoIDs(body map[string]any) []string {
if body == nil {
return nil
}
raw, ok := body["cameo_ids"]
if !ok {
return nil
}
switch typed := raw.(type) {
case []string:
out := make([]string, 0, len(typed))
for _, item := range typed {
item = strings.TrimSpace(item)
if item != "" {
out = append(out, item)
}
}
return out
case []any:
out := make([]string, 0, len(typed))
for _, item := range typed {
str, ok := item.(string)
if !ok {
continue
}
str = strings.TrimSpace(str)
if str != "" {
out = append(out, str)
}
}
return out
default:
return nil
}
}
func (s *SoraGatewayService) createCharacterFromVideo(ctx context.Context, account *Account, videoData []byte, opts soraCharacterOptions) (*soraCharacterFlowResult, error) {
cameoID, err := s.soraClient.UploadCharacterVideo(ctx, account, videoData)
if err != nil {
return nil, err
}
cameoStatus, err := s.pollCameoStatus(ctx, account, cameoID)
if err != nil {
return nil, err
}
username := processSoraCharacterUsername(cameoStatus.UsernameHint)
displayName := strings.TrimSpace(cameoStatus.DisplayNameHint)
if displayName == "" {
displayName = "Character"
}
profileAssetURL := strings.TrimSpace(cameoStatus.ProfileAssetURL)
if profileAssetURL == "" {
return nil, errors.New("profile asset url not found in cameo status")
}
avatarData, err := s.soraClient.DownloadCharacterImage(ctx, account, profileAssetURL)
if err != nil {
return nil, err
}
assetPointer, err := s.soraClient.UploadCharacterImage(ctx, account, avatarData)
if err != nil {
return nil, err
}
instructionSet := cameoStatus.InstructionSetHint
if instructionSet == nil {
instructionSet = cameoStatus.InstructionSet
}
characterID, err := s.soraClient.FinalizeCharacter(ctx, account, SoraCharacterFinalizeRequest{
CameoID: strings.TrimSpace(cameoID),
Username: username,
DisplayName: displayName,
ProfileAssetPointer: assetPointer,
InstructionSet: instructionSet,
})
if err != nil {
return nil, err
}
if opts.SetPublic {
if err := s.soraClient.SetCharacterPublic(ctx, account, cameoID); err != nil {
return nil, err
}
}
return &soraCharacterFlowResult{
CameoID: strings.TrimSpace(cameoID),
CharacterID: strings.TrimSpace(characterID),
Username: strings.TrimSpace(username),
DisplayName: displayName,
}, nil
}
func (s *SoraGatewayService) pollCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) {
timeout := 10 * time.Minute
interval := 5 * time.Second
maxAttempts := int(math.Ceil(timeout.Seconds() / interval.Seconds()))
if maxAttempts < 1 {
maxAttempts = 1
}
var lastErr error
consecutiveErrors := 0
for attempt := 0; attempt < maxAttempts; attempt++ {
status, err := s.soraClient.GetCameoStatus(ctx, account, cameoID)
if err != nil {
lastErr = err
consecutiveErrors++
if consecutiveErrors >= 3 {
break
}
if attempt < maxAttempts-1 {
if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil {
return nil, sleepErr
}
}
continue
}
consecutiveErrors = 0
if status == nil {
if attempt < maxAttempts-1 {
if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil {
return nil, sleepErr
}
}
continue
}
currentStatus := strings.ToLower(strings.TrimSpace(status.Status))
statusMessage := strings.TrimSpace(status.StatusMessage)
if currentStatus == "failed" {
if statusMessage == "" {
statusMessage = "character creation failed"
}
return nil, errors.New(statusMessage)
}
if strings.EqualFold(statusMessage, "Completed") || currentStatus == "finalized" {
return status, nil
}
if attempt < maxAttempts-1 {
if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil {
return nil, sleepErr
}
}
}
if lastErr != nil {
return nil, fmt.Errorf("poll cameo status failed: %w", lastErr)
}
return nil, errors.New("cameo processing timeout")
}
func processSoraCharacterUsername(usernameHint string) string {
usernameHint = strings.TrimSpace(usernameHint)
if usernameHint == "" {
usernameHint = "character"
}
if strings.Contains(usernameHint, ".") {
parts := strings.Split(usernameHint, ".")
usernameHint = strings.TrimSpace(parts[len(parts)-1])
}
if usernameHint == "" {
usernameHint = "character"
}
return fmt.Sprintf("%s%d", usernameHint, soraRandInt(900)+100)
}
func (s *SoraGatewayService) resolveWatermarkFreeURL(ctx context.Context, account *Account, generationID string, opts soraWatermarkOptions) (string, string, error) {
generationID = strings.TrimSpace(generationID)
if generationID == "" {
return "", "", errors.New("generation id is required for watermark-free mode")
}
postID, err := s.soraClient.PostVideoForWatermarkFree(ctx, account, generationID)
if err != nil {
return "", "", err
}
postID = strings.TrimSpace(postID)
if postID == "" {
return "", "", errors.New("watermark-free publish returned empty post id")
}
switch opts.ParseMethod {
case "custom":
urlVal, parseErr := s.soraClient.GetWatermarkFreeURLCustom(ctx, account, opts.ParseURL, opts.ParseToken, postID)
if parseErr != nil {
return "", postID, parseErr
}
return strings.TrimSpace(urlVal), postID, nil
case "", "third_party":
return fmt.Sprintf("https://oscdn2.dyysy.com/MP4/%s.mp4", postID), postID, nil
default:
return "", postID, fmt.Errorf("unsupported watermark parse method: %s", opts.ParseMethod)
}
}
func (s *SoraGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
switch statusCode {
case 401, 402, 403, 429, 529:
case 401, 402, 403, 404, 429, 529:
return true
default:
return statusCode >= 500
......@@ -434,7 +869,18 @@ func (s *SoraGatewayService) writeSoraError(c *gin.Context, status int, errType,
}
if stream {
flusher, _ := c.Writer.(http.Flusher)
errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
errorData := map[string]any{
"error": map[string]string{
"type": errType,
"message": message,
},
}
jsonBytes, err := json.Marshal(errorData)
if err != nil {
_ = c.Error(err)
return
}
errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes))
_, _ = fmt.Fprint(c.Writer, errorEvent)
_, _ = fmt.Fprint(c.Writer, "data: [DONE]\n\n")
if flusher != nil {
......@@ -460,7 +906,15 @@ func (s *SoraGatewayService) handleSoraRequestError(ctx context.Context, account
s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body)
}
if s.shouldFailoverUpstreamError(upstreamErr.StatusCode) {
return &UpstreamFailoverError{StatusCode: upstreamErr.StatusCode}
var responseHeaders http.Header
if upstreamErr.Headers != nil {
responseHeaders = upstreamErr.Headers.Clone()
}
return &UpstreamFailoverError{
StatusCode: upstreamErr.StatusCode,
ResponseBody: upstreamErr.Body,
ResponseHeaders: responseHeaders,
}
}
msg := upstreamErr.Message
if override := soraProErrorMessage(model, msg); override != "" {
......@@ -505,7 +959,7 @@ func (s *SoraGatewayService) pollImageTask(ctx context.Context, c *gin.Context,
return nil, errors.New("sora image generation timeout")
}
func (s *SoraGatewayService) pollVideoTask(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) ([]string, error) {
func (s *SoraGatewayService) pollVideoTaskDetailed(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) (*SoraVideoTaskStatus, error) {
interval := s.pollInterval()
maxAttempts := s.pollMaxAttempts()
lastPing := time.Now()
......@@ -516,7 +970,7 @@ func (s *SoraGatewayService) pollVideoTask(ctx context.Context, c *gin.Context,
}
switch strings.ToLower(status.Status) {
case "completed", "succeeded":
return status.URLs, nil
return status, nil
case "failed":
if status.ErrorMsg != "" {
return nil, errors.New(status.ErrorMsg)
......@@ -620,7 +1074,7 @@ func extractSoraInput(body map[string]any) (prompt, imageInput, videoInput, remi
return "", "", "", ""
}
if v, ok := body["remix_target_id"].(string); ok {
remixTargetID = v
remixTargetID = strings.TrimSpace(v)
}
if v, ok := body["image"].(string); ok {
imageInput = v
......@@ -661,6 +1115,10 @@ func extractSoraInput(body map[string]any) (prompt, imageInput, videoInput, remi
prompt = builder.String()
}
}
if remixTargetID == "" {
remixTargetID = extractRemixTargetIDFromPrompt(prompt)
}
prompt = cleanRemixLinkFromPrompt(prompt)
return prompt, imageInput, videoInput, remixTargetID
}
......@@ -708,6 +1166,69 @@ func parseSoraMessageContent(content any) (text, imageInput, videoInput string)
}
}
func isSoraStoryboardPrompt(prompt string) bool {
prompt = strings.TrimSpace(prompt)
if prompt == "" {
return false
}
return len(soraStoryboardPattern.FindAllString(prompt, -1)) >= 1
}
func formatSoraStoryboardPrompt(prompt string) string {
prompt = strings.TrimSpace(prompt)
if prompt == "" {
return ""
}
matches := soraStoryboardShotPattern.FindAllStringSubmatch(prompt, -1)
if len(matches) == 0 {
return prompt
}
firstBracketPos := strings.Index(prompt, "[")
instructions := ""
if firstBracketPos > 0 {
instructions = strings.TrimSpace(prompt[:firstBracketPos])
}
shots := make([]string, 0, len(matches))
for i, match := range matches {
if len(match) < 3 {
continue
}
duration := strings.TrimSpace(match[1])
scene := strings.TrimSpace(match[2])
if scene == "" {
continue
}
shots = append(shots, fmt.Sprintf("Shot %d:\nduration: %ssec\nScene: %s", i+1, duration, scene))
}
if len(shots) == 0 {
return prompt
}
timeline := strings.Join(shots, "\n\n")
if instructions == "" {
return timeline
}
return fmt.Sprintf("current timeline:\n%s\n\ninstructions:\n%s", timeline, instructions)
}
func extractRemixTargetIDFromPrompt(prompt string) string {
prompt = strings.TrimSpace(prompt)
if prompt == "" {
return ""
}
return strings.TrimSpace(soraRemixTargetPattern.FindString(prompt))
}
func cleanRemixLinkFromPrompt(prompt string) string {
prompt = strings.TrimSpace(prompt)
if prompt == "" {
return prompt
}
cleaned := soraRemixTargetInURLPattern.ReplaceAllString(prompt, "")
cleaned = soraRemixTargetPattern.ReplaceAllString(cleaned, "")
cleaned = strings.Join(strings.Fields(cleaned), " ")
return strings.TrimSpace(cleaned)
}
func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, error) {
raw := strings.TrimSpace(input)
if raw == "" {
......@@ -720,7 +1241,7 @@ func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, er
}
meta := parts[0]
payload := parts[1]
decoded, err := base64.StdEncoding.DecodeString(payload)
decoded, err := decodeBase64WithLimit(payload, soraImageInputMaxBytes)
if err != nil {
return nil, "", err
}
......@@ -739,15 +1260,47 @@ func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, er
if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") {
return downloadSoraImageInput(ctx, raw)
}
decoded, err := base64.StdEncoding.DecodeString(raw)
decoded, err := decodeBase64WithLimit(raw, soraImageInputMaxBytes)
if err != nil {
return nil, "", errors.New("invalid base64 image")
}
return decoded, "image.png", nil
}
func decodeSoraVideoInput(ctx context.Context, input string) ([]byte, error) {
raw := strings.TrimSpace(input)
if raw == "" {
return nil, errors.New("empty video input")
}
if strings.HasPrefix(raw, "data:") {
parts := strings.SplitN(raw, ",", 2)
if len(parts) != 2 {
return nil, errors.New("invalid video data url")
}
decoded, err := decodeBase64WithLimit(parts[1], soraVideoInputMaxBytes)
if err != nil {
return nil, errors.New("invalid base64 video")
}
if len(decoded) == 0 {
return nil, errors.New("empty video data")
}
return decoded, nil
}
if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") {
return downloadSoraVideoInput(ctx, raw)
}
decoded, err := decodeBase64WithLimit(raw, soraVideoInputMaxBytes)
if err != nil {
return nil, errors.New("invalid base64 video")
}
if len(decoded) == 0 {
return nil, errors.New("empty video data")
}
return decoded, nil
}
func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string, error) {
parsed, err := validateSoraImageURL(rawURL)
parsed, err := validateSoraRemoteURL(rawURL)
if err != nil {
return nil, "", err
}
......@@ -761,7 +1314,7 @@ func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string,
if len(via) >= soraImageInputMaxRedirects {
return errors.New("too many redirects")
}
return validateSoraImageURLValue(req.URL)
return validateSoraRemoteURLValue(req.URL)
},
}
resp, err := client.Do(req)
......@@ -784,51 +1337,103 @@ func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string,
return data, filename, nil
}
func validateSoraImageURL(raw string) (*url.URL, error) {
func downloadSoraVideoInput(ctx context.Context, rawURL string) ([]byte, error) {
parsed, err := validateSoraRemoteURL(rawURL)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, parsed.String(), nil)
if err != nil {
return nil, err
}
client := &http.Client{
Timeout: soraVideoInputTimeout,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
if len(via) >= soraVideoInputMaxRedirects {
return errors.New("too many redirects")
}
return validateSoraRemoteURLValue(req.URL)
},
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("download video failed: %d", resp.StatusCode)
}
data, err := io.ReadAll(io.LimitReader(resp.Body, soraVideoInputMaxBytes))
if err != nil {
return nil, err
}
if len(data) == 0 {
return nil, errors.New("empty video content")
}
return data, nil
}
func decodeBase64WithLimit(encoded string, maxBytes int64) ([]byte, error) {
if maxBytes <= 0 {
return nil, errors.New("invalid max bytes limit")
}
decoder := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded))
limited := io.LimitReader(decoder, maxBytes+1)
data, err := io.ReadAll(limited)
if err != nil {
return nil, err
}
if int64(len(data)) > maxBytes {
return nil, fmt.Errorf("input exceeds %d bytes limit", maxBytes)
}
return data, nil
}
func validateSoraRemoteURL(raw string) (*url.URL, error) {
if strings.TrimSpace(raw) == "" {
return nil, errors.New("empty image url")
return nil, errors.New("empty remote url")
}
parsed, err := url.Parse(raw)
if err != nil {
return nil, fmt.Errorf("invalid image url: %w", err)
return nil, fmt.Errorf("invalid remote url: %w", err)
}
if err := validateSoraImageURLValue(parsed); err != nil {
if err := validateSoraRemoteURLValue(parsed); err != nil {
return nil, err
}
return parsed, nil
}
func validateSoraImageURLValue(parsed *url.URL) error {
func validateSoraRemoteURLValue(parsed *url.URL) error {
if parsed == nil {
return errors.New("invalid image url")
return errors.New("invalid remote url")
}
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
if scheme != "http" && scheme != "https" {
return errors.New("only http/https image url is allowed")
return errors.New("only http/https remote url is allowed")
}
if parsed.User != nil {
return errors.New("image url cannot contain userinfo")
return errors.New("remote url cannot contain userinfo")
}
host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
if host == "" {
return errors.New("image url missing host")
return errors.New("remote url missing host")
}
if _, blocked := soraBlockedHostnames[host]; blocked {
return errors.New("image url is not allowed")
return errors.New("remote url is not allowed")
}
if ip := net.ParseIP(host); ip != nil {
if isSoraBlockedIP(ip) {
return errors.New("image url is not allowed")
return errors.New("remote url is not allowed")
}
return nil
}
ips, err := net.LookupIP(host)
if err != nil {
return fmt.Errorf("resolve image url failed: %w", err)
return fmt.Errorf("resolve remote url failed: %w", err)
}
for _, ip := range ips {
if isSoraBlockedIP(ip) {
return errors.New("image url is not allowed")
return errors.New("remote url is not allowed")
}
}
return nil
......
......@@ -4,10 +4,16 @@ package service
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
......@@ -18,6 +24,13 @@ type stubSoraClientForPoll struct {
videoStatus *SoraVideoTaskStatus
imageCalls int
videoCalls int
enhanced string
enhanceErr error
storyboard bool
videoReq SoraVideoRequest
parseErr error
postCalls int
deleteCalls int
}
func (s *stubSoraClientForPoll) Enabled() bool { return true }
......@@ -28,8 +41,60 @@ func (s *stubSoraClientForPoll) CreateImageTask(ctx context.Context, account *Ac
return "task-image", nil
}
func (s *stubSoraClientForPoll) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) {
s.videoReq = req
return "task-video", nil
}
func (s *stubSoraClientForPoll) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) {
s.storyboard = true
return "task-video", nil
}
func (s *stubSoraClientForPoll) UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) {
return "cameo-1", nil
}
func (s *stubSoraClientForPoll) GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) {
return &SoraCameoStatus{
Status: "finalized",
StatusMessage: "Completed",
DisplayNameHint: "Character",
UsernameHint: "user.character",
ProfileAssetURL: "https://example.com/avatar.webp",
}, nil
}
func (s *stubSoraClientForPoll) DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) {
return []byte("avatar"), nil
}
func (s *stubSoraClientForPoll) UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) {
return "asset-pointer", nil
}
func (s *stubSoraClientForPoll) FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) {
return "character-1", nil
}
func (s *stubSoraClientForPoll) SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error {
return nil
}
func (s *stubSoraClientForPoll) DeleteCharacter(ctx context.Context, account *Account, characterID string) error {
return nil
}
func (s *stubSoraClientForPoll) PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) {
s.postCalls++
return "s_post", nil
}
func (s *stubSoraClientForPoll) DeletePost(ctx context.Context, account *Account, postID string) error {
s.deleteCalls++
return nil
}
func (s *stubSoraClientForPoll) GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) {
if s.parseErr != nil {
return "", s.parseErr
}
return "https://example.com/no-watermark.mp4", nil
}
func (s *stubSoraClientForPoll) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) {
if s.enhanced != "" {
return s.enhanced, s.enhanceErr
}
return "enhanced prompt", s.enhanceErr
}
func (s *stubSoraClientForPoll) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) {
s.imageCalls++
return s.imageStatus, nil
......@@ -62,6 +127,136 @@ func TestSoraGatewayService_PollImageTaskCompleted(t *testing.T) {
require.Equal(t, 1, client.imageCalls)
}
func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) {
client := &stubSoraClientForPoll{
enhanced: "cinematic prompt",
}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
svc := NewSoraGatewayService(client, nil, nil, cfg)
account := &Account{
ID: 1,
Platform: PlatformSora,
Status: StatusActive,
}
body := []byte(`{"model":"prompt-enhance-short-10s","messages":[{"role":"user","content":"cat running"}],"stream":false}`)
result, err := svc.Forward(context.Background(), nil, account, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "prompt", result.MediaType)
require.Equal(t, "prompt-enhance-short-10s", result.Model)
}
func TestSoraGatewayService_ForwardStoryboardPrompt(t *testing.T) {
client := &stubSoraClientForPoll{
videoStatus: &SoraVideoTaskStatus{
Status: "completed",
URLs: []string{"https://example.com/v.mp4"},
},
}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
svc := NewSoraGatewayService(client, nil, nil, cfg)
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"[5.0s]猫猫跳伞 [5.0s]猫猫落地"}],"stream":false}`)
result, err := svc.Forward(context.Background(), nil, account, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, client.storyboard)
}
func TestSoraGatewayService_ForwardCharacterOnly(t *testing.T) {
client := &stubSoraClientForPoll{}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
svc := NewSoraGatewayService(client, nil, nil, cfg)
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
body := []byte(`{"model":"sora2-landscape-10s","video":"aGVsbG8=","stream":false}`)
result, err := svc.Forward(context.Background(), nil, account, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "prompt", result.MediaType)
require.Equal(t, 0, client.videoCalls)
}
func TestSoraGatewayService_ForwardWatermarkFallback(t *testing.T) {
client := &stubSoraClientForPoll{
videoStatus: &SoraVideoTaskStatus{
Status: "completed",
URLs: []string{"https://example.com/original.mp4"},
GenerationID: "gen_1",
},
parseErr: errors.New("parse failed"),
}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
svc := NewSoraGatewayService(client, nil, nil, cfg)
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_fallback_on_failure":true}`)
result, err := svc.Forward(context.Background(), nil, account, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "https://example.com/original.mp4", result.MediaURL)
require.Equal(t, 1, client.postCalls)
require.Equal(t, 0, client.deleteCalls)
}
func TestSoraGatewayService_ForwardWatermarkCustomSuccessAndDelete(t *testing.T) {
client := &stubSoraClientForPoll{
videoStatus: &SoraVideoTaskStatus{
Status: "completed",
URLs: []string{"https://example.com/original.mp4"},
GenerationID: "gen_1",
},
}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
svc := NewSoraGatewayService(client, nil, nil, cfg)
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_delete_post":true}`)
result, err := svc.Forward(context.Background(), nil, account, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "https://example.com/no-watermark.mp4", result.MediaURL)
require.Equal(t, 1, client.postCalls)
require.Equal(t, 1, client.deleteCalls)
}
func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) {
client := &stubSoraClientForPoll{
videoStatus: &SoraVideoTaskStatus{
......@@ -79,9 +274,9 @@ func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) {
}
service := NewSoraGatewayService(client, nil, nil, cfg)
urls, err := service.pollVideoTask(context.Background(), nil, &Account{ID: 1}, "task", false)
status, err := service.pollVideoTaskDetailed(context.Background(), nil, &Account{ID: 1}, "task", false)
require.Error(t, err)
require.Empty(t, urls)
require.Nil(t, status)
require.Contains(t, err.Error(), "reject")
require.Equal(t, 1, client.videoCalls)
}
......@@ -175,9 +370,65 @@ func TestSoraProErrorMessage(t *testing.T) {
require.Empty(t, soraProErrorMessage("sora-basic", ""))
}
func TestSoraGatewayService_WriteSoraError_StreamEscapesJSON(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
svc.writeSoraError(c, http.StatusBadGateway, "upstream_error", "invalid \"prompt\"\nline2", true)
body := rec.Body.String()
require.Contains(t, body, "event: error\n")
require.Contains(t, body, "data: [DONE]\n\n")
lines := strings.Split(body, "\n")
require.GreaterOrEqual(t, len(lines), 2)
require.Equal(t, "event: error", lines[0])
require.True(t, strings.HasPrefix(lines[1], "data: "))
data := strings.TrimPrefix(lines[1], "data: ")
var parsed map[string]any
require.NoError(t, json.Unmarshal([]byte(data), &parsed))
errObj, ok := parsed["error"].(map[string]any)
require.True(t, ok)
require.Equal(t, "upstream_error", errObj["type"])
require.Equal(t, "invalid \"prompt\"\nline2", errObj["message"])
}
func TestSoraGatewayService_HandleSoraRequestError_FailoverHeadersCloned(t *testing.T) {
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
sourceHeaders := http.Header{}
sourceHeaders.Set("cf-ray", "9d01b0e9ecc35829-SEA")
err := svc.handleSoraRequestError(
context.Background(),
&Account{ID: 1, Platform: PlatformSora},
&SoraUpstreamError{
StatusCode: http.StatusForbidden,
Message: "forbidden",
Headers: sourceHeaders,
Body: []byte(`<!DOCTYPE html><title>Just a moment...</title>`),
},
"sora2-landscape-10s",
nil,
false,
)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr)
require.NotNil(t, failoverErr.ResponseHeaders)
require.Equal(t, "9d01b0e9ecc35829-SEA", failoverErr.ResponseHeaders.Get("cf-ray"))
sourceHeaders.Set("cf-ray", "mutated-after-return")
require.Equal(t, "9d01b0e9ecc35829-SEA", failoverErr.ResponseHeaders.Get("cf-ray"))
}
func TestShouldFailoverUpstreamError(t *testing.T) {
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
require.True(t, svc.shouldFailoverUpstreamError(401))
require.True(t, svc.shouldFailoverUpstreamError(404))
require.True(t, svc.shouldFailoverUpstreamError(429))
require.True(t, svc.shouldFailoverUpstreamError(500))
require.True(t, svc.shouldFailoverUpstreamError(502))
......@@ -257,3 +508,19 @@ func TestDecodeSoraImageInput_DataURL(t *testing.T) {
require.NotEmpty(t, data)
require.Contains(t, filename, ".png")
}
func TestDecodeBase64WithLimit_ExceedLimit(t *testing.T) {
data, err := decodeBase64WithLimit("aGVsbG8=", 3)
require.Error(t, err)
require.Nil(t, data)
}
func TestParseSoraWatermarkOptions_NumericBool(t *testing.T) {
body := map[string]any{
"watermark_free": float64(1),
"watermark_fallback_on_failure": float64(0),
}
opts := parseSoraWatermarkOptions(body)
require.True(t, opts.Enabled)
require.False(t, opts.FallbackOnFailure)
}
......@@ -17,6 +17,9 @@ type SoraModelConfig struct {
Model string
Size string
RequirePro bool
// Prompt-enhance 专用参数
ExpansionLevel string
DurationS int
}
var soraModelConfigs = map[string]SoraModelConfig{
......@@ -160,31 +163,49 @@ var soraModelConfigs = map[string]SoraModelConfig{
RequirePro: true,
},
"prompt-enhance-short-10s": {
Type: "prompt_enhance",
Type: "prompt_enhance",
ExpansionLevel: "short",
DurationS: 10,
},
"prompt-enhance-short-15s": {
Type: "prompt_enhance",
Type: "prompt_enhance",
ExpansionLevel: "short",
DurationS: 15,
},
"prompt-enhance-short-20s": {
Type: "prompt_enhance",
Type: "prompt_enhance",
ExpansionLevel: "short",
DurationS: 20,
},
"prompt-enhance-medium-10s": {
Type: "prompt_enhance",
Type: "prompt_enhance",
ExpansionLevel: "medium",
DurationS: 10,
},
"prompt-enhance-medium-15s": {
Type: "prompt_enhance",
Type: "prompt_enhance",
ExpansionLevel: "medium",
DurationS: 15,
},
"prompt-enhance-medium-20s": {
Type: "prompt_enhance",
Type: "prompt_enhance",
ExpansionLevel: "medium",
DurationS: 20,
},
"prompt-enhance-long-10s": {
Type: "prompt_enhance",
Type: "prompt_enhance",
ExpansionLevel: "long",
DurationS: 10,
},
"prompt-enhance-long-15s": {
Type: "prompt_enhance",
Type: "prompt_enhance",
ExpansionLevel: "long",
DurationS: 15,
},
"prompt-enhance-long-20s": {
Type: "prompt_enhance",
Type: "prompt_enhance",
ExpansionLevel: "long",
DurationS: 20,
},
}
......
package service
import (
"fmt"
"math"
"net/http"
"net/url"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
"github.com/google/uuid"
)
type soraChallengeCooldownEntry struct {
Until time.Time
StatusCode int
CFRay string
ConsecutiveChallenges int
LastChallengeAt time.Time
}
type soraSidecarSessionEntry struct {
SessionKey string
ExpiresAt time.Time
LastUsedAt time.Time
}
func (c *SoraDirectClient) cloudflareChallengeCooldownSeconds() int {
if c == nil || c.cfg == nil {
return 900
}
cooldown := c.cfg.Sora.Client.CloudflareChallengeCooldownSeconds
if cooldown <= 0 {
return 0
}
return cooldown
}
func (c *SoraDirectClient) checkCloudflareChallengeCooldown(account *Account, proxyURL string) error {
if c == nil {
return nil
}
if account == nil || account.ID <= 0 {
return nil
}
cooldownSeconds := c.cloudflareChallengeCooldownSeconds()
if cooldownSeconds <= 0 {
return nil
}
key := soraAccountProxyKey(account, proxyURL)
now := time.Now()
c.challengeCooldownMu.RLock()
entry, ok := c.challengeCooldowns[key]
c.challengeCooldownMu.RUnlock()
if !ok {
return nil
}
if !entry.Until.After(now) {
c.challengeCooldownMu.Lock()
delete(c.challengeCooldowns, key)
c.challengeCooldownMu.Unlock()
return nil
}
remaining := int(math.Ceil(entry.Until.Sub(now).Seconds()))
if remaining < 1 {
remaining = 1
}
message := fmt.Sprintf("Sora request cooling down due to recent Cloudflare challenge. Retry in %d seconds.", remaining)
if entry.ConsecutiveChallenges > 1 {
message = fmt.Sprintf("%s (streak=%d)", message, entry.ConsecutiveChallenges)
}
if entry.CFRay != "" {
message = fmt.Sprintf("%s (last cf-ray: %s)", message, entry.CFRay)
}
return &SoraUpstreamError{
StatusCode: http.StatusTooManyRequests,
Message: message,
Headers: make(http.Header),
}
}
func (c *SoraDirectClient) recordCloudflareChallengeCooldown(account *Account, proxyURL string, statusCode int, headers http.Header, body []byte) {
if c == nil {
return
}
if account == nil || account.ID <= 0 {
return
}
cooldownSeconds := c.cloudflareChallengeCooldownSeconds()
if cooldownSeconds <= 0 {
return
}
key := soraAccountProxyKey(account, proxyURL)
now := time.Now()
cfRay := soraerror.ExtractCloudflareRayID(headers, body)
c.challengeCooldownMu.Lock()
c.cleanupExpiredChallengeCooldownsLocked(now)
streak := 1
existing, ok := c.challengeCooldowns[key]
if ok && now.Sub(existing.LastChallengeAt) <= 30*time.Minute {
streak = existing.ConsecutiveChallenges + 1
}
effectiveCooldown := soraComputeChallengeCooldownSeconds(cooldownSeconds, streak)
until := now.Add(time.Duration(effectiveCooldown) * time.Second)
if ok && existing.Until.After(until) {
until = existing.Until
if existing.ConsecutiveChallenges > streak {
streak = existing.ConsecutiveChallenges
}
if cfRay == "" {
cfRay = existing.CFRay
}
}
c.challengeCooldowns[key] = soraChallengeCooldownEntry{
Until: until,
StatusCode: statusCode,
CFRay: cfRay,
ConsecutiveChallenges: streak,
LastChallengeAt: now,
}
c.challengeCooldownMu.Unlock()
if c.debugEnabled() {
remain := int(math.Ceil(until.Sub(now).Seconds()))
if remain < 0 {
remain = 0
}
c.debugLogf("cloudflare_challenge_cooldown_set key=%s status=%d remain_s=%d streak=%d cf_ray=%s", key, statusCode, remain, streak, cfRay)
}
}
func soraComputeChallengeCooldownSeconds(baseSeconds, streak int) int {
if baseSeconds <= 0 {
return 0
}
if streak < 1 {
streak = 1
}
multiplier := streak
if multiplier > 4 {
multiplier = 4
}
cooldown := baseSeconds * multiplier
if cooldown > 3600 {
cooldown = 3600
}
return cooldown
}
func (c *SoraDirectClient) clearCloudflareChallengeCooldown(account *Account, proxyURL string) {
if c == nil {
return
}
if account == nil || account.ID <= 0 {
return
}
key := soraAccountProxyKey(account, proxyURL)
c.challengeCooldownMu.Lock()
_, existed := c.challengeCooldowns[key]
if existed {
delete(c.challengeCooldowns, key)
}
c.challengeCooldownMu.Unlock()
if existed && c.debugEnabled() {
c.debugLogf("cloudflare_challenge_cooldown_cleared key=%s", key)
}
}
func (c *SoraDirectClient) sidecarSessionKey(account *Account, proxyURL string) string {
if c == nil || !c.sidecarSessionReuseEnabled() {
return ""
}
if account == nil || account.ID <= 0 {
return ""
}
key := soraAccountProxyKey(account, proxyURL)
now := time.Now()
ttlSeconds := c.sidecarSessionTTLSeconds()
c.sidecarSessionMu.Lock()
defer c.sidecarSessionMu.Unlock()
c.cleanupExpiredSidecarSessionsLocked(now)
if existing, exists := c.sidecarSessions[key]; exists {
existing.LastUsedAt = now
c.sidecarSessions[key] = existing
return existing.SessionKey
}
expiresAt := now.Add(time.Duration(ttlSeconds) * time.Second)
if ttlSeconds <= 0 {
expiresAt = now.Add(365 * 24 * time.Hour)
}
newEntry := soraSidecarSessionEntry{
SessionKey: "sora-" + uuid.NewString(),
ExpiresAt: expiresAt,
LastUsedAt: now,
}
c.sidecarSessions[key] = newEntry
if c.debugEnabled() {
c.debugLogf("sidecar_session_created key=%s ttl_s=%d", key, ttlSeconds)
}
return newEntry.SessionKey
}
func (c *SoraDirectClient) cleanupExpiredChallengeCooldownsLocked(now time.Time) {
if c == nil || len(c.challengeCooldowns) == 0 {
return
}
for key, entry := range c.challengeCooldowns {
if !entry.Until.After(now) {
delete(c.challengeCooldowns, key)
}
}
}
func (c *SoraDirectClient) cleanupExpiredSidecarSessionsLocked(now time.Time) {
if c == nil || len(c.sidecarSessions) == 0 {
return
}
for key, entry := range c.sidecarSessions {
if !entry.ExpiresAt.After(now) {
delete(c.sidecarSessions, key)
}
}
}
func soraAccountProxyKey(account *Account, proxyURL string) string {
accountID := int64(0)
if account != nil {
accountID = account.ID
}
return fmt.Sprintf("account:%d|proxy:%s", accountID, normalizeSoraProxyKey(proxyURL))
}
func normalizeSoraProxyKey(proxyURL string) string {
raw := strings.TrimSpace(proxyURL)
if raw == "" {
return "direct"
}
parsed, err := url.Parse(raw)
if err != nil {
return strings.ToLower(raw)
}
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
port := strings.TrimSpace(parsed.Port())
if host == "" {
return strings.ToLower(raw)
}
if (scheme == "http" && port == "80") || (scheme == "https" && port == "443") {
port = ""
}
if port != "" {
host = host + ":" + port
}
if scheme == "" {
scheme = "proxy"
}
return scheme + "://" + host
}
......@@ -43,10 +43,13 @@ func NewTokenRefreshService(
stopCh: make(chan struct{}),
}
openAIRefresher := NewOpenAITokenRefresher(openaiOAuthService, accountRepo)
openAIRefresher.SetSyncLinkedSoraAccounts(cfg.TokenRefresh.SyncLinkedSoraAccounts)
// 注册平台特定的刷新器
s.refreshers = []TokenRefresher{
NewClaudeTokenRefresher(oauthService),
NewOpenAITokenRefresher(openaiOAuthService, accountRepo),
openAIRefresher,
NewGeminiTokenRefresher(geminiOAuthService),
NewAntigravityTokenRefresher(antigravityOAuthService),
}
......
......@@ -86,6 +86,7 @@ type OpenAITokenRefresher struct {
openaiOAuthService *OpenAIOAuthService
accountRepo AccountRepository
soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步
syncLinkedSora bool
}
// NewOpenAITokenRefresher 创建 OpenAI token刷新器
......@@ -103,11 +104,15 @@ func (r *OpenAITokenRefresher) SetSoraAccountRepo(repo SoraAccountRepository) {
r.soraAccountRepo = repo
}
// SetSyncLinkedSoraAccounts 控制是否同步覆盖关联的 Sora 账号 token。
func (r *OpenAITokenRefresher) SetSyncLinkedSoraAccounts(enabled bool) {
r.syncLinkedSora = enabled
}
// CanRefresh 检查是否能处理此账号
// 只处理 openai 平台的 oauth 类型账号
// 只处理 openai 平台的 oauth 类型账号(不直接刷新 sora 平台账号)
func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool {
return (account.Platform == PlatformOpenAI || account.Platform == PlatformSora) &&
account.Type == AccountTypeOAuth
return account.Platform == PlatformOpenAI && account.Type == AccountTypeOAuth
}
// NeedsRefresh 检查token是否需要刷新
......@@ -141,7 +146,7 @@ func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (m
}
// 异步同步关联的 Sora 账号(不阻塞主流程)
if r.accountRepo != nil {
if r.accountRepo != nil && r.syncLinkedSora {
go r.syncLinkedSoraAccounts(context.Background(), account.ID, newCredentials)
}
......
......@@ -226,3 +226,43 @@ func TestClaudeTokenRefresher_CanRefresh(t *testing.T) {
})
}
}
func TestOpenAITokenRefresher_CanRefresh(t *testing.T) {
refresher := &OpenAITokenRefresher{}
tests := []struct {
name string
platform string
accType string
want bool
}{
{
name: "openai oauth - can refresh",
platform: PlatformOpenAI,
accType: AccountTypeOAuth,
want: true,
},
{
name: "sora oauth - cannot refresh directly",
platform: PlatformSora,
accType: AccountTypeOAuth,
want: false,
},
{
name: "openai apikey - cannot refresh",
platform: PlatformOpenAI,
accType: AccountTypeAPIKey,
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
account := &Account{
Platform: tt.platform,
Type: tt.accType,
}
require.Equal(t, tt.want, refresher.CanRefresh(account))
})
}
}
......@@ -26,8 +26,8 @@ type UsageLog struct {
CacheCreationTokens int
CacheReadTokens int
CacheCreation5mTokens int
CacheCreation1hTokens int
CacheCreation5mTokens int `gorm:"column:cache_creation_5m_tokens"`
CacheCreation1hTokens int `gorm:"column:cache_creation_1h_tokens"`
InputCost float64
OutputCost float64
......@@ -46,6 +46,9 @@ type UsageLog struct {
UserAgent *string
IPAddress *string
// Cache TTL Override 标记(管理员强制替换了缓存 TTL 计费)
CacheTTLOverridden bool
// 图片生成字段
ImageCount int
ImageSize *string
......
......@@ -206,6 +206,18 @@ func ProvideSoraMediaStorage(cfg *config.Config) *SoraMediaStorage {
return NewSoraMediaStorage(cfg)
}
func ProvideSoraDirectClient(
cfg *config.Config,
httpUpstream HTTPUpstream,
tokenProvider *OpenAITokenProvider,
accountRepo AccountRepository,
soraAccountRepo SoraAccountRepository,
) *SoraDirectClient {
client := NewSoraDirectClient(cfg, httpUpstream, tokenProvider)
client.SetAccountRepositories(accountRepo, soraAccountRepo)
return client
}
// ProvideSoraMediaCleanupService 创建并启动 Sora 媒体清理服务
func ProvideSoraMediaCleanupService(storage *SoraMediaStorage, cfg *config.Config) *SoraMediaCleanupService {
svc := NewSoraMediaCleanupService(storage, cfg)
......@@ -255,7 +267,7 @@ var ProviderSet = wire.NewSet(
NewGatewayService,
ProvideSoraMediaStorage,
ProvideSoraMediaCleanupService,
NewSoraDirectClient,
ProvideSoraDirectClient,
wire.Bind(new(SoraClient), new(*SoraDirectClient)),
NewSoraGatewayService,
NewOpenAIGatewayService,
......
package soraerror
import (
"encoding/json"
"fmt"
"net/http"
"regexp"
"strings"
)
var (
cfRayPattern = regexp.MustCompile(`(?i)cf-ray[:\s=]+([a-z0-9-]+)`)
cRayPattern = regexp.MustCompile(`(?i)cRay:\s*'([a-z0-9-]+)'`)
htmlChallenge = []string{
"window._cf_chl_opt",
"just a moment",
"enable javascript and cookies to continue",
"__cf_chl_",
"challenge-platform",
}
)
// IsCloudflareChallengeResponse reports whether the upstream response matches Cloudflare challenge behavior.
func IsCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
if statusCode != http.StatusForbidden && statusCode != http.StatusTooManyRequests {
return false
}
if headers != nil && strings.EqualFold(strings.TrimSpace(headers.Get("cf-mitigated")), "challenge") {
return true
}
preview := strings.ToLower(TruncateBody(body, 4096))
for _, marker := range htmlChallenge {
if strings.Contains(preview, marker) {
return true
}
}
contentType := ""
if headers != nil {
contentType = strings.ToLower(strings.TrimSpace(headers.Get("content-type")))
}
if strings.Contains(contentType, "text/html") &&
(strings.Contains(preview, "<html") || strings.Contains(preview, "<!doctype html")) &&
(strings.Contains(preview, "cloudflare") || strings.Contains(preview, "challenge")) {
return true
}
return false
}
// ExtractCloudflareRayID extracts cf-ray from headers or response body.
func ExtractCloudflareRayID(headers http.Header, body []byte) string {
if headers != nil {
rayID := strings.TrimSpace(headers.Get("cf-ray"))
if rayID != "" {
return rayID
}
rayID = strings.TrimSpace(headers.Get("Cf-Ray"))
if rayID != "" {
return rayID
}
}
preview := TruncateBody(body, 8192)
if matches := cfRayPattern.FindStringSubmatch(preview); len(matches) >= 2 {
return strings.TrimSpace(matches[1])
}
if matches := cRayPattern.FindStringSubmatch(preview); len(matches) >= 2 {
return strings.TrimSpace(matches[1])
}
return ""
}
// FormatCloudflareChallengeMessage appends cf-ray info when available.
func FormatCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
rayID := ExtractCloudflareRayID(headers, body)
if rayID == "" {
return base
}
return fmt.Sprintf("%s (cf-ray: %s)", base, rayID)
}
// ExtractUpstreamErrorCodeAndMessage extracts structured error code/message from common JSON layouts.
func ExtractUpstreamErrorCodeAndMessage(body []byte) (string, string) {
trimmed := strings.TrimSpace(string(body))
if trimmed == "" {
return "", ""
}
if !json.Valid([]byte(trimmed)) {
return "", truncateMessage(trimmed, 256)
}
var payload map[string]any
if err := json.Unmarshal([]byte(trimmed), &payload); err != nil {
return "", truncateMessage(trimmed, 256)
}
code := firstNonEmpty(
extractNestedString(payload, "error", "code"),
extractRootString(payload, "code"),
)
message := firstNonEmpty(
extractNestedString(payload, "error", "message"),
extractRootString(payload, "message"),
extractNestedString(payload, "error", "detail"),
extractRootString(payload, "detail"),
)
return strings.TrimSpace(code), truncateMessage(strings.TrimSpace(message), 512)
}
// TruncateBody truncates body text for logging/inspection.
func TruncateBody(body []byte, max int) string {
if max <= 0 {
max = 512
}
raw := strings.TrimSpace(string(body))
if len(raw) <= max {
return raw
}
return raw[:max] + "...(truncated)"
}
func truncateMessage(s string, max int) string {
if max <= 0 {
return ""
}
if len(s) <= max {
return s
}
return s[:max] + "...(truncated)"
}
func firstNonEmpty(values ...string) string {
for _, v := range values {
if strings.TrimSpace(v) != "" {
return v
}
}
return ""
}
func extractRootString(m map[string]any, key string) string {
if m == nil {
return ""
}
v, ok := m[key]
if !ok {
return ""
}
s, _ := v.(string)
return s
}
func extractNestedString(m map[string]any, parent, key string) string {
if m == nil {
return ""
}
node, ok := m[parent]
if !ok {
return ""
}
child, ok := node.(map[string]any)
if !ok {
return ""
}
s, _ := child[key].(string)
return s
}
package soraerror
import (
"net/http"
"testing"
"github.com/stretchr/testify/require"
)
func TestIsCloudflareChallengeResponse(t *testing.T) {
headers := make(http.Header)
headers.Set("cf-mitigated", "challenge")
require.True(t, IsCloudflareChallengeResponse(http.StatusForbidden, headers, []byte(`{"ok":false}`)))
require.True(t, IsCloudflareChallengeResponse(http.StatusTooManyRequests, nil, []byte(`<!DOCTYPE html><title>Just a moment...</title><script>window._cf_chl_opt={};</script>`)))
require.False(t, IsCloudflareChallengeResponse(http.StatusBadGateway, nil, []byte(`<!DOCTYPE html><title>Just a moment...</title>`)))
}
func TestExtractCloudflareRayID(t *testing.T) {
headers := make(http.Header)
headers.Set("cf-ray", "9d01b0e9ecc35829-SEA")
require.Equal(t, "9d01b0e9ecc35829-SEA", ExtractCloudflareRayID(headers, nil))
body := []byte(`<script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script>`)
require.Equal(t, "9cff2d62d83bb98d", ExtractCloudflareRayID(nil, body))
}
func TestExtractUpstreamErrorCodeAndMessage(t *testing.T) {
code, msg := ExtractUpstreamErrorCodeAndMessage([]byte(`{"error":{"code":"cf_shield_429","message":"rate limited"}}`))
require.Equal(t, "cf_shield_429", code)
require.Equal(t, "rate limited", msg)
code, msg = ExtractUpstreamErrorCodeAndMessage([]byte(`{"code":"unsupported_country_code","message":"not available"}`))
require.Equal(t, "unsupported_country_code", code)
require.Equal(t, "not available", msg)
code, msg = ExtractUpstreamErrorCodeAndMessage([]byte(`plain text`))
require.Equal(t, "", code)
require.Equal(t, "plain text", msg)
}
func TestFormatCloudflareChallengeMessage(t *testing.T) {
headers := make(http.Header)
headers.Set("cf-ray", "9d03b68c086027a1-SEA")
msg := FormatCloudflareChallengeMessage("blocked", headers, nil)
require.Equal(t, "blocked (cf-ray: 9d03b68c086027a1-SEA)", msg)
}
......@@ -86,6 +86,7 @@ func (s *FrontendServer) Middleware() gin.HandlerFunc {
if strings.HasPrefix(path, "/api/") ||
strings.HasPrefix(path, "/v1/") ||
strings.HasPrefix(path, "/v1beta/") ||
strings.HasPrefix(path, "/sora/") ||
strings.HasPrefix(path, "/antigravity/") ||
strings.HasPrefix(path, "/setup/") ||
path == "/health" ||
......@@ -209,6 +210,7 @@ func ServeEmbeddedFrontend() gin.HandlerFunc {
if strings.HasPrefix(path, "/api/") ||
strings.HasPrefix(path, "/v1/") ||
strings.HasPrefix(path, "/v1beta/") ||
strings.HasPrefix(path, "/sora/") ||
strings.HasPrefix(path, "/antigravity/") ||
strings.HasPrefix(path, "/setup/") ||
path == "/health" ||
......
......@@ -362,6 +362,7 @@ func TestFrontendServer_Middleware(t *testing.T) {
"/api/v1/users",
"/v1/models",
"/v1beta/chat",
"/sora/v1/models",
"/antigravity/test",
"/setup/init",
"/health",
......@@ -537,6 +538,7 @@ func TestServeEmbeddedFrontend(t *testing.T) {
"/api/users",
"/v1/models",
"/v1beta/chat",
"/sora/v1/models",
"/antigravity/test",
"/setup/init",
"/health",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment