"vscode:/vscode.git/clone" did not exist on "67d028cf50e9f4e00e404ee0c7db330c8af0ac3b"
Commit 3c619a8d authored by huangenjun's avatar huangenjun
Browse files

refactor: 使用 go-sora2api SDK 替代自建 Sora 客户端



使用 go-sora2api v1.1.0 SDK 替代原有 ~2000 行自建 HTTP/PoW/TLS 指纹代码,
SDK 提供高并发性能优化(实例级 rand、PoW 缓冲区复用、context.Context 支持)。

- 新增 SoraSDKClient 适配器实现 SoraClient 接口
- 精简 sora_client.go 为仅保留接口和类型定义
- 更新 Wire 绑定使用 SoraSDKClient
- 删除 SoraDirectClient、sora_curl_cffi_sidecar、sora_request_guard 等旧代码
Co-Authored-By: default avatarClaude Opus 4.6 <noreply@anthropic.com>
parent ded9b6c1
......@@ -187,9 +187,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig)
soraDirectClient := service.ProvideSoraDirectClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository)
soraSDKClient := service.ProvideSoraSDKClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository)
soraMediaStorage := service.ProvideSoraMediaStorage(configConfig)
soraGatewayService := service.NewSoraGatewayService(soraDirectClient, soraMediaStorage, rateLimitService, configConfig)
soraGatewayService := service.NewSoraGatewayService(soraSDKClient, soraMediaStorage, rateLimitService, configConfig)
soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, usageRecordWorkerPool, configConfig)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
totpHandler := handler.NewTotpHandler(totpService)
......
......@@ -5,6 +5,7 @@ go 1.25.7
require (
entgo.io/ent v0.14.5
github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/DouDOU-start/go-sora2api v1.1.0
github.com/alitto/pond/v2 v2.6.2
github.com/cespare/xxhash/v2 v2.3.0
github.com/dgraph-io/ristretto v0.2.0
......@@ -29,10 +30,10 @@ require (
github.com/tidwall/sjson v1.2.5
github.com/zeromicro/go-zero v1.9.4
go.uber.org/zap v1.24.0
golang.org/x/crypto v0.47.0
golang.org/x/crypto v0.48.0
golang.org/x/net v0.49.0
golang.org/x/sync v0.19.0
golang.org/x/term v0.39.0
golang.org/x/term v0.40.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1
gopkg.in/yaml.v3 v3.0.1
modernc.org/sqlite v1.44.3
......@@ -46,7 +47,14 @@ require (
github.com/agext/levenshtein v1.2.3 // indirect
github.com/andybalholm/brotli v1.2.0 // indirect
github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect
github.com/bdandy/go-errors v1.2.2 // indirect
github.com/bdandy/go-socks4 v1.2.3 // indirect
github.com/bmatcuk/doublestar v1.3.4 // indirect
github.com/bogdanfinn/fhttp v0.6.8 // indirect
github.com/bogdanfinn/quic-go-utls v1.0.9-utls // indirect
github.com/bogdanfinn/tls-client v1.14.0 // indirect
github.com/bogdanfinn/utls v1.7.7-barnius // indirect
github.com/bogdanfinn/websocket v1.5.5-barnius // indirect
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
......@@ -123,6 +131,7 @@ require (
github.com/spf13/cast v1.6.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5 // indirect
github.com/testcontainers/testcontainers-go v0.40.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
......@@ -144,9 +153,9 @@ require (
go.uber.org/multierr v1.9.0 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
golang.org/x/mod v0.31.0 // indirect
golang.org/x/sys v0.40.0 // indirect
golang.org/x/text v0.33.0 // indirect
golang.org/x/mod v0.32.0 // indirect
golang.org/x/sys v0.41.0 // indirect
golang.org/x/text v0.34.0 // indirect
google.golang.org/grpc v1.75.1 // indirect
google.golang.org/protobuf v1.36.10 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
......
......@@ -10,6 +10,8 @@ github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOEl
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/DouDOU-start/go-sora2api v1.1.0 h1:PxWiukK77StiHxEngOFwT1rKUn9oTAJJTl07wQUXwiU=
github.com/DouDOU-start/go-sora2api v1.1.0/go.mod h1:dcwpethoKfAsMWskDD9iGgc/3yox2tkthPLSMVGnhkE=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/agext/levenshtein v1.2.3 h1:YB2fHEn0UJagG8T1rrWknE3ZQzWM06O8AMAatNn7lmo=
......@@ -20,10 +22,24 @@ github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwTo
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY=
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q=
github.com/bdandy/go-errors v1.2.2/go.mod h1:NkYHl4Fey9oRRdbB1CoC6e84tuqQHiqrOcZpqFEkBxM=
github.com/bdandy/go-socks4 v1.2.3 h1:Q6Y2heY1GRjCtHbmlKfnwrKVU/k81LS8mRGLRlmDlic=
github.com/bdandy/go-socks4 v1.2.3/go.mod h1:98kiVFgpdogR8aIGLWLvjDVZ8XcKPsSI/ypGrO+bqHI=
github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8=
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/bmatcuk/doublestar v1.3.4 h1:gPypJ5xD31uhX6Tf54sDPUOBXTqKH4c9aPY66CyQrS0=
github.com/bmatcuk/doublestar v1.3.4/go.mod h1:wiQtGV+rzVYxB7WIlirSN++5HPtPlXEo9MEoZQC/PmE=
github.com/bogdanfinn/fhttp v0.6.8 h1:LiQyHOY3i0QoxxNB7nq27/nGNNbtPj0fuBPozhR7Ws4=
github.com/bogdanfinn/fhttp v0.6.8/go.mod h1:A+EKDzMx2hb4IUbMx4TlkoHnaJEiLl8r/1Ss1Y+5e5M=
github.com/bogdanfinn/quic-go-utls v1.0.9-utls h1:tV6eDEiRbRCcepALSzxR94JUVD3N3ACIiRLgyc2Ep8s=
github.com/bogdanfinn/quic-go-utls v1.0.9-utls/go.mod h1:aHph9B9H9yPOt5xnhWKSOum27DJAqpiHzwX+gjvaXcg=
github.com/bogdanfinn/tls-client v1.14.0 h1:vyk7Cn4BIvLAGVuMfb0tP22OqogfO1lYamquQNEZU1A=
github.com/bogdanfinn/tls-client v1.14.0/go.mod h1:LsU6mXVn8MOFDwTkyRfI7V1BZM1p0wf2ZfZsICW/1fM=
github.com/bogdanfinn/utls v1.7.7-barnius h1:OuJ497cc7F3yKNVHRsYPQdGggmk5x6+V5ZlrCR7fOLU=
github.com/bogdanfinn/utls v1.7.7-barnius/go.mod h1:aAK1VZQlpKZClF1WEQeq6kyclbkPq4hz6xTbB5xSlmg=
github.com/bogdanfinn/websocket v1.5.5-barnius h1:bY+qnxpai1qe7Jmjx+Sds/cmOSpuuLoR8x61rWltjOI=
github.com/bogdanfinn/websocket v1.5.5-barnius/go.mod h1:gvvEw6pTKHb7yOiFvIfAFTStQWyrm25BMVCTj5wRSsI=
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI=
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
......@@ -279,6 +295,8 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5 h1:YqAladjX7xpA6BM04leXMWAEjS0mTZ5kUU9KRBriQJc=
github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5/go.mod h1:2JjD2zLQYH5HO74y5+aE3remJQvl6q4Sn6aWA2wD1Ng=
github.com/testcontainers/testcontainers-go v0.40.0 h1:pSdJYLOVgLE8YdUY2FHQ1Fxu+aMnb6JfVz1mxk7OeMU=
github.com/testcontainers/testcontainers-go v0.40.0/go.mod h1:FSXV5KQtX2HAMlm7U3APNyLkkap35zNLxukw9oBi/MY=
github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0 h1:s2bIayFXlbDFexo96y+htn7FzuhpXLYJNnIuglNKqOk=
......@@ -345,18 +363,21 @@ go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c=
golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU=
golang.org/x/net v0.0.0-20211104170005-ce137452f963/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
......@@ -364,16 +385,19 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY=
golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww=
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ=
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU=
......
package service
import (
"bytes"
"context"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"hash/fnv"
"io"
"log"
"math/rand"
"mime"
"mime/multipart"
"net/http"
"net/textproto"
"net/url"
"path"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
openaioauth "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
"github.com/google/uuid"
"github.com/tidwall/gjson"
"golang.org/x/crypto/sha3"
)
const (
soraChatGPTBaseURL = "https://chatgpt.com"
soraSentinelFlow = "sora_2_create_task"
soraDefaultUserAgent = "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)"
)
var (
soraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session"
soraOAuthTokenURL = "https://auth.openai.com/oauth/token"
)
const (
soraPowMaxIteration = 500000
)
var soraPowCores = []int{8, 16, 24, 32}
var soraPowScripts = []string{
"https://cdn.oaistatic.com/_next/static/cXh69klOLzS0Gy2joLDRS/_ssgManifest.js?dpl=453ebaec0d44c2decab71692e1bfe39be35a24b3",
}
var soraPowDPL = []string{
"prod-f501fe933b3edf57aea882da888e1a544df99840",
}
var soraPowNavigatorKeys = []string{
"registerProtocolHandler−function registerProtocolHandler() { [native code] }",
"storage−[object StorageManager]",
"locks−[object LockManager]",
"appCodeName−Mozilla",
"permissions−[object Permissions]",
"webdriver−false",
"vendor−Google Inc.",
"mediaDevices−[object MediaDevices]",
"cookieEnabled−true",
"product−Gecko",
"productSub−20030107",
"hardwareConcurrency−32",
"onLine−true",
}
var soraPowDocumentKeys = []string{
"_reactListeningo743lnnpvdg",
"location",
}
var soraPowWindowKeys = []string{
"0", "window", "self", "document", "name", "location",
"navigator", "screen", "innerWidth", "innerHeight",
"localStorage", "sessionStorage", "crypto", "performance",
"fetch", "setTimeout", "setInterval", "console",
}
var soraDesktopUserAgents = []string{
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36",
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36",
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
"Mozilla/5.0 (Windows NT 11.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
}
var soraMobileUserAgents = []string{
"Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)",
"Sora/1.2026.007 (Android 14; SM-G998B; build 2600700)",
"Sora/1.2026.007 (Android 15; Pixel 8 Pro; build 2600700)",
"Sora/1.2026.007 (Android 14; Pixel 7; build 2600700)",
"Sora/1.2026.007 (Android 15; 2211133C; build 2600700)",
"Sora/1.2026.007 (Android 14; SM-S918B; build 2600700)",
"Sora/1.2026.007 (Android 15; OnePlus 12; build 2600700)",
}
var soraRand = rand.New(rand.NewSource(time.Now().UnixNano()))
var soraRandMu sync.Mutex
var soraPerfStart = time.Now()
var soraPowTokenGenerator = soraGetPowToken
// SoraClient 定义直连 Sora 的任务操作接口。
type SoraClient interface {
Enabled() bool
......@@ -219,1905 +114,3 @@ func (e *SoraUpstreamError) Error() string {
}
return fmt.Sprintf("sora upstream error: %d", e.StatusCode)
}
// SoraDirectClient 直连 Sora 实现
type SoraDirectClient struct {
cfg *config.Config
httpUpstream HTTPUpstream
tokenProvider *OpenAITokenProvider
accountRepo AccountRepository
soraAccountRepo SoraAccountRepository
baseURL string
challengeCooldownMu sync.RWMutex
challengeCooldowns map[string]soraChallengeCooldownEntry
sidecarSessionMu sync.RWMutex
sidecarSessions map[string]soraSidecarSessionEntry
}
type soraRequestTraceContextKey struct{}
type soraRequestTrace struct {
ID string
ProxyKey string
UAHash string
}
// NewSoraDirectClient 创建 Sora 直连客户端
func NewSoraDirectClient(cfg *config.Config, httpUpstream HTTPUpstream, tokenProvider *OpenAITokenProvider) *SoraDirectClient {
baseURL := ""
if cfg != nil {
rawBaseURL := strings.TrimRight(strings.TrimSpace(cfg.Sora.Client.BaseURL), "/")
baseURL = normalizeSoraBaseURL(rawBaseURL)
if rawBaseURL != "" && baseURL != rawBaseURL {
log.Printf("[SoraClient] normalized base_url from %s to %s", sanitizeSoraLogURL(rawBaseURL), sanitizeSoraLogURL(baseURL))
}
}
return &SoraDirectClient{
cfg: cfg,
httpUpstream: httpUpstream,
tokenProvider: tokenProvider,
baseURL: baseURL,
challengeCooldowns: make(map[string]soraChallengeCooldownEntry),
sidecarSessions: make(map[string]soraSidecarSessionEntry),
}
}
func (c *SoraDirectClient) SetAccountRepositories(accountRepo AccountRepository, soraAccountRepo SoraAccountRepository) {
if c == nil {
return
}
c.accountRepo = accountRepo
c.soraAccountRepo = soraAccountRepo
}
// Enabled 判断是否启用 Sora 直连
func (c *SoraDirectClient) Enabled() bool {
if c == nil {
return false
}
if strings.TrimSpace(c.baseURL) != "" {
return true
}
if c.cfg == nil {
return false
}
return strings.TrimSpace(normalizeSoraBaseURL(c.cfg.Sora.Client.BaseURL)) != ""
}
// PreflightCheck 在创建任务前执行账号能力预检。
// 当前仅对视频模型执行 /nf/check 预检,用于提前识别额度耗尽或能力缺失。
func (c *SoraDirectClient) PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error {
if modelCfg.Type != "video" {
return nil
}
token, err := c.getAccessToken(ctx, account)
if err != nil {
return err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Accept", "application/json")
body, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodGet, c.buildURL("/nf/check"), headers, nil, false)
if err != nil {
var upstreamErr *SoraUpstreamError
if errors.As(err, &upstreamErr) && upstreamErr.StatusCode == http.StatusNotFound {
return &SoraUpstreamError{
StatusCode: http.StatusForbidden,
Message: "当前账号未开通 Sora2 能力或无可用配额",
Headers: upstreamErr.Headers,
Body: upstreamErr.Body,
}
}
return err
}
rateLimitReached := gjson.GetBytes(body, "rate_limit_and_credit_balance.rate_limit_reached").Bool()
remaining := gjson.GetBytes(body, "rate_limit_and_credit_balance.estimated_num_videos_remaining")
if rateLimitReached || (remaining.Exists() && remaining.Int() <= 0) {
msg := "当前账号 Sora2 可用配额不足"
if requestedModel != "" {
msg = fmt.Sprintf("当前账号 %s 可用配额不足", requestedModel)
}
return &SoraUpstreamError{
StatusCode: http.StatusTooManyRequests,
Message: msg,
Headers: http.Header{},
}
}
return nil
}
func (c *SoraDirectClient) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) {
if len(data) == 0 {
return "", errors.New("empty image data")
}
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
if filename == "" {
filename = "image.png"
}
var body bytes.Buffer
writer := multipart.NewWriter(&body)
contentType := mime.TypeByExtension(path.Ext(filename))
if contentType == "" {
contentType = "application/octet-stream"
}
partHeader := make(textproto.MIMEHeader)
partHeader.Set("Content-Disposition", fmt.Sprintf(`form-data; name="file"; filename="%s"`, filename))
partHeader.Set("Content-Type", contentType)
part, err := writer.CreatePart(partHeader)
if err != nil {
return "", err
}
if _, err := part.Write(data); err != nil {
return "", err
}
if err := writer.WriteField("file_name", filename); err != nil {
return "", err
}
if err := writer.Close(); err != nil {
return "", err
}
headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Content-Type", writer.FormDataContentType())
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/uploads"), headers, &body, false)
if err != nil {
return "", err
}
id := strings.TrimSpace(gjson.GetBytes(respBody, "id").String())
if id == "" {
return "", errors.New("upload response missing id")
}
return id, nil
}
func (c *SoraDirectClient) CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
ctx = c.withRequestTrace(ctx, account, proxyURL, userAgent)
operation := "simple_compose"
inpaintItems := []map[string]any{}
if strings.TrimSpace(req.MediaID) != "" {
operation = "remix"
inpaintItems = append(inpaintItems, map[string]any{
"type": "image",
"frame_index": 0,
"upload_media_id": req.MediaID,
})
}
payload := map[string]any{
"type": "image_gen",
"operation": operation,
"prompt": req.Prompt,
"width": req.Width,
"height": req.Height,
"n_variants": 1,
"n_frames": 1,
"inpaint_items": inpaintItems,
}
headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Content-Type", "application/json")
headers.Set("Origin", "https://sora.chatgpt.com")
headers.Set("Referer", "https://sora.chatgpt.com/")
body, err := json.Marshal(payload)
if err != nil {
return "", err
}
sentinel, err := c.generateSentinelToken(ctx, account, token, userAgent, proxyURL)
if err != nil {
return "", err
}
headers.Set("openai-sentinel-token", sentinel)
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/video_gen"), headers, bytes.NewReader(body), true)
if err != nil {
return "", err
}
taskID := strings.TrimSpace(gjson.GetBytes(respBody, "id").String())
if taskID == "" {
return "", errors.New("image task response missing id")
}
return taskID, nil
}
func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
ctx = c.withRequestTrace(ctx, account, proxyURL, userAgent)
orientation := req.Orientation
if orientation == "" {
orientation = "landscape"
}
nFrames := req.Frames
if nFrames <= 0 {
nFrames = 450
}
model := req.Model
if model == "" {
model = "sy_8"
}
size := req.Size
if size == "" {
size = "small"
}
inpaintItems := []map[string]any{}
if strings.TrimSpace(req.MediaID) != "" {
inpaintItems = append(inpaintItems, map[string]any{
"kind": "upload",
"upload_id": req.MediaID,
})
}
payload := map[string]any{
"kind": "video",
"prompt": req.Prompt,
"orientation": orientation,
"size": size,
"n_frames": nFrames,
"model": model,
"inpaint_items": inpaintItems,
}
if strings.TrimSpace(req.RemixTargetID) != "" {
payload["remix_target_id"] = req.RemixTargetID
payload["cameo_ids"] = []string{}
payload["cameo_replacements"] = map[string]any{}
} else if len(req.CameoIDs) > 0 {
payload["cameo_ids"] = req.CameoIDs
payload["cameo_replacements"] = map[string]any{}
}
headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Content-Type", "application/json")
headers.Set("Origin", "https://sora.chatgpt.com")
headers.Set("Referer", "https://sora.chatgpt.com/")
body, err := json.Marshal(payload)
if err != nil {
return "", err
}
sentinel, err := c.generateSentinelToken(ctx, account, token, userAgent, proxyURL)
if err != nil {
return "", err
}
headers.Set("openai-sentinel-token", sentinel)
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/nf/create"), headers, bytes.NewReader(body), true)
if err != nil {
return "", err
}
taskID := strings.TrimSpace(gjson.GetBytes(respBody, "id").String())
if taskID == "" {
return "", errors.New("video task response missing id")
}
return taskID, nil
}
func (c *SoraDirectClient) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
ctx = c.withRequestTrace(ctx, account, proxyURL, userAgent)
orientation := req.Orientation
if orientation == "" {
orientation = "landscape"
}
nFrames := req.Frames
if nFrames <= 0 {
nFrames = 450
}
model := req.Model
if model == "" {
model = "sy_8"
}
size := req.Size
if size == "" {
size = "small"
}
inpaintItems := []map[string]any{}
if strings.TrimSpace(req.MediaID) != "" {
inpaintItems = append(inpaintItems, map[string]any{
"kind": "upload",
"upload_id": req.MediaID,
})
}
payload := map[string]any{
"kind": "video",
"prompt": req.Prompt,
"title": "Draft your video",
"orientation": orientation,
"size": size,
"n_frames": nFrames,
"storyboard_id": nil,
"inpaint_items": inpaintItems,
"remix_target_id": nil,
"model": model,
"metadata": nil,
"style_id": nil,
"cameo_ids": nil,
"cameo_replacements": nil,
"audio_caption": nil,
"audio_transcript": nil,
"video_caption": nil,
}
headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Content-Type", "application/json")
headers.Set("Origin", "https://sora.chatgpt.com")
headers.Set("Referer", "https://sora.chatgpt.com/")
body, err := json.Marshal(payload)
if err != nil {
return "", err
}
sentinel, err := c.generateSentinelToken(ctx, account, token, userAgent, proxyURL)
if err != nil {
return "", err
}
headers.Set("openai-sentinel-token", sentinel)
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/nf/create/storyboard"), headers, bytes.NewReader(body), true)
if err != nil {
return "", err
}
taskID := strings.TrimSpace(gjson.GetBytes(respBody, "id").String())
if taskID == "" {
return "", errors.New("storyboard task response missing id")
}
return taskID, nil
}
func (c *SoraDirectClient) UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) {
if len(data) == 0 {
return "", errors.New("empty video data")
}
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
var body bytes.Buffer
writer := multipart.NewWriter(&body)
partHeader := make(textproto.MIMEHeader)
partHeader.Set("Content-Disposition", `form-data; name="file"; filename="video.mp4"`)
partHeader.Set("Content-Type", "video/mp4")
part, err := writer.CreatePart(partHeader)
if err != nil {
return "", err
}
if _, err := part.Write(data); err != nil {
return "", err
}
if err := writer.WriteField("timestamps", "0,3"); err != nil {
return "", err
}
if err := writer.Close(); err != nil {
return "", err
}
headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Content-Type", writer.FormDataContentType())
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/characters/upload"), headers, &body, false)
if err != nil {
return "", err
}
cameoID := strings.TrimSpace(gjson.GetBytes(respBody, "id").String())
if cameoID == "" {
return "", errors.New("character upload response missing id")
}
return cameoID, nil
}
func (c *SoraDirectClient) GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return nil, err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
headers := c.buildBaseHeaders(token, userAgent)
respBody, _, err := c.doRequestWithProxy(
ctx,
account,
proxyURL,
http.MethodGet,
c.buildURL("/project_y/cameos/in_progress/"+strings.TrimSpace(cameoID)),
headers,
nil,
false,
)
if err != nil {
return nil, err
}
return &SoraCameoStatus{
Status: strings.TrimSpace(gjson.GetBytes(respBody, "status").String()),
StatusMessage: strings.TrimSpace(gjson.GetBytes(respBody, "status_message").String()),
DisplayNameHint: strings.TrimSpace(gjson.GetBytes(respBody, "display_name_hint").String()),
UsernameHint: strings.TrimSpace(gjson.GetBytes(respBody, "username_hint").String()),
ProfileAssetURL: strings.TrimSpace(gjson.GetBytes(respBody, "profile_asset_url").String()),
InstructionSetHint: gjson.GetBytes(respBody, "instruction_set_hint").Value(),
InstructionSet: gjson.GetBytes(respBody, "instruction_set").Value(),
}, nil
}
func (c *SoraDirectClient) DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return nil, err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Accept", "image/*,*/*;q=0.8")
respBody, _, err := c.doRequestWithProxy(
ctx,
account,
proxyURL,
http.MethodGet,
strings.TrimSpace(imageURL),
headers,
nil,
false,
)
if err != nil {
return nil, err
}
return respBody, nil
}
func (c *SoraDirectClient) UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) {
if len(data) == 0 {
return "", errors.New("empty character image")
}
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
var body bytes.Buffer
writer := multipart.NewWriter(&body)
partHeader := make(textproto.MIMEHeader)
partHeader.Set("Content-Disposition", `form-data; name="file"; filename="profile.webp"`)
partHeader.Set("Content-Type", "image/webp")
part, err := writer.CreatePart(partHeader)
if err != nil {
return "", err
}
if _, err := part.Write(data); err != nil {
return "", err
}
if err := writer.WriteField("use_case", "profile"); err != nil {
return "", err
}
if err := writer.Close(); err != nil {
return "", err
}
headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Content-Type", writer.FormDataContentType())
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/project_y/file/upload"), headers, &body, false)
if err != nil {
return "", err
}
assetPointer := strings.TrimSpace(gjson.GetBytes(respBody, "asset_pointer").String())
if assetPointer == "" {
return "", errors.New("character image upload response missing asset_pointer")
}
return assetPointer, nil
}
func (c *SoraDirectClient) FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
ctx = c.withRequestTrace(ctx, account, proxyURL, userAgent)
payload := map[string]any{
"cameo_id": req.CameoID,
"username": req.Username,
"display_name": req.DisplayName,
"profile_asset_pointer": req.ProfileAssetPointer,
"instruction_set": nil,
"safety_instruction_set": nil,
}
body, err := json.Marshal(payload)
if err != nil {
return "", err
}
headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Content-Type", "application/json")
headers.Set("Origin", "https://sora.chatgpt.com")
headers.Set("Referer", "https://sora.chatgpt.com/")
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/characters/finalize"), headers, bytes.NewReader(body), false)
if err != nil {
return "", err
}
characterID := strings.TrimSpace(gjson.GetBytes(respBody, "character.character_id").String())
if characterID == "" {
return "", errors.New("character finalize response missing character_id")
}
return characterID, nil
}
func (c *SoraDirectClient) SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
payload := map[string]any{"visibility": "public"}
body, err := json.Marshal(payload)
if err != nil {
return err
}
headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Content-Type", "application/json")
headers.Set("Origin", "https://sora.chatgpt.com")
headers.Set("Referer", "https://sora.chatgpt.com/")
_, _, err = c.doRequestWithProxy(
ctx,
account,
proxyURL,
http.MethodPost,
c.buildURL("/project_y/cameos/by_id/"+strings.TrimSpace(cameoID)+"/update_v2"),
headers,
bytes.NewReader(body),
false,
)
return err
}
func (c *SoraDirectClient) DeleteCharacter(ctx context.Context, account *Account, characterID string) error {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
headers := c.buildBaseHeaders(token, userAgent)
_, _, err = c.doRequestWithProxy(
ctx,
account,
proxyURL,
http.MethodDelete,
c.buildURL("/project_y/characters/"+strings.TrimSpace(characterID)),
headers,
nil,
false,
)
return err
}
func (c *SoraDirectClient) PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
ctx = c.withRequestTrace(ctx, account, proxyURL, userAgent)
payload := map[string]any{
"attachments_to_create": []map[string]any{
{
"generation_id": generationID,
"kind": "sora",
},
},
"post_text": "",
}
body, err := json.Marshal(payload)
if err != nil {
return "", err
}
headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Content-Type", "application/json")
headers.Set("Origin", "https://sora.chatgpt.com")
headers.Set("Referer", "https://sora.chatgpt.com/")
sentinel, err := c.generateSentinelToken(ctx, account, token, userAgent, proxyURL)
if err != nil {
return "", err
}
headers.Set("openai-sentinel-token", sentinel)
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/project_y/post"), headers, bytes.NewReader(body), true)
if err != nil {
return "", err
}
postID := strings.TrimSpace(gjson.GetBytes(respBody, "post.id").String())
if postID == "" {
return "", errors.New("watermark-free publish response missing post.id")
}
return postID, nil
}
func (c *SoraDirectClient) DeletePost(ctx context.Context, account *Account, postID string) error {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
headers := c.buildBaseHeaders(token, userAgent)
_, _, err = c.doRequestWithProxy(
ctx,
account,
proxyURL,
http.MethodDelete,
c.buildURL("/project_y/post/"+strings.TrimSpace(postID)),
headers,
nil,
false,
)
return err
}
func (c *SoraDirectClient) GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) {
parseURL = strings.TrimRight(strings.TrimSpace(parseURL), "/")
if parseURL == "" {
return "", errors.New("custom parse url is required")
}
if strings.TrimSpace(parseToken) == "" {
return "", errors.New("custom parse token is required")
}
shareURL := "https://sora.chatgpt.com/p/" + strings.TrimSpace(postID)
payload := map[string]any{
"url": shareURL,
"token": strings.TrimSpace(parseToken),
}
body, err := json.Marshal(payload)
if err != nil {
return "", err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, parseURL+"/get-sora-link", bytes.NewReader(body))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
proxyURL := c.resolveProxyURL(account)
accountID := int64(0)
accountConcurrency := 0
if account != nil {
accountID = account.ID
accountConcurrency = account.Concurrency
}
var resp *http.Response
if c.httpUpstream != nil {
resp, err = c.httpUpstream.Do(req, proxyURL, accountID, accountConcurrency)
} else {
resp, err = http.DefaultClient.Do(req)
}
if err != nil {
return "", err
}
defer func() { _ = resp.Body.Close() }()
raw, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20))
if err != nil {
return "", err
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("custom parse failed: %d %s", resp.StatusCode, truncateForLog(raw, 256))
}
downloadLink := strings.TrimSpace(gjson.GetBytes(raw, "download_link").String())
if downloadLink == "" {
return "", errors.New("custom parse response missing download_link")
}
return downloadLink, nil
}
func (c *SoraDirectClient) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
if strings.TrimSpace(expansionLevel) == "" {
expansionLevel = "medium"
}
if durationS <= 0 {
durationS = 10
}
payload := map[string]any{
"prompt": prompt,
"expansion_level": expansionLevel,
"duration_s": durationS,
}
body, err := json.Marshal(payload)
if err != nil {
return "", err
}
headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Content-Type", "application/json")
headers.Set("Accept", "application/json")
headers.Set("Origin", "https://sora.chatgpt.com")
headers.Set("Referer", "https://sora.chatgpt.com/")
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/editor/enhance_prompt"), headers, bytes.NewReader(body), false)
if err != nil {
return "", err
}
enhancedPrompt := strings.TrimSpace(gjson.GetBytes(respBody, "enhanced_prompt").String())
if enhancedPrompt == "" {
return "", errors.New("enhance_prompt response missing enhanced_prompt")
}
return enhancedPrompt, nil
}
func (c *SoraDirectClient) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) {
status, found, err := c.fetchRecentImageTask(ctx, account, taskID, c.recentTaskLimit())
if err != nil {
return nil, err
}
if found {
return status, nil
}
maxLimit := c.recentTaskLimitMax()
if maxLimit > 0 && maxLimit != c.recentTaskLimit() {
status, found, err = c.fetchRecentImageTask(ctx, account, taskID, maxLimit)
if err != nil {
return nil, err
}
if found {
return status, nil
}
}
return &SoraImageTaskStatus{ID: taskID, Status: "processing"}, nil
}
func (c *SoraDirectClient) fetchRecentImageTask(ctx context.Context, account *Account, taskID string, limit int) (*SoraImageTaskStatus, bool, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return nil, false, err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
headers := c.buildBaseHeaders(token, userAgent)
if limit <= 0 {
limit = 20
}
endpoint := fmt.Sprintf("/v2/recent_tasks?limit=%d", limit)
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodGet, c.buildURL(endpoint), headers, nil, false)
if err != nil {
return nil, false, err
}
var found *SoraImageTaskStatus
gjson.GetBytes(respBody, "task_responses").ForEach(func(_, item gjson.Result) bool {
if item.Get("id").String() != taskID {
return true // continue
}
status := strings.TrimSpace(item.Get("status").String())
progress := item.Get("progress_pct").Float()
var urls []string
item.Get("generations").ForEach(func(_, gen gjson.Result) bool {
if u := strings.TrimSpace(gen.Get("url").String()); u != "" {
urls = append(urls, u)
}
return true
})
found = &SoraImageTaskStatus{
ID: taskID,
Status: status,
ProgressPct: progress,
URLs: urls,
}
return false // break
})
if found != nil {
return found, true, nil
}
return &SoraImageTaskStatus{ID: taskID, Status: "processing"}, false, nil
}
func (c *SoraDirectClient) recentTaskLimit() int {
if c == nil || c.cfg == nil {
return 20
}
if c.cfg.Sora.Client.RecentTaskLimit > 0 {
return c.cfg.Sora.Client.RecentTaskLimit
}
return 20
}
func (c *SoraDirectClient) recentTaskLimitMax() int {
if c == nil || c.cfg == nil {
return 0
}
if c.cfg.Sora.Client.RecentTaskLimitMax > 0 {
return c.cfg.Sora.Client.RecentTaskLimitMax
}
return 0
}
func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return nil, err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
headers := c.buildBaseHeaders(token, userAgent)
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodGet, c.buildURL("/nf/pending/v2"), headers, nil, false)
if err != nil {
return nil, err
}
// 搜索 pending 列表(JSON 数组)
pendingResult := gjson.ParseBytes(respBody)
if pendingResult.IsArray() {
var pendingFound *SoraVideoTaskStatus
pendingResult.ForEach(func(_, task gjson.Result) bool {
if task.Get("id").String() != taskID {
return true
}
progress := 0
if v := task.Get("progress_pct"); v.Exists() {
progress = int(v.Float() * 100)
}
status := strings.TrimSpace(task.Get("status").String())
pendingFound = &SoraVideoTaskStatus{
ID: taskID,
Status: status,
ProgressPct: progress,
}
return false
})
if pendingFound != nil {
return pendingFound, nil
}
}
respBody, _, err = c.doRequestWithProxy(ctx, account, proxyURL, http.MethodGet, c.buildURL("/project_y/profile/drafts?limit=15"), headers, nil, false)
if err != nil {
return nil, err
}
var draftFound *SoraVideoTaskStatus
gjson.GetBytes(respBody, "items").ForEach(func(_, draft gjson.Result) bool {
if draft.Get("task_id").String() != taskID {
return true
}
generationID := strings.TrimSpace(draft.Get("id").String())
kind := strings.TrimSpace(draft.Get("kind").String())
reason := strings.TrimSpace(draft.Get("reason_str").String())
if reason == "" {
reason = strings.TrimSpace(draft.Get("markdown_reason_str").String())
}
urlStr := strings.TrimSpace(draft.Get("downloadable_url").String())
if urlStr == "" {
urlStr = strings.TrimSpace(draft.Get("url").String())
}
if kind == "sora_content_violation" || reason != "" || urlStr == "" {
msg := reason
if msg == "" {
msg = "Content violates guardrails"
}
draftFound = &SoraVideoTaskStatus{
ID: taskID,
Status: "failed",
GenerationID: generationID,
ErrorMsg: msg,
}
} else {
draftFound = &SoraVideoTaskStatus{
ID: taskID,
Status: "completed",
GenerationID: generationID,
URLs: []string{urlStr},
}
}
return false
})
if draftFound != nil {
return draftFound, nil
}
return &SoraVideoTaskStatus{ID: taskID, Status: "processing"}, nil
}
func (c *SoraDirectClient) buildURL(endpoint string) string {
base := strings.TrimRight(strings.TrimSpace(c.baseURL), "/")
if base == "" && c != nil && c.cfg != nil {
base = normalizeSoraBaseURL(c.cfg.Sora.Client.BaseURL)
c.baseURL = base
}
if base == "" {
return endpoint
}
if strings.HasPrefix(endpoint, "/") {
return base + endpoint
}
return base + "/" + endpoint
}
func (c *SoraDirectClient) defaultUserAgent() string {
if c == nil || c.cfg == nil {
return soraDefaultUserAgent
}
ua := strings.TrimSpace(c.cfg.Sora.Client.UserAgent)
if ua == "" {
return soraDefaultUserAgent
}
return ua
}
func (c *SoraDirectClient) taskUserAgent() string {
if c != nil && c.cfg != nil {
if ua := strings.TrimSpace(c.cfg.Sora.Client.UserAgent); ua != "" {
return ua
}
}
if len(soraMobileUserAgents) > 0 {
return soraMobileUserAgents[soraRandInt(len(soraMobileUserAgents))]
}
if len(soraDesktopUserAgents) > 0 {
return soraDesktopUserAgents[soraRandInt(len(soraDesktopUserAgents))]
}
return soraDefaultUserAgent
}
func (c *SoraDirectClient) resolveProxyURL(account *Account) string {
if account == nil || account.ProxyID == nil || account.Proxy == nil {
return ""
}
return strings.TrimSpace(account.Proxy.URL())
}
func (c *SoraDirectClient) getAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
allowProvider := c.allowOpenAITokenProvider(account)
var providerErr error
if allowProvider && c.tokenProvider != nil {
token, err := c.tokenProvider.GetAccessToken(ctx, account)
if err == nil && strings.TrimSpace(token) != "" {
c.logTokenSource(account, "openai_token_provider")
return token, nil
}
providerErr = err
if err != nil && c.debugEnabled() {
c.debugLogf(
"token_provider_failed account_id=%d platform=%s err=%s",
account.ID,
account.Platform,
logredact.RedactText(err.Error()),
)
}
}
token := strings.TrimSpace(account.GetCredential("access_token"))
if token != "" {
expiresAt := account.GetCredentialAsTime("expires_at")
if expiresAt != nil && time.Until(*expiresAt) <= 2*time.Minute {
refreshed, refreshErr := c.recoverAccessToken(ctx, account, "access_token_expiring")
if refreshErr == nil && strings.TrimSpace(refreshed) != "" {
c.logTokenSource(account, "refresh_token_recovered")
return refreshed, nil
}
if refreshErr != nil && c.debugEnabled() {
c.debugLogf("token_refresh_before_use_failed account_id=%d err=%s", account.ID, logredact.RedactText(refreshErr.Error()))
}
}
c.logTokenSource(account, "account_credentials")
return token, nil
}
recovered, recoverErr := c.recoverAccessToken(ctx, account, "access_token_missing")
if recoverErr == nil && strings.TrimSpace(recovered) != "" {
c.logTokenSource(account, "session_or_refresh_recovered")
return recovered, nil
}
if recoverErr != nil && c.debugEnabled() {
c.debugLogf("token_recover_failed account_id=%d platform=%s err=%s", account.ID, account.Platform, logredact.RedactText(recoverErr.Error()))
}
if providerErr != nil {
return "", providerErr
}
if c.tokenProvider != nil && !allowProvider {
c.logTokenSource(account, "account_credentials(provider_disabled)")
}
return "", errors.New("access_token not found")
}
func (c *SoraDirectClient) recoverAccessToken(ctx context.Context, account *Account, reason string) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
if sessionToken := strings.TrimSpace(account.GetCredential("session_token")); sessionToken != "" {
accessToken, expiresAt, err := c.exchangeSessionToken(ctx, account, sessionToken)
if err == nil && strings.TrimSpace(accessToken) != "" {
c.applyRecoveredToken(ctx, account, accessToken, "", expiresAt, sessionToken)
c.logTokenRecover(account, "session_token", reason, true, nil)
return accessToken, nil
}
c.logTokenRecover(account, "session_token", reason, false, err)
}
refreshToken := strings.TrimSpace(account.GetCredential("refresh_token"))
if refreshToken == "" {
return "", errors.New("session_token/refresh_token not found")
}
accessToken, newRefreshToken, expiresAt, err := c.exchangeRefreshToken(ctx, account, refreshToken)
if err != nil {
c.logTokenRecover(account, "refresh_token", reason, false, err)
return "", err
}
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("refreshed access_token is empty")
}
c.applyRecoveredToken(ctx, account, accessToken, newRefreshToken, expiresAt, "")
c.logTokenRecover(account, "refresh_token", reason, true, nil)
return accessToken, nil
}
func (c *SoraDirectClient) exchangeSessionToken(ctx context.Context, account *Account, sessionToken string) (string, string, error) {
headers := http.Header{}
headers.Set("Cookie", "__Secure-next-auth.session-token="+sessionToken)
headers.Set("Accept", "application/json")
headers.Set("Origin", "https://sora.chatgpt.com")
headers.Set("Referer", "https://sora.chatgpt.com/")
headers.Set("User-Agent", c.defaultUserAgent())
body, _, err := c.doRequest(ctx, account, http.MethodGet, soraSessionAuthURL, headers, nil, false)
if err != nil {
return "", "", err
}
accessToken := strings.TrimSpace(gjson.GetBytes(body, "accessToken").String())
if accessToken == "" {
return "", "", errors.New("session exchange missing accessToken")
}
expiresAt := strings.TrimSpace(gjson.GetBytes(body, "expires").String())
return accessToken, expiresAt, nil
}
func (c *SoraDirectClient) exchangeRefreshToken(ctx context.Context, account *Account, refreshToken string) (string, string, string, error) {
clientIDs := []string{
strings.TrimSpace(account.GetCredential("client_id")),
openaioauth.SoraClientID,
openaioauth.ClientID,
}
tried := make(map[string]struct{}, len(clientIDs))
var lastErr error
for _, clientID := range clientIDs {
if clientID == "" {
continue
}
if _, ok := tried[clientID]; ok {
continue
}
tried[clientID] = struct{}{}
formData := url.Values{}
formData.Set("client_id", clientID)
formData.Set("grant_type", "refresh_token")
formData.Set("refresh_token", refreshToken)
formData.Set("redirect_uri", "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback")
headers := http.Header{}
headers.Set("Accept", "application/json")
headers.Set("Content-Type", "application/x-www-form-urlencoded")
headers.Set("User-Agent", c.defaultUserAgent())
respBody, _, err := c.doRequest(ctx, account, http.MethodPost, soraOAuthTokenURL, headers, strings.NewReader(formData.Encode()), false)
if err != nil {
lastErr = err
if c.debugEnabled() {
c.debugLogf("refresh_token_exchange_failed account_id=%d client_id=%s err=%s", account.ID, clientID, logredact.RedactText(err.Error()))
}
continue
}
accessToken := strings.TrimSpace(gjson.GetBytes(respBody, "access_token").String())
if accessToken == "" {
lastErr = errors.New("oauth refresh response missing access_token")
continue
}
newRefreshToken := strings.TrimSpace(gjson.GetBytes(respBody, "refresh_token").String())
expiresIn := gjson.GetBytes(respBody, "expires_in").Int()
expiresAt := ""
if expiresIn > 0 {
expiresAt = time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339)
}
return accessToken, newRefreshToken, expiresAt, nil
}
if lastErr != nil {
return "", "", "", lastErr
}
return "", "", "", errors.New("no available client_id for refresh_token exchange")
}
func (c *SoraDirectClient) applyRecoveredToken(ctx context.Context, account *Account, accessToken, refreshToken, expiresAt, sessionToken string) {
if account == nil {
return
}
if account.Credentials == nil {
account.Credentials = make(map[string]any)
}
if strings.TrimSpace(accessToken) != "" {
account.Credentials["access_token"] = accessToken
}
if strings.TrimSpace(refreshToken) != "" {
account.Credentials["refresh_token"] = refreshToken
}
if strings.TrimSpace(expiresAt) != "" {
account.Credentials["expires_at"] = expiresAt
}
if strings.TrimSpace(sessionToken) != "" {
account.Credentials["session_token"] = sessionToken
}
if c.accountRepo != nil {
if err := c.accountRepo.Update(ctx, account); err != nil {
if c.debugEnabled() {
c.debugLogf("persist_recovered_token_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error()))
}
}
}
c.updateSoraAccountExtension(ctx, account, accessToken, refreshToken, sessionToken)
}
func (c *SoraDirectClient) updateSoraAccountExtension(ctx context.Context, account *Account, accessToken, refreshToken, sessionToken string) {
if c == nil || c.soraAccountRepo == nil || account == nil || account.ID <= 0 {
return
}
updates := make(map[string]any)
if strings.TrimSpace(accessToken) != "" && strings.TrimSpace(refreshToken) != "" {
updates["access_token"] = accessToken
updates["refresh_token"] = refreshToken
}
if strings.TrimSpace(sessionToken) != "" {
updates["session_token"] = sessionToken
}
if len(updates) == 0 {
return
}
if err := c.soraAccountRepo.Upsert(ctx, account.ID, updates); err != nil && c.debugEnabled() {
c.debugLogf("persist_sora_extension_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error()))
}
}
func (c *SoraDirectClient) logTokenRecover(account *Account, source, reason string, success bool, err error) {
if !c.debugEnabled() || account == nil {
return
}
if success {
c.debugLogf("token_recover_success account_id=%d platform=%s source=%s reason=%s", account.ID, account.Platform, source, reason)
return
}
if err == nil {
c.debugLogf("token_recover_failed account_id=%d platform=%s source=%s reason=%s", account.ID, account.Platform, source, reason)
return
}
c.debugLogf("token_recover_failed account_id=%d platform=%s source=%s reason=%s err=%s", account.ID, account.Platform, source, reason, logredact.RedactText(err.Error()))
}
func (c *SoraDirectClient) allowOpenAITokenProvider(account *Account) bool {
if c == nil || c.tokenProvider == nil {
return false
}
if account != nil && account.Platform == PlatformSora {
return c.cfg != nil && c.cfg.Sora.Client.UseOpenAITokenProvider
}
return true
}
func (c *SoraDirectClient) logTokenSource(account *Account, source string) {
if !c.debugEnabled() || account == nil {
return
}
c.debugLogf(
"token_selected account_id=%d platform=%s account_type=%s source=%s",
account.ID,
account.Platform,
account.Type,
source,
)
}
func (c *SoraDirectClient) buildBaseHeaders(token, userAgent string) http.Header {
headers := http.Header{}
if token != "" {
headers.Set("Authorization", "Bearer "+token)
}
if userAgent != "" {
headers.Set("User-Agent", userAgent)
}
if c != nil && c.cfg != nil {
for key, value := range c.cfg.Sora.Client.Headers {
if strings.EqualFold(key, "authorization") || strings.EqualFold(key, "openai-sentinel-token") {
continue
}
headers.Set(key, value)
}
}
return headers
}
func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, method, urlStr string, headers http.Header, body io.Reader, allowRetry bool) ([]byte, http.Header, error) {
return c.doRequestWithProxy(ctx, account, c.resolveProxyURL(account), method, urlStr, headers, body, allowRetry)
}
func (c *SoraDirectClient) doRequestWithProxy(
ctx context.Context,
account *Account,
proxyURL string,
method,
urlStr string,
headers http.Header,
body io.Reader,
allowRetry bool,
) ([]byte, http.Header, error) {
if strings.TrimSpace(urlStr) == "" {
return nil, nil, errors.New("empty upstream url")
}
proxyURL = strings.TrimSpace(proxyURL)
if proxyURL == "" {
proxyURL = c.resolveProxyURL(account)
}
if cooldownErr := c.checkCloudflareChallengeCooldown(account, proxyURL); cooldownErr != nil {
return nil, nil, cooldownErr
}
traceID, traceProxyKey, traceUAHash := c.requestTraceFields(ctx, proxyURL, headers.Get("User-Agent"))
timeout := 0
if c != nil && c.cfg != nil {
timeout = c.cfg.Sora.Client.TimeoutSeconds
}
if timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, time.Duration(timeout)*time.Second)
defer cancel()
}
maxRetries := 0
if allowRetry && c != nil && c.cfg != nil {
maxRetries = c.cfg.Sora.Client.MaxRetries
}
if maxRetries < 0 {
maxRetries = 0
}
var bodyBytes []byte
if body != nil {
b, err := io.ReadAll(body)
if err != nil {
return nil, nil, err
}
bodyBytes = b
}
attempts := maxRetries + 1
authRecovered := false
authRecoverExtraAttemptGranted := false
challengeRetried := false
sawCFChallenge := false
var lastErr error
for attempt := 1; attempt <= attempts; attempt++ {
if c.debugEnabled() {
c.debugLogf(
"request_start trace_id=%s method=%s url=%s attempt=%d/%d timeout_s=%d body_bytes=%d proxy_bound=%t proxy_key=%s ua_hash=%s headers=%s",
traceID,
method,
sanitizeSoraLogURL(urlStr),
attempt,
attempts,
timeout,
len(bodyBytes),
proxyURL != "",
traceProxyKey,
traceUAHash,
formatSoraHeaders(headers),
)
}
var reader io.Reader
if bodyBytes != nil {
reader = bytes.NewReader(bodyBytes)
}
req, err := http.NewRequestWithContext(ctx, method, urlStr, reader)
if err != nil {
return nil, nil, err
}
req.Header = headers.Clone()
start := time.Now()
resp, err := c.doHTTP(req, proxyURL, account)
if err != nil {
lastErr = err
if c.debugEnabled() {
c.debugLogf(
"request_transport_error trace_id=%s method=%s url=%s attempt=%d/%d err=%s",
traceID,
method,
sanitizeSoraLogURL(urlStr),
attempt,
attempts,
logredact.RedactText(err.Error()),
)
}
if attempt < attempts && allowRetry {
if c.debugEnabled() {
c.debugLogf("request_retry_scheduled trace_id=%s method=%s url=%s reason=transport_error next_attempt=%d/%d", traceID, method, sanitizeSoraLogURL(urlStr), attempt+1, attempts)
}
c.sleepRetry(attempt)
continue
}
return nil, nil, err
}
respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
if readErr != nil {
return nil, resp.Header, readErr
}
if c.cfg != nil && c.cfg.Sora.Client.Debug {
c.debugLogf(
"response_received trace_id=%s method=%s url=%s attempt=%d/%d status=%d cost=%s resp_bytes=%d resp_headers=%s",
traceID,
method,
sanitizeSoraLogURL(urlStr),
attempt,
attempts,
resp.StatusCode,
time.Since(start),
len(respBody),
formatSoraHeaders(resp.Header),
)
}
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
isCFChallenge := soraerror.IsCloudflareChallengeResponse(resp.StatusCode, resp.Header, respBody)
if isCFChallenge {
sawCFChallenge = true
c.recordCloudflareChallengeCooldown(account, proxyURL, resp.StatusCode, resp.Header, respBody)
if allowRetry && attempt < attempts && !challengeRetried {
challengeRetried = true
if c.debugEnabled() {
c.debugLogf("request_retry_scheduled trace_id=%s method=%s url=%s reason=cloudflare_challenge status=%d next_attempt=%d/%d", traceID, method, sanitizeSoraLogURL(urlStr), resp.StatusCode, attempt+1, attempts)
}
c.sleepRetry(attempt)
continue
}
}
if !isCFChallenge && !authRecovered && shouldAttemptSoraTokenRecover(resp.StatusCode, urlStr) && account != nil {
if recovered, recoverErr := c.recoverAccessToken(ctx, account, fmt.Sprintf("upstream_status_%d", resp.StatusCode)); recoverErr == nil && strings.TrimSpace(recovered) != "" {
headers.Set("Authorization", "Bearer "+recovered)
authRecovered = true
if attempt == attempts && !authRecoverExtraAttemptGranted {
attempts++
authRecoverExtraAttemptGranted = true
}
if c.debugEnabled() {
c.debugLogf("request_retry_with_recovered_token trace_id=%s method=%s url=%s status=%d", traceID, method, sanitizeSoraLogURL(urlStr), resp.StatusCode)
}
continue
} else if recoverErr != nil && c.debugEnabled() {
c.debugLogf("request_recover_token_failed trace_id=%s method=%s url=%s status=%d err=%s", traceID, method, sanitizeSoraLogURL(urlStr), resp.StatusCode, logredact.RedactText(recoverErr.Error()))
}
}
if c.debugEnabled() {
c.debugLogf(
"response_non_success trace_id=%s method=%s url=%s attempt=%d/%d status=%d body=%s",
traceID,
method,
sanitizeSoraLogURL(urlStr),
attempt,
attempts,
resp.StatusCode,
summarizeSoraResponseBody(respBody, 512),
)
}
upstreamErr := c.buildUpstreamError(resp.StatusCode, resp.Header, respBody, urlStr)
lastErr = upstreamErr
if isCFChallenge {
return nil, resp.Header, upstreamErr
}
if allowRetry && attempt < attempts && (resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode >= 500) {
if c.debugEnabled() {
c.debugLogf("request_retry_scheduled trace_id=%s method=%s url=%s reason=status_%d next_attempt=%d/%d", traceID, method, sanitizeSoraLogURL(urlStr), resp.StatusCode, attempt+1, attempts)
}
c.sleepRetry(attempt)
continue
}
return nil, resp.Header, upstreamErr
}
if sawCFChallenge {
c.clearCloudflareChallengeCooldown(account, proxyURL)
}
return respBody, resp.Header, nil
}
if lastErr != nil {
return nil, nil, lastErr
}
return nil, nil, errors.New("upstream retries exhausted")
}
func shouldAttemptSoraTokenRecover(statusCode int, rawURL string) bool {
switch statusCode {
case http.StatusUnauthorized, http.StatusForbidden:
parsed, err := url.Parse(strings.TrimSpace(rawURL))
if err != nil {
return false
}
host := strings.ToLower(parsed.Hostname())
if host != "sora.chatgpt.com" && host != "chatgpt.com" {
return false
}
// 避免在 ST->AT 转换接口上递归触发 token 恢复导致死循环。
path := strings.ToLower(strings.TrimSpace(parsed.Path))
if path == "/api/auth/session" {
return false
}
return true
default:
return false
}
}
func (c *SoraDirectClient) doHTTP(req *http.Request, proxyURL string, account *Account) (*http.Response, error) {
if c != nil && c.cfg != nil && c.cfg.Sora.Client.CurlCFFISidecar.Enabled {
resp, err := c.doHTTPViaCurlCFFISidecar(req, proxyURL, account)
if err != nil {
return nil, err
}
return resp, nil
}
enableTLS := c == nil || c.cfg == nil || !c.cfg.Sora.Client.DisableTLSFingerprint
if c.httpUpstream != nil {
accountID := int64(0)
accountConcurrency := 0
if account != nil {
accountID = account.ID
accountConcurrency = account.Concurrency
}
return c.httpUpstream.DoWithTLS(req, proxyURL, accountID, accountConcurrency, enableTLS)
}
return http.DefaultClient.Do(req)
}
func (c *SoraDirectClient) sleepRetry(attempt int) {
backoff := time.Duration(attempt*attempt) * time.Second
if backoff > 10*time.Second {
backoff = 10 * time.Second
}
time.Sleep(backoff)
}
func (c *SoraDirectClient) buildUpstreamError(status int, headers http.Header, body []byte, requestURL string) error {
msg := strings.TrimSpace(extractUpstreamErrorMessage(body))
msg = sanitizeUpstreamErrorMessage(msg)
if status == http.StatusNotFound && strings.Contains(strings.ToLower(msg), "not found") {
if hint := soraBaseURLNotFoundHint(requestURL); hint != "" {
msg = strings.TrimSpace(msg + " " + hint)
}
}
if msg == "" {
msg = truncateForLog(body, 256)
}
return &SoraUpstreamError{
StatusCode: status,
Message: msg,
Headers: headers,
Body: body,
}
}
func normalizeSoraBaseURL(raw string) string {
trimmed := strings.TrimRight(strings.TrimSpace(raw), "/")
if trimmed == "" {
return ""
}
parsed, err := url.Parse(trimmed)
if err != nil || parsed.Scheme == "" || parsed.Host == "" {
return trimmed
}
host := strings.ToLower(parsed.Hostname())
if host != "sora.chatgpt.com" && host != "chatgpt.com" {
return trimmed
}
pathVal := strings.TrimRight(strings.TrimSpace(parsed.Path), "/")
switch pathVal {
case "", "/":
parsed.Path = "/backend"
case "/backend-api":
parsed.Path = "/backend"
}
return strings.TrimRight(parsed.String(), "/")
}
func soraBaseURLNotFoundHint(requestURL string) string {
parsed, err := url.Parse(strings.TrimSpace(requestURL))
if err != nil || parsed.Host == "" {
return ""
}
host := strings.ToLower(parsed.Hostname())
if host != "sora.chatgpt.com" && host != "chatgpt.com" {
return ""
}
pathVal := strings.TrimSpace(parsed.Path)
if strings.HasPrefix(pathVal, "/backend/") || pathVal == "/backend" {
return ""
}
return "(请检查 sora.client.base_url,建议配置为 https://sora.chatgpt.com/backend)"
}
func (c *SoraDirectClient) generateSentinelToken(ctx context.Context, account *Account, accessToken, userAgent, proxyURL string) (string, error) {
reqID := uuid.NewString()
userAgent = strings.TrimSpace(userAgent)
if userAgent == "" {
userAgent = c.taskUserAgent()
}
powToken := soraPowTokenGenerator(userAgent)
payload := map[string]any{
"p": powToken,
"flow": soraSentinelFlow,
"id": reqID,
}
body, err := json.Marshal(payload)
if err != nil {
return "", err
}
headers := http.Header{}
headers.Set("Accept", "application/json, text/plain, */*")
headers.Set("Content-Type", "application/json")
headers.Set("Origin", "https://sora.chatgpt.com")
headers.Set("Referer", "https://sora.chatgpt.com/")
headers.Set("User-Agent", userAgent)
if accessToken != "" {
headers.Set("Authorization", "Bearer "+accessToken)
}
urlStr := soraChatGPTBaseURL + "/backend-api/sentinel/req"
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, urlStr, headers, bytes.NewReader(body), true)
if err != nil {
return "", err
}
var resp map[string]any
if err := json.Unmarshal(respBody, &resp); err != nil {
return "", err
}
sentinel := soraBuildSentinelToken(soraSentinelFlow, reqID, powToken, resp, userAgent)
if sentinel == "" {
return "", errors.New("failed to build sentinel token")
}
return sentinel, nil
}
func soraGetPowToken(userAgent string) string {
configList := soraBuildPowConfig(userAgent)
seed := strconv.FormatFloat(soraRandFloat(), 'f', -1, 64)
difficulty := "0fffff"
solution, _ := soraSolvePow(seed, difficulty, configList)
return "gAAAAAC" + solution
}
func soraRandFloat() float64 {
soraRandMu.Lock()
defer soraRandMu.Unlock()
return soraRand.Float64()
}
func soraRandInt(max int) int {
if max <= 1 {
return 0
}
soraRandMu.Lock()
defer soraRandMu.Unlock()
return soraRand.Intn(max)
}
func soraBuildPowConfig(userAgent string) []any {
userAgent = strings.TrimSpace(userAgent)
if userAgent == "" && len(soraDesktopUserAgents) > 0 {
userAgent = soraDesktopUserAgents[0]
}
screenVal := soraStableChoiceInt([]int{
1920 + 1080,
2560 + 1440,
1920 + 1200,
2560 + 1600,
}, userAgent+"|screen")
perfMs := float64(time.Since(soraPerfStart).Milliseconds())
wallMs := float64(time.Now().UnixNano()) / 1e6
diff := wallMs - perfMs
return []any{
screenVal,
soraPowParseTime(),
4294705152,
0,
userAgent,
soraStableChoice(soraPowScripts, userAgent+"|script"),
soraStableChoice(soraPowDPL, userAgent+"|dpl"),
"en-US",
"en-US,es-US,en,es",
0,
soraStableChoice(soraPowNavigatorKeys, userAgent+"|navigator"),
soraStableChoice(soraPowDocumentKeys, userAgent+"|document"),
soraStableChoice(soraPowWindowKeys, userAgent+"|window"),
perfMs,
uuid.NewString(),
"",
soraStableChoiceInt(soraPowCores, userAgent+"|cores"),
diff,
}
}
func soraStableChoice(items []string, seed string) string {
if len(items) == 0 {
return ""
}
idx := soraStableIndex(seed, len(items))
return items[idx]
}
func soraStableChoiceInt(items []int, seed string) int {
if len(items) == 0 {
return 0
}
idx := soraStableIndex(seed, len(items))
return items[idx]
}
func soraStableIndex(seed string, size int) int {
if size <= 0 {
return 0
}
h := fnv.New32a()
_, _ = h.Write([]byte(seed))
return int(h.Sum32() % uint32(size))
}
func soraPowParseTime() string {
loc := time.FixedZone("EST", -5*3600)
return time.Now().In(loc).Format("Mon Jan 02 2006 15:04:05 GMT-0700 (Eastern Standard Time)")
}
func soraSolvePow(seed, difficulty string, configList []any) (string, bool) {
diffLen := len(difficulty) / 2
target, err := hexDecodeString(difficulty)
if err != nil {
return "", false
}
seedBytes := []byte(seed)
part1 := mustMarshalJSON(configList[:3])
part2 := mustMarshalJSON(configList[4:9])
part3 := mustMarshalJSON(configList[10:])
staticPart1 := append(part1[:len(part1)-1], ',')
staticPart2 := append([]byte(","), append(part2[1:len(part2)-1], ',')...)
staticPart3 := append([]byte(","), part3[1:]...)
for i := 0; i < soraPowMaxIteration; i++ {
dynamicI := []byte(strconv.Itoa(i))
dynamicJ := []byte(strconv.Itoa(i >> 1))
finalJSON := make([]byte, 0, len(staticPart1)+len(dynamicI)+len(staticPart2)+len(dynamicJ)+len(staticPart3))
finalJSON = append(finalJSON, staticPart1...)
finalJSON = append(finalJSON, dynamicI...)
finalJSON = append(finalJSON, staticPart2...)
finalJSON = append(finalJSON, dynamicJ...)
finalJSON = append(finalJSON, staticPart3...)
b64 := base64.StdEncoding.EncodeToString(finalJSON)
hash := sha3.Sum512(append(seedBytes, []byte(b64)...))
if bytes.Compare(hash[:diffLen], target[:diffLen]) <= 0 {
return b64, true
}
}
errorToken := "wQ8Lk5FbGpA2NcR9dShT6gYjU7VxZ4D" + base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("\"%s\"", seed)))
return errorToken, false
}
func soraBuildSentinelToken(flow, reqID, powToken string, resp map[string]any, userAgent string) string {
finalPow := powToken
proof, _ := resp["proofofwork"].(map[string]any)
if required, _ := proof["required"].(bool); required {
seed, _ := proof["seed"].(string)
difficulty, _ := proof["difficulty"].(string)
if seed != "" && difficulty != "" {
configList := soraBuildPowConfig(userAgent)
solution, _ := soraSolvePow(seed, difficulty, configList)
finalPow = "gAAAAAB" + solution
}
}
if !strings.HasSuffix(finalPow, "~S") {
finalPow += "~S"
}
turnstile, _ := resp["turnstile"].(map[string]any)
tokenPayload := map[string]any{
"p": finalPow,
"t": safeMapString(turnstile, "dx"),
"c": safeString(resp["token"]),
"id": reqID,
"flow": flow,
}
encoded, _ := json.Marshal(tokenPayload)
return string(encoded)
}
func safeMapString(m map[string]any, key string) string {
if m == nil {
return ""
}
if v, ok := m[key]; ok {
return safeString(v)
}
return ""
}
func safeString(v any) string {
switch val := v.(type) {
case string:
return val
default:
return fmt.Sprintf("%v", val)
}
}
func mustMarshalJSON(v any) []byte {
b, _ := json.Marshal(v)
return b
}
func hexDecodeString(s string) ([]byte, error) {
dst := make([]byte, len(s)/2)
_, err := hex.Decode(dst, []byte(s))
return dst, err
}
func (c *SoraDirectClient) withRequestTrace(ctx context.Context, account *Account, proxyURL, userAgent string) context.Context {
if ctx == nil {
ctx = context.Background()
}
if existing, ok := ctx.Value(soraRequestTraceContextKey{}).(*soraRequestTrace); ok && existing != nil && existing.ID != "" {
return ctx
}
accountID := int64(0)
if account != nil {
accountID = account.ID
}
seed := fmt.Sprintf("%d|%s|%s|%d", accountID, normalizeSoraProxyKey(proxyURL), strings.TrimSpace(userAgent), time.Now().UnixNano())
trace := &soraRequestTrace{
ID: "sora-" + soraHashForLog(seed),
ProxyKey: normalizeSoraProxyKey(proxyURL),
UAHash: soraHashForLog(strings.TrimSpace(userAgent)),
}
return context.WithValue(ctx, soraRequestTraceContextKey{}, trace)
}
func (c *SoraDirectClient) requestTraceFields(ctx context.Context, proxyURL, userAgent string) (string, string, string) {
proxyKey := normalizeSoraProxyKey(proxyURL)
uaHash := soraHashForLog(strings.TrimSpace(userAgent))
traceID := ""
if ctx != nil {
if trace, ok := ctx.Value(soraRequestTraceContextKey{}).(*soraRequestTrace); ok && trace != nil {
if strings.TrimSpace(trace.ID) != "" {
traceID = strings.TrimSpace(trace.ID)
}
if strings.TrimSpace(trace.ProxyKey) != "" {
proxyKey = strings.TrimSpace(trace.ProxyKey)
}
if strings.TrimSpace(trace.UAHash) != "" {
uaHash = strings.TrimSpace(trace.UAHash)
}
}
}
if traceID == "" {
traceID = "sora-" + soraHashForLog(fmt.Sprintf("%s|%d", proxyKey, time.Now().UnixNano()))
}
return traceID, proxyKey, uaHash
}
func soraHashForLog(raw string) string {
h := fnv.New32a()
_, _ = h.Write([]byte(raw))
return fmt.Sprintf("%08x", h.Sum32())
}
func sanitizeSoraLogURL(raw string) string {
parsed, err := url.Parse(raw)
if err != nil {
return raw
}
q := parsed.Query()
q.Del("sig")
q.Del("expires")
parsed.RawQuery = q.Encode()
return parsed.String()
}
func (c *SoraDirectClient) debugEnabled() bool {
return c != nil && c.cfg != nil && c.cfg.Sora.Client.Debug
}
func (c *SoraDirectClient) debugLogf(format string, args ...any) {
if !c.debugEnabled() {
return
}
log.Printf("[SoraClient] "+format, args...)
}
func formatSoraHeaders(headers http.Header) string {
if len(headers) == 0 {
return "{}"
}
keys := make([]string, 0, len(headers))
for key := range headers {
keys = append(keys, key)
}
sort.Strings(keys)
out := make(map[string]string, len(keys))
for _, key := range keys {
values := headers.Values(key)
if len(values) == 0 {
continue
}
val := strings.Join(values, ",")
if isSensitiveHeader(key) {
out[key] = "***"
continue
}
out[key] = truncateForLog([]byte(logredact.RedactText(val)), 160)
}
encoded, err := json.Marshal(out)
if err != nil {
return "{}"
}
return string(encoded)
}
func isSensitiveHeader(key string) bool {
k := strings.ToLower(strings.TrimSpace(key))
switch k {
case "authorization", "openai-sentinel-token", "cookie", "set-cookie", "x-api-key":
return true
default:
return false
}
}
func summarizeSoraResponseBody(body []byte, maxLen int) string {
if len(body) == 0 {
return ""
}
var text string
if json.Valid(body) {
text = logredact.RedactJSON(body)
} else {
text = logredact.RedactText(string(body))
}
text = strings.TrimSpace(text)
if maxLen <= 0 || len(text) <= maxLen {
return text
}
return text[:maxLen] + "...(truncated)"
}
//go:build unit
package service
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
// ---------- 辅助解析函数(复制生产代码中的 gjson 解析逻辑,用于单元测试) ----------
// testParseUploadOrCreateTaskID 模拟 UploadImage / CreateImageTask / CreateVideoTask 中
// 用 gjson.GetBytes(respBody, "id") 提取 id 的逻辑。
func testParseUploadOrCreateTaskID(respBody []byte) (string, error) {
id := strings.TrimSpace(gjson.GetBytes(respBody, "id").String())
if id == "" {
return "", assert.AnError // 占位错误,表示 "missing id"
}
return id, nil
}
// testParseFetchRecentImageTask 模拟 fetchRecentImageTask 中的 gjson.ForEach 解析逻辑。
func testParseFetchRecentImageTask(respBody []byte, taskID string) (*SoraImageTaskStatus, bool) {
var found *SoraImageTaskStatus
gjson.GetBytes(respBody, "task_responses").ForEach(func(_, item gjson.Result) bool {
if item.Get("id").String() != taskID {
return true // continue
}
status := strings.TrimSpace(item.Get("status").String())
progress := item.Get("progress_pct").Float()
var urls []string
item.Get("generations").ForEach(func(_, gen gjson.Result) bool {
if u := strings.TrimSpace(gen.Get("url").String()); u != "" {
urls = append(urls, u)
}
return true
})
found = &SoraImageTaskStatus{
ID: taskID,
Status: status,
ProgressPct: progress,
URLs: urls,
}
return false // break
})
if found != nil {
return found, true
}
return &SoraImageTaskStatus{ID: taskID, Status: "processing"}, false
}
// testParseGetVideoTaskPending 模拟 GetVideoTask 中解析 pending 列表的逻辑。
func testParseGetVideoTaskPending(respBody []byte, taskID string) (*SoraVideoTaskStatus, bool) {
pendingResult := gjson.ParseBytes(respBody)
if !pendingResult.IsArray() {
return nil, false
}
var pendingFound *SoraVideoTaskStatus
pendingResult.ForEach(func(_, task gjson.Result) bool {
if task.Get("id").String() != taskID {
return true
}
progress := 0
if v := task.Get("progress_pct"); v.Exists() {
progress = int(v.Float() * 100)
}
status := strings.TrimSpace(task.Get("status").String())
pendingFound = &SoraVideoTaskStatus{
ID: taskID,
Status: status,
ProgressPct: progress,
}
return false
})
if pendingFound != nil {
return pendingFound, true
}
return nil, false
}
// testParseGetVideoTaskDrafts 模拟 GetVideoTask 中解析 drafts 列表的逻辑。
func testParseGetVideoTaskDrafts(respBody []byte, taskID string) (*SoraVideoTaskStatus, bool) {
var draftFound *SoraVideoTaskStatus
gjson.GetBytes(respBody, "items").ForEach(func(_, draft gjson.Result) bool {
if draft.Get("task_id").String() != taskID {
return true
}
kind := strings.TrimSpace(draft.Get("kind").String())
reason := strings.TrimSpace(draft.Get("reason_str").String())
if reason == "" {
reason = strings.TrimSpace(draft.Get("markdown_reason_str").String())
}
urlStr := strings.TrimSpace(draft.Get("downloadable_url").String())
if urlStr == "" {
urlStr = strings.TrimSpace(draft.Get("url").String())
}
if kind == "sora_content_violation" || reason != "" || urlStr == "" {
msg := reason
if msg == "" {
msg = "Content violates guardrails"
}
draftFound = &SoraVideoTaskStatus{
ID: taskID,
Status: "failed",
ErrorMsg: msg,
}
} else {
draftFound = &SoraVideoTaskStatus{
ID: taskID,
Status: "completed",
URLs: []string{urlStr},
}
}
return false
})
if draftFound != nil {
return draftFound, true
}
return nil, false
}
// ===================== Test 1: TestSoraParseUploadResponse =====================
func TestSoraParseUploadResponse(t *testing.T) {
tests := []struct {
name string
body string
wantID string
wantErr bool
}{
{
name: "正常 id",
body: `{"id":"file-abc123","status":"uploaded"}`,
wantID: "file-abc123",
},
{
name: "空 id",
body: `{"id":"","status":"uploaded"}`,
wantErr: true,
},
{
name: "无 id 字段",
body: `{"status":"uploaded"}`,
wantErr: true,
},
{
name: "id 全为空白",
body: `{"id":" ","status":"uploaded"}`,
wantErr: true,
},
{
name: "id 前后有空白",
body: `{"id":" file-trimmed ","status":"uploaded"}`,
wantID: "file-trimmed",
},
{
name: "空 JSON 对象",
body: `{}`,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
id, err := testParseUploadOrCreateTaskID([]byte(tt.body))
if tt.wantErr {
require.Error(t, err, "应返回错误")
return
}
require.NoError(t, err)
require.Equal(t, tt.wantID, id)
})
}
}
// ===================== Test 2: TestSoraParseCreateTaskResponse =====================
func TestSoraParseCreateTaskResponse(t *testing.T) {
tests := []struct {
name string
body string
wantID string
wantErr bool
}{
{
name: "正常任务 id",
body: `{"id":"task-123"}`,
wantID: "task-123",
},
{
name: "缺失 id",
body: `{"status":"created"}`,
wantErr: true,
},
{
name: "空 id",
body: `{"id":" "}`,
wantErr: true,
},
{
name: "id 为数字(gjson 转字符串)",
body: `{"id":123}`,
wantID: "123",
},
{
name: "id 含特殊字符",
body: `{"id":"task-abc-def-456-ghi"}`,
wantID: "task-abc-def-456-ghi",
},
{
name: "额外字段不影响解析",
body: `{"id":"task-999","type":"image_gen","extra":"data"}`,
wantID: "task-999",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
id, err := testParseUploadOrCreateTaskID([]byte(tt.body))
if tt.wantErr {
require.Error(t, err, "应返回错误")
return
}
require.NoError(t, err)
require.Equal(t, tt.wantID, id)
})
}
}
// ===================== Test 3: TestSoraParseFetchRecentImageTask =====================
func TestSoraParseFetchRecentImageTask(t *testing.T) {
tests := []struct {
name string
body string
taskID string
wantFound bool
wantStatus string
wantProgress float64
wantURLs []string
}{
{
name: "匹配已完成任务",
body: `{"task_responses":[{"id":"task-1","status":"completed","progress_pct":1.0,"generations":[{"url":"https://example.com/img.png"}]}]}`,
taskID: "task-1",
wantFound: true,
wantStatus: "completed",
wantProgress: 1.0,
wantURLs: []string{"https://example.com/img.png"},
},
{
name: "匹配处理中任务",
body: `{"task_responses":[{"id":"task-2","status":"processing","progress_pct":0.5,"generations":[]}]}`,
taskID: "task-2",
wantFound: true,
wantStatus: "processing",
wantProgress: 0.5,
wantURLs: nil,
},
{
name: "无匹配任务",
body: `{"task_responses":[{"id":"other","status":"completed"}]}`,
taskID: "task-1",
wantFound: false,
wantStatus: "processing",
},
{
name: "空 task_responses",
body: `{"task_responses":[]}`,
taskID: "task-1",
wantFound: false,
wantStatus: "processing",
},
{
name: "缺少 task_responses 字段",
body: `{"other":"data"}`,
taskID: "task-1",
wantFound: false,
wantStatus: "processing",
},
{
name: "多个任务中精准匹配",
body: `{"task_responses":[{"id":"task-a","status":"completed","progress_pct":1.0,"generations":[{"url":"https://a.com/1.png"}]},{"id":"task-b","status":"processing","progress_pct":0.3,"generations":[]},{"id":"task-c","status":"failed","progress_pct":0}]}`,
taskID: "task-b",
wantFound: true,
wantStatus: "processing",
wantProgress: 0.3,
wantURLs: nil,
},
{
name: "多个 generations",
body: `{"task_responses":[{"id":"task-m","status":"completed","progress_pct":1.0,"generations":[{"url":"https://a.com/1.png"},{"url":"https://a.com/2.png"},{"url":""}]}]}`,
taskID: "task-m",
wantFound: true,
wantStatus: "completed",
wantProgress: 1.0,
wantURLs: []string{"https://a.com/1.png", "https://a.com/2.png"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
status, found := testParseFetchRecentImageTask([]byte(tt.body), tt.taskID)
require.Equal(t, tt.wantFound, found, "found 不匹配")
require.NotNil(t, status)
require.Equal(t, tt.taskID, status.ID)
require.Equal(t, tt.wantStatus, status.Status)
if tt.wantFound {
require.InDelta(t, tt.wantProgress, status.ProgressPct, 0.001, "进度不匹配")
require.Equal(t, tt.wantURLs, status.URLs)
}
})
}
}
// ===================== Test 4: TestSoraParseGetVideoTaskPending =====================
func TestSoraParseGetVideoTaskPending(t *testing.T) {
tests := []struct {
name string
body string
taskID string
wantFound bool
wantStatus string
wantProgress int
}{
{
name: "匹配 pending 任务",
body: `[{"id":"task-1","status":"processing","progress_pct":0.5}]`,
taskID: "task-1",
wantFound: true,
wantStatus: "processing",
wantProgress: 50,
},
{
name: "进度为 0",
body: `[{"id":"task-2","status":"queued","progress_pct":0}]`,
taskID: "task-2",
wantFound: true,
wantStatus: "queued",
wantProgress: 0,
},
{
name: "进度为 1(100%)",
body: `[{"id":"task-3","status":"completing","progress_pct":1.0}]`,
taskID: "task-3",
wantFound: true,
wantStatus: "completing",
wantProgress: 100,
},
{
name: "空数组",
body: `[]`,
taskID: "task-1",
wantFound: false,
},
{
name: "无匹配 id",
body: `[{"id":"task-other","status":"processing","progress_pct":0.3}]`,
taskID: "task-1",
wantFound: false,
},
{
name: "多个任务精准匹配",
body: `[{"id":"task-a","status":"processing","progress_pct":0.2},{"id":"task-b","status":"queued","progress_pct":0},{"id":"task-c","status":"processing","progress_pct":0.8}]`,
taskID: "task-c",
wantFound: true,
wantStatus: "processing",
wantProgress: 80,
},
{
name: "非数组 JSON",
body: `{"id":"task-1","status":"processing"}`,
taskID: "task-1",
wantFound: false,
},
{
name: "无 progress_pct 字段",
body: `[{"id":"task-4","status":"pending"}]`,
taskID: "task-4",
wantFound: true,
wantStatus: "pending",
wantProgress: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
status, found := testParseGetVideoTaskPending([]byte(tt.body), tt.taskID)
require.Equal(t, tt.wantFound, found, "found 不匹配")
if tt.wantFound {
require.NotNil(t, status)
require.Equal(t, tt.taskID, status.ID)
require.Equal(t, tt.wantStatus, status.Status)
require.Equal(t, tt.wantProgress, status.ProgressPct)
}
})
}
}
// ===================== Test 5: TestSoraParseGetVideoTaskDrafts =====================
func TestSoraParseGetVideoTaskDrafts(t *testing.T) {
tests := []struct {
name string
body string
taskID string
wantFound bool
wantStatus string
wantURLs []string
wantErr string
}{
{
name: "正常完成的视频",
body: `{"items":[{"task_id":"task-1","kind":"video","downloadable_url":"https://example.com/video.mp4"}]}`,
taskID: "task-1",
wantFound: true,
wantStatus: "completed",
wantURLs: []string{"https://example.com/video.mp4"},
},
{
name: "使用 url 字段回退",
body: `{"items":[{"task_id":"task-2","kind":"video","url":"https://example.com/fallback.mp4"}]}`,
taskID: "task-2",
wantFound: true,
wantStatus: "completed",
wantURLs: []string{"https://example.com/fallback.mp4"},
},
{
name: "内容违规",
body: `{"items":[{"task_id":"task-3","kind":"sora_content_violation","reason_str":"Content policy violation"}]}`,
taskID: "task-3",
wantFound: true,
wantStatus: "failed",
wantErr: "Content policy violation",
},
{
name: "内容违规 - markdown_reason_str 回退",
body: `{"items":[{"task_id":"task-4","kind":"sora_content_violation","markdown_reason_str":"Markdown reason"}]}`,
taskID: "task-4",
wantFound: true,
wantStatus: "failed",
wantErr: "Markdown reason",
},
{
name: "内容违规 - 无 reason 使用默认消息",
body: `{"items":[{"task_id":"task-5","kind":"sora_content_violation"}]}`,
taskID: "task-5",
wantFound: true,
wantStatus: "failed",
wantErr: "Content violates guardrails",
},
{
name: "有 reason_str 但非 violation kind(仍判定失败)",
body: `{"items":[{"task_id":"task-6","kind":"video","reason_str":"Some error occurred"}]}`,
taskID: "task-6",
wantFound: true,
wantStatus: "failed",
wantErr: "Some error occurred",
},
{
name: "空 URL 判定为失败",
body: `{"items":[{"task_id":"task-7","kind":"video","downloadable_url":"","url":""}]}`,
taskID: "task-7",
wantFound: true,
wantStatus: "failed",
wantErr: "Content violates guardrails",
},
{
name: "无匹配 task_id",
body: `{"items":[{"task_id":"task-other","kind":"video","downloadable_url":"https://example.com/video.mp4"}]}`,
taskID: "task-1",
wantFound: false,
},
{
name: "空 items",
body: `{"items":[]}`,
taskID: "task-1",
wantFound: false,
},
{
name: "缺少 items 字段",
body: `{"other":"data"}`,
taskID: "task-1",
wantFound: false,
},
{
name: "多个 items 精准匹配",
body: `{"items":[{"task_id":"task-a","kind":"video","downloadable_url":"https://a.com/a.mp4"},{"task_id":"task-b","kind":"sora_content_violation","reason_str":"Bad content"},{"task_id":"task-c","kind":"video","downloadable_url":"https://c.com/c.mp4"}]}`,
taskID: "task-b",
wantFound: true,
wantStatus: "failed",
wantErr: "Bad content",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
status, found := testParseGetVideoTaskDrafts([]byte(tt.body), tt.taskID)
require.Equal(t, tt.wantFound, found, "found 不匹配")
if !tt.wantFound {
return
}
require.NotNil(t, status)
require.Equal(t, tt.taskID, status.ID)
require.Equal(t, tt.wantStatus, status.Status)
if tt.wantErr != "" {
require.Equal(t, tt.wantErr, status.ErrorMsg)
}
if tt.wantURLs != nil {
require.Equal(t, tt.wantURLs, status.URLs)
}
})
}
}
//go:build unit
package service
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestSoraDirectClient_DoRequestSuccess(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"ok":true}`))
}))
defer server.Close()
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{BaseURL: server.URL},
},
}
client := NewSoraDirectClient(cfg, nil, nil)
body, _, err := client.doRequest(context.Background(), &Account{ID: 1}, http.MethodGet, server.URL, http.Header{}, nil, false)
require.NoError(t, err)
require.Contains(t, string(body), "ok")
}
func TestSoraDirectClient_BuildBaseHeaders(t *testing.T) {
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
Headers: map[string]string{
"X-Test": "yes",
"Authorization": "should-ignore",
"openai-sentinel-token": "skip",
},
},
},
}
client := NewSoraDirectClient(cfg, nil, nil)
headers := client.buildBaseHeaders("token-123", "UA")
require.Equal(t, "Bearer token-123", headers.Get("Authorization"))
require.Equal(t, "UA", headers.Get("User-Agent"))
require.Equal(t, "yes", headers.Get("X-Test"))
require.Empty(t, headers.Get("openai-sentinel-token"))
}
func TestSoraDirectClient_GetImageTaskFallbackLimit(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
limit := r.URL.Query().Get("limit")
w.Header().Set("Content-Type", "application/json")
switch limit {
case "1":
_, _ = w.Write([]byte(`{"task_responses":[]}`))
case "2":
_, _ = w.Write([]byte(`{"task_responses":[{"id":"task-1","status":"completed","progress_pct":1,"generations":[{"url":"https://example.com/a.png"}]}]}`))
default:
_, _ = w.Write([]byte(`{"task_responses":[]}`))
}
}))
defer server.Close()
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: server.URL,
RecentTaskLimit: 1,
RecentTaskLimitMax: 2,
},
},
}
client := NewSoraDirectClient(cfg, nil, nil)
account := &Account{Credentials: map[string]any{"access_token": "token"}}
status, err := client.GetImageTask(context.Background(), account, "task-1")
require.NoError(t, err)
require.Equal(t, "completed", status.Status)
require.Equal(t, []string{"https://example.com/a.png"}, status.URLs)
}
func TestNormalizeSoraBaseURL(t *testing.T) {
t.Parallel()
tests := []struct {
name string
raw string
want string
}{
{
name: "empty",
raw: "",
want: "",
},
{
name: "append_backend_for_sora_host",
raw: "https://sora.chatgpt.com",
want: "https://sora.chatgpt.com/backend",
},
{
name: "convert_backend_api_to_backend",
raw: "https://sora.chatgpt.com/backend-api",
want: "https://sora.chatgpt.com/backend",
},
{
name: "keep_backend",
raw: "https://sora.chatgpt.com/backend",
want: "https://sora.chatgpt.com/backend",
},
{
name: "keep_custom_host",
raw: "https://example.com/custom-path",
want: "https://example.com/custom-path",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := normalizeSoraBaseURL(tt.raw)
require.Equal(t, tt.want, got)
})
}
}
func TestSoraDirectClient_BuildURL_UsesNormalizedBaseURL(t *testing.T) {
t.Parallel()
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com",
},
},
}
client := NewSoraDirectClient(cfg, nil, nil)
require.Equal(t, "https://sora.chatgpt.com/backend/video_gen", client.buildURL("/video_gen"))
}
func TestSoraDirectClient_BuildUpstreamError_NotFoundHint(t *testing.T) {
t.Parallel()
client := NewSoraDirectClient(&config.Config{}, nil, nil)
err := client.buildUpstreamError(http.StatusNotFound, http.Header{}, []byte(`{"error":{"message":"Not found"}}`), "https://sora.chatgpt.com/video_gen")
var upstreamErr *SoraUpstreamError
require.ErrorAs(t, err, &upstreamErr)
require.Contains(t, upstreamErr.Message, "请检查 sora.client.base_url")
errNoHint := client.buildUpstreamError(http.StatusNotFound, http.Header{}, []byte(`{"error":{"message":"Not found"}}`), "https://sora.chatgpt.com/backend/video_gen")
require.ErrorAs(t, errNoHint, &upstreamErr)
require.NotContains(t, upstreamErr.Message, "请检查 sora.client.base_url")
}
func TestFormatSoraHeaders_RedactsSensitive(t *testing.T) {
t.Parallel()
headers := http.Header{}
headers.Set("Authorization", "Bearer secret-token")
headers.Set("openai-sentinel-token", "sentinel-secret")
headers.Set("X-Test", "ok")
out := formatSoraHeaders(headers)
require.Contains(t, out, `"Authorization":"***"`)
require.Contains(t, out, `Sentinel-Token":"***"`)
require.Contains(t, out, `"X-Test":"ok"`)
require.NotContains(t, out, "secret-token")
require.NotContains(t, out, "sentinel-secret")
}
func TestSummarizeSoraResponseBody_RedactsJSON(t *testing.T) {
t.Parallel()
body := []byte(`{"error":{"message":"bad"},"access_token":"abc123"}`)
out := summarizeSoraResponseBody(body, 512)
require.Contains(t, out, `"access_token":"***"`)
require.NotContains(t, out, "abc123")
}
func TestSummarizeSoraResponseBody_Truncates(t *testing.T) {
t.Parallel()
body := []byte(strings.Repeat("x", 100))
out := summarizeSoraResponseBody(body, 10)
require.Contains(t, out, "(truncated)")
}
func TestSoraDirectClient_GetAccessToken_SoraDefaultUseCredentials(t *testing.T) {
t.Parallel()
cache := newOpenAITokenCacheStub()
provider := NewOpenAITokenProvider(nil, cache, nil)
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com/backend",
},
},
}
client := NewSoraDirectClient(cfg, nil, provider)
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "sora-credential-token",
},
}
token, err := client.getAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "sora-credential-token", token)
require.Equal(t, int32(0), atomic.LoadInt32(&cache.getCalled))
}
func TestSoraDirectClient_GetAccessToken_SoraCanEnableProvider(t *testing.T) {
t.Parallel()
cache := newOpenAITokenCacheStub()
account := &Account{
ID: 2,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "sora-credential-token",
},
}
cache.tokens[OpenAITokenCacheKey(account)] = "provider-token"
provider := NewOpenAITokenProvider(nil, cache, nil)
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com/backend",
UseOpenAITokenProvider: true,
},
},
}
client := NewSoraDirectClient(cfg, nil, provider)
token, err := client.getAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "provider-token", token)
require.Greater(t, atomic.LoadInt32(&cache.getCalled), int32(0))
}
func TestSoraDirectClient_GetAccessToken_FromSessionToken(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=session-token")
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"accessToken": "session-access-token",
"expires": "2099-01-01T00:00:00Z",
})
}))
defer server.Close()
origin := soraSessionAuthURL
soraSessionAuthURL = server.URL
defer func() { soraSessionAuthURL = origin }()
client := NewSoraDirectClient(&config.Config{}, nil, nil)
account := &Account{
ID: 10,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"session_token": "session-token",
},
}
token, err := client.getAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "session-access-token", token)
require.Equal(t, "session-access-token", account.GetCredential("access_token"))
}
func TestSoraDirectClient_GetAccessToken_FromRefreshToken(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodPost, r.Method)
require.Equal(t, "/oauth/token", r.URL.Path)
require.Equal(t, "application/x-www-form-urlencoded", r.Header.Get("Content-Type"))
require.NoError(t, r.ParseForm())
require.Equal(t, "refresh_token", r.FormValue("grant_type"))
require.Equal(t, "refresh-token-old", r.FormValue("refresh_token"))
require.NotEmpty(t, r.FormValue("client_id"))
require.Equal(t, "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback", r.FormValue("redirect_uri"))
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"access_token": "refresh-access-token",
"refresh_token": "refresh-token-new",
"expires_in": 3600,
})
}))
defer server.Close()
origin := soraOAuthTokenURL
soraOAuthTokenURL = server.URL + "/oauth/token"
defer func() { soraOAuthTokenURL = origin }()
client := NewSoraDirectClient(&config.Config{}, nil, nil)
account := &Account{
ID: 11,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"refresh_token": "refresh-token-old",
},
}
token, err := client.getAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "refresh-access-token", token)
require.Equal(t, "refresh-token-new", account.GetCredential("refresh_token"))
require.NotNil(t, account.GetCredentialAsTime("expires_at"))
}
func TestSoraDirectClient_PreflightCheck_VideoQuotaExceeded(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
require.Equal(t, "/nf/check", r.URL.Path)
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"rate_limit_and_credit_balance": map[string]any{
"estimated_num_videos_remaining": 0,
"rate_limit_reached": true,
},
})
}))
defer server.Close()
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: server.URL,
},
},
}
client := NewSoraDirectClient(cfg, nil, nil)
account := &Account{
ID: 12,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "ok",
"expires_at": time.Now().Add(2 * time.Hour).Format(time.RFC3339),
},
}
err := client.PreflightCheck(context.Background(), account, "sora2-landscape-10s", SoraModelConfig{Type: "video"})
require.Error(t, err)
var upstreamErr *SoraUpstreamError
require.ErrorAs(t, err, &upstreamErr)
require.Equal(t, http.StatusTooManyRequests, upstreamErr.StatusCode)
}
func TestShouldAttemptSoraTokenRecover(t *testing.T) {
t.Parallel()
require.True(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://sora.chatgpt.com/backend/video_gen"))
require.True(t, shouldAttemptSoraTokenRecover(http.StatusForbidden, "https://chatgpt.com/backend/video_gen"))
require.False(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://sora.chatgpt.com/api/auth/session"))
require.False(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://auth.openai.com/oauth/token"))
require.False(t, shouldAttemptSoraTokenRecover(http.StatusTooManyRequests, "https://sora.chatgpt.com/backend/video_gen"))
}
type soraClientRequestCall struct {
Path string
UserAgent string
ProxyURL string
}
type soraClientRecordingUpstream struct {
calls []soraClientRequestCall
}
func (u *soraClientRecordingUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
return nil, errors.New("unexpected Do call")
}
func (u *soraClientRecordingUpstream) DoWithTLS(req *http.Request, proxyURL string, _ int64, _ int, _ bool) (*http.Response, error) {
u.calls = append(u.calls, soraClientRequestCall{
Path: req.URL.Path,
UserAgent: req.Header.Get("User-Agent"),
ProxyURL: proxyURL,
})
switch req.URL.Path {
case "/backend-api/sentinel/req":
return newSoraClientMockResponse(http.StatusOK, `{"token":"sentinel-token","turnstile":{"dx":"ok"}}`), nil
case "/backend/nf/create":
return newSoraClientMockResponse(http.StatusOK, `{"id":"task-123"}`), nil
case "/backend/nf/create/storyboard":
return newSoraClientMockResponse(http.StatusOK, `{"id":"storyboard-123"}`), nil
case "/backend/uploads":
return newSoraClientMockResponse(http.StatusOK, `{"id":"upload-123"}`), nil
case "/backend/nf/check":
return newSoraClientMockResponse(http.StatusOK, `{"rate_limit_and_credit_balance":{"estimated_num_videos_remaining":1,"rate_limit_reached":false}}`), nil
case "/backend/characters/upload":
return newSoraClientMockResponse(http.StatusOK, `{"id":"cameo-123"}`), nil
case "/backend/project_y/cameos/in_progress/cameo-123":
return newSoraClientMockResponse(http.StatusOK, `{"status":"finalized","status_message":"Completed","username_hint":"foo.bar","display_name_hint":"Bar","profile_asset_url":"https://example.com/avatar.webp"}`), nil
case "/backend/project_y/file/upload":
return newSoraClientMockResponse(http.StatusOK, `{"asset_pointer":"asset-123"}`), nil
case "/backend/characters/finalize":
return newSoraClientMockResponse(http.StatusOK, `{"character":{"character_id":"character-123"}}`), nil
case "/backend/project_y/post":
return newSoraClientMockResponse(http.StatusOK, `{"post":{"id":"s_post"}}`), nil
default:
return newSoraClientMockResponse(http.StatusOK, `{"ok":true}`), nil
}
}
func newSoraClientMockResponse(statusCode int, body string) *http.Response {
return &http.Response{
StatusCode: statusCode,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(body)),
}
}
func TestSoraDirectClient_TaskUserAgent_DefaultMobileFallback(t *testing.T) {
client := NewSoraDirectClient(&config.Config{}, nil, nil)
ua := client.taskUserAgent()
require.NotEmpty(t, ua)
allowed := append([]string{}, soraMobileUserAgents...)
allowed = append(allowed, soraDesktopUserAgents...)
require.Contains(t, allowed, ua)
}
func TestSoraDirectClient_CreateVideoTask_UsesSameUserAgentAndProxyForSentinelAndCreate(t *testing.T) {
originPowTokenGenerator := soraPowTokenGenerator
soraPowTokenGenerator = func(_ string) string { return "gAAAAACmock" }
defer func() {
soraPowTokenGenerator = originPowTokenGenerator
}()
upstream := &soraClientRecordingUpstream{}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com/backend",
},
},
}
client := NewSoraDirectClient(cfg, upstream, nil)
proxyID := int64(9)
account := &Account{
ID: 21,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
ProxyID: &proxyID,
Proxy: &Proxy{
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
},
Credentials: map[string]any{
"access_token": "access-token",
"expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
},
}
taskID, err := client.CreateVideoTask(context.Background(), account, SoraVideoRequest{Prompt: "test"})
require.NoError(t, err)
require.Equal(t, "task-123", taskID)
require.Len(t, upstream.calls, 2)
sentinelCall := upstream.calls[0]
createCall := upstream.calls[1]
require.Equal(t, "/backend-api/sentinel/req", sentinelCall.Path)
require.Equal(t, "/backend/nf/create", createCall.Path)
require.Equal(t, "http://127.0.0.1:8080", sentinelCall.ProxyURL)
require.Equal(t, sentinelCall.ProxyURL, createCall.ProxyURL)
require.NotEmpty(t, sentinelCall.UserAgent)
require.Equal(t, sentinelCall.UserAgent, createCall.UserAgent)
}
func TestSoraDirectClient_UploadImage_UsesTaskUserAgentAndProxy(t *testing.T) {
upstream := &soraClientRecordingUpstream{}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com/backend",
},
},
}
client := NewSoraDirectClient(cfg, upstream, nil)
proxyID := int64(3)
account := &Account{
ID: 31,
ProxyID: &proxyID,
Proxy: &Proxy{
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
},
Credentials: map[string]any{
"access_token": "access-token",
"expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
},
}
uploadID, err := client.UploadImage(context.Background(), account, []byte("mock-image"), "a.png")
require.NoError(t, err)
require.Equal(t, "upload-123", uploadID)
require.Len(t, upstream.calls, 1)
require.Equal(t, "/backend/uploads", upstream.calls[0].Path)
require.Equal(t, "http://127.0.0.1:8080", upstream.calls[0].ProxyURL)
require.NotEmpty(t, upstream.calls[0].UserAgent)
}
func TestSoraDirectClient_PreflightCheck_UsesTaskUserAgentAndProxy(t *testing.T) {
upstream := &soraClientRecordingUpstream{}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com/backend",
},
},
}
client := NewSoraDirectClient(cfg, upstream, nil)
proxyID := int64(7)
account := &Account{
ID: 41,
ProxyID: &proxyID,
Proxy: &Proxy{
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
},
Credentials: map[string]any{
"access_token": "access-token",
"expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
},
}
err := client.PreflightCheck(context.Background(), account, "sora2", SoraModelConfig{Type: "video"})
require.NoError(t, err)
require.Len(t, upstream.calls, 1)
require.Equal(t, "/backend/nf/check", upstream.calls[0].Path)
require.Equal(t, "http://127.0.0.1:8080", upstream.calls[0].ProxyURL)
require.NotEmpty(t, upstream.calls[0].UserAgent)
}
func TestSoraDirectClient_CreateStoryboardTask(t *testing.T) {
originPowTokenGenerator := soraPowTokenGenerator
soraPowTokenGenerator = func(_ string) string { return "gAAAAACmock" }
defer func() { soraPowTokenGenerator = originPowTokenGenerator }()
upstream := &soraClientRecordingUpstream{}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com/backend",
},
},
}
client := NewSoraDirectClient(cfg, upstream, nil)
account := &Account{
ID: 51,
Credentials: map[string]any{
"access_token": "access-token",
"expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
},
}
taskID, err := client.CreateStoryboardTask(context.Background(), account, SoraStoryboardRequest{
Prompt: "Shot 1:\nduration: 5sec\nScene: cat",
})
require.NoError(t, err)
require.Equal(t, "storyboard-123", taskID)
require.Len(t, upstream.calls, 2)
require.Equal(t, "/backend-api/sentinel/req", upstream.calls[0].Path)
require.Equal(t, "/backend/nf/create/storyboard", upstream.calls[1].Path)
}
func TestSoraDirectClient_GetVideoTask_ReturnsGenerationID(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch r.URL.Path {
case "/nf/pending/v2":
_, _ = w.Write([]byte(`[]`))
case "/project_y/profile/drafts":
_, _ = w.Write([]byte(`{"items":[{"id":"gen_1","task_id":"task-1","kind":"video","downloadable_url":"https://example.com/v.mp4"}]}`))
default:
http.NotFound(w, r)
}
}))
defer server.Close()
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: server.URL,
},
},
}
client := NewSoraDirectClient(cfg, nil, nil)
account := &Account{Credentials: map[string]any{"access_token": "token"}}
status, err := client.GetVideoTask(context.Background(), account, "task-1")
require.NoError(t, err)
require.Equal(t, "completed", status.Status)
require.Equal(t, "gen_1", status.GenerationID)
require.Equal(t, []string{"https://example.com/v.mp4"}, status.URLs)
}
func TestSoraDirectClient_PostVideoForWatermarkFree(t *testing.T) {
originPowTokenGenerator := soraPowTokenGenerator
soraPowTokenGenerator = func(_ string) string { return "gAAAAACmock" }
defer func() { soraPowTokenGenerator = originPowTokenGenerator }()
upstream := &soraClientRecordingUpstream{}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com/backend",
},
},
}
client := NewSoraDirectClient(cfg, upstream, nil)
account := &Account{
ID: 52,
Credentials: map[string]any{
"access_token": "access-token",
"expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
},
}
postID, err := client.PostVideoForWatermarkFree(context.Background(), account, "gen_1")
require.NoError(t, err)
require.Equal(t, "s_post", postID)
require.Len(t, upstream.calls, 2)
require.Equal(t, "/backend-api/sentinel/req", upstream.calls[0].Path)
require.Equal(t, "/backend/project_y/post", upstream.calls[1].Path)
}
type soraClientFallbackUpstream struct {
doWithTLSCalls int32
respBody string
respStatusCode int
err error
}
func (u *soraClientFallbackUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
return nil, errors.New("unexpected Do call")
}
func (u *soraClientFallbackUpstream) DoWithTLS(_ *http.Request, _ string, _ int64, _ int, _ bool) (*http.Response, error) {
atomic.AddInt32(&u.doWithTLSCalls, 1)
if u.err != nil {
return nil, u.err
}
statusCode := u.respStatusCode
if statusCode <= 0 {
statusCode = http.StatusOK
}
body := u.respBody
if body == "" {
body = `{"ok":true}`
}
return newSoraClientMockResponse(statusCode, body), nil
}
func TestSoraDirectClient_DoHTTP_UsesCurlCFFISidecarWhenEnabled(t *testing.T) {
var captured soraCurlCFFISidecarRequest
sidecar := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodPost, r.Method)
require.Equal(t, "/request", r.URL.Path)
raw, err := io.ReadAll(r.Body)
require.NoError(t, err)
require.NoError(t, json.Unmarshal(raw, &captured))
_ = json.NewEncoder(w).Encode(map[string]any{
"status_code": http.StatusOK,
"headers": map[string]any{
"Content-Type": "application/json",
"X-Sidecar": []string{"yes"},
},
"body_base64": base64.StdEncoding.EncodeToString([]byte(`{"ok":true}`)),
})
}))
defer sidecar.Close()
upstream := &soraClientFallbackUpstream{}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com/backend",
CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{
Enabled: true,
BaseURL: sidecar.URL,
Impersonate: "chrome131",
TimeoutSeconds: 15,
SessionReuseEnabled: true,
},
},
},
}
client := NewSoraDirectClient(cfg, upstream, nil)
req, err := http.NewRequest(http.MethodPost, "https://sora.chatgpt.com/backend/me", strings.NewReader("hello-sidecar"))
require.NoError(t, err)
req.Header.Set("User-Agent", "test-ua")
resp, err := client.doHTTP(req, "http://127.0.0.1:18080", &Account{ID: 1})
require.NoError(t, err)
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.JSONEq(t, `{"ok":true}`, string(body))
require.Equal(t, int32(0), atomic.LoadInt32(&upstream.doWithTLSCalls))
require.Equal(t, "http://127.0.0.1:18080", captured.ProxyURL)
require.NotEmpty(t, captured.SessionKey)
require.Equal(t, "chrome131", captured.Impersonate)
require.Equal(t, "https://sora.chatgpt.com/backend/me", captured.URL)
decodedReqBody, err := base64.StdEncoding.DecodeString(captured.BodyBase64)
require.NoError(t, err)
require.Equal(t, "hello-sidecar", string(decodedReqBody))
}
func TestSoraDirectClient_DoHTTP_CurlCFFISidecarFailureReturnsError(t *testing.T) {
sidecar := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadGateway)
_, _ = w.Write([]byte(`{"error":"boom"}`))
}))
defer sidecar.Close()
upstream := &soraClientFallbackUpstream{respBody: `{"fallback":true}`}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com/backend",
CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{
Enabled: true,
BaseURL: sidecar.URL,
},
},
},
}
client := NewSoraDirectClient(cfg, upstream, nil)
req, err := http.NewRequest(http.MethodGet, "https://sora.chatgpt.com/backend/me", nil)
require.NoError(t, err)
_, err = client.doHTTP(req, "", &Account{ID: 2})
require.Error(t, err)
require.Contains(t, err.Error(), "sora curl_cffi sidecar")
require.Equal(t, int32(0), atomic.LoadInt32(&upstream.doWithTLSCalls))
}
func TestSoraDirectClient_DoHTTP_CurlCFFISidecarDisabledUsesLegacyStack(t *testing.T) {
upstream := &soraClientFallbackUpstream{respBody: `{"legacy":true}`}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com/backend",
CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{
Enabled: false,
BaseURL: "http://127.0.0.1:18080",
},
},
},
}
client := NewSoraDirectClient(cfg, upstream, nil)
req, err := http.NewRequest(http.MethodGet, "https://sora.chatgpt.com/backend/me", nil)
require.NoError(t, err)
resp, err := client.doHTTP(req, "", &Account{ID: 3})
require.NoError(t, err)
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.JSONEq(t, `{"legacy":true}`, string(body))
require.Equal(t, int32(1), atomic.LoadInt32(&upstream.doWithTLSCalls))
}
func TestConvertSidecarHeaderValue_NilAndSlice(t *testing.T) {
require.Nil(t, convertSidecarHeaderValue(nil))
require.Equal(t, []string{"a", "b"}, convertSidecarHeaderValue([]any{"a", " ", "b"}))
}
func TestSoraDirectClient_DoHTTP_SidecarSessionKeyStableForSameAccountProxy(t *testing.T) {
var captured []soraCurlCFFISidecarRequest
sidecar := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
raw, err := io.ReadAll(r.Body)
require.NoError(t, err)
var reqPayload soraCurlCFFISidecarRequest
require.NoError(t, json.Unmarshal(raw, &reqPayload))
captured = append(captured, reqPayload)
_ = json.NewEncoder(w).Encode(map[string]any{
"status_code": http.StatusOK,
"headers": map[string]any{
"Content-Type": "application/json",
},
"body": `{"ok":true}`,
})
}))
defer sidecar.Close()
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com/backend",
CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{
Enabled: true,
BaseURL: sidecar.URL,
SessionReuseEnabled: true,
SessionTTLSeconds: 3600,
},
},
},
}
client := NewSoraDirectClient(cfg, nil, nil)
account := &Account{ID: 1001}
req1, err := http.NewRequest(http.MethodGet, "https://sora.chatgpt.com/backend/me", nil)
require.NoError(t, err)
_, err = client.doHTTP(req1, "http://127.0.0.1:18080", account)
require.NoError(t, err)
req2, err := http.NewRequest(http.MethodGet, "https://sora.chatgpt.com/backend/me", nil)
require.NoError(t, err)
_, err = client.doHTTP(req2, "http://127.0.0.1:18080", account)
require.NoError(t, err)
require.Len(t, captured, 2)
require.NotEmpty(t, captured[0].SessionKey)
require.Equal(t, captured[0].SessionKey, captured[1].SessionKey)
}
func TestSoraDirectClient_DoRequestWithProxy_CloudflareChallengeSetsCooldownAfterSingleRetry(t *testing.T) {
var sidecarCalls int32
sidecar := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&sidecarCalls, 1)
_ = json.NewEncoder(w).Encode(map[string]any{
"status_code": http.StatusForbidden,
"headers": map[string]any{
"cf-ray": "9d05d73dec4d8c8e-GRU",
"content-type": "text/html",
},
"body": `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={};</script></body></html>`,
})
}))
defer sidecar.Close()
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com/backend",
MaxRetries: 3,
CloudflareChallengeCooldownSeconds: 60,
CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{
Enabled: true,
BaseURL: sidecar.URL,
Impersonate: "chrome131",
},
},
},
}
client := NewSoraDirectClient(cfg, nil, nil)
headers := http.Header{}
_, _, err := client.doRequestWithProxy(
context.Background(),
&Account{ID: 99},
"http://127.0.0.1:18080",
http.MethodGet,
"https://sora.chatgpt.com/backend/me",
headers,
nil,
true,
)
require.Error(t, err)
var upstreamErr *SoraUpstreamError
require.ErrorAs(t, err, &upstreamErr)
require.Equal(t, http.StatusForbidden, upstreamErr.StatusCode)
require.Equal(t, int32(2), atomic.LoadInt32(&sidecarCalls), "challenge should trigger exactly one same-proxy retry")
_, _, err = client.doRequestWithProxy(
context.Background(),
&Account{ID: 99},
"http://127.0.0.1:18080",
http.MethodGet,
"https://sora.chatgpt.com/backend/me",
headers,
nil,
true,
)
require.Error(t, err)
require.ErrorAs(t, err, &upstreamErr)
require.Equal(t, http.StatusTooManyRequests, upstreamErr.StatusCode)
require.Contains(t, upstreamErr.Message, "cooling down")
require.Contains(t, upstreamErr.Message, "cf-ray")
require.Equal(t, int32(2), atomic.LoadInt32(&sidecarCalls), "cooldown should block outbound request")
}
func TestSoraDirectClient_DoRequestWithProxy_CloudflareRetrySuccessClearsCooldown(t *testing.T) {
var sidecarCalls int32
sidecar := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
call := atomic.AddInt32(&sidecarCalls, 1)
if call == 1 {
_ = json.NewEncoder(w).Encode(map[string]any{
"status_code": http.StatusForbidden,
"headers": map[string]any{
"cf-ray": "9d05d73dec4d8c8e-GRU",
"content-type": "text/html",
},
"body": `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={};</script></body></html>`,
})
return
}
_ = json.NewEncoder(w).Encode(map[string]any{
"status_code": http.StatusOK,
"headers": map[string]any{
"content-type": "application/json",
},
"body": `{"ok":true}`,
})
}))
defer sidecar.Close()
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com/backend",
MaxRetries: 3,
CloudflareChallengeCooldownSeconds: 60,
CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{
Enabled: true,
BaseURL: sidecar.URL,
Impersonate: "chrome131",
},
},
},
}
client := NewSoraDirectClient(cfg, nil, nil)
headers := http.Header{}
account := &Account{ID: 109}
proxyURL := "http://127.0.0.1:18080"
body, _, err := client.doRequestWithProxy(
context.Background(),
account,
proxyURL,
http.MethodGet,
"https://sora.chatgpt.com/backend/me",
headers,
nil,
true,
)
require.NoError(t, err)
require.Contains(t, string(body), `"ok":true`)
require.Equal(t, int32(2), atomic.LoadInt32(&sidecarCalls))
_, _, err = client.doRequestWithProxy(
context.Background(),
account,
proxyURL,
http.MethodGet,
"https://sora.chatgpt.com/backend/me",
headers,
nil,
true,
)
require.NoError(t, err)
require.Equal(t, int32(3), atomic.LoadInt32(&sidecarCalls), "cooldown should be cleared after retry succeeds")
}
func TestSoraComputeChallengeCooldownSeconds(t *testing.T) {
require.Equal(t, 0, soraComputeChallengeCooldownSeconds(0, 3))
require.Equal(t, 10, soraComputeChallengeCooldownSeconds(10, 1))
require.Equal(t, 20, soraComputeChallengeCooldownSeconds(10, 2))
require.Equal(t, 40, soraComputeChallengeCooldownSeconds(10, 4))
require.Equal(t, 40, soraComputeChallengeCooldownSeconds(10, 9), "streak should cap at x4")
require.Equal(t, 3600, soraComputeChallengeCooldownSeconds(1200, 9), "cooldown should cap at 3600s")
}
func TestSoraDirectClient_RecordCloudflareChallengeCooldown_EscalatesStreak(t *testing.T) {
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
CloudflareChallengeCooldownSeconds: 10,
},
},
}
client := NewSoraDirectClient(cfg, nil, nil)
account := &Account{ID: 201}
proxyURL := "http://127.0.0.1:18080"
client.recordCloudflareChallengeCooldown(account, proxyURL, http.StatusForbidden, http.Header{"Cf-Ray": []string{"9d05d73dec4d8c8e-GRU"}}, nil)
client.recordCloudflareChallengeCooldown(account, proxyURL, http.StatusForbidden, http.Header{"Cf-Ray": []string{"9d05d73dec4d8c8f-GRU"}}, nil)
key := soraAccountProxyKey(account, proxyURL)
entry, ok := client.challengeCooldowns[key]
require.True(t, ok)
require.Equal(t, 2, entry.ConsecutiveChallenges)
require.Equal(t, "9d05d73dec4d8c8f-GRU", entry.CFRay)
remain := int(entry.Until.Sub(entry.LastChallengeAt).Seconds())
require.GreaterOrEqual(t, remain, 19)
}
func TestSoraDirectClient_SidecarSessionKey_SkipsWhenAccountMissing(t *testing.T) {
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{
Enabled: true,
SessionReuseEnabled: true,
SessionTTLSeconds: 3600,
},
},
},
}
client := NewSoraDirectClient(cfg, nil, nil)
require.Equal(t, "", client.sidecarSessionKey(nil, "http://127.0.0.1:18080"))
require.Empty(t, client.sidecarSessions)
}
func TestSoraDirectClient_SidecarSessionKey_PrunesExpiredAndRecreates(t *testing.T) {
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{
Enabled: true,
SessionReuseEnabled: true,
SessionTTLSeconds: 3600,
},
},
},
}
client := NewSoraDirectClient(cfg, nil, nil)
account := &Account{ID: 123}
key := soraAccountProxyKey(account, "http://127.0.0.1:18080")
client.sidecarSessions[key] = soraSidecarSessionEntry{
SessionKey: "sora-expired",
ExpiresAt: time.Now().Add(-time.Minute),
LastUsedAt: time.Now().Add(-2 * time.Minute),
}
sessionKey := client.sidecarSessionKey(account, "http://127.0.0.1:18080")
require.NotEmpty(t, sessionKey)
require.NotEqual(t, "sora-expired", sessionKey)
require.Len(t, client.sidecarSessions, 1)
}
func TestSoraDirectClient_SidecarSessionKey_TTLZeroKeepsLongLivedSession(t *testing.T) {
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{
Enabled: true,
SessionReuseEnabled: true,
SessionTTLSeconds: 0,
},
},
},
}
client := NewSoraDirectClient(cfg, nil, nil)
account := &Account{ID: 456}
first := client.sidecarSessionKey(account, "http://127.0.0.1:18080")
second := client.sidecarSessionKey(account, "http://127.0.0.1:18080")
require.NotEmpty(t, first)
require.Equal(t, first, second)
key := soraAccountProxyKey(account, "http://127.0.0.1:18080")
entry, ok := client.sidecarSessions[key]
require.True(t, ok)
require.True(t, entry.ExpiresAt.After(time.Now().Add(300*24*time.Hour)))
}
package service
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
)
const soraCurlCFFISidecarDefaultTimeoutSeconds = 60
type soraCurlCFFISidecarRequest struct {
Method string `json:"method"`
URL string `json:"url"`
Headers map[string][]string `json:"headers,omitempty"`
BodyBase64 string `json:"body_base64,omitempty"`
ProxyURL string `json:"proxy_url,omitempty"`
SessionKey string `json:"session_key,omitempty"`
Impersonate string `json:"impersonate,omitempty"`
TimeoutSeconds int `json:"timeout_seconds,omitempty"`
}
type soraCurlCFFISidecarResponse struct {
StatusCode int `json:"status_code"`
Status int `json:"status"`
Headers map[string]any `json:"headers"`
BodyBase64 string `json:"body_base64"`
Body string `json:"body"`
Error string `json:"error"`
}
func (c *SoraDirectClient) doHTTPViaCurlCFFISidecar(req *http.Request, proxyURL string, account *Account) (*http.Response, error) {
if req == nil || req.URL == nil {
return nil, errors.New("request url is nil")
}
if c == nil || c.cfg == nil {
return nil, errors.New("sora curl_cffi sidecar config is nil")
}
if !c.cfg.Sora.Client.CurlCFFISidecar.Enabled {
return nil, errors.New("sora curl_cffi sidecar is disabled")
}
endpoint := c.curlCFFISidecarEndpoint()
if endpoint == "" {
return nil, errors.New("sora curl_cffi sidecar base_url is empty")
}
bodyBytes, err := readAndRestoreRequestBody(req)
if err != nil {
return nil, fmt.Errorf("sora curl_cffi sidecar read request body failed: %w", err)
}
headers := make(map[string][]string, len(req.Header)+1)
for key, vals := range req.Header {
copied := make([]string, len(vals))
copy(copied, vals)
headers[key] = copied
}
if strings.TrimSpace(req.Host) != "" {
if _, ok := headers["Host"]; !ok {
headers["Host"] = []string{req.Host}
}
}
payload := soraCurlCFFISidecarRequest{
Method: req.Method,
URL: req.URL.String(),
Headers: headers,
ProxyURL: strings.TrimSpace(proxyURL),
SessionKey: c.sidecarSessionKey(account, proxyURL),
Impersonate: c.curlCFFIImpersonate(),
TimeoutSeconds: c.curlCFFISidecarTimeoutSeconds(),
}
if len(bodyBytes) > 0 {
payload.BodyBase64 = base64.StdEncoding.EncodeToString(bodyBytes)
}
encoded, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("sora curl_cffi sidecar marshal request failed: %w", err)
}
sidecarReq, err := http.NewRequestWithContext(req.Context(), http.MethodPost, endpoint, bytes.NewReader(encoded))
if err != nil {
return nil, fmt.Errorf("sora curl_cffi sidecar build request failed: %w", err)
}
sidecarReq.Header.Set("Content-Type", "application/json")
sidecarReq.Header.Set("Accept", "application/json")
httpClient := &http.Client{Timeout: time.Duration(payload.TimeoutSeconds) * time.Second}
sidecarResp, err := httpClient.Do(sidecarReq)
if err != nil {
return nil, fmt.Errorf("sora curl_cffi sidecar request failed: %w", err)
}
defer func() {
_ = sidecarResp.Body.Close()
}()
sidecarRespBody, err := io.ReadAll(io.LimitReader(sidecarResp.Body, 8<<20))
if err != nil {
return nil, fmt.Errorf("sora curl_cffi sidecar read response failed: %w", err)
}
if sidecarResp.StatusCode != http.StatusOK {
redacted := truncateForLog([]byte(logredact.RedactText(string(sidecarRespBody))), 512)
return nil, fmt.Errorf("sora curl_cffi sidecar http status=%d body=%s", sidecarResp.StatusCode, redacted)
}
var payloadResp soraCurlCFFISidecarResponse
if err := json.Unmarshal(sidecarRespBody, &payloadResp); err != nil {
return nil, fmt.Errorf("sora curl_cffi sidecar parse response failed: %w", err)
}
if msg := strings.TrimSpace(payloadResp.Error); msg != "" {
return nil, fmt.Errorf("sora curl_cffi sidecar upstream error: %s", msg)
}
statusCode := payloadResp.StatusCode
if statusCode <= 0 {
statusCode = payloadResp.Status
}
if statusCode <= 0 {
return nil, errors.New("sora curl_cffi sidecar response missing status code")
}
responseBody := []byte(payloadResp.Body)
if strings.TrimSpace(payloadResp.BodyBase64) != "" {
decoded, err := base64.StdEncoding.DecodeString(payloadResp.BodyBase64)
if err != nil {
return nil, fmt.Errorf("sora curl_cffi sidecar decode body failed: %w", err)
}
responseBody = decoded
}
respHeaders := make(http.Header)
for key, rawVal := range payloadResp.Headers {
for _, v := range convertSidecarHeaderValue(rawVal) {
respHeaders.Add(key, v)
}
}
return &http.Response{
StatusCode: statusCode,
Header: respHeaders,
Body: io.NopCloser(bytes.NewReader(responseBody)),
ContentLength: int64(len(responseBody)),
Request: req,
}, nil
}
func readAndRestoreRequestBody(req *http.Request) ([]byte, error) {
if req == nil || req.Body == nil {
return nil, nil
}
bodyBytes, err := io.ReadAll(req.Body)
if err != nil {
return nil, err
}
_ = req.Body.Close()
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.ContentLength = int64(len(bodyBytes))
return bodyBytes, nil
}
func (c *SoraDirectClient) curlCFFISidecarEndpoint() string {
if c == nil || c.cfg == nil {
return ""
}
raw := strings.TrimSpace(c.cfg.Sora.Client.CurlCFFISidecar.BaseURL)
if raw == "" {
return ""
}
parsed, err := url.Parse(raw)
if err != nil || strings.TrimSpace(parsed.Scheme) == "" || strings.TrimSpace(parsed.Host) == "" {
return raw
}
if path := strings.TrimSpace(parsed.Path); path == "" || path == "/" {
parsed.Path = "/request"
}
return parsed.String()
}
func (c *SoraDirectClient) curlCFFISidecarTimeoutSeconds() int {
if c == nil || c.cfg == nil {
return soraCurlCFFISidecarDefaultTimeoutSeconds
}
timeoutSeconds := c.cfg.Sora.Client.CurlCFFISidecar.TimeoutSeconds
if timeoutSeconds <= 0 {
return soraCurlCFFISidecarDefaultTimeoutSeconds
}
return timeoutSeconds
}
func (c *SoraDirectClient) curlCFFIImpersonate() string {
if c == nil || c.cfg == nil {
return "chrome131"
}
impersonate := strings.TrimSpace(c.cfg.Sora.Client.CurlCFFISidecar.Impersonate)
if impersonate == "" {
return "chrome131"
}
return impersonate
}
func (c *SoraDirectClient) sidecarSessionReuseEnabled() bool {
if c == nil || c.cfg == nil {
return true
}
return c.cfg.Sora.Client.CurlCFFISidecar.SessionReuseEnabled
}
func (c *SoraDirectClient) sidecarSessionTTLSeconds() int {
if c == nil || c.cfg == nil {
return 3600
}
ttl := c.cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds
if ttl < 0 {
return 3600
}
return ttl
}
func convertSidecarHeaderValue(raw any) []string {
switch val := raw.(type) {
case nil:
return nil
case string:
if strings.TrimSpace(val) == "" {
return nil
}
return []string{val}
case []any:
out := make([]string, 0, len(val))
for _, item := range val {
s := strings.TrimSpace(fmt.Sprint(item))
if s != "" {
out = append(out, s)
}
}
return out
case []string:
out := make([]string, 0, len(val))
for _, item := range val {
if strings.TrimSpace(item) != "" {
out = append(out, item)
}
}
return out
default:
s := strings.TrimSpace(fmt.Sprint(val))
if s == "" {
return nil
}
return []string{s}
}
}
......@@ -9,6 +9,7 @@ import (
"io"
"log"
"math"
"math/rand"
"mime"
"net"
"net/http"
......@@ -669,7 +670,7 @@ func processSoraCharacterUsername(usernameHint string) string {
if usernameHint == "" {
usernameHint = "character"
}
return fmt.Sprintf("%s%d", usernameHint, soraRandInt(900)+100)
return fmt.Sprintf("%s%d", usernameHint, rand.Intn(900)+100)
}
func (s *SoraGatewayService) resolveWatermarkFreeURL(ctx context.Context, account *Account, generationID string, opts soraWatermarkOptions) (string, string, error) {
......
......@@ -181,7 +181,7 @@ func (s *SoraMediaStorage) downloadAndStore(ctx context.Context, mediaType, rawU
return relative, nil
}
if s.debug {
log.Printf("[SoraStorage] 下载失败(%d/%d): %s err=%v", attempt, retries, sanitizeSoraLogURL(rawURL), err)
log.Printf("[SoraStorage] 下载失败(%d/%d): %s err=%v", attempt, retries, sanitizeMediaLogURL(rawURL), err)
}
if attempt < retries {
time.Sleep(time.Duration(attempt*attempt) * time.Second)
......@@ -252,7 +252,7 @@ func (s *SoraMediaStorage) downloadOnce(ctx context.Context, root, mediaType, ra
relative := path.Join("/", mediaType, datePath, filename)
if s.debug {
log.Printf("[SoraStorage] 已落地 %s -> %s", sanitizeSoraLogURL(rawURL), relative)
log.Printf("[SoraStorage] 已落地 %s -> %s", sanitizeMediaLogURL(rawURL), relative)
}
return relative, nil
}
......@@ -305,3 +305,19 @@ func removePartialDownload(root *os.Root, filePath string) {
}
_ = root.Remove(filePath)
}
// sanitizeMediaLogURL 脱敏 URL 用于日志记录(去除 query 参数中可能的 token 信息)
func sanitizeMediaLogURL(rawURL string) string {
parsed, err := url.Parse(rawURL)
if err != nil {
if len(rawURL) > 80 {
return rawURL[:80] + "..."
}
return rawURL
}
safe := parsed.Scheme + "://" + parsed.Host + parsed.Path
if len(safe) > 120 {
return safe[:120] + "..."
}
return safe
}
package service
import (
"fmt"
"math"
"net/http"
"net/url"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
"github.com/google/uuid"
)
type soraChallengeCooldownEntry struct {
Until time.Time
StatusCode int
CFRay string
ConsecutiveChallenges int
LastChallengeAt time.Time
}
type soraSidecarSessionEntry struct {
SessionKey string
ExpiresAt time.Time
LastUsedAt time.Time
}
func (c *SoraDirectClient) cloudflareChallengeCooldownSeconds() int {
if c == nil || c.cfg == nil {
return 900
}
cooldown := c.cfg.Sora.Client.CloudflareChallengeCooldownSeconds
if cooldown <= 0 {
return 0
}
return cooldown
}
func (c *SoraDirectClient) checkCloudflareChallengeCooldown(account *Account, proxyURL string) error {
if c == nil {
return nil
}
if account == nil || account.ID <= 0 {
return nil
}
cooldownSeconds := c.cloudflareChallengeCooldownSeconds()
if cooldownSeconds <= 0 {
return nil
}
key := soraAccountProxyKey(account, proxyURL)
now := time.Now()
c.challengeCooldownMu.RLock()
entry, ok := c.challengeCooldowns[key]
c.challengeCooldownMu.RUnlock()
if !ok {
return nil
}
if !entry.Until.After(now) {
c.challengeCooldownMu.Lock()
delete(c.challengeCooldowns, key)
c.challengeCooldownMu.Unlock()
return nil
}
remaining := int(math.Ceil(entry.Until.Sub(now).Seconds()))
if remaining < 1 {
remaining = 1
}
message := fmt.Sprintf("Sora request cooling down due to recent Cloudflare challenge. Retry in %d seconds.", remaining)
if entry.ConsecutiveChallenges > 1 {
message = fmt.Sprintf("%s (streak=%d)", message, entry.ConsecutiveChallenges)
}
if entry.CFRay != "" {
message = fmt.Sprintf("%s (last cf-ray: %s)", message, entry.CFRay)
}
return &SoraUpstreamError{
StatusCode: http.StatusTooManyRequests,
Message: message,
Headers: make(http.Header),
}
}
func (c *SoraDirectClient) recordCloudflareChallengeCooldown(account *Account, proxyURL string, statusCode int, headers http.Header, body []byte) {
if c == nil {
return
}
if account == nil || account.ID <= 0 {
return
}
cooldownSeconds := c.cloudflareChallengeCooldownSeconds()
if cooldownSeconds <= 0 {
return
}
key := soraAccountProxyKey(account, proxyURL)
now := time.Now()
cfRay := soraerror.ExtractCloudflareRayID(headers, body)
c.challengeCooldownMu.Lock()
c.cleanupExpiredChallengeCooldownsLocked(now)
streak := 1
existing, ok := c.challengeCooldowns[key]
if ok && now.Sub(existing.LastChallengeAt) <= 30*time.Minute {
streak = existing.ConsecutiveChallenges + 1
}
effectiveCooldown := soraComputeChallengeCooldownSeconds(cooldownSeconds, streak)
until := now.Add(time.Duration(effectiveCooldown) * time.Second)
if ok && existing.Until.After(until) {
until = existing.Until
if existing.ConsecutiveChallenges > streak {
streak = existing.ConsecutiveChallenges
}
if cfRay == "" {
cfRay = existing.CFRay
}
}
c.challengeCooldowns[key] = soraChallengeCooldownEntry{
Until: until,
StatusCode: statusCode,
CFRay: cfRay,
ConsecutiveChallenges: streak,
LastChallengeAt: now,
}
c.challengeCooldownMu.Unlock()
if c.debugEnabled() {
remain := int(math.Ceil(until.Sub(now).Seconds()))
if remain < 0 {
remain = 0
}
c.debugLogf("cloudflare_challenge_cooldown_set key=%s status=%d remain_s=%d streak=%d cf_ray=%s", key, statusCode, remain, streak, cfRay)
}
}
func soraComputeChallengeCooldownSeconds(baseSeconds, streak int) int {
if baseSeconds <= 0 {
return 0
}
if streak < 1 {
streak = 1
}
multiplier := streak
if multiplier > 4 {
multiplier = 4
}
cooldown := baseSeconds * multiplier
if cooldown > 3600 {
cooldown = 3600
}
return cooldown
}
func (c *SoraDirectClient) clearCloudflareChallengeCooldown(account *Account, proxyURL string) {
if c == nil {
return
}
if account == nil || account.ID <= 0 {
return
}
key := soraAccountProxyKey(account, proxyURL)
c.challengeCooldownMu.Lock()
_, existed := c.challengeCooldowns[key]
if existed {
delete(c.challengeCooldowns, key)
}
c.challengeCooldownMu.Unlock()
if existed && c.debugEnabled() {
c.debugLogf("cloudflare_challenge_cooldown_cleared key=%s", key)
}
}
func (c *SoraDirectClient) sidecarSessionKey(account *Account, proxyURL string) string {
if c == nil || !c.sidecarSessionReuseEnabled() {
return ""
}
if account == nil || account.ID <= 0 {
return ""
}
key := soraAccountProxyKey(account, proxyURL)
now := time.Now()
ttlSeconds := c.sidecarSessionTTLSeconds()
c.sidecarSessionMu.Lock()
defer c.sidecarSessionMu.Unlock()
c.cleanupExpiredSidecarSessionsLocked(now)
if existing, exists := c.sidecarSessions[key]; exists {
existing.LastUsedAt = now
c.sidecarSessions[key] = existing
return existing.SessionKey
}
expiresAt := now.Add(time.Duration(ttlSeconds) * time.Second)
if ttlSeconds <= 0 {
expiresAt = now.Add(365 * 24 * time.Hour)
}
newEntry := soraSidecarSessionEntry{
SessionKey: "sora-" + uuid.NewString(),
ExpiresAt: expiresAt,
LastUsedAt: now,
}
c.sidecarSessions[key] = newEntry
if c.debugEnabled() {
c.debugLogf("sidecar_session_created key=%s ttl_s=%d", key, ttlSeconds)
}
return newEntry.SessionKey
}
func (c *SoraDirectClient) cleanupExpiredChallengeCooldownsLocked(now time.Time) {
if c == nil || len(c.challengeCooldowns) == 0 {
return
}
for key, entry := range c.challengeCooldowns {
if !entry.Until.After(now) {
delete(c.challengeCooldowns, key)
}
}
}
func (c *SoraDirectClient) cleanupExpiredSidecarSessionsLocked(now time.Time) {
if c == nil || len(c.sidecarSessions) == 0 {
return
}
for key, entry := range c.sidecarSessions {
if !entry.ExpiresAt.After(now) {
delete(c.sidecarSessions, key)
}
}
}
func soraAccountProxyKey(account *Account, proxyURL string) string {
accountID := int64(0)
if account != nil {
accountID = account.ID
}
return fmt.Sprintf("account:%d|proxy:%s", accountID, normalizeSoraProxyKey(proxyURL))
}
func normalizeSoraProxyKey(proxyURL string) string {
raw := strings.TrimSpace(proxyURL)
if raw == "" {
return "direct"
}
parsed, err := url.Parse(raw)
if err != nil {
return strings.ToLower(raw)
}
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
port := strings.TrimSpace(parsed.Port())
if host == "" {
return strings.ToLower(raw)
}
if (scheme == "http" && port == "80") || (scheme == "https" && port == "443") {
port = ""
}
if port != "" {
host = host + ":" + port
}
if scheme == "" {
scheme = "proxy"
}
return scheme + "://" + host
}
package service
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"strings"
"sync"
"time"
"github.com/DouDOU-start/go-sora2api/sora"
"github.com/Wei-Shaw/sub2api/internal/config"
openaioauth "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
"github.com/tidwall/gjson"
)
// SoraSDKClient 基于 go-sora2api SDK 的 Sora 客户端实现。
// 它实现了 SoraClient 接口,用 SDK 替代原有的自建 HTTP/PoW/TLS 指纹逻辑。
type SoraSDKClient struct {
cfg *config.Config
httpUpstream HTTPUpstream
tokenProvider *OpenAITokenProvider
accountRepo AccountRepository
soraAccountRepo SoraAccountRepository
// 每个 proxyURL 对应一个 SDK 客户端实例
sdkClients sync.Map // key: proxyURL (string), value: *sora.Client
}
// NewSoraSDKClient 创建基于 SDK 的 Sora 客户端
func NewSoraSDKClient(cfg *config.Config, httpUpstream HTTPUpstream, tokenProvider *OpenAITokenProvider) *SoraSDKClient {
return &SoraSDKClient{
cfg: cfg,
httpUpstream: httpUpstream,
tokenProvider: tokenProvider,
}
}
// SetAccountRepositories 设置账号和 Sora 扩展仓库(用于 token 持久化)
func (c *SoraSDKClient) SetAccountRepositories(accountRepo AccountRepository, soraAccountRepo SoraAccountRepository) {
if c == nil {
return
}
c.accountRepo = accountRepo
c.soraAccountRepo = soraAccountRepo
}
// Enabled 判断是否启用 Sora
func (c *SoraSDKClient) Enabled() bool {
if c == nil || c.cfg == nil {
return false
}
return strings.TrimSpace(c.cfg.Sora.Client.BaseURL) != ""
}
// PreflightCheck 在创建任务前执行账号能力预检。
// 当前仅对视频模型执行预检,用于提前识别额度耗尽或能力缺失。
func (c *SoraSDKClient) PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error {
if modelCfg.Type != "video" {
return nil
}
token, err := c.getAccessToken(ctx, account)
if err != nil {
return err
}
sdkClient, err := c.getSDKClient(account)
if err != nil {
return err
}
balance, err := sdkClient.GetCreditBalance(ctx, token)
if err != nil {
return &SoraUpstreamError{
StatusCode: http.StatusForbidden,
Message: "当前账号未开通 Sora2 能力或无可用配额",
}
}
if balance.RateLimitReached || balance.RemainingCount <= 0 {
msg := "当前账号 Sora2 可用配额不足"
if requestedModel != "" {
msg = fmt.Sprintf("当前账号 %s 可用配额不足", requestedModel)
}
return &SoraUpstreamError{
StatusCode: http.StatusTooManyRequests,
Message: msg,
}
}
return nil
}
func (c *SoraSDKClient) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) {
if len(data) == 0 {
return "", errors.New("empty image data")
}
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
sdkClient, err := c.getSDKClient(account)
if err != nil {
return "", err
}
if filename == "" {
filename = "image.png"
}
mediaID, err := sdkClient.UploadImage(ctx, token, data, filename)
if err != nil {
return "", c.wrapSDKError(err, account)
}
return mediaID, nil
}
func (c *SoraSDKClient) CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
sdkClient, err := c.getSDKClient(account)
if err != nil {
return "", err
}
sentinel, err := sdkClient.GenerateSentinelToken(ctx, token)
if err != nil {
return "", c.wrapSDKError(err, account)
}
var taskID string
if strings.TrimSpace(req.MediaID) != "" {
taskID, err = sdkClient.CreateImageTaskWithImage(ctx, token, sentinel, req.Prompt, req.Width, req.Height, req.MediaID)
} else {
taskID, err = sdkClient.CreateImageTask(ctx, token, sentinel, req.Prompt, req.Width, req.Height)
}
if err != nil {
return "", c.wrapSDKError(err, account)
}
return taskID, nil
}
func (c *SoraSDKClient) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
sdkClient, err := c.getSDKClient(account)
if err != nil {
return "", err
}
sentinel, err := sdkClient.GenerateSentinelToken(ctx, token)
if err != nil {
return "", c.wrapSDKError(err, account)
}
orientation := req.Orientation
if orientation == "" {
orientation = "landscape"
}
nFrames := req.Frames
if nFrames <= 0 {
nFrames = 450
}
model := req.Model
if model == "" {
model = "sy_8"
}
size := req.Size
if size == "" {
size = "small"
}
// Remix 模式
if strings.TrimSpace(req.RemixTargetID) != "" {
styleID := "" // SDK ExtractStyle 可从 prompt 中提取
taskID, err := sdkClient.RemixVideo(ctx, token, sentinel, req.RemixTargetID, req.Prompt, orientation, nFrames, styleID)
if err != nil {
return "", c.wrapSDKError(err, account)
}
return taskID, nil
}
// 普通视频(文生视频或图生视频)
taskID, err := sdkClient.CreateVideoTaskWithOptions(ctx, token, sentinel, req.Prompt, orientation, nFrames, model, size, req.MediaID, "")
if err != nil {
return "", c.wrapSDKError(err, account)
}
return taskID, nil
}
func (c *SoraSDKClient) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
sdkClient, err := c.getSDKClient(account)
if err != nil {
return "", err
}
sentinel, err := sdkClient.GenerateSentinelToken(ctx, token)
if err != nil {
return "", c.wrapSDKError(err, account)
}
orientation := req.Orientation
if orientation == "" {
orientation = "landscape"
}
nFrames := req.Frames
if nFrames <= 0 {
nFrames = 450
}
taskID, err := sdkClient.CreateStoryboardTask(ctx, token, sentinel, req.Prompt, orientation, nFrames, req.MediaID, "")
if err != nil {
return "", c.wrapSDKError(err, account)
}
return taskID, nil
}
func (c *SoraSDKClient) UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) {
if len(data) == 0 {
return "", errors.New("empty video data")
}
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
sdkClient, err := c.getSDKClient(account)
if err != nil {
return "", err
}
cameoID, err := sdkClient.UploadCharacterVideo(ctx, token, data)
if err != nil {
return "", c.wrapSDKError(err, account)
}
return cameoID, nil
}
func (c *SoraSDKClient) GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return nil, err
}
sdkClient, err := c.getSDKClient(account)
if err != nil {
return nil, err
}
status, err := sdkClient.GetCameoStatus(ctx, token, cameoID)
if err != nil {
return nil, c.wrapSDKError(err, account)
}
return &SoraCameoStatus{
Status: status.Status,
DisplayNameHint: status.DisplayNameHint,
UsernameHint: status.UsernameHint,
ProfileAssetURL: status.ProfileAssetURL,
}, nil
}
func (c *SoraSDKClient) DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) {
sdkClient, err := c.getSDKClient(account)
if err != nil {
return nil, err
}
data, err := sdkClient.DownloadCharacterImage(ctx, imageURL)
if err != nil {
return nil, c.wrapSDKError(err, account)
}
return data, nil
}
func (c *SoraSDKClient) UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) {
if len(data) == 0 {
return "", errors.New("empty character image")
}
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
sdkClient, err := c.getSDKClient(account)
if err != nil {
return "", err
}
assetPointer, err := sdkClient.UploadCharacterImage(ctx, token, data)
if err != nil {
return "", c.wrapSDKError(err, account)
}
return assetPointer, nil
}
func (c *SoraSDKClient) FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
sdkClient, err := c.getSDKClient(account)
if err != nil {
return "", err
}
characterID, err := sdkClient.FinalizeCharacter(ctx, token, req.CameoID, req.Username, req.DisplayName, req.ProfileAssetPointer)
if err != nil {
return "", c.wrapSDKError(err, account)
}
return characterID, nil
}
func (c *SoraSDKClient) SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return err
}
sdkClient, err := c.getSDKClient(account)
if err != nil {
return err
}
if err := sdkClient.SetCharacterPublic(ctx, token, cameoID); err != nil {
return c.wrapSDKError(err, account)
}
return nil
}
func (c *SoraSDKClient) DeleteCharacter(ctx context.Context, account *Account, characterID string) error {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return err
}
sdkClient, err := c.getSDKClient(account)
if err != nil {
return err
}
if err := sdkClient.DeleteCharacter(ctx, token, characterID); err != nil {
return c.wrapSDKError(err, account)
}
return nil
}
func (c *SoraSDKClient) PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
sdkClient, err := c.getSDKClient(account)
if err != nil {
return "", err
}
sentinel, err := sdkClient.GenerateSentinelToken(ctx, token)
if err != nil {
return "", c.wrapSDKError(err, account)
}
postID, err := sdkClient.PublishVideo(ctx, token, sentinel, generationID)
if err != nil {
return "", c.wrapSDKError(err, account)
}
return postID, nil
}
func (c *SoraSDKClient) DeletePost(ctx context.Context, account *Account, postID string) error {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return err
}
sdkClient, err := c.getSDKClient(account)
if err != nil {
return err
}
if err := sdkClient.DeletePost(ctx, token, postID); err != nil {
return c.wrapSDKError(err, account)
}
return nil
}
// GetWatermarkFreeURLCustom 使用自定义第三方解析服务获取去水印链接。
// SDK 不涉及此功能,保留自建实现。
func (c *SoraSDKClient) GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) {
parseURL = strings.TrimRight(strings.TrimSpace(parseURL), "/")
if parseURL == "" {
return "", errors.New("custom parse url is required")
}
if strings.TrimSpace(parseToken) == "" {
return "", errors.New("custom parse token is required")
}
shareURL := "https://sora.chatgpt.com/p/" + strings.TrimSpace(postID)
payload := map[string]any{
"url": shareURL,
"token": strings.TrimSpace(parseToken),
}
body, err := json.Marshal(payload)
if err != nil {
return "", err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, parseURL+"/get-sora-link", bytes.NewReader(body))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
proxyURL := c.resolveProxyURL(account)
accountID := int64(0)
accountConcurrency := 0
if account != nil {
accountID = account.ID
accountConcurrency = account.Concurrency
}
var resp *http.Response
if c.httpUpstream != nil {
resp, err = c.httpUpstream.Do(req, proxyURL, accountID, accountConcurrency)
} else {
resp, err = http.DefaultClient.Do(req)
}
if err != nil {
return "", err
}
defer func() { _ = resp.Body.Close() }()
raw, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20))
if err != nil {
return "", err
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("custom parse failed: %d %s", resp.StatusCode, truncateForLog(raw, 256))
}
downloadLink := strings.TrimSpace(gjson.GetBytes(raw, "download_link").String())
if downloadLink == "" {
return "", errors.New("custom parse response missing download_link")
}
return downloadLink, nil
}
func (c *SoraSDKClient) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
sdkClient, err := c.getSDKClient(account)
if err != nil {
return "", err
}
if strings.TrimSpace(expansionLevel) == "" {
expansionLevel = "medium"
}
if durationS <= 0 {
durationS = 10
}
enhanced, err := sdkClient.EnhancePrompt(ctx, token, prompt, expansionLevel, durationS)
if err != nil {
return "", c.wrapSDKError(err, account)
}
return enhanced, nil
}
func (c *SoraSDKClient) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return nil, err
}
sdkClient, err := c.getSDKClient(account)
if err != nil {
return nil, err
}
result := sdkClient.QueryImageTaskOnce(ctx, token, taskID, time.Now().Add(-10*time.Second))
if result.Err != nil {
return &SoraImageTaskStatus{
ID: taskID,
Status: "failed",
ErrorMsg: result.Err.Error(),
}, nil
}
if result.Done && result.ImageURL != "" {
return &SoraImageTaskStatus{
ID: taskID,
Status: "succeeded",
URLs: []string{result.ImageURL},
}, nil
}
status := result.Progress.Status
if status == "" {
status = "processing"
}
return &SoraImageTaskStatus{
ID: taskID,
Status: status,
ProgressPct: float64(result.Progress.Percent) / 100.0,
}, nil
}
func (c *SoraSDKClient) GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return nil, err
}
sdkClient, err := c.getSDKClient(account)
if err != nil {
return nil, err
}
// 先查询 pending 列表
result := sdkClient.QueryVideoTaskOnce(ctx, token, taskID, time.Now().Add(-10*time.Second), 0)
if result.Err != nil {
return &SoraVideoTaskStatus{
ID: taskID,
Status: "failed",
ErrorMsg: result.Err.Error(),
}, nil
}
if !result.Done {
return &SoraVideoTaskStatus{
ID: taskID,
Status: result.Progress.Status,
ProgressPct: result.Progress.Percent,
}, nil
}
// 任务不在 pending 中,查询 drafts 获取下载链接
downloadURL, err := sdkClient.GetDownloadURL(ctx, token, taskID)
if err != nil {
errMsg := err.Error()
if strings.Contains(errMsg, "内容违规") || strings.Contains(errMsg, "Content violates") {
return &SoraVideoTaskStatus{
ID: taskID,
Status: "failed",
ErrorMsg: errMsg,
}, nil
}
// 可能还在处理中
return &SoraVideoTaskStatus{
ID: taskID,
Status: "processing",
}, nil
}
return &SoraVideoTaskStatus{
ID: taskID,
Status: "completed",
URLs: []string{downloadURL},
}, nil
}
// --- 内部方法 ---
// getSDKClient 获取或创建指定代理的 SDK 客户端实例
func (c *SoraSDKClient) getSDKClient(account *Account) (*sora.Client, error) {
proxyURL := c.resolveProxyURL(account)
if v, ok := c.sdkClients.Load(proxyURL); ok {
return v.(*sora.Client), nil
}
client, err := sora.New(proxyURL)
if err != nil {
return nil, fmt.Errorf("创建 Sora SDK 客户端失败: %w", err)
}
actual, _ := c.sdkClients.LoadOrStore(proxyURL, client)
return actual.(*sora.Client), nil
}
func (c *SoraSDKClient) resolveProxyURL(account *Account) string {
if account == nil || account.ProxyID == nil || account.Proxy == nil {
return ""
}
return strings.TrimSpace(account.Proxy.URL())
}
// getAccessToken 获取账号的 access_token,支持多种 token 来源和自动刷新。
// 此方法保留了原 SoraDirectClient 的 token 管理逻辑。
func (c *SoraSDKClient) getAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
// 优先尝试 OpenAI Token Provider
allowProvider := c.allowOpenAITokenProvider(account)
var providerErr error
if allowProvider && c.tokenProvider != nil {
token, err := c.tokenProvider.GetAccessToken(ctx, account)
if err == nil && strings.TrimSpace(token) != "" {
c.debugLogf("token_selected account_id=%d source=openai_token_provider", account.ID)
return token, nil
}
providerErr = err
if err != nil && c.debugEnabled() {
c.debugLogf("token_provider_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error()))
}
}
// 尝试直接使用 credentials 中的 access_token
token := strings.TrimSpace(account.GetCredential("access_token"))
if token != "" {
expiresAt := account.GetCredentialAsTime("expires_at")
if expiresAt != nil && time.Until(*expiresAt) <= 2*time.Minute {
refreshed, refreshErr := c.recoverAccessToken(ctx, account, "access_token_expiring")
if refreshErr == nil && strings.TrimSpace(refreshed) != "" {
return refreshed, nil
}
}
return token, nil
}
// 尝试通过 session_token 或 refresh_token 恢复
recovered, recoverErr := c.recoverAccessToken(ctx, account, "access_token_missing")
if recoverErr == nil && strings.TrimSpace(recovered) != "" {
return recovered, nil
}
if providerErr != nil {
return "", providerErr
}
return "", errors.New("access_token not found")
}
// recoverAccessToken 通过 session_token 或 refresh_token 恢复 access_token
func (c *SoraSDKClient) recoverAccessToken(ctx context.Context, account *Account, reason string) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
// 先尝试 session_token
if sessionToken := strings.TrimSpace(account.GetCredential("session_token")); sessionToken != "" {
accessToken, expiresAt, err := c.exchangeSessionToken(ctx, account, sessionToken)
if err == nil && strings.TrimSpace(accessToken) != "" {
c.applyRecoveredToken(ctx, account, accessToken, "", expiresAt, sessionToken)
return accessToken, nil
}
}
// 再尝试 refresh_token
refreshToken := strings.TrimSpace(account.GetCredential("refresh_token"))
if refreshToken == "" {
return "", errors.New("session_token/refresh_token not found")
}
sdkClient, err := c.getSDKClient(account)
if err != nil {
return "", err
}
// 尝试多个 client_id
clientIDs := []string{
strings.TrimSpace(account.GetCredential("client_id")),
openaioauth.SoraClientID,
openaioauth.ClientID,
}
tried := make(map[string]struct{}, len(clientIDs))
var lastErr error
for _, clientID := range clientIDs {
if clientID == "" {
continue
}
if _, ok := tried[clientID]; ok {
continue
}
tried[clientID] = struct{}{}
newAccess, newRefresh, refreshErr := sdkClient.RefreshAccessToken(ctx, refreshToken, clientID)
if refreshErr != nil {
lastErr = refreshErr
continue
}
if strings.TrimSpace(newAccess) == "" {
lastErr = errors.New("refreshed access_token is empty")
continue
}
c.applyRecoveredToken(ctx, account, newAccess, newRefresh, "", "")
return newAccess, nil
}
if lastErr != nil {
return "", lastErr
}
return "", errors.New("no available client_id for refresh_token exchange")
}
// exchangeSessionToken 通过 session_token 换取 access_token
func (c *SoraSDKClient) exchangeSessionToken(ctx context.Context, account *Account, sessionToken string) (string, string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://sora.chatgpt.com/api/auth/session", nil)
if err != nil {
return "", "", err
}
req.Header.Set("Cookie", "__Secure-next-auth.session-token="+sessionToken)
req.Header.Set("Accept", "application/json")
req.Header.Set("Origin", "https://sora.chatgpt.com")
req.Header.Set("Referer", "https://sora.chatgpt.com/")
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
proxyURL := c.resolveProxyURL(account)
accountID := int64(0)
accountConcurrency := 0
if account != nil {
accountID = account.ID
accountConcurrency = account.Concurrency
}
var resp *http.Response
if c.httpUpstream != nil {
resp, err = c.httpUpstream.Do(req, proxyURL, accountID, accountConcurrency)
} else {
resp, err = http.DefaultClient.Do(req)
}
if err != nil {
return "", "", err
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
if err != nil {
return "", "", err
}
if resp.StatusCode != http.StatusOK {
return "", "", fmt.Errorf("session exchange failed: %d", resp.StatusCode)
}
accessToken := strings.TrimSpace(gjson.GetBytes(body, "accessToken").String())
if accessToken == "" {
return "", "", errors.New("session exchange missing accessToken")
}
expiresAt := strings.TrimSpace(gjson.GetBytes(body, "expires").String())
return accessToken, expiresAt, nil
}
// applyRecoveredToken 将恢复的 token 写入账号内存和数据库
func (c *SoraSDKClient) applyRecoveredToken(ctx context.Context, account *Account, accessToken, refreshToken, expiresAt, sessionToken string) {
if account == nil {
return
}
if account.Credentials == nil {
account.Credentials = make(map[string]any)
}
if strings.TrimSpace(accessToken) != "" {
account.Credentials["access_token"] = accessToken
}
if strings.TrimSpace(refreshToken) != "" {
account.Credentials["refresh_token"] = refreshToken
}
if strings.TrimSpace(expiresAt) != "" {
account.Credentials["expires_at"] = expiresAt
}
if strings.TrimSpace(sessionToken) != "" {
account.Credentials["session_token"] = sessionToken
}
if c.accountRepo != nil {
if err := c.accountRepo.Update(ctx, account); err != nil && c.debugEnabled() {
c.debugLogf("persist_recovered_token_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error()))
}
}
c.updateSoraAccountExtension(ctx, account, accessToken, refreshToken, sessionToken)
}
func (c *SoraSDKClient) updateSoraAccountExtension(ctx context.Context, account *Account, accessToken, refreshToken, sessionToken string) {
if c == nil || c.soraAccountRepo == nil || account == nil || account.ID <= 0 {
return
}
updates := make(map[string]any)
if strings.TrimSpace(accessToken) != "" && strings.TrimSpace(refreshToken) != "" {
updates["access_token"] = accessToken
updates["refresh_token"] = refreshToken
}
if strings.TrimSpace(sessionToken) != "" {
updates["session_token"] = sessionToken
}
if len(updates) == 0 {
return
}
if err := c.soraAccountRepo.Upsert(ctx, account.ID, updates); err != nil && c.debugEnabled() {
c.debugLogf("persist_sora_extension_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error()))
}
}
func (c *SoraSDKClient) allowOpenAITokenProvider(account *Account) bool {
if c == nil || c.tokenProvider == nil {
return false
}
if account != nil && account.Platform == PlatformSora {
return c.cfg != nil && c.cfg.Sora.Client.UseOpenAITokenProvider
}
return true
}
// wrapSDKError 将 SDK 错误包装为 SoraUpstreamError
func (c *SoraSDKClient) wrapSDKError(err error, account *Account) error {
if err == nil {
return nil
}
msg := err.Error()
statusCode := http.StatusBadGateway
if strings.Contains(msg, "HTTP 401") || strings.Contains(msg, "HTTP 403") {
statusCode = http.StatusUnauthorized
} else if strings.Contains(msg, "HTTP 429") {
statusCode = http.StatusTooManyRequests
} else if strings.Contains(msg, "HTTP 404") {
statusCode = http.StatusNotFound
}
return &SoraUpstreamError{
StatusCode: statusCode,
Message: msg,
}
}
func (c *SoraSDKClient) debugEnabled() bool {
return c != nil && c.cfg != nil && c.cfg.Sora.Client.Debug
}
func (c *SoraSDKClient) debugLogf(format string, args ...any) {
if c.debugEnabled() {
log.Printf("[SoraSDK] "+format, args...)
}
}
......@@ -206,14 +206,14 @@ func ProvideSoraMediaStorage(cfg *config.Config) *SoraMediaStorage {
return NewSoraMediaStorage(cfg)
}
func ProvideSoraDirectClient(
func ProvideSoraSDKClient(
cfg *config.Config,
httpUpstream HTTPUpstream,
tokenProvider *OpenAITokenProvider,
accountRepo AccountRepository,
soraAccountRepo SoraAccountRepository,
) *SoraDirectClient {
client := NewSoraDirectClient(cfg, httpUpstream, tokenProvider)
) *SoraSDKClient {
client := NewSoraSDKClient(cfg, httpUpstream, tokenProvider)
client.SetAccountRepositories(accountRepo, soraAccountRepo)
return client
}
......@@ -306,8 +306,8 @@ var ProviderSet = wire.NewSet(
NewGatewayService,
ProvideSoraMediaStorage,
ProvideSoraMediaCleanupService,
ProvideSoraDirectClient,
wire.Bind(new(SoraClient), new(*SoraDirectClient)),
ProvideSoraSDKClient,
wire.Bind(new(SoraClient), new(*SoraSDKClient)),
NewSoraGatewayService,
NewOpenAIGatewayService,
NewOAuthService,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment