Commit 40498aac authored by yangjianbo's avatar yangjianbo
Browse files

feat(sora): 对齐sora2api分镜角色去水印与挑战错误治理

parent 440b8709
......@@ -184,7 +184,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, errorPassthroughService, configConfig)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, errorPassthroughService, configConfig)
soraDirectClient := service.NewSoraDirectClient(configConfig, httpUpstream, openAITokenProvider)
soraDirectClient := service.ProvideSoraDirectClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository)
soraMediaStorage := service.ProvideSoraMediaStorage(configConfig)
soraGatewayService := service.NewSoraGatewayService(soraDirectClient, soraMediaStorage, rateLimitService, configConfig)
soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, configConfig)
......
......@@ -12,7 +12,6 @@ import (
"os"
"path"
"path/filepath"
"regexp"
"strconv"
"strings"
"time"
......@@ -22,6 +21,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
......@@ -29,8 +29,6 @@ import (
"go.uber.org/zap"
)
var soraCloudflareRayPattern = regexp.MustCompile(`(?i)cf-ray[:\s=]+([a-z0-9-]+)`)
// SoraGatewayHandler handles Sora chat completions requests
type SoraGatewayHandler struct {
gatewayService *service.GatewayService
......@@ -385,6 +383,10 @@ func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseHeaders ht
}
upstreamCode, upstreamMessage := extractUpstreamErrorCodeAndMessage(responseBody)
if strings.EqualFold(upstreamCode, "cf_shield_429") {
baseMsg := "Sora request blocked by Cloudflare shield (429). Please switch to a clean proxy/network and retry."
return http.StatusTooManyRequests, "rate_limit_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody)
}
if shouldPassthroughSoraUpstreamMessage(statusCode, upstreamMessage) {
switch statusCode {
case 401, 403, 404, 500, 502, 503, 504:
......@@ -416,27 +418,7 @@ func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseHeaders ht
}
func isSoraCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
if statusCode != http.StatusForbidden && statusCode != http.StatusTooManyRequests {
return false
}
if strings.EqualFold(strings.TrimSpace(headers.Get("cf-mitigated")), "challenge") {
return true
}
preview := strings.ToLower(truncateSoraErrorBody(body, 4096))
if strings.Contains(preview, "window._cf_chl_opt") ||
strings.Contains(preview, "just a moment") ||
strings.Contains(preview, "enable javascript and cookies to continue") ||
strings.Contains(preview, "__cf_chl_") ||
strings.Contains(preview, "challenge-platform") {
return true
}
contentType := strings.ToLower(strings.TrimSpace(headers.Get("content-type")))
if strings.Contains(contentType, "text/html") &&
(strings.Contains(preview, "<html") || strings.Contains(preview, "<!doctype html")) &&
(strings.Contains(preview, "cloudflare") || strings.Contains(preview, "challenge")) {
return true
}
return false
return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body)
}
func shouldPassthroughSoraUpstreamMessage(statusCode int, message string) bool {
......@@ -454,76 +436,11 @@ func shouldPassthroughSoraUpstreamMessage(statusCode int, message string) bool {
}
func formatSoraCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
rayID := extractSoraCloudflareRayID(headers, body)
if rayID == "" {
return base
}
return fmt.Sprintf("%s (cf-ray: %s)", base, rayID)
}
func extractSoraCloudflareRayID(headers http.Header, body []byte) string {
if headers != nil {
rayID := strings.TrimSpace(headers.Get("cf-ray"))
if rayID != "" {
return rayID
}
rayID = strings.TrimSpace(headers.Get("Cf-Ray"))
if rayID != "" {
return rayID
}
}
preview := truncateSoraErrorBody(body, 8192)
matches := soraCloudflareRayPattern.FindStringSubmatch(preview)
if len(matches) >= 2 {
return strings.TrimSpace(matches[1])
}
return ""
return soraerror.FormatCloudflareChallengeMessage(base, headers, body)
}
func extractUpstreamErrorCodeAndMessage(body []byte) (string, string) {
trimmed := strings.TrimSpace(string(body))
if trimmed == "" {
return "", ""
}
if !gjson.Valid(trimmed) {
return "", truncateSoraErrorMessage(trimmed, 256)
}
code := strings.TrimSpace(gjson.Get(trimmed, "error.code").String())
if code == "" {
code = strings.TrimSpace(gjson.Get(trimmed, "code").String())
}
message := strings.TrimSpace(gjson.Get(trimmed, "error.message").String())
if message == "" {
message = strings.TrimSpace(gjson.Get(trimmed, "message").String())
}
if message == "" {
message = strings.TrimSpace(gjson.Get(trimmed, "error.detail").String())
}
if message == "" {
message = strings.TrimSpace(gjson.Get(trimmed, "detail").String())
}
return code, truncateSoraErrorMessage(message, 512)
}
func truncateSoraErrorMessage(s string, maxLen int) string {
if maxLen <= 0 {
return ""
}
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "...(truncated)"
}
func truncateSoraErrorBody(body []byte, maxLen int) string {
if maxLen <= 0 {
maxLen = 512
}
raw := strings.TrimSpace(string(body))
if len(raw) <= maxLen {
return raw
}
return raw[:maxLen] + "...(truncated)"
return soraerror.ExtractUpstreamErrorCodeAndMessage(body)
}
func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
......
......@@ -43,6 +43,45 @@ func (s *stubSoraClient) CreateImageTask(ctx context.Context, account *service.A
func (s *stubSoraClient) CreateVideoTask(ctx context.Context, account *service.Account, req service.SoraVideoRequest) (string, error) {
return "task-video", nil
}
func (s *stubSoraClient) CreateStoryboardTask(ctx context.Context, account *service.Account, req service.SoraStoryboardRequest) (string, error) {
return "task-video", nil
}
func (s *stubSoraClient) UploadCharacterVideo(ctx context.Context, account *service.Account, data []byte) (string, error) {
return "cameo-1", nil
}
func (s *stubSoraClient) GetCameoStatus(ctx context.Context, account *service.Account, cameoID string) (*service.SoraCameoStatus, error) {
return &service.SoraCameoStatus{
Status: "finalized",
StatusMessage: "Completed",
DisplayNameHint: "Character",
UsernameHint: "user.character",
ProfileAssetURL: "https://example.com/avatar.webp",
}, nil
}
func (s *stubSoraClient) DownloadCharacterImage(ctx context.Context, account *service.Account, imageURL string) ([]byte, error) {
return []byte("avatar"), nil
}
func (s *stubSoraClient) UploadCharacterImage(ctx context.Context, account *service.Account, data []byte) (string, error) {
return "asset-pointer", nil
}
func (s *stubSoraClient) FinalizeCharacter(ctx context.Context, account *service.Account, req service.SoraCharacterFinalizeRequest) (string, error) {
return "character-1", nil
}
func (s *stubSoraClient) SetCharacterPublic(ctx context.Context, account *service.Account, cameoID string) error {
return nil
}
func (s *stubSoraClient) DeleteCharacter(ctx context.Context, account *service.Account, characterID string) error {
return nil
}
func (s *stubSoraClient) PostVideoForWatermarkFree(ctx context.Context, account *service.Account, generationID string) (string, error) {
return "s_post", nil
}
func (s *stubSoraClient) DeletePost(ctx context.Context, account *service.Account, postID string) error {
return nil
}
func (s *stubSoraClient) GetWatermarkFreeURLCustom(ctx context.Context, account *service.Account, parseURL, parseToken, postID string) (string, error) {
return "https://example.com/no-watermark.mp4", nil
}
func (s *stubSoraClient) EnhancePrompt(ctx context.Context, account *service.Account, prompt, expansionLevel string, durationS int) (string, error) {
return "enhanced prompt", nil
}
......@@ -607,3 +646,31 @@ func TestSoraHandleFailoverExhausted_CloudflareChallengeIncludesRay(t *testing.T
require.Contains(t, msg, "Cloudflare challenge")
require.Contains(t, msg, "cf-ray: 9d01b0e9ecc35829-SEA")
}
func TestSoraHandleFailoverExhausted_CfShield429MappedToRateLimitError(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
headers := http.Header{}
headers.Set("cf-ray", "9d03b68c086027a1-SEA")
body := []byte(`{"error":{"code":"cf_shield_429","message":"shield blocked"}}`)
h := &SoraGatewayHandler{}
h.handleFailoverExhausted(c, http.StatusTooManyRequests, headers, body, true)
lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n")
require.Len(t, lines, 2)
jsonStr := strings.TrimPrefix(lines[1], "data: ")
var parsed map[string]any
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok)
require.Equal(t, "rate_limit_error", errorObj["type"])
msg, _ := errorObj["message"].(string)
require.Contains(t, msg, "Cloudflare shield")
require.Contains(t, msg, "cf-ray: 9d03b68c086027a1-SEA")
}
......@@ -15,11 +15,14 @@ import (
"net/url"
"regexp"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
......@@ -28,7 +31,6 @@ import (
// sseDataPrefix matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var sseDataPrefix = regexp.MustCompile(`^data:\s*`)
var cloudflareRayPattern = regexp.MustCompile(`(?i)cRay:\s*'([a-z0-9-]+)'`)
const (
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
......@@ -45,6 +47,9 @@ type TestEvent struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
Model string `json:"model,omitempty"`
Status string `json:"status,omitempty"`
Code string `json:"code,omitempty"`
Data any `json:"data,omitempty"`
Success bool `json:"success,omitempty"`
Error string `json:"error,omitempty"`
}
......@@ -56,8 +61,13 @@ type AccountTestService struct {
antigravityGatewayService *AntigravityGatewayService
httpUpstream HTTPUpstream
cfg *config.Config
soraTestGuardMu sync.Mutex
soraTestLastRun map[int64]time.Time
soraTestCooldown time.Duration
}
const defaultSoraTestCooldown = 10 * time.Second
// NewAccountTestService creates a new AccountTestService
func NewAccountTestService(
accountRepo AccountRepository,
......@@ -72,6 +82,8 @@ func NewAccountTestService(
antigravityGatewayService: antigravityGatewayService,
httpUpstream: httpUpstream,
cfg: cfg,
soraTestLastRun: make(map[int64]time.Time),
soraTestCooldown: defaultSoraTestCooldown,
}
}
......@@ -473,13 +485,129 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
return s.processGeminiStream(c, resp.Body)
}
type soraProbeStep struct {
Name string `json:"name"`
Status string `json:"status"`
HTTPStatus int `json:"http_status,omitempty"`
ErrorCode string `json:"error_code,omitempty"`
Message string `json:"message,omitempty"`
}
type soraProbeSummary struct {
Status string `json:"status"`
Steps []soraProbeStep `json:"steps"`
}
type soraProbeRecorder struct {
steps []soraProbeStep
}
func (r *soraProbeRecorder) addStep(name, status string, httpStatus int, errorCode, message string) {
r.steps = append(r.steps, soraProbeStep{
Name: name,
Status: status,
HTTPStatus: httpStatus,
ErrorCode: strings.TrimSpace(errorCode),
Message: strings.TrimSpace(message),
})
}
func (r *soraProbeRecorder) finalize() soraProbeSummary {
meSuccess := false
partial := false
for _, step := range r.steps {
if step.Name == "me" {
meSuccess = strings.EqualFold(step.Status, "success")
continue
}
if strings.EqualFold(step.Status, "failed") {
partial = true
}
}
status := "success"
if !meSuccess {
status = "failed"
} else if partial {
status = "partial_success"
}
return soraProbeSummary{
Status: status,
Steps: append([]soraProbeStep(nil), r.steps...),
}
}
func (s *AccountTestService) emitSoraProbeSummary(c *gin.Context, rec *soraProbeRecorder) {
if rec == nil {
return
}
summary := rec.finalize()
code := ""
for _, step := range summary.Steps {
if strings.EqualFold(step.Status, "failed") && strings.TrimSpace(step.ErrorCode) != "" {
code = step.ErrorCode
break
}
}
s.sendEvent(c, TestEvent{
Type: "sora_test_result",
Status: summary.Status,
Code: code,
Data: summary,
})
}
func (s *AccountTestService) acquireSoraTestPermit(accountID int64) (time.Duration, bool) {
if accountID <= 0 {
return 0, true
}
s.soraTestGuardMu.Lock()
defer s.soraTestGuardMu.Unlock()
if s.soraTestLastRun == nil {
s.soraTestLastRun = make(map[int64]time.Time)
}
cooldown := s.soraTestCooldown
if cooldown <= 0 {
cooldown = defaultSoraTestCooldown
}
now := time.Now()
if lastRun, ok := s.soraTestLastRun[accountID]; ok {
elapsed := now.Sub(lastRun)
if elapsed < cooldown {
return cooldown - elapsed, false
}
}
s.soraTestLastRun[accountID] = now
return 0, true
}
func ceilSeconds(d time.Duration) int {
if d <= 0 {
return 1
}
sec := int(d / time.Second)
if d%time.Second != 0 {
sec++
}
if sec < 1 {
sec = 1
}
return sec
}
// testSoraAccountConnection 测试 Sora 账号的连接
// 调用 /backend/me 接口验证 access_token 有效性(不需要 Sentinel Token)
func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *Account) error {
ctx := c.Request.Context()
recorder := &soraProbeRecorder{}
authToken := account.GetCredential("access_token")
if authToken == "" {
recorder.addStep("me", "failed", http.StatusUnauthorized, "missing_access_token", "No access token available")
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, "No access token available")
}
......@@ -490,11 +618,20 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Flush()
if wait, ok := s.acquireSoraTestPermit(account.ID); !ok {
msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait))
recorder.addStep("rate_limit", "failed", http.StatusTooManyRequests, "test_rate_limited", msg)
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, msg)
}
// Send test_start event
s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora"})
req, err := http.NewRequestWithContext(ctx, "GET", soraMeAPIURL, nil)
if err != nil {
recorder.addStep("me", "failed", 0, "request_build_failed", err.Error())
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, "Failed to create request")
}
......@@ -515,6 +652,8 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint)
if err != nil {
recorder.addStep("me", "failed", 0, "network_error", err.Error())
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
defer func() { _ = resp.Body.Close() }()
......@@ -522,12 +661,33 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
if isCloudflareChallengeResponse(resp.StatusCode, body) {
if isCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) {
recorder.addStep("me", "failed", resp.StatusCode, "cf_challenge", "Cloudflare challenge detected")
s.emitSoraProbeSummary(c, recorder)
s.logSoraCloudflareChallenge(account, proxyURL, soraMeAPIURL, resp.Header, body)
return s.sendErrorAndEnd(c, formatCloudflareChallengeMessage("Sora request blocked by Cloudflare challenge (HTTP 403). Please switch to a clean proxy/network and retry.", resp.Header, body))
return s.sendErrorAndEnd(c, formatCloudflareChallengeMessage(fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", resp.StatusCode), resp.Header, body))
}
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(body)
switch {
case resp.StatusCode == http.StatusUnauthorized && strings.EqualFold(upstreamCode, "token_invalidated"):
recorder.addStep("me", "failed", resp.StatusCode, "token_invalidated", "Sora token invalidated")
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, "Sora token 已失效(token_invalidated),请重新授权账号")
case strings.EqualFold(upstreamCode, "unsupported_country_code"):
recorder.addStep("me", "failed", resp.StatusCode, "unsupported_country_code", "Sora is unavailable in current egress region")
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, "Sora 在当前网络出口地区不可用(unsupported_country_code),请切换到支持地区后重试")
case strings.TrimSpace(upstreamMessage) != "":
recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, upstreamMessage)
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, upstreamMessage))
default:
recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, "Sora me endpoint failed")
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, truncateSoraErrorBody(body, 512)))
}
return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, truncateSoraErrorBody(body, 512)))
}
recorder.addStep("me", "success", resp.StatusCode, "", "me endpoint ok")
// 解析 /me 响应,提取用户信息
var meResp map[string]any
......@@ -557,21 +717,26 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
subResp, subErr := s.httpUpstream.DoWithTLS(subReq, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint)
if subErr != nil {
recorder.addStep("subscription", "failed", 0, "network_error", subErr.Error())
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check skipped: %s", subErr.Error())})
} else {
subBody, _ := io.ReadAll(subResp.Body)
_ = subResp.Body.Close()
if subResp.StatusCode == http.StatusOK {
recorder.addStep("subscription", "success", subResp.StatusCode, "", "subscription endpoint ok")
if summary := parseSoraSubscriptionSummary(subBody); summary != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: summary})
} else {
s.sendEvent(c, TestEvent{Type: "content", Text: "Subscription check OK"})
}
} else {
if isCloudflareChallengeResponse(subResp.StatusCode, subBody) {
if isCloudflareChallengeResponse(subResp.StatusCode, subResp.Header, subBody) {
recorder.addStep("subscription", "failed", subResp.StatusCode, "cf_challenge", "Cloudflare challenge detected")
s.logSoraCloudflareChallenge(account, proxyURL, soraBillingAPIURL, subResp.Header, subBody)
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage("Subscription check blocked by Cloudflare challenge (HTTP 403)", subResp.Header, subBody)})
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Subscription check blocked by Cloudflare challenge (HTTP %d)", subResp.StatusCode), subResp.Header, subBody)})
} else {
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(subBody)
recorder.addStep("subscription", "failed", subResp.StatusCode, upstreamCode, upstreamMessage)
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check returned %d", subResp.StatusCode)})
}
}
......@@ -579,8 +744,9 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
}
// 追加 Sora2 能力探测(对齐 sora2api 的测试思路):邀请码 + 剩余额度。
s.testSora2Capabilities(c, ctx, account, authToken, proxyURL, enableSoraTLSFingerprint)
s.testSora2Capabilities(c, ctx, account, authToken, proxyURL, enableSoraTLSFingerprint, recorder)
s.emitSoraProbeSummary(c, recorder)
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
......@@ -592,6 +758,7 @@ func (s *AccountTestService) testSora2Capabilities(
authToken string,
proxyURL string,
enableTLSFingerprint bool,
recorder *soraProbeRecorder,
) {
inviteStatus, inviteHeader, inviteBody, err := s.fetchSoraTestEndpoint(
ctx,
......@@ -602,6 +769,9 @@ func (s *AccountTestService) testSora2Capabilities(
enableTLSFingerprint,
)
if err != nil {
if recorder != nil {
recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error())
}
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check skipped: %s", err.Error())})
return
}
......@@ -616,6 +786,9 @@ func (s *AccountTestService) testSora2Capabilities(
enableTLSFingerprint,
)
if bootstrapErr == nil && bootstrapStatus == http.StatusOK {
if recorder != nil {
recorder.addStep("sora2_bootstrap", "success", bootstrapStatus, "", "bootstrap endpoint ok")
}
s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 bootstrap OK, retry invite check"})
inviteStatus, inviteHeader, inviteBody, err = s.fetchSoraTestEndpoint(
ctx,
......@@ -626,20 +799,42 @@ func (s *AccountTestService) testSora2Capabilities(
enableTLSFingerprint,
)
if err != nil {
if recorder != nil {
recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error())
}
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite retry failed: %s", err.Error())})
return
}
} else if recorder != nil {
code := ""
msg := ""
if bootstrapErr != nil {
code = "network_error"
msg = bootstrapErr.Error()
}
recorder.addStep("sora2_bootstrap", "failed", bootstrapStatus, code, msg)
}
}
if inviteStatus != http.StatusOK {
if isCloudflareChallengeResponse(inviteStatus, inviteBody) {
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage("Sora2 invite check blocked by Cloudflare challenge (HTTP 403)", inviteHeader, inviteBody)})
if isCloudflareChallengeResponse(inviteStatus, inviteHeader, inviteBody) {
if recorder != nil {
recorder.addStep("sora2_invite", "failed", inviteStatus, "cf_challenge", "Cloudflare challenge detected")
}
s.logSoraCloudflareChallenge(account, proxyURL, soraInviteMineURL, inviteHeader, inviteBody)
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 invite check blocked by Cloudflare challenge (HTTP %d)", inviteStatus), inviteHeader, inviteBody)})
return
}
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(inviteBody)
if recorder != nil {
recorder.addStep("sora2_invite", "failed", inviteStatus, upstreamCode, upstreamMessage)
}
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check returned %d", inviteStatus)})
return
}
if recorder != nil {
recorder.addStep("sora2_invite", "success", inviteStatus, "", "invite endpoint ok")
}
if summary := parseSoraInviteSummary(inviteBody); summary != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: summary})
......@@ -656,17 +851,31 @@ func (s *AccountTestService) testSora2Capabilities(
enableTLSFingerprint,
)
if remainingErr != nil {
if recorder != nil {
recorder.addStep("sora2_remaining", "failed", 0, "network_error", remainingErr.Error())
}
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check skipped: %s", remainingErr.Error())})
return
}
if remainingStatus != http.StatusOK {
if isCloudflareChallengeResponse(remainingStatus, remainingBody) {
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage("Sora2 remaining check blocked by Cloudflare challenge (HTTP 403)", remainingHeader, remainingBody)})
if isCloudflareChallengeResponse(remainingStatus, remainingHeader, remainingBody) {
if recorder != nil {
recorder.addStep("sora2_remaining", "failed", remainingStatus, "cf_challenge", "Cloudflare challenge detected")
}
s.logSoraCloudflareChallenge(account, proxyURL, soraRemainingURL, remainingHeader, remainingBody)
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 remaining check blocked by Cloudflare challenge (HTTP %d)", remainingStatus), remainingHeader, remainingBody)})
return
}
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(remainingBody)
if recorder != nil {
recorder.addStep("sora2_remaining", "failed", remainingStatus, upstreamCode, upstreamMessage)
}
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check returned %d", remainingStatus)})
return
}
if recorder != nil {
recorder.addStep("sora2_remaining", "success", remainingStatus, "", "remaining endpoint ok")
}
if summary := parseSoraRemainingSummary(remainingBody); summary != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: summary})
} else {
......@@ -789,42 +998,16 @@ func (s *AccountTestService) shouldEnableSoraTLSFingerprint() bool {
return !s.cfg.Sora.Client.DisableTLSFingerprint
}
func isCloudflareChallengeResponse(statusCode int, body []byte) bool {
if statusCode != http.StatusForbidden {
return false
}
preview := strings.ToLower(truncateSoraErrorBody(body, 4096))
return strings.Contains(preview, "window._cf_chl_opt") ||
strings.Contains(preview, "just a moment") ||
strings.Contains(preview, "enable javascript and cookies to continue")
func isCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body)
}
func formatCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
rayID := extractCloudflareRayID(headers, body)
if rayID == "" {
return base
}
return fmt.Sprintf("%s (cf-ray: %s)", base, rayID)
return soraerror.FormatCloudflareChallengeMessage(base, headers, body)
}
func extractCloudflareRayID(headers http.Header, body []byte) string {
if headers != nil {
rayID := strings.TrimSpace(headers.Get("cf-ray"))
if rayID != "" {
return rayID
}
rayID = strings.TrimSpace(headers.Get("Cf-Ray"))
if rayID != "" {
return rayID
}
}
preview := truncateSoraErrorBody(body, 8192)
matches := cloudflareRayPattern.FindStringSubmatch(preview)
if len(matches) >= 2 {
return strings.TrimSpace(matches[1])
}
return ""
return soraerror.ExtractCloudflareRayID(headers, body)
}
func extractSoraEgressIPHint(headers http.Header) string {
......@@ -897,14 +1080,7 @@ func (s *AccountTestService) logSoraCloudflareChallenge(account *Account, proxyU
}
func truncateSoraErrorBody(body []byte, max int) string {
if max <= 0 {
max = 512
}
raw := strings.TrimSpace(string(body))
if len(raw) <= max {
return raw
}
return raw[:max] + "...(truncated)"
return soraerror.TruncateBody(body, max)
}
// testAntigravityAccountConnection tests an Antigravity account's connection
......
......@@ -7,6 +7,7 @@ import (
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
......@@ -109,6 +110,8 @@ func TestAccountTestService_testSoraAccountConnection_WithSubscription(t *testin
require.Contains(t, body, "Subscription: ChatGPT Plus | chatgpt_plus | end=2026-12-31T00:00:00Z")
require.Contains(t, body, "Sora2: supported | invite=inv_abc | used=3/50")
require.Contains(t, body, "Sora2 remaining: 27 | reset_in=46833s")
require.Contains(t, body, `"type":"sora_test_result"`)
require.Contains(t, body, `"status":"success"`)
require.Contains(t, body, `"type":"test_complete","success":true`)
}
......@@ -141,6 +144,8 @@ func TestAccountTestService_testSoraAccountConnection_SubscriptionFailedStillSuc
require.Contains(t, body, "Sora connection OK - User: demo-user")
require.Contains(t, body, "Subscription check returned 403")
require.Contains(t, body, "Sora2 invite check returned 401")
require.Contains(t, body, `"type":"sora_test_result"`)
require.Contains(t, body, `"status":"partial_success"`)
require.Contains(t, body, `"type":"test_complete","success":true`)
}
......@@ -173,6 +178,97 @@ func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge(t *tes
require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
}
func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge429WithHeader(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponseWithHeader(http.StatusTooManyRequests, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body></body></html>`, "cf-mitigated", "challenge"),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.Error(t, err)
require.Contains(t, err.Error(), "Cloudflare challenge")
require.Contains(t, err.Error(), "HTTP 429")
body := rec.Body.String()
require.Contains(t, body, "Cloudflare challenge")
}
func TestAccountTestService_testSoraAccountConnection_TokenInvalidated(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusUnauthorized, `{"error":{"code":"token_invalidated","message":"Token invalid"}}`),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.Error(t, err)
require.Contains(t, err.Error(), "token_invalidated")
body := rec.Body.String()
require.Contains(t, body, `"type":"sora_test_result"`)
require.Contains(t, body, `"status":"failed"`)
require.Contains(t, body, "token_invalidated")
require.NotContains(t, body, `"type":"test_complete","success":true`)
}
func TestAccountTestService_testSoraAccountConnection_RateLimited(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`),
},
}
svc := &AccountTestService{
httpUpstream: upstream,
soraTestCooldown: time.Hour,
}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c1, _ := newSoraTestContext()
err := svc.testSoraAccountConnection(c1, account)
require.NoError(t, err)
c2, rec2 := newSoraTestContext()
err = svc.testSoraAccountConnection(c2, account)
require.Error(t, err)
require.Contains(t, err.Error(), "测试过于频繁")
body := rec2.Body.String()
require.Contains(t, body, `"type":"sora_test_result"`)
require.Contains(t, body, `"code":"test_rate_limited"`)
require.Contains(t, body, `"status":"failed"`)
require.NotContains(t, body, `"type":"test_complete","success":true`)
}
func TestAccountTestService_testSoraAccountConnection_SubscriptionCloudflareChallengeWithRay(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
......
......@@ -95,6 +95,16 @@ var soraDesktopUserAgents = []string{
"Mozilla/5.0 (Windows NT 11.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
}
var soraMobileUserAgents = []string{
"Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)",
"Sora/1.2026.007 (Android 14; SM-G998B; build 2600700)",
"Sora/1.2026.007 (Android 15; Pixel 8 Pro; build 2600700)",
"Sora/1.2026.007 (Android 14; Pixel 7; build 2600700)",
"Sora/1.2026.007 (Android 15; 2211133C; build 2600700)",
"Sora/1.2026.007 (Android 14; SM-S918B; build 2600700)",
"Sora/1.2026.007 (Android 15; OnePlus 12; build 2600700)",
}
var soraRand = rand.New(rand.NewSource(time.Now().UnixNano()))
var soraRandMu sync.Mutex
var soraPerfStart = time.Now()
......@@ -106,6 +116,17 @@ type SoraClient interface {
UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error)
CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error)
CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error)
CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error)
UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error)
GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error)
DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error)
UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error)
FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error)
SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error
DeleteCharacter(ctx context.Context, account *Account, characterID string) error
PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error)
DeletePost(ctx context.Context, account *Account, postID string) error
GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error)
EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error)
GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error)
GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error)
......@@ -128,6 +149,17 @@ type SoraVideoRequest struct {
Size string
MediaID string
RemixTargetID string
CameoIDs []string
}
// SoraStoryboardRequest 分镜视频生成请求参数
type SoraStoryboardRequest struct {
Prompt string
Orientation string
Frames int
Model string
Size string
MediaID string
}
// SoraImageTaskStatus 图片任务状态
......@@ -141,11 +173,32 @@ type SoraImageTaskStatus struct {
// SoraVideoTaskStatus 视频任务状态
type SoraVideoTaskStatus struct {
ID string
Status string
ProgressPct int
URLs []string
ErrorMsg string
ID string
Status string
ProgressPct int
URLs []string
GenerationID string
ErrorMsg string
}
// SoraCameoStatus 角色处理中间态
type SoraCameoStatus struct {
Status string
StatusMessage string
DisplayNameHint string
UsernameHint string
ProfileAssetURL string
InstructionSetHint any
InstructionSet any
}
// SoraCharacterFinalizeRequest 角色定稿请求参数
type SoraCharacterFinalizeRequest struct {
CameoID string
Username string
DisplayName string
ProfileAssetPointer string
InstructionSet any
}
// SoraUpstreamError 上游错误
......@@ -407,6 +460,9 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account
payload["remix_target_id"] = req.RemixTargetID
payload["cameo_ids"] = []string{}
payload["cameo_replacements"] = map[string]any{}
} else if len(req.CameoIDs) > 0 {
payload["cameo_ids"] = req.CameoIDs
payload["cameo_replacements"] = map[string]any{}
}
headers := c.buildBaseHeaders(token, userAgent)
......@@ -434,6 +490,425 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account
return taskID, nil
}
func (c *SoraDirectClient) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
orientation := req.Orientation
if orientation == "" {
orientation = "landscape"
}
nFrames := req.Frames
if nFrames <= 0 {
nFrames = 450
}
model := req.Model
if model == "" {
model = "sy_8"
}
size := req.Size
if size == "" {
size = "small"
}
inpaintItems := []map[string]any{}
if strings.TrimSpace(req.MediaID) != "" {
inpaintItems = append(inpaintItems, map[string]any{
"kind": "upload",
"upload_id": req.MediaID,
})
}
payload := map[string]any{
"kind": "video",
"prompt": req.Prompt,
"title": "Draft your video",
"orientation": orientation,
"size": size,
"n_frames": nFrames,
"storyboard_id": nil,
"inpaint_items": inpaintItems,
"remix_target_id": nil,
"model": model,
"metadata": nil,
"style_id": nil,
"cameo_ids": nil,
"cameo_replacements": nil,
"audio_caption": nil,
"audio_transcript": nil,
"video_caption": nil,
}
headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Content-Type", "application/json")
headers.Set("Origin", "https://sora.chatgpt.com")
headers.Set("Referer", "https://sora.chatgpt.com/")
body, err := json.Marshal(payload)
if err != nil {
return "", err
}
sentinel, err := c.generateSentinelToken(ctx, account, token, userAgent, proxyURL)
if err != nil {
return "", err
}
headers.Set("openai-sentinel-token", sentinel)
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/nf/create/storyboard"), headers, bytes.NewReader(body), true)
if err != nil {
return "", err
}
taskID := strings.TrimSpace(gjson.GetBytes(respBody, "id").String())
if taskID == "" {
return "", errors.New("storyboard task response missing id")
}
return taskID, nil
}
func (c *SoraDirectClient) UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) {
if len(data) == 0 {
return "", errors.New("empty video data")
}
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
var body bytes.Buffer
writer := multipart.NewWriter(&body)
partHeader := make(textproto.MIMEHeader)
partHeader.Set("Content-Disposition", `form-data; name="file"; filename="video.mp4"`)
partHeader.Set("Content-Type", "video/mp4")
part, err := writer.CreatePart(partHeader)
if err != nil {
return "", err
}
if _, err := part.Write(data); err != nil {
return "", err
}
if err := writer.WriteField("timestamps", "0,3"); err != nil {
return "", err
}
if err := writer.Close(); err != nil {
return "", err
}
headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Content-Type", writer.FormDataContentType())
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/characters/upload"), headers, &body, false)
if err != nil {
return "", err
}
cameoID := strings.TrimSpace(gjson.GetBytes(respBody, "id").String())
if cameoID == "" {
return "", errors.New("character upload response missing id")
}
return cameoID, nil
}
func (c *SoraDirectClient) GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return nil, err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
headers := c.buildBaseHeaders(token, userAgent)
respBody, _, err := c.doRequestWithProxy(
ctx,
account,
proxyURL,
http.MethodGet,
c.buildURL("/project_y/cameos/in_progress/"+strings.TrimSpace(cameoID)),
headers,
nil,
false,
)
if err != nil {
return nil, err
}
return &SoraCameoStatus{
Status: strings.TrimSpace(gjson.GetBytes(respBody, "status").String()),
StatusMessage: strings.TrimSpace(gjson.GetBytes(respBody, "status_message").String()),
DisplayNameHint: strings.TrimSpace(gjson.GetBytes(respBody, "display_name_hint").String()),
UsernameHint: strings.TrimSpace(gjson.GetBytes(respBody, "username_hint").String()),
ProfileAssetURL: strings.TrimSpace(gjson.GetBytes(respBody, "profile_asset_url").String()),
InstructionSetHint: gjson.GetBytes(respBody, "instruction_set_hint").Value(),
InstructionSet: gjson.GetBytes(respBody, "instruction_set").Value(),
}, nil
}
func (c *SoraDirectClient) DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return nil, err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Accept", "image/*,*/*;q=0.8")
respBody, _, err := c.doRequestWithProxy(
ctx,
account,
proxyURL,
http.MethodGet,
strings.TrimSpace(imageURL),
headers,
nil,
false,
)
if err != nil {
return nil, err
}
return respBody, nil
}
func (c *SoraDirectClient) UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) {
if len(data) == 0 {
return "", errors.New("empty character image")
}
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
var body bytes.Buffer
writer := multipart.NewWriter(&body)
partHeader := make(textproto.MIMEHeader)
partHeader.Set("Content-Disposition", `form-data; name="file"; filename="profile.webp"`)
partHeader.Set("Content-Type", "image/webp")
part, err := writer.CreatePart(partHeader)
if err != nil {
return "", err
}
if _, err := part.Write(data); err != nil {
return "", err
}
if err := writer.WriteField("use_case", "profile"); err != nil {
return "", err
}
if err := writer.Close(); err != nil {
return "", err
}
headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Content-Type", writer.FormDataContentType())
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/project_y/file/upload"), headers, &body, false)
if err != nil {
return "", err
}
assetPointer := strings.TrimSpace(gjson.GetBytes(respBody, "asset_pointer").String())
if assetPointer == "" {
return "", errors.New("character image upload response missing asset_pointer")
}
return assetPointer, nil
}
func (c *SoraDirectClient) FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
payload := map[string]any{
"cameo_id": req.CameoID,
"username": req.Username,
"display_name": req.DisplayName,
"profile_asset_pointer": req.ProfileAssetPointer,
"instruction_set": nil,
"safety_instruction_set": nil,
}
body, err := json.Marshal(payload)
if err != nil {
return "", err
}
headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Content-Type", "application/json")
headers.Set("Origin", "https://sora.chatgpt.com")
headers.Set("Referer", "https://sora.chatgpt.com/")
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/characters/finalize"), headers, bytes.NewReader(body), false)
if err != nil {
return "", err
}
characterID := strings.TrimSpace(gjson.GetBytes(respBody, "character.character_id").String())
if characterID == "" {
return "", errors.New("character finalize response missing character_id")
}
return characterID, nil
}
func (c *SoraDirectClient) SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
payload := map[string]any{"visibility": "public"}
body, err := json.Marshal(payload)
if err != nil {
return err
}
headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Content-Type", "application/json")
headers.Set("Origin", "https://sora.chatgpt.com")
headers.Set("Referer", "https://sora.chatgpt.com/")
_, _, err = c.doRequestWithProxy(
ctx,
account,
proxyURL,
http.MethodPost,
c.buildURL("/project_y/cameos/by_id/"+strings.TrimSpace(cameoID)+"/update_v2"),
headers,
bytes.NewReader(body),
false,
)
return err
}
func (c *SoraDirectClient) DeleteCharacter(ctx context.Context, account *Account, characterID string) error {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
headers := c.buildBaseHeaders(token, userAgent)
_, _, err = c.doRequestWithProxy(
ctx,
account,
proxyURL,
http.MethodDelete,
c.buildURL("/project_y/characters/"+strings.TrimSpace(characterID)),
headers,
nil,
false,
)
return err
}
func (c *SoraDirectClient) PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
payload := map[string]any{
"attachments_to_create": []map[string]any{
{
"generation_id": generationID,
"kind": "sora",
},
},
"post_text": "",
}
body, err := json.Marshal(payload)
if err != nil {
return "", err
}
headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Content-Type", "application/json")
headers.Set("Origin", "https://sora.chatgpt.com")
headers.Set("Referer", "https://sora.chatgpt.com/")
sentinel, err := c.generateSentinelToken(ctx, account, token, userAgent, proxyURL)
if err != nil {
return "", err
}
headers.Set("openai-sentinel-token", sentinel)
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/project_y/post"), headers, bytes.NewReader(body), true)
if err != nil {
return "", err
}
postID := strings.TrimSpace(gjson.GetBytes(respBody, "post.id").String())
if postID == "" {
return "", errors.New("watermark-free publish response missing post.id")
}
return postID, nil
}
func (c *SoraDirectClient) DeletePost(ctx context.Context, account *Account, postID string) error {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
headers := c.buildBaseHeaders(token, userAgent)
_, _, err = c.doRequestWithProxy(
ctx,
account,
proxyURL,
http.MethodDelete,
c.buildURL("/project_y/post/"+strings.TrimSpace(postID)),
headers,
nil,
false,
)
return err
}
func (c *SoraDirectClient) GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) {
parseURL = strings.TrimRight(strings.TrimSpace(parseURL), "/")
if parseURL == "" {
return "", errors.New("custom parse url is required")
}
if strings.TrimSpace(parseToken) == "" {
return "", errors.New("custom parse token is required")
}
shareURL := "https://sora.chatgpt.com/p/" + strings.TrimSpace(postID)
payload := map[string]any{
"url": shareURL,
"token": strings.TrimSpace(parseToken),
}
body, err := json.Marshal(payload)
if err != nil {
return "", err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, parseURL+"/get-sora-link", bytes.NewReader(body))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
proxyURL := c.resolveProxyURL(account)
accountID := int64(0)
accountConcurrency := 0
if account != nil {
accountID = account.ID
accountConcurrency = account.Concurrency
}
var resp *http.Response
if c.httpUpstream != nil {
resp, err = c.httpUpstream.Do(req, proxyURL, accountID, accountConcurrency)
} else {
resp, err = http.DefaultClient.Do(req)
}
if err != nil {
return "", err
}
defer func() { _ = resp.Body.Close() }()
raw, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20))
if err != nil {
return "", err
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("custom parse failed: %d %s", resp.StatusCode, truncateForLog(raw, 256))
}
downloadLink := strings.TrimSpace(gjson.GetBytes(raw, "download_link").String())
if downloadLink == "" {
return "", errors.New("custom parse response missing download_link")
}
return downloadLink, nil
}
func (c *SoraDirectClient) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
......@@ -607,6 +1082,7 @@ func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, t
if draft.Get("task_id").String() != taskID {
return true
}
generationID := strings.TrimSpace(draft.Get("id").String())
kind := strings.TrimSpace(draft.Get("kind").String())
reason := strings.TrimSpace(draft.Get("reason_str").String())
if reason == "" {
......@@ -623,15 +1099,17 @@ func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, t
msg = "Content violates guardrails"
}
draftFound = &SoraVideoTaskStatus{
ID: taskID,
Status: "failed",
ErrorMsg: msg,
ID: taskID,
Status: "failed",
GenerationID: generationID,
ErrorMsg: msg,
}
} else {
draftFound = &SoraVideoTaskStatus{
ID: taskID,
Status: "completed",
URLs: []string{urlStr},
ID: taskID,
Status: "completed",
GenerationID: generationID,
URLs: []string{urlStr},
}
}
return false
......@@ -675,8 +1153,11 @@ func (c *SoraDirectClient) taskUserAgent() string {
return ua
}
}
if len(soraMobileUserAgents) > 0 {
return soraMobileUserAgents[soraRandInt(len(soraMobileUserAgents))]
}
if len(soraDesktopUserAgents) > 0 {
return soraDesktopUserAgents[0]
return soraDesktopUserAgents[soraRandInt(len(soraDesktopUserAgents))]
}
return soraDefaultUserAgent
}
......@@ -1149,10 +1630,7 @@ func shouldAttemptSoraTokenRecover(statusCode int, rawURL string) bool {
}
func (c *SoraDirectClient) doHTTP(req *http.Request, proxyURL string, account *Account) (*http.Response, error) {
enableTLS := true
if c != nil && c.cfg != nil && c.cfg.Sora.Client.DisableTLSFingerprint {
enableTLS = false
}
enableTLS := c == nil || c.cfg == nil || !c.cfg.Sora.Client.DisableTLSFingerprint
if c.httpUpstream != nil {
accountID := int64(0)
accountConcurrency := 0
......@@ -1288,6 +1766,15 @@ func soraRandFloat() float64 {
return soraRand.Float64()
}
func soraRandInt(max int) int {
if max <= 1 {
return 0
}
soraRandMu.Lock()
defer soraRandMu.Unlock()
return soraRand.Intn(max)
}
func soraBuildPowConfig(userAgent string) []any {
userAgent = strings.TrimSpace(userAgent)
if userAgent == "" && len(soraDesktopUserAgents) > 0 {
......
......@@ -393,10 +393,22 @@ func (u *soraClientRecordingUpstream) DoWithTLS(req *http.Request, proxyURL stri
return newSoraClientMockResponse(http.StatusOK, `{"token":"sentinel-token","turnstile":{"dx":"ok"}}`), nil
case "/backend/nf/create":
return newSoraClientMockResponse(http.StatusOK, `{"id":"task-123"}`), nil
case "/backend/nf/create/storyboard":
return newSoraClientMockResponse(http.StatusOK, `{"id":"storyboard-123"}`), nil
case "/backend/uploads":
return newSoraClientMockResponse(http.StatusOK, `{"id":"upload-123"}`), nil
case "/backend/nf/check":
return newSoraClientMockResponse(http.StatusOK, `{"rate_limit_and_credit_balance":{"estimated_num_videos_remaining":1,"rate_limit_reached":false}}`), nil
case "/backend/characters/upload":
return newSoraClientMockResponse(http.StatusOK, `{"id":"cameo-123"}`), nil
case "/backend/project_y/cameos/in_progress/cameo-123":
return newSoraClientMockResponse(http.StatusOK, `{"status":"finalized","status_message":"Completed","username_hint":"foo.bar","display_name_hint":"Bar","profile_asset_url":"https://example.com/avatar.webp"}`), nil
case "/backend/project_y/file/upload":
return newSoraClientMockResponse(http.StatusOK, `{"asset_pointer":"asset-123"}`), nil
case "/backend/characters/finalize":
return newSoraClientMockResponse(http.StatusOK, `{"character":{"character_id":"character-123"}}`), nil
case "/backend/project_y/post":
return newSoraClientMockResponse(http.StatusOK, `{"post":{"id":"s_post"}}`), nil
default:
return newSoraClientMockResponse(http.StatusOK, `{"ok":true}`), nil
}
......@@ -410,9 +422,13 @@ func newSoraClientMockResponse(statusCode int, body string) *http.Response {
}
}
func TestSoraDirectClient_TaskUserAgent_DefaultDesktopFallback(t *testing.T) {
func TestSoraDirectClient_TaskUserAgent_DefaultMobileFallback(t *testing.T) {
client := NewSoraDirectClient(&config.Config{}, nil, nil)
require.Equal(t, soraDesktopUserAgents[0], client.taskUserAgent())
ua := client.taskUserAgent()
require.NotEmpty(t, ua)
allowed := append([]string{}, soraMobileUserAgents...)
allowed = append(allowed, soraDesktopUserAgents...)
require.Contains(t, allowed, ua)
}
func TestSoraDirectClient_CreateVideoTask_UsesSameUserAgentAndProxyForSentinelAndCreate(t *testing.T) {
......@@ -460,7 +476,7 @@ func TestSoraDirectClient_CreateVideoTask_UsesSameUserAgentAndProxyForSentinelAn
require.Equal(t, "/backend/nf/create", createCall.Path)
require.Equal(t, "http://127.0.0.1:8080", sentinelCall.ProxyURL)
require.Equal(t, sentinelCall.ProxyURL, createCall.ProxyURL)
require.Equal(t, soraDesktopUserAgents[0], sentinelCall.UserAgent)
require.NotEmpty(t, sentinelCall.UserAgent)
require.Equal(t, sentinelCall.UserAgent, createCall.UserAgent)
}
......@@ -495,7 +511,7 @@ func TestSoraDirectClient_UploadImage_UsesTaskUserAgentAndProxy(t *testing.T) {
require.Len(t, upstream.calls, 1)
require.Equal(t, "/backend/uploads", upstream.calls[0].Path)
require.Equal(t, "http://127.0.0.1:8080", upstream.calls[0].ProxyURL)
require.Equal(t, soraDesktopUserAgents[0], upstream.calls[0].UserAgent)
require.NotEmpty(t, upstream.calls[0].UserAgent)
}
func TestSoraDirectClient_PreflightCheck_UsesTaskUserAgentAndProxy(t *testing.T) {
......@@ -528,5 +544,98 @@ func TestSoraDirectClient_PreflightCheck_UsesTaskUserAgentAndProxy(t *testing.T)
require.Len(t, upstream.calls, 1)
require.Equal(t, "/backend/nf/check", upstream.calls[0].Path)
require.Equal(t, "http://127.0.0.1:8080", upstream.calls[0].ProxyURL)
require.Equal(t, soraDesktopUserAgents[0], upstream.calls[0].UserAgent)
require.NotEmpty(t, upstream.calls[0].UserAgent)
}
func TestSoraDirectClient_CreateStoryboardTask(t *testing.T) {
originPowTokenGenerator := soraPowTokenGenerator
soraPowTokenGenerator = func(_ string) string { return "gAAAAACmock" }
defer func() { soraPowTokenGenerator = originPowTokenGenerator }()
upstream := &soraClientRecordingUpstream{}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com/backend",
},
},
}
client := NewSoraDirectClient(cfg, upstream, nil)
account := &Account{
ID: 51,
Credentials: map[string]any{
"access_token": "access-token",
"expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
},
}
taskID, err := client.CreateStoryboardTask(context.Background(), account, SoraStoryboardRequest{
Prompt: "Shot 1:\nduration: 5sec\nScene: cat",
})
require.NoError(t, err)
require.Equal(t, "storyboard-123", taskID)
require.Len(t, upstream.calls, 2)
require.Equal(t, "/backend-api/sentinel/req", upstream.calls[0].Path)
require.Equal(t, "/backend/nf/create/storyboard", upstream.calls[1].Path)
}
func TestSoraDirectClient_GetVideoTask_ReturnsGenerationID(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch r.URL.Path {
case "/nf/pending/v2":
_, _ = w.Write([]byte(`[]`))
case "/project_y/profile/drafts":
_, _ = w.Write([]byte(`{"items":[{"id":"gen_1","task_id":"task-1","kind":"video","downloadable_url":"https://example.com/v.mp4"}]}`))
default:
http.NotFound(w, r)
}
}))
defer server.Close()
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: server.URL,
},
},
}
client := NewSoraDirectClient(cfg, nil, nil)
account := &Account{Credentials: map[string]any{"access_token": "token"}}
status, err := client.GetVideoTask(context.Background(), account, "task-1")
require.NoError(t, err)
require.Equal(t, "completed", status.Status)
require.Equal(t, "gen_1", status.GenerationID)
require.Equal(t, []string{"https://example.com/v.mp4"}, status.URLs)
}
func TestSoraDirectClient_PostVideoForWatermarkFree(t *testing.T) {
originPowTokenGenerator := soraPowTokenGenerator
soraPowTokenGenerator = func(_ string) string { return "gAAAAACmock" }
defer func() { soraPowTokenGenerator = originPowTokenGenerator }()
upstream := &soraClientRecordingUpstream{}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com/backend",
},
},
}
client := NewSoraDirectClient(cfg, upstream, nil)
account := &Account{
ID: 52,
Credentials: map[string]any{
"access_token": "access-token",
"expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
},
}
postID, err := client.PostVideoForWatermarkFree(context.Background(), account, "gen_1")
require.NoError(t, err)
require.Equal(t, "s_post", postID)
require.Len(t, upstream.calls, 2)
require.Equal(t, "/backend-api/sentinel/req", upstream.calls[0].Path)
require.Equal(t, "/backend/project_y/post", upstream.calls[1].Path)
}
......@@ -5,6 +5,7 @@ package service
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
......@@ -25,6 +26,11 @@ type stubSoraClientForPoll struct {
videoCalls int
enhanced string
enhanceErr error
storyboard bool
videoReq SoraVideoRequest
parseErr error
postCalls int
deleteCalls int
}
func (s *stubSoraClientForPoll) Enabled() bool { return true }
......@@ -35,8 +41,54 @@ func (s *stubSoraClientForPoll) CreateImageTask(ctx context.Context, account *Ac
return "task-image", nil
}
func (s *stubSoraClientForPoll) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) {
s.videoReq = req
return "task-video", nil
}
func (s *stubSoraClientForPoll) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) {
s.storyboard = true
return "task-video", nil
}
func (s *stubSoraClientForPoll) UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) {
return "cameo-1", nil
}
func (s *stubSoraClientForPoll) GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) {
return &SoraCameoStatus{
Status: "finalized",
StatusMessage: "Completed",
DisplayNameHint: "Character",
UsernameHint: "user.character",
ProfileAssetURL: "https://example.com/avatar.webp",
}, nil
}
func (s *stubSoraClientForPoll) DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) {
return []byte("avatar"), nil
}
func (s *stubSoraClientForPoll) UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) {
return "asset-pointer", nil
}
func (s *stubSoraClientForPoll) FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) {
return "character-1", nil
}
func (s *stubSoraClientForPoll) SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error {
return nil
}
func (s *stubSoraClientForPoll) DeleteCharacter(ctx context.Context, account *Account, characterID string) error {
return nil
}
func (s *stubSoraClientForPoll) PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) {
s.postCalls++
return "s_post", nil
}
func (s *stubSoraClientForPoll) DeletePost(ctx context.Context, account *Account, postID string) error {
s.deleteCalls++
return nil
}
func (s *stubSoraClientForPoll) GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) {
if s.parseErr != nil {
return "", s.parseErr
}
return "https://example.com/no-watermark.mp4", nil
}
func (s *stubSoraClientForPoll) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) {
if s.enhanced != "" {
return s.enhanced, s.enhanceErr
......@@ -102,6 +154,109 @@ func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) {
require.Equal(t, "prompt-enhance-short-10s", result.Model)
}
func TestSoraGatewayService_ForwardStoryboardPrompt(t *testing.T) {
client := &stubSoraClientForPoll{
videoStatus: &SoraVideoTaskStatus{
Status: "completed",
URLs: []string{"https://example.com/v.mp4"},
},
}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
svc := NewSoraGatewayService(client, nil, nil, cfg)
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"[5.0s]猫猫跳伞 [5.0s]猫猫落地"}],"stream":false}`)
result, err := svc.Forward(context.Background(), nil, account, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, client.storyboard)
}
func TestSoraGatewayService_ForwardCharacterOnly(t *testing.T) {
client := &stubSoraClientForPoll{}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
svc := NewSoraGatewayService(client, nil, nil, cfg)
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
body := []byte(`{"model":"sora2-landscape-10s","video":"aGVsbG8=","stream":false}`)
result, err := svc.Forward(context.Background(), nil, account, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "prompt", result.MediaType)
require.Equal(t, 0, client.videoCalls)
}
func TestSoraGatewayService_ForwardWatermarkFallback(t *testing.T) {
client := &stubSoraClientForPoll{
videoStatus: &SoraVideoTaskStatus{
Status: "completed",
URLs: []string{"https://example.com/original.mp4"},
GenerationID: "gen_1",
},
parseErr: errors.New("parse failed"),
}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
svc := NewSoraGatewayService(client, nil, nil, cfg)
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_fallback_on_failure":true}`)
result, err := svc.Forward(context.Background(), nil, account, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "https://example.com/original.mp4", result.MediaURL)
require.Equal(t, 1, client.postCalls)
require.Equal(t, 0, client.deleteCalls)
}
func TestSoraGatewayService_ForwardWatermarkCustomSuccessAndDelete(t *testing.T) {
client := &stubSoraClientForPoll{
videoStatus: &SoraVideoTaskStatus{
Status: "completed",
URLs: []string{"https://example.com/original.mp4"},
GenerationID: "gen_1",
},
}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
svc := NewSoraGatewayService(client, nil, nil, cfg)
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_delete_post":true}`)
result, err := svc.Forward(context.Background(), nil, account, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "https://example.com/no-watermark.mp4", result.MediaURL)
require.Equal(t, 1, client.postCalls)
require.Equal(t, 1, client.deleteCalls)
}
func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) {
client := &stubSoraClientForPoll{
videoStatus: &SoraVideoTaskStatus{
......@@ -119,9 +274,9 @@ func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) {
}
service := NewSoraGatewayService(client, nil, nil, cfg)
urls, err := service.pollVideoTask(context.Background(), nil, &Account{ID: 1}, "task", false)
status, err := service.pollVideoTaskDetailed(context.Background(), nil, &Account{ID: 1}, "task", false)
require.Error(t, err)
require.Empty(t, urls)
require.Nil(t, status)
require.Contains(t, err.Error(), "reject")
require.Equal(t, 1, client.videoCalls)
}
......@@ -325,3 +480,19 @@ func TestDecodeSoraImageInput_DataURL(t *testing.T) {
require.NotEmpty(t, data)
require.Contains(t, filename, ".png")
}
func TestDecodeBase64WithLimit_ExceedLimit(t *testing.T) {
data, err := decodeBase64WithLimit("aGVsbG8=", 3)
require.Error(t, err)
require.Nil(t, data)
}
func TestParseSoraWatermarkOptions_NumericBool(t *testing.T) {
body := map[string]any{
"watermark_free": float64(1),
"watermark_fallback_on_failure": float64(0),
}
opts := parseSoraWatermarkOptions(body)
require.True(t, opts.Enabled)
require.False(t, opts.FallbackOnFailure)
}
package soraerror
import (
"encoding/json"
"fmt"
"net/http"
"regexp"
"strings"
)
var (
cfRayPattern = regexp.MustCompile(`(?i)cf-ray[:\s=]+([a-z0-9-]+)`)
cRayPattern = regexp.MustCompile(`(?i)cRay:\s*'([a-z0-9-]+)'`)
htmlChallenge = []string{
"window._cf_chl_opt",
"just a moment",
"enable javascript and cookies to continue",
"__cf_chl_",
"challenge-platform",
}
)
// IsCloudflareChallengeResponse reports whether the upstream response matches Cloudflare challenge behavior.
func IsCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
if statusCode != http.StatusForbidden && statusCode != http.StatusTooManyRequests {
return false
}
if headers != nil && strings.EqualFold(strings.TrimSpace(headers.Get("cf-mitigated")), "challenge") {
return true
}
preview := strings.ToLower(TruncateBody(body, 4096))
for _, marker := range htmlChallenge {
if strings.Contains(preview, marker) {
return true
}
}
contentType := ""
if headers != nil {
contentType = strings.ToLower(strings.TrimSpace(headers.Get("content-type")))
}
if strings.Contains(contentType, "text/html") &&
(strings.Contains(preview, "<html") || strings.Contains(preview, "<!doctype html")) &&
(strings.Contains(preview, "cloudflare") || strings.Contains(preview, "challenge")) {
return true
}
return false
}
// ExtractCloudflareRayID extracts cf-ray from headers or response body.
func ExtractCloudflareRayID(headers http.Header, body []byte) string {
if headers != nil {
rayID := strings.TrimSpace(headers.Get("cf-ray"))
if rayID != "" {
return rayID
}
rayID = strings.TrimSpace(headers.Get("Cf-Ray"))
if rayID != "" {
return rayID
}
}
preview := TruncateBody(body, 8192)
if matches := cfRayPattern.FindStringSubmatch(preview); len(matches) >= 2 {
return strings.TrimSpace(matches[1])
}
if matches := cRayPattern.FindStringSubmatch(preview); len(matches) >= 2 {
return strings.TrimSpace(matches[1])
}
return ""
}
// FormatCloudflareChallengeMessage appends cf-ray info when available.
func FormatCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
rayID := ExtractCloudflareRayID(headers, body)
if rayID == "" {
return base
}
return fmt.Sprintf("%s (cf-ray: %s)", base, rayID)
}
// ExtractUpstreamErrorCodeAndMessage extracts structured error code/message from common JSON layouts.
func ExtractUpstreamErrorCodeAndMessage(body []byte) (string, string) {
trimmed := strings.TrimSpace(string(body))
if trimmed == "" {
return "", ""
}
if !json.Valid([]byte(trimmed)) {
return "", truncateMessage(trimmed, 256)
}
var payload map[string]any
if err := json.Unmarshal([]byte(trimmed), &payload); err != nil {
return "", truncateMessage(trimmed, 256)
}
code := firstNonEmpty(
extractNestedString(payload, "error", "code"),
extractRootString(payload, "code"),
)
message := firstNonEmpty(
extractNestedString(payload, "error", "message"),
extractRootString(payload, "message"),
extractNestedString(payload, "error", "detail"),
extractRootString(payload, "detail"),
)
return strings.TrimSpace(code), truncateMessage(strings.TrimSpace(message), 512)
}
// TruncateBody truncates body text for logging/inspection.
func TruncateBody(body []byte, max int) string {
if max <= 0 {
max = 512
}
raw := strings.TrimSpace(string(body))
if len(raw) <= max {
return raw
}
return raw[:max] + "...(truncated)"
}
func truncateMessage(s string, max int) string {
if max <= 0 {
return ""
}
if len(s) <= max {
return s
}
return s[:max] + "...(truncated)"
}
func firstNonEmpty(values ...string) string {
for _, v := range values {
if strings.TrimSpace(v) != "" {
return v
}
}
return ""
}
func extractRootString(m map[string]any, key string) string {
if m == nil {
return ""
}
v, ok := m[key]
if !ok {
return ""
}
s, _ := v.(string)
return s
}
func extractNestedString(m map[string]any, parent, key string) string {
if m == nil {
return ""
}
node, ok := m[parent]
if !ok {
return ""
}
child, ok := node.(map[string]any)
if !ok {
return ""
}
s, _ := child[key].(string)
return s
}
package soraerror
import (
"net/http"
"testing"
"github.com/stretchr/testify/require"
)
func TestIsCloudflareChallengeResponse(t *testing.T) {
headers := make(http.Header)
headers.Set("cf-mitigated", "challenge")
require.True(t, IsCloudflareChallengeResponse(http.StatusForbidden, headers, []byte(`{"ok":false}`)))
require.True(t, IsCloudflareChallengeResponse(http.StatusTooManyRequests, nil, []byte(`<!DOCTYPE html><title>Just a moment...</title><script>window._cf_chl_opt={};</script>`)))
require.False(t, IsCloudflareChallengeResponse(http.StatusBadGateway, nil, []byte(`<!DOCTYPE html><title>Just a moment...</title>`)))
}
func TestExtractCloudflareRayID(t *testing.T) {
headers := make(http.Header)
headers.Set("cf-ray", "9d01b0e9ecc35829-SEA")
require.Equal(t, "9d01b0e9ecc35829-SEA", ExtractCloudflareRayID(headers, nil))
body := []byte(`<script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script>`)
require.Equal(t, "9cff2d62d83bb98d", ExtractCloudflareRayID(nil, body))
}
func TestExtractUpstreamErrorCodeAndMessage(t *testing.T) {
code, msg := ExtractUpstreamErrorCodeAndMessage([]byte(`{"error":{"code":"cf_shield_429","message":"rate limited"}}`))
require.Equal(t, "cf_shield_429", code)
require.Equal(t, "rate limited", msg)
code, msg = ExtractUpstreamErrorCodeAndMessage([]byte(`{"code":"unsupported_country_code","message":"not available"}`))
require.Equal(t, "unsupported_country_code", code)
require.Equal(t, "not available", msg)
code, msg = ExtractUpstreamErrorCodeAndMessage([]byte(`plain text`))
require.Equal(t, "", code)
require.Equal(t, "plain text", msg)
}
func TestFormatCloudflareChallengeMessage(t *testing.T) {
headers := make(http.Header)
headers.Set("cf-ray", "9d03b68c086027a1-SEA")
msg := FormatCloudflareChallengeMessage("blocked", headers, nil)
require.Equal(t, "blocked (cf-ray: 9d03b68c086027a1-SEA)", msg)
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment