Commit eb2dce92 authored by 陈曦's avatar 陈曦
Browse files

升级v1.0.8 解决冲突

parents 7b83d6e7 339d906e
...@@ -158,30 +158,6 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_DefaultsToOpenAIClientID() { ...@@ -158,30 +158,6 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_DefaultsToOpenAIClientID() {
require.Equal(s.T(), []string{openai.ClientID}, seenClientIDs) require.Equal(s.T(), []string{openai.ClientID}, seenClientIDs)
} }
// TestRefreshToken_UseSoraClientID 验证显式传入 Sora ClientID 时直接使用,不回退。
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseSoraClientID() {
var seenClientIDs []string
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
clientID := r.PostForm.Get("client_id")
seenClientIDs = append(seenClientIDs, clientID)
if clientID == openai.SoraClientID {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at-sora","refresh_token":"rt-sora","token_type":"bearer","expires_in":3600}`)
return
}
w.WriteHeader(http.StatusBadRequest)
}))
resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", openai.SoraClientID)
require.NoError(s.T(), err, "RefreshTokenWithClientID")
require.Equal(s.T(), "at-sora", resp.AccessToken)
require.Equal(s.T(), []string{openai.SoraClientID}, seenClientIDs)
}
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseProvidedClientID() { func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseProvidedClientID() {
const customClientID = "custom-client-id" const customClientID = "custom-client-id"
var seenClientIDs []string var seenClientIDs []string
...@@ -276,7 +252,7 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UsesProvidedRedirectURI() { ...@@ -276,7 +252,7 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UsesProvidedRedirectURI() {
} }
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UseProvidedClientID() { func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UseProvidedClientID() {
wantClientID := openai.SoraClientID wantClientID := "custom-exchange-client-id"
errCh := make(chan string, 1) errCh := make(chan string, 1)
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = r.ParseForm() _ = r.ParseForm()
......
package repository
import (
"context"
"database/sql"
"errors"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// soraAccountRepository 实现 service.SoraAccountRepository 接口。
// 使用原生 SQL 操作 sora_accounts 表,因为该表不在 Ent ORM 管理范围内。
//
// 设计说明:
// - sora_accounts 表是独立迁移创建的,不通过 Ent Schema 管理
// - 使用 ON CONFLICT (account_id) DO UPDATE 实现 Upsert 语义
// - 与 accounts 主表通过外键关联,ON DELETE CASCADE 确保级联删除
type soraAccountRepository struct {
sql *sql.DB
}
// NewSoraAccountRepository 创建 Sora 账号扩展表仓储实例
func NewSoraAccountRepository(sqlDB *sql.DB) service.SoraAccountRepository {
return &soraAccountRepository{sql: sqlDB}
}
// Upsert 创建或更新 Sora 账号扩展信息
// 使用 PostgreSQL ON CONFLICT ... DO UPDATE 实现原子性 upsert
func (r *soraAccountRepository) Upsert(ctx context.Context, accountID int64, updates map[string]any) error {
accessToken, accessOK := updates["access_token"].(string)
refreshToken, refreshOK := updates["refresh_token"].(string)
sessionToken, sessionOK := updates["session_token"].(string)
if !accessOK || accessToken == "" || !refreshOK || refreshToken == "" {
if !sessionOK {
return errors.New("缺少 access_token/refresh_token,且未提供可更新字段")
}
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_accounts
SET session_token = CASE WHEN $2 = '' THEN session_token ELSE $2 END,
updated_at = NOW()
WHERE account_id = $1
`, accountID, sessionToken)
if err != nil {
return err
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return errors.New("sora_accounts 记录不存在,无法仅更新 session_token")
}
return nil
}
_, err := r.sql.ExecContext(ctx, `
INSERT INTO sora_accounts (account_id, access_token, refresh_token, session_token, created_at, updated_at)
VALUES ($1, $2, $3, $4, NOW(), NOW())
ON CONFLICT (account_id) DO UPDATE SET
access_token = EXCLUDED.access_token,
refresh_token = EXCLUDED.refresh_token,
session_token = CASE WHEN EXCLUDED.session_token = '' THEN sora_accounts.session_token ELSE EXCLUDED.session_token END,
updated_at = NOW()
`, accountID, accessToken, refreshToken, sessionToken)
return err
}
// GetByAccountID 根据账号 ID 获取 Sora 扩展信息
func (r *soraAccountRepository) GetByAccountID(ctx context.Context, accountID int64) (*service.SoraAccount, error) {
rows, err := r.sql.QueryContext(ctx, `
SELECT account_id, access_token, refresh_token, COALESCE(session_token, '')
FROM sora_accounts
WHERE account_id = $1
`, accountID)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
return nil, nil // 记录不存在
}
var sa service.SoraAccount
if err := rows.Scan(&sa.AccountID, &sa.AccessToken, &sa.RefreshToken, &sa.SessionToken); err != nil {
return nil, err
}
return &sa, nil
}
// Delete 删除 Sora 账号扩展信息
func (r *soraAccountRepository) Delete(ctx context.Context, accountID int64) error {
_, err := r.sql.ExecContext(ctx, `
DELETE FROM sora_accounts WHERE account_id = $1
`, accountID)
return err
}
package repository
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// soraGenerationRepository 实现 service.SoraGenerationRepository 接口。
// 使用原生 SQL 操作 sora_generations 表。
type soraGenerationRepository struct {
sql *sql.DB
}
// NewSoraGenerationRepository 创建 Sora 生成记录仓储实例。
func NewSoraGenerationRepository(sqlDB *sql.DB) service.SoraGenerationRepository {
return &soraGenerationRepository{sql: sqlDB}
}
func (r *soraGenerationRepository) Create(ctx context.Context, gen *service.SoraGeneration) error {
mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
err := r.sql.QueryRowContext(ctx, `
INSERT INTO sora_generations (
user_id, api_key_id, model, prompt, media_type,
status, media_url, media_urls, file_size_bytes,
storage_type, s3_object_keys, upstream_task_id, error_message
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
RETURNING id, created_at
`,
gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType,
gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage,
).Scan(&gen.ID, &gen.CreatedAt)
return err
}
// CreatePendingWithLimit 在单事务内执行“并发上限检查 + 创建”,避免 count+create 竞态。
func (r *soraGenerationRepository) CreatePendingWithLimit(
ctx context.Context,
gen *service.SoraGeneration,
activeStatuses []string,
maxActive int64,
) error {
if gen == nil {
return fmt.Errorf("generation is nil")
}
if maxActive <= 0 {
return r.Create(ctx, gen)
}
if len(activeStatuses) == 0 {
activeStatuses = []string{service.SoraGenStatusPending, service.SoraGenStatusGenerating}
}
tx, err := r.sql.BeginTx(ctx, nil)
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
// 使用用户级 advisory lock 串行化并发创建,避免超限竞态。
if _, err := tx.ExecContext(ctx, `SELECT pg_advisory_xact_lock($1)`, gen.UserID); err != nil {
return err
}
placeholders := make([]string, len(activeStatuses))
args := make([]any, 0, 1+len(activeStatuses))
args = append(args, gen.UserID)
for i, s := range activeStatuses {
placeholders[i] = fmt.Sprintf("$%d", i+2)
args = append(args, s)
}
countQuery := fmt.Sprintf(
`SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)`,
strings.Join(placeholders, ","),
)
var activeCount int64
if err := tx.QueryRowContext(ctx, countQuery, args...).Scan(&activeCount); err != nil {
return err
}
if activeCount >= maxActive {
return service.ErrSoraGenerationConcurrencyLimit
}
mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
if err := tx.QueryRowContext(ctx, `
INSERT INTO sora_generations (
user_id, api_key_id, model, prompt, media_type,
status, media_url, media_urls, file_size_bytes,
storage_type, s3_object_keys, upstream_task_id, error_message
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
RETURNING id, created_at
`,
gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType,
gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage,
).Scan(&gen.ID, &gen.CreatedAt); err != nil {
return err
}
return tx.Commit()
}
func (r *soraGenerationRepository) GetByID(ctx context.Context, id int64) (*service.SoraGeneration, error) {
gen := &service.SoraGeneration{}
var mediaURLsJSON, s3KeysJSON []byte
var completedAt sql.NullTime
var apiKeyID sql.NullInt64
err := r.sql.QueryRowContext(ctx, `
SELECT id, user_id, api_key_id, model, prompt, media_type,
status, media_url, media_urls, file_size_bytes,
storage_type, s3_object_keys, upstream_task_id, error_message,
created_at, completed_at
FROM sora_generations WHERE id = $1
`, id).Scan(
&gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType,
&gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes,
&gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage,
&gen.CreatedAt, &completedAt,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("生成记录不存在")
}
return nil, err
}
if apiKeyID.Valid {
gen.APIKeyID = &apiKeyID.Int64
}
if completedAt.Valid {
gen.CompletedAt = &completedAt.Time
}
_ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs)
_ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys)
return gen, nil
}
func (r *soraGenerationRepository) Update(ctx context.Context, gen *service.SoraGeneration) error {
mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
var completedAt *time.Time
if gen.CompletedAt != nil {
completedAt = gen.CompletedAt
}
_, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations SET
status = $2, media_url = $3, media_urls = $4, file_size_bytes = $5,
storage_type = $6, s3_object_keys = $7, upstream_task_id = $8,
error_message = $9, completed_at = $10
WHERE id = $1
`,
gen.ID, gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
gen.StorageType, s3KeysJSON, gen.UpstreamTaskID,
gen.ErrorMessage, completedAt,
)
return err
}
// UpdateGeneratingIfPending 仅当状态为 pending 时更新为 generating。
func (r *soraGenerationRepository) UpdateGeneratingIfPending(ctx context.Context, id int64, upstreamTaskID string) (bool, error) {
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations
SET status = $2, upstream_task_id = $3
WHERE id = $1 AND status = $4
`,
id, service.SoraGenStatusGenerating, upstreamTaskID, service.SoraGenStatusPending,
)
if err != nil {
return false, err
}
affected, err := result.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
// UpdateCompletedIfActive 仅当状态为 pending/generating 时更新为 completed。
func (r *soraGenerationRepository) UpdateCompletedIfActive(
ctx context.Context,
id int64,
mediaURL string,
mediaURLs []string,
storageType string,
s3Keys []string,
fileSizeBytes int64,
completedAt time.Time,
) (bool, error) {
mediaURLsJSON, _ := json.Marshal(mediaURLs)
s3KeysJSON, _ := json.Marshal(s3Keys)
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations
SET status = $2,
media_url = $3,
media_urls = $4,
file_size_bytes = $5,
storage_type = $6,
s3_object_keys = $7,
error_message = '',
completed_at = $8
WHERE id = $1 AND status IN ($9, $10)
`,
id, service.SoraGenStatusCompleted, mediaURL, mediaURLsJSON, fileSizeBytes,
storageType, s3KeysJSON, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
)
if err != nil {
return false, err
}
affected, err := result.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
// UpdateFailedIfActive 仅当状态为 pending/generating 时更新为 failed。
func (r *soraGenerationRepository) UpdateFailedIfActive(
ctx context.Context,
id int64,
errMsg string,
completedAt time.Time,
) (bool, error) {
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations
SET status = $2,
error_message = $3,
completed_at = $4
WHERE id = $1 AND status IN ($5, $6)
`,
id, service.SoraGenStatusFailed, errMsg, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
)
if err != nil {
return false, err
}
affected, err := result.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
// UpdateCancelledIfActive 仅当状态为 pending/generating 时更新为 cancelled。
func (r *soraGenerationRepository) UpdateCancelledIfActive(ctx context.Context, id int64, completedAt time.Time) (bool, error) {
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations
SET status = $2, completed_at = $3
WHERE id = $1 AND status IN ($4, $5)
`,
id, service.SoraGenStatusCancelled, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
)
if err != nil {
return false, err
}
affected, err := result.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
// UpdateStorageIfCompleted 更新已完成记录的存储信息(用于手动保存,不重置 completed_at)。
func (r *soraGenerationRepository) UpdateStorageIfCompleted(
ctx context.Context,
id int64,
mediaURL string,
mediaURLs []string,
storageType string,
s3Keys []string,
fileSizeBytes int64,
) (bool, error) {
mediaURLsJSON, _ := json.Marshal(mediaURLs)
s3KeysJSON, _ := json.Marshal(s3Keys)
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations
SET media_url = $2,
media_urls = $3,
file_size_bytes = $4,
storage_type = $5,
s3_object_keys = $6
WHERE id = $1 AND status = $7
`,
id, mediaURL, mediaURLsJSON, fileSizeBytes, storageType, s3KeysJSON, service.SoraGenStatusCompleted,
)
if err != nil {
return false, err
}
affected, err := result.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
func (r *soraGenerationRepository) Delete(ctx context.Context, id int64) error {
_, err := r.sql.ExecContext(ctx, `DELETE FROM sora_generations WHERE id = $1`, id)
return err
}
func (r *soraGenerationRepository) List(ctx context.Context, params service.SoraGenerationListParams) ([]*service.SoraGeneration, int64, error) {
// 构建 WHERE 条件
conditions := []string{"user_id = $1"}
args := []any{params.UserID}
argIdx := 2
if params.Status != "" {
// 支持逗号分隔的多状态
statuses := strings.Split(params.Status, ",")
placeholders := make([]string, len(statuses))
for i, s := range statuses {
placeholders[i] = fmt.Sprintf("$%d", argIdx)
args = append(args, strings.TrimSpace(s))
argIdx++
}
conditions = append(conditions, fmt.Sprintf("status IN (%s)", strings.Join(placeholders, ",")))
}
if params.StorageType != "" {
storageTypes := strings.Split(params.StorageType, ",")
placeholders := make([]string, len(storageTypes))
for i, s := range storageTypes {
placeholders[i] = fmt.Sprintf("$%d", argIdx)
args = append(args, strings.TrimSpace(s))
argIdx++
}
conditions = append(conditions, fmt.Sprintf("storage_type IN (%s)", strings.Join(placeholders, ",")))
}
if params.MediaType != "" {
conditions = append(conditions, fmt.Sprintf("media_type = $%d", argIdx))
args = append(args, params.MediaType)
argIdx++
}
whereClause := "WHERE " + strings.Join(conditions, " AND ")
// 计数
var total int64
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations %s", whereClause)
if err := r.sql.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
return nil, 0, err
}
// 分页查询
offset := (params.Page - 1) * params.PageSize
listQuery := fmt.Sprintf(`
SELECT id, user_id, api_key_id, model, prompt, media_type,
status, media_url, media_urls, file_size_bytes,
storage_type, s3_object_keys, upstream_task_id, error_message,
created_at, completed_at
FROM sora_generations %s
ORDER BY created_at DESC
LIMIT $%d OFFSET $%d
`, whereClause, argIdx, argIdx+1)
args = append(args, params.PageSize, offset)
rows, err := r.sql.QueryContext(ctx, listQuery, args...)
if err != nil {
return nil, 0, err
}
defer func() {
_ = rows.Close()
}()
var results []*service.SoraGeneration
for rows.Next() {
gen := &service.SoraGeneration{}
var mediaURLsJSON, s3KeysJSON []byte
var completedAt sql.NullTime
var apiKeyID sql.NullInt64
if err := rows.Scan(
&gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType,
&gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes,
&gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage,
&gen.CreatedAt, &completedAt,
); err != nil {
return nil, 0, err
}
if apiKeyID.Valid {
gen.APIKeyID = &apiKeyID.Int64
}
if completedAt.Valid {
gen.CompletedAt = &completedAt.Time
}
_ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs)
_ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys)
results = append(results, gen)
}
return results, total, rows.Err()
}
func (r *soraGenerationRepository) CountByUserAndStatus(ctx context.Context, userID int64, statuses []string) (int64, error) {
if len(statuses) == 0 {
return 0, nil
}
placeholders := make([]string, len(statuses))
args := []any{userID}
for i, s := range statuses {
placeholders[i] = fmt.Sprintf("$%d", i+2)
args = append(args, s)
}
var count int64
query := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)", strings.Join(placeholders, ","))
err := r.sql.QueryRowContext(ctx, query, args...).Scan(&count)
return count, err
}
...@@ -28,7 +28,7 @@ import ( ...@@ -28,7 +28,7 @@ import (
gocache "github.com/patrickmn/go-cache" gocache "github.com/patrickmn/go-cache"
) )
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, created_at" const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, created_at"
// usageLogInsertArgTypes must stay in the same order as: // usageLogInsertArgTypes must stay in the same order as:
// 1. prepareUsageLogInsert().args // 1. prepareUsageLogInsert().args
...@@ -73,7 +73,6 @@ var usageLogInsertArgTypes = [...]string{ ...@@ -73,7 +73,6 @@ var usageLogInsertArgTypes = [...]string{
"text", // ip_address "text", // ip_address
"integer", // image_count "integer", // image_count
"text", // image_size "text", // image_size
"text", // media_type
"text", // service_tier "text", // service_tier
"text", // reasoning_effort "text", // reasoning_effort
"text", // inbound_endpoint "text", // inbound_endpoint
...@@ -352,7 +351,6 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, ...@@ -352,7 +351,6 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
ip_address, ip_address,
image_count, image_count,
image_size, image_size,
media_type,
service_tier, service_tier,
reasoning_effort, reasoning_effort,
inbound_endpoint, inbound_endpoint,
...@@ -369,7 +367,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, ...@@ -369,7 +367,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
$10, $11, $12, $13, $10, $11, $12, $13,
$14, $15, $16, $17, $14, $15, $16, $17,
$18, $19, $20, $21, $22, $23, $18, $19, $20, $21, $22, $23,
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46 $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45
) )
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at RETURNING id, created_at
...@@ -790,7 +788,6 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage ...@@ -790,7 +788,6 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
ip_address, ip_address,
image_count, image_count,
image_size, image_size,
media_type,
service_tier, service_tier,
reasoning_effort, reasoning_effort,
inbound_endpoint, inbound_endpoint,
...@@ -803,7 +800,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage ...@@ -803,7 +800,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
created_at created_at
) AS (VALUES `) ) AS (VALUES `)
args := make([]any, 0, len(keys)*47) args := make([]any, 0, len(keys)*46)
argPos := 1 argPos := 1
for idx, key := range keys { for idx, key := range keys {
if idx > 0 { if idx > 0 {
...@@ -867,7 +864,6 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage ...@@ -867,7 +864,6 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
ip_address, ip_address,
image_count, image_count,
image_size, image_size,
media_type,
service_tier, service_tier,
reasoning_effort, reasoning_effort,
inbound_endpoint, inbound_endpoint,
...@@ -915,7 +911,6 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage ...@@ -915,7 +911,6 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
ip_address, ip_address,
image_count, image_count,
image_size, image_size,
media_type,
service_tier, service_tier,
reasoning_effort, reasoning_effort,
inbound_endpoint, inbound_endpoint,
...@@ -1003,7 +998,6 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ...@@ -1003,7 +998,6 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
ip_address, ip_address,
image_count, image_count,
image_size, image_size,
media_type,
service_tier, service_tier,
reasoning_effort, reasoning_effort,
inbound_endpoint, inbound_endpoint,
...@@ -1016,7 +1010,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ...@@ -1016,7 +1010,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
created_at created_at
) AS (VALUES `) ) AS (VALUES `)
args := make([]any, 0, len(preparedList)*46) args := make([]any, 0, len(preparedList)*45)
argPos := 1 argPos := 1
for idx, prepared := range preparedList { for idx, prepared := range preparedList {
if idx > 0 { if idx > 0 {
...@@ -1077,7 +1071,6 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ...@@ -1077,7 +1071,6 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
ip_address, ip_address,
image_count, image_count,
image_size, image_size,
media_type,
service_tier, service_tier,
reasoning_effort, reasoning_effort,
inbound_endpoint, inbound_endpoint,
...@@ -1125,7 +1118,6 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ...@@ -1125,7 +1118,6 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
ip_address, ip_address,
image_count, image_count,
image_size, image_size,
media_type,
service_tier, service_tier,
reasoning_effort, reasoning_effort,
inbound_endpoint, inbound_endpoint,
...@@ -1181,7 +1173,6 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared ...@@ -1181,7 +1173,6 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
ip_address, ip_address,
image_count, image_count,
image_size, image_size,
media_type,
service_tier, service_tier,
reasoning_effort, reasoning_effort,
inbound_endpoint, inbound_endpoint,
...@@ -1198,7 +1189,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared ...@@ -1198,7 +1189,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
$10, $11, $12, $13, $10, $11, $12, $13,
$14, $15, $16, $17, $14, $15, $16, $17,
$18, $19, $20, $21, $22, $23, $18, $19, $20, $21, $22, $23,
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46 $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45
) )
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
`, prepared.args...) `, prepared.args...)
...@@ -1225,7 +1216,6 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { ...@@ -1225,7 +1216,6 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
userAgent := nullString(log.UserAgent) userAgent := nullString(log.UserAgent)
ipAddress := nullString(log.IPAddress) ipAddress := nullString(log.IPAddress)
imageSize := nullString(log.ImageSize) imageSize := nullString(log.ImageSize)
mediaType := nullString(log.MediaType)
serviceTier := nullString(log.ServiceTier) serviceTier := nullString(log.ServiceTier)
reasoningEffort := nullString(log.ReasoningEffort) reasoningEffort := nullString(log.ReasoningEffort)
inboundEndpoint := nullString(log.InboundEndpoint) inboundEndpoint := nullString(log.InboundEndpoint)
...@@ -1286,7 +1276,6 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { ...@@ -1286,7 +1276,6 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
ipAddress, ipAddress,
log.ImageCount, log.ImageCount,
imageSize, imageSize,
mediaType,
serviceTier, serviceTier,
reasoningEffort, reasoningEffort,
inboundEndpoint, inboundEndpoint,
...@@ -4051,7 +4040,6 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -4051,7 +4040,6 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
ipAddress sql.NullString ipAddress sql.NullString
imageCount int imageCount int
imageSize sql.NullString imageSize sql.NullString
mediaType sql.NullString
serviceTier sql.NullString serviceTier sql.NullString
reasoningEffort sql.NullString reasoningEffort sql.NullString
inboundEndpoint sql.NullString inboundEndpoint sql.NullString
...@@ -4101,7 +4089,6 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -4101,7 +4089,6 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&ipAddress, &ipAddress,
&imageCount, &imageCount,
&imageSize, &imageSize,
&mediaType,
&serviceTier, &serviceTier,
&reasoningEffort, &reasoningEffort,
&inboundEndpoint, &inboundEndpoint,
...@@ -4179,9 +4166,6 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -4179,9 +4166,6 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if imageSize.Valid { if imageSize.Valid {
log.ImageSize = &imageSize.String log.ImageSize = &imageSize.String
} }
if mediaType.Valid {
log.MediaType = &mediaType.String
}
if serviceTier.Valid { if serviceTier.Valid {
log.ServiceTier = &serviceTier.String log.ServiceTier = &serviceTier.String
} }
......
...@@ -76,7 +76,6 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) { ...@@ -76,7 +76,6 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
sqlmock.AnyArg(), // ip_address sqlmock.AnyArg(), // ip_address
log.ImageCount, log.ImageCount,
sqlmock.AnyArg(), // image_size sqlmock.AnyArg(), // image_size
sqlmock.AnyArg(), // media_type
sqlmock.AnyArg(), // service_tier sqlmock.AnyArg(), // service_tier
sqlmock.AnyArg(), // reasoning_effort sqlmock.AnyArg(), // reasoning_effort
sqlmock.AnyArg(), // inbound_endpoint sqlmock.AnyArg(), // inbound_endpoint
...@@ -155,7 +154,6 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) { ...@@ -155,7 +154,6 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
sqlmock.AnyArg(), sqlmock.AnyArg(),
log.ImageCount, log.ImageCount,
sqlmock.AnyArg(), sqlmock.AnyArg(),
sqlmock.AnyArg(),
serviceTier, serviceTier,
sqlmock.AnyArg(), sqlmock.AnyArg(),
sqlmock.AnyArg(), sqlmock.AnyArg(),
...@@ -471,7 +469,6 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { ...@@ -471,7 +469,6 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{}, sql.NullString{},
0, 0,
sql.NullString{}, sql.NullString{},
sql.NullString{},
sql.NullString{Valid: true, String: "priority"}, sql.NullString{Valid: true, String: "priority"},
sql.NullString{}, sql.NullString{},
sql.NullString{}, sql.NullString{},
...@@ -519,7 +516,6 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { ...@@ -519,7 +516,6 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{}, sql.NullString{},
0, 0,
sql.NullString{}, sql.NullString{},
sql.NullString{},
sql.NullString{Valid: true, String: "flex"}, sql.NullString{Valid: true, String: "flex"},
sql.NullString{}, sql.NullString{},
sql.NullString{}, sql.NullString{},
...@@ -567,7 +563,6 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { ...@@ -567,7 +563,6 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{}, sql.NullString{},
0, 0,
sql.NullString{}, sql.NullString{},
sql.NullString{},
sql.NullString{Valid: true, String: "priority"}, sql.NullString{Valid: true, String: "priority"},
sql.NullString{}, sql.NullString{},
sql.NullString{}, sql.NullString{},
......
...@@ -62,7 +62,6 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error ...@@ -62,7 +62,6 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
SetBalance(userIn.Balance). SetBalance(userIn.Balance).
SetConcurrency(userIn.Concurrency). SetConcurrency(userIn.Concurrency).
SetStatus(userIn.Status). SetStatus(userIn.Status).
SetSoraStorageQuotaBytes(userIn.SoraStorageQuotaBytes).
Save(ctx) Save(ctx)
if err != nil { if err != nil {
return translatePersistenceError(err, nil, service.ErrEmailExists) return translatePersistenceError(err, nil, service.ErrEmailExists)
...@@ -145,8 +144,6 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error ...@@ -145,8 +144,6 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
SetBalance(userIn.Balance). SetBalance(userIn.Balance).
SetConcurrency(userIn.Concurrency). SetConcurrency(userIn.Concurrency).
SetStatus(userIn.Status). SetStatus(userIn.Status).
SetSoraStorageQuotaBytes(userIn.SoraStorageQuotaBytes).
SetSoraStorageUsedBytes(userIn.SoraStorageUsedBytes).
Save(ctx) Save(ctx)
if err != nil { if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists) return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
...@@ -376,65 +373,6 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount ...@@ -376,65 +373,6 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount
return nil return nil
} }
// AddSoraStorageUsageWithQuota 原子累加 Sora 存储用量,并在有配额时校验不超额。
func (r *userRepository) AddSoraStorageUsageWithQuota(ctx context.Context, userID int64, deltaBytes int64, effectiveQuota int64) (int64, error) {
if deltaBytes <= 0 {
user, err := r.GetByID(ctx, userID)
if err != nil {
return 0, err
}
return user.SoraStorageUsedBytes, nil
}
var newUsed int64
err := scanSingleRow(ctx, r.sql, `
UPDATE users
SET sora_storage_used_bytes = sora_storage_used_bytes + $2
WHERE id = $1
AND ($3 = 0 OR sora_storage_used_bytes + $2 <= $3)
RETURNING sora_storage_used_bytes
`, []any{userID, deltaBytes, effectiveQuota}, &newUsed)
if err == nil {
return newUsed, nil
}
if errors.Is(err, sql.ErrNoRows) {
// 区分用户不存在和配额冲突
exists, existsErr := r.client.User.Query().Where(dbuser.IDEQ(userID)).Exist(ctx)
if existsErr != nil {
return 0, existsErr
}
if !exists {
return 0, service.ErrUserNotFound
}
return 0, service.ErrSoraStorageQuotaExceeded
}
return 0, err
}
// ReleaseSoraStorageUsageAtomic 原子释放 Sora 存储用量,并保证不低于 0。
func (r *userRepository) ReleaseSoraStorageUsageAtomic(ctx context.Context, userID int64, deltaBytes int64) (int64, error) {
if deltaBytes <= 0 {
user, err := r.GetByID(ctx, userID)
if err != nil {
return 0, err
}
return user.SoraStorageUsedBytes, nil
}
var newUsed int64
err := scanSingleRow(ctx, r.sql, `
UPDATE users
SET sora_storage_used_bytes = GREATEST(sora_storage_used_bytes - $2, 0)
WHERE id = $1
RETURNING sora_storage_used_bytes
`, []any{userID, deltaBytes}, &newUsed)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return 0, service.ErrUserNotFound
}
return 0, err
}
return newUsed, nil
}
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) { func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx) return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx)
} }
......
...@@ -53,7 +53,6 @@ var ProviderSet = wire.NewSet( ...@@ -53,7 +53,6 @@ var ProviderSet = wire.NewSet(
NewAPIKeyRepository, NewAPIKeyRepository,
NewGroupRepository, NewGroupRepository,
NewAccountRepository, NewAccountRepository,
NewSoraAccountRepository, // Sora 账号扩展表仓储
NewScheduledTestPlanRepository, // 定时测试计划仓储 NewScheduledTestPlanRepository, // 定时测试计划仓储
NewScheduledTestResultRepository, // 定时测试结果仓储 NewScheduledTestResultRepository, // 定时测试结果仓储
NewProxyRepository, NewProxyRepository,
......
...@@ -204,11 +204,6 @@ func TestAPIContracts(t *testing.T) { ...@@ -204,11 +204,6 @@ func TestAPIContracts(t *testing.T) {
"image_price_1k": null, "image_price_1k": null,
"image_price_2k": null, "image_price_2k": null,
"image_price_4k": null, "image_price_4k": null,
"sora_image_price_360": null,
"sora_image_price_540": null,
"sora_storage_quota_bytes": 0,
"sora_video_price_per_request": null,
"sora_video_price_per_request_hd": null,
"claude_code_only": false, "claude_code_only": false,
"allow_messages_dispatch": false, "allow_messages_dispatch": false,
"fallback_group_id": null, "fallback_group_id": null,
...@@ -532,7 +527,6 @@ func TestAPIContracts(t *testing.T) { ...@@ -532,7 +527,6 @@ func TestAPIContracts(t *testing.T) {
"fallback_model_openai": "gpt-4o", "fallback_model_openai": "gpt-4o",
"enable_identity_patch": true, "enable_identity_patch": true,
"identity_patch_prompt": "", "identity_patch_prompt": "",
"sora_client_enabled": false,
"invitation_code_enabled": false, "invitation_code_enabled": false,
"home_content": "", "home_content": "",
"hide_ccs_import_button": false, "hide_ccs_import_button": false,
...@@ -653,11 +647,11 @@ func newContractDeps(t *testing.T) *contractDeps { ...@@ -653,11 +647,11 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo() settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg) settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil, nil) adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil)
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
jwtAuth := func(c *gin.Context) { jwtAuth := func(c *gin.Context) {
......
...@@ -94,7 +94,6 @@ func isAPIRoutePath(c *gin.Context) bool { ...@@ -94,7 +94,6 @@ func isAPIRoutePath(c *gin.Context) bool {
return strings.HasPrefix(path, "/v1/") || return strings.HasPrefix(path, "/v1/") ||
strings.HasPrefix(path, "/v1beta/") || strings.HasPrefix(path, "/v1beta/") ||
strings.HasPrefix(path, "/antigravity/") || strings.HasPrefix(path, "/antigravity/") ||
strings.HasPrefix(path, "/sora/") ||
strings.HasPrefix(path, "/responses") strings.HasPrefix(path, "/responses")
} }
......
...@@ -109,7 +109,6 @@ func registerRoutes( ...@@ -109,7 +109,6 @@ func registerRoutes(
// 注册各模块路由 // 注册各模块路由
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient, settingService) routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient, settingService)
routes.RegisterUserRoutes(v1, h, jwtAuth, settingService) routes.RegisterUserRoutes(v1, h, jwtAuth, settingService)
routes.RegisterSoraClientRoutes(v1, h, jwtAuth, settingService)
routes.RegisterAdminRoutes(v1, h, adminAuth) routes.RegisterAdminRoutes(v1, h, adminAuth)
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg) routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg)
} }
...@@ -34,8 +34,6 @@ func RegisterAdminRoutes( ...@@ -34,8 +34,6 @@ func RegisterAdminRoutes(
// OpenAI OAuth // OpenAI OAuth
registerOpenAIOAuthRoutes(admin, h) registerOpenAIOAuthRoutes(admin, h)
// Sora OAuth(实现复用 OpenAI OAuth 服务,入口独立)
registerSoraOAuthRoutes(admin, h)
// Gemini OAuth // Gemini OAuth
registerGeminiOAuthRoutes(admin, h) registerGeminiOAuthRoutes(admin, h)
...@@ -321,19 +319,6 @@ func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -321,19 +319,6 @@ func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
} }
} }
func registerSoraOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
sora := admin.Group("/sora")
{
sora.POST("/generate-auth-url", h.Admin.OpenAIOAuth.GenerateAuthURL)
sora.POST("/exchange-code", h.Admin.OpenAIOAuth.ExchangeCode)
sora.POST("/refresh-token", h.Admin.OpenAIOAuth.RefreshToken)
sora.POST("/st2at", h.Admin.OpenAIOAuth.ExchangeSoraSessionToken)
sora.POST("/rt2at", h.Admin.OpenAIOAuth.RefreshToken)
sora.POST("/accounts/:id/refresh", h.Admin.OpenAIOAuth.RefreshAccountToken)
sora.POST("/create-from-oauth", h.Admin.OpenAIOAuth.CreateAccountFromOAuth)
}
}
func registerGeminiOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) { func registerGeminiOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
gemini := admin.Group("/gemini") gemini := admin.Group("/gemini")
{ {
...@@ -422,15 +407,6 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -422,15 +407,6 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
// Beta 策略配置 // Beta 策略配置
adminSettings.GET("/beta-policy", h.Admin.Setting.GetBetaPolicySettings) adminSettings.GET("/beta-policy", h.Admin.Setting.GetBetaPolicySettings)
adminSettings.PUT("/beta-policy", h.Admin.Setting.UpdateBetaPolicySettings) adminSettings.PUT("/beta-policy", h.Admin.Setting.UpdateBetaPolicySettings)
// Sora S3 存储配置
adminSettings.GET("/sora-s3", h.Admin.Setting.GetSoraS3Settings)
adminSettings.PUT("/sora-s3", h.Admin.Setting.UpdateSoraS3Settings)
adminSettings.POST("/sora-s3/test", h.Admin.Setting.TestSoraS3Connection)
adminSettings.GET("/sora-s3/profiles", h.Admin.Setting.ListSoraS3Profiles)
adminSettings.POST("/sora-s3/profiles", h.Admin.Setting.CreateSoraS3Profile)
adminSettings.PUT("/sora-s3/profiles/:profile_id", h.Admin.Setting.UpdateSoraS3Profile)
adminSettings.DELETE("/sora-s3/profiles/:profile_id", h.Admin.Setting.DeleteSoraS3Profile)
adminSettings.POST("/sora-s3/profiles/:profile_id/activate", h.Admin.Setting.SetActiveSoraS3Profile)
} }
} }
......
...@@ -23,11 +23,6 @@ func RegisterGatewayRoutes( ...@@ -23,11 +23,6 @@ func RegisterGatewayRoutes(
cfg *config.Config, cfg *config.Config,
) { ) {
bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize) bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize)
soraMaxBodySize := cfg.Gateway.SoraMaxBodySize
if soraMaxBodySize <= 0 {
soraMaxBodySize = cfg.Gateway.MaxBodySize
}
soraBodyLimit := middleware.RequestBodyLimit(soraMaxBodySize)
clientRequestID := middleware.ClientRequestID() clientRequestID := middleware.ClientRequestID()
opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService) opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService)
endpointNorm := handler.InboundEndpointMiddleware() endpointNorm := handler.InboundEndpointMiddleware()
...@@ -163,28 +158,6 @@ func RegisterGatewayRoutes( ...@@ -163,28 +158,6 @@ func RegisterGatewayRoutes(
antigravityV1Beta.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels) antigravityV1Beta.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels)
} }
// Sora 专用路由(强制使用 sora 平台)
soraV1 := r.Group("/sora/v1")
soraV1.Use(soraBodyLimit)
soraV1.Use(clientRequestID)
soraV1.Use(opsErrorLogger)
soraV1.Use(endpointNorm)
soraV1.Use(middleware.ForcePlatform(service.PlatformSora))
soraV1.Use(gin.HandlerFunc(apiKeyAuth))
soraV1.Use(requireGroupAnthropic)
{
soraV1.POST("/chat/completions", h.SoraGateway.ChatCompletions)
soraV1.GET("/models", h.Gateway.Models)
}
// Sora 媒体代理(可选 API Key 验证)
if cfg.Gateway.SoraMediaRequireAPIKey {
r.GET("/sora/media/*filepath", gin.HandlerFunc(apiKeyAuth), h.SoraGateway.MediaProxy)
} else {
r.GET("/sora/media/*filepath", h.SoraGateway.MediaProxy)
}
// Sora 媒体代理(签名 URL,无需 API Key)
r.GET("/sora/media-signed/*filepath", h.SoraGateway.MediaProxySigned)
} }
// getGroupPlatform extracts the group platform from the API Key stored in context. // getGroupPlatform extracts the group platform from the API Key stored in context.
......
...@@ -22,7 +22,6 @@ func newGatewayRoutesTestRouter() *gin.Engine { ...@@ -22,7 +22,6 @@ func newGatewayRoutesTestRouter() *gin.Engine {
&handler.Handlers{ &handler.Handlers{
Gateway: &handler.GatewayHandler{}, Gateway: &handler.GatewayHandler{},
OpenAIGateway: &handler.OpenAIGatewayHandler{}, OpenAIGateway: &handler.OpenAIGatewayHandler{},
SoraGateway: &handler.SoraGatewayHandler{},
}, },
servermiddleware.APIKeyAuthMiddleware(func(c *gin.Context) { servermiddleware.APIKeyAuthMiddleware(func(c *gin.Context) {
c.Next() c.Next()
......
package routes
import (
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// RegisterSoraClientRoutes 注册 Sora 客户端 API 路由(需要用户认证)。
func RegisterSoraClientRoutes(
v1 *gin.RouterGroup,
h *handler.Handlers,
jwtAuth middleware.JWTAuthMiddleware,
settingService *service.SettingService,
) {
if h.SoraClient == nil {
return
}
authenticated := v1.Group("/sora")
authenticated.Use(gin.HandlerFunc(jwtAuth))
authenticated.Use(middleware.BackendModeUserGuard(settingService))
{
authenticated.POST("/generate", h.SoraClient.Generate)
authenticated.GET("/generations", h.SoraClient.ListGenerations)
authenticated.GET("/generations/:id", h.SoraClient.GetGeneration)
authenticated.DELETE("/generations/:id", h.SoraClient.DeleteGeneration)
authenticated.POST("/generations/:id/cancel", h.SoraClient.CancelGeneration)
authenticated.POST("/generations/:id/save", h.SoraClient.SaveToStorage)
authenticated.GET("/quota", h.SoraClient.GetQuota)
authenticated.GET("/models", h.SoraClient.GetModels)
authenticated.GET("/storage-status", h.SoraClient.GetStorageStatus)
}
}
...@@ -28,8 +28,7 @@ type AccountRepository interface { ...@@ -28,8 +28,7 @@ type AccountRepository interface {
// GetByCRSAccountID finds an account previously synced from CRS. // GetByCRSAccountID finds an account previously synced from CRS.
// Returns (nil, nil) if not found. // Returns (nil, nil) if not found.
GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error)
// FindByExtraField 根据 extra 字段中的键值对查找账号(限定 platform='sora') // FindByExtraField 根据 extra 字段中的键值对查找账号
// 用于查找通过 linked_openai_account_id 关联的 Sora 账号
FindByExtraField(ctx context.Context, key string, value any) ([]Account, error) FindByExtraField(ctx context.Context, key string, value any) ([]Account, error)
// ListCRSAccountIDs returns a map of crs_account_id -> local account ID // ListCRSAccountIDs returns a map of crs_account_id -> local account ID
// for all accounts that have been synced from CRS. // for all accounts that have been synced from CRS.
......
...@@ -13,18 +13,14 @@ import ( ...@@ -13,18 +13,14 @@ import (
"log" "log"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url"
"regexp" "regexp"
"strings" "strings"
"sync"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
...@@ -37,11 +33,6 @@ var sseDataPrefix = regexp.MustCompile(`^data:\s*`) ...@@ -37,11 +33,6 @@ var sseDataPrefix = regexp.MustCompile(`^data:\s*`)
const ( const (
testClaudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true" testClaudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses" chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses"
soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接
soraBillingAPIURL = "https://sora.chatgpt.com/backend/billing/subscriptions"
soraInviteMineURL = "https://sora.chatgpt.com/backend/project_y/invite/mine"
soraBootstrapURL = "https://sora.chatgpt.com/backend/m/bootstrap"
soraRemainingURL = "https://sora.chatgpt.com/backend/nf/check"
) )
// TestEvent represents a SSE event for account testing // TestEvent represents a SSE event for account testing
...@@ -71,13 +62,8 @@ type AccountTestService struct { ...@@ -71,13 +62,8 @@ type AccountTestService struct {
httpUpstream HTTPUpstream httpUpstream HTTPUpstream
cfg *config.Config cfg *config.Config
tlsFPProfileService *TLSFingerprintProfileService tlsFPProfileService *TLSFingerprintProfileService
soraTestGuardMu sync.Mutex
soraTestLastRun map[int64]time.Time
soraTestCooldown time.Duration
} }
const defaultSoraTestCooldown = 10 * time.Second
// NewAccountTestService creates a new AccountTestService // NewAccountTestService creates a new AccountTestService
func NewAccountTestService( func NewAccountTestService(
accountRepo AccountRepository, accountRepo AccountRepository,
...@@ -94,8 +80,6 @@ func NewAccountTestService( ...@@ -94,8 +80,6 @@ func NewAccountTestService(
httpUpstream: httpUpstream, httpUpstream: httpUpstream,
cfg: cfg, cfg: cfg,
tlsFPProfileService: tlsFPProfileService, tlsFPProfileService: tlsFPProfileService,
soraTestLastRun: make(map[int64]time.Time),
soraTestCooldown: defaultSoraTestCooldown,
} }
} }
...@@ -197,10 +181,6 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int ...@@ -197,10 +181,6 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
return s.routeAntigravityTest(c, account, modelID, prompt) return s.routeAntigravityTest(c, account, modelID, prompt)
} }
if account.Platform == PlatformSora {
return s.testSoraAccountConnection(c, account)
}
return s.testClaudeAccountConnection(c, account, modelID) return s.testClaudeAccountConnection(c, account, modelID)
} }
...@@ -634,698 +614,6 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account ...@@ -634,698 +614,6 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
return s.processGeminiStream(c, resp.Body) return s.processGeminiStream(c, resp.Body)
} }
type soraProbeStep struct {
Name string `json:"name"`
Status string `json:"status"`
HTTPStatus int `json:"http_status,omitempty"`
ErrorCode string `json:"error_code,omitempty"`
Message string `json:"message,omitempty"`
}
type soraProbeSummary struct {
Status string `json:"status"`
Steps []soraProbeStep `json:"steps"`
}
type soraProbeRecorder struct {
steps []soraProbeStep
}
func (r *soraProbeRecorder) addStep(name, status string, httpStatus int, errorCode, message string) {
r.steps = append(r.steps, soraProbeStep{
Name: name,
Status: status,
HTTPStatus: httpStatus,
ErrorCode: strings.TrimSpace(errorCode),
Message: strings.TrimSpace(message),
})
}
func (r *soraProbeRecorder) finalize() soraProbeSummary {
meSuccess := false
partial := false
for _, step := range r.steps {
if step.Name == "me" {
meSuccess = strings.EqualFold(step.Status, "success")
continue
}
if strings.EqualFold(step.Status, "failed") {
partial = true
}
}
status := "success"
if !meSuccess {
status = "failed"
} else if partial {
status = "partial_success"
}
return soraProbeSummary{
Status: status,
Steps: append([]soraProbeStep(nil), r.steps...),
}
}
func (s *AccountTestService) emitSoraProbeSummary(c *gin.Context, rec *soraProbeRecorder) {
if rec == nil {
return
}
summary := rec.finalize()
code := ""
for _, step := range summary.Steps {
if strings.EqualFold(step.Status, "failed") && strings.TrimSpace(step.ErrorCode) != "" {
code = step.ErrorCode
break
}
}
s.sendEvent(c, TestEvent{
Type: "sora_test_result",
Status: summary.Status,
Code: code,
Data: summary,
})
}
func (s *AccountTestService) acquireSoraTestPermit(accountID int64) (time.Duration, bool) {
if accountID <= 0 {
return 0, true
}
s.soraTestGuardMu.Lock()
defer s.soraTestGuardMu.Unlock()
if s.soraTestLastRun == nil {
s.soraTestLastRun = make(map[int64]time.Time)
}
cooldown := s.soraTestCooldown
if cooldown <= 0 {
cooldown = defaultSoraTestCooldown
}
now := time.Now()
if lastRun, ok := s.soraTestLastRun[accountID]; ok {
elapsed := now.Sub(lastRun)
if elapsed < cooldown {
return cooldown - elapsed, false
}
}
s.soraTestLastRun[accountID] = now
return 0, true
}
func ceilSeconds(d time.Duration) int {
if d <= 0 {
return 1
}
sec := int(d / time.Second)
if d%time.Second != 0 {
sec++
}
if sec < 1 {
sec = 1
}
return sec
}
// testSoraAPIKeyAccountConnection 测试 Sora apikey 类型账号的连通性。
// 向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性和 API Key 有效性。
func (s *AccountTestService) testSoraAPIKeyAccountConnection(c *gin.Context, account *Account) error {
ctx := c.Request.Context()
apiKey := account.GetCredential("api_key")
if apiKey == "" {
return s.sendErrorAndEnd(c, "Sora apikey 账号缺少 api_key 凭证")
}
baseURL := account.GetBaseURL()
if baseURL == "" {
return s.sendErrorAndEnd(c, "Sora apikey 账号缺少 base_url")
}
// 验证 base_url 格式
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("base_url 无效: %s", err.Error()))
}
upstreamURL := strings.TrimSuffix(normalizedBaseURL, "/") + "/sora/v1/chat/completions"
// 设置 SSE 头
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Flush()
if wait, ok := s.acquireSoraTestPermit(account.ID); !ok {
msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait))
return s.sendErrorAndEnd(c, msg)
}
s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora-upstream"})
// 构建轻量级 prompt-enhance 请求作为连通性测试
testPayload := map[string]any{
"model": "prompt-enhance-short-10s",
"messages": []map[string]string{{"role": "user", "content": "test"}},
"stream": false,
}
payloadBytes, _ := json.Marshal(testPayload)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(payloadBytes))
if err != nil {
return s.sendErrorAndEnd(c, "构建测试请求失败")
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
// 获取代理 URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("上游连接失败: %s", err.Error()))
}
defer func() { _ = resp.Body.Close() }()
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 64*1024))
if resp.StatusCode == http.StatusOK {
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("上游连接成功 (%s)", upstreamURL)})
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("API Key 有效 (HTTP %d)", resp.StatusCode)})
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
return s.sendErrorAndEnd(c, fmt.Sprintf("上游认证失败 (HTTP %d),请检查 API Key 是否正确", resp.StatusCode))
}
// 其他错误但能连通(如 400 参数错误)也算连通性测试通过
if resp.StatusCode == http.StatusBadRequest {
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("上游连接成功 (%s)", upstreamURL)})
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("API Key 有效(上游返回 %d,参数校验错误属正常)", resp.StatusCode)})
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
return s.sendErrorAndEnd(c, fmt.Sprintf("上游返回异常 HTTP %d: %s", resp.StatusCode, truncateSoraErrorBody(respBody, 256)))
}
// testSoraAccountConnection 测试 Sora 账号的连接
// OAuth 类型:调用 /backend/me 接口验证 access_token 有效性
// APIKey 类型:向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性
func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *Account) error {
// apikey 类型走独立测试流程
if account.Type == AccountTypeAPIKey {
return s.testSoraAPIKeyAccountConnection(c, account)
}
ctx := c.Request.Context()
recorder := &soraProbeRecorder{}
authToken := account.GetCredential("access_token")
if authToken == "" {
recorder.addStep("me", "failed", http.StatusUnauthorized, "missing_access_token", "No access token available")
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, "No access token available")
}
// Set SSE headers
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Flush()
if wait, ok := s.acquireSoraTestPermit(account.ID); !ok {
msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait))
recorder.addStep("rate_limit", "failed", http.StatusTooManyRequests, "test_rate_limited", msg)
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, msg)
}
// Send test_start event
s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora"})
req, err := http.NewRequestWithContext(ctx, "GET", soraMeAPIURL, nil)
if err != nil {
recorder.addStep("me", "failed", 0, "request_build_failed", err.Error())
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, "Failed to create request")
}
// 使用 Sora 客户端标准请求头
req.Header.Set("Authorization", "Bearer "+authToken)
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
req.Header.Set("Accept", "application/json")
req.Header.Set("Accept-Language", "en-US,en;q=0.9")
req.Header.Set("Origin", "https://sora.chatgpt.com")
req.Header.Set("Referer", "https://sora.chatgpt.com/")
// Get proxy URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
soraTLSProfile := s.resolveSoraTLSProfile()
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, soraTLSProfile)
if err != nil {
recorder.addStep("me", "failed", 0, "network_error", err.Error())
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
if isCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) {
recorder.addStep("me", "failed", resp.StatusCode, "cf_challenge", "Cloudflare challenge detected")
s.emitSoraProbeSummary(c, recorder)
s.logSoraCloudflareChallenge(account, proxyURL, soraMeAPIURL, resp.Header, body)
return s.sendErrorAndEnd(c, formatCloudflareChallengeMessage(fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", resp.StatusCode), resp.Header, body))
}
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(body)
switch {
case resp.StatusCode == http.StatusUnauthorized && strings.EqualFold(upstreamCode, "token_invalidated"):
recorder.addStep("me", "failed", resp.StatusCode, "token_invalidated", "Sora token invalidated")
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, "Sora token 已失效(token_invalidated),请重新授权账号")
case strings.EqualFold(upstreamCode, "unsupported_country_code"):
recorder.addStep("me", "failed", resp.StatusCode, "unsupported_country_code", "Sora is unavailable in current egress region")
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, "Sora 在当前网络出口地区不可用(unsupported_country_code),请切换到支持地区后重试")
case strings.TrimSpace(upstreamMessage) != "":
recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, upstreamMessage)
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, upstreamMessage))
default:
recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, "Sora me endpoint failed")
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, truncateSoraErrorBody(body, 512)))
}
}
recorder.addStep("me", "success", resp.StatusCode, "", "me endpoint ok")
// 解析 /me 响应,提取用户信息
var meResp map[string]any
if err := json.Unmarshal(body, &meResp); err != nil {
// 能收到 200 就说明 token 有效
s.sendEvent(c, TestEvent{Type: "content", Text: "Sora connection OK (token valid)"})
} else {
// 尝试提取用户名或邮箱信息
info := "Sora connection OK"
if name, ok := meResp["name"].(string); ok && name != "" {
info = fmt.Sprintf("Sora connection OK - User: %s", name)
} else if email, ok := meResp["email"].(string); ok && email != "" {
info = fmt.Sprintf("Sora connection OK - Email: %s", email)
}
s.sendEvent(c, TestEvent{Type: "content", Text: info})
}
// 追加轻量能力检查:订阅信息查询(失败仅告警,不中断连接测试)
subReq, err := http.NewRequestWithContext(ctx, "GET", soraBillingAPIURL, nil)
if err == nil {
subReq.Header.Set("Authorization", "Bearer "+authToken)
subReq.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
subReq.Header.Set("Accept", "application/json")
subReq.Header.Set("Accept-Language", "en-US,en;q=0.9")
subReq.Header.Set("Origin", "https://sora.chatgpt.com")
subReq.Header.Set("Referer", "https://sora.chatgpt.com/")
subResp, subErr := s.httpUpstream.DoWithTLS(subReq, proxyURL, account.ID, account.Concurrency, soraTLSProfile)
if subErr != nil {
recorder.addStep("subscription", "failed", 0, "network_error", subErr.Error())
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check skipped: %s", subErr.Error())})
} else {
subBody, _ := io.ReadAll(subResp.Body)
_ = subResp.Body.Close()
if subResp.StatusCode == http.StatusOK {
recorder.addStep("subscription", "success", subResp.StatusCode, "", "subscription endpoint ok")
if summary := parseSoraSubscriptionSummary(subBody); summary != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: summary})
} else {
s.sendEvent(c, TestEvent{Type: "content", Text: "Subscription check OK"})
}
} else {
if isCloudflareChallengeResponse(subResp.StatusCode, subResp.Header, subBody) {
recorder.addStep("subscription", "failed", subResp.StatusCode, "cf_challenge", "Cloudflare challenge detected")
s.logSoraCloudflareChallenge(account, proxyURL, soraBillingAPIURL, subResp.Header, subBody)
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Subscription check blocked by Cloudflare challenge (HTTP %d)", subResp.StatusCode), subResp.Header, subBody)})
} else {
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(subBody)
recorder.addStep("subscription", "failed", subResp.StatusCode, upstreamCode, upstreamMessage)
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check returned %d", subResp.StatusCode)})
}
}
}
}
// 追加 Sora2 能力探测(对齐 sora2api 的测试思路):邀请码 + 剩余额度。
s.testSora2Capabilities(c, ctx, account, authToken, proxyURL, soraTLSProfile, recorder)
s.emitSoraProbeSummary(c, recorder)
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
func (s *AccountTestService) testSora2Capabilities(
c *gin.Context,
ctx context.Context,
account *Account,
authToken string,
proxyURL string,
tlsProfile *tlsfingerprint.Profile,
recorder *soraProbeRecorder,
) {
inviteStatus, inviteHeader, inviteBody, err := s.fetchSoraTestEndpoint(
ctx,
account,
authToken,
soraInviteMineURL,
proxyURL,
tlsProfile,
)
if err != nil {
if recorder != nil {
recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error())
}
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check skipped: %s", err.Error())})
return
}
if inviteStatus == http.StatusUnauthorized {
bootstrapStatus, _, _, bootstrapErr := s.fetchSoraTestEndpoint(
ctx,
account,
authToken,
soraBootstrapURL,
proxyURL,
tlsProfile,
)
if bootstrapErr == nil && bootstrapStatus == http.StatusOK {
if recorder != nil {
recorder.addStep("sora2_bootstrap", "success", bootstrapStatus, "", "bootstrap endpoint ok")
}
s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 bootstrap OK, retry invite check"})
inviteStatus, inviteHeader, inviteBody, err = s.fetchSoraTestEndpoint(
ctx,
account,
authToken,
soraInviteMineURL,
proxyURL,
tlsProfile,
)
if err != nil {
if recorder != nil {
recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error())
}
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite retry failed: %s", err.Error())})
return
}
} else if recorder != nil {
code := ""
msg := ""
if bootstrapErr != nil {
code = "network_error"
msg = bootstrapErr.Error()
}
recorder.addStep("sora2_bootstrap", "failed", bootstrapStatus, code, msg)
}
}
if inviteStatus != http.StatusOK {
if isCloudflareChallengeResponse(inviteStatus, inviteHeader, inviteBody) {
if recorder != nil {
recorder.addStep("sora2_invite", "failed", inviteStatus, "cf_challenge", "Cloudflare challenge detected")
}
s.logSoraCloudflareChallenge(account, proxyURL, soraInviteMineURL, inviteHeader, inviteBody)
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 invite check blocked by Cloudflare challenge (HTTP %d)", inviteStatus), inviteHeader, inviteBody)})
return
}
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(inviteBody)
if recorder != nil {
recorder.addStep("sora2_invite", "failed", inviteStatus, upstreamCode, upstreamMessage)
}
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check returned %d", inviteStatus)})
return
}
if recorder != nil {
recorder.addStep("sora2_invite", "success", inviteStatus, "", "invite endpoint ok")
}
if summary := parseSoraInviteSummary(inviteBody); summary != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: summary})
} else {
s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 invite check OK"})
}
remainingStatus, remainingHeader, remainingBody, remainingErr := s.fetchSoraTestEndpoint(
ctx,
account,
authToken,
soraRemainingURL,
proxyURL,
tlsProfile,
)
if remainingErr != nil {
if recorder != nil {
recorder.addStep("sora2_remaining", "failed", 0, "network_error", remainingErr.Error())
}
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check skipped: %s", remainingErr.Error())})
return
}
if remainingStatus != http.StatusOK {
if isCloudflareChallengeResponse(remainingStatus, remainingHeader, remainingBody) {
if recorder != nil {
recorder.addStep("sora2_remaining", "failed", remainingStatus, "cf_challenge", "Cloudflare challenge detected")
}
s.logSoraCloudflareChallenge(account, proxyURL, soraRemainingURL, remainingHeader, remainingBody)
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 remaining check blocked by Cloudflare challenge (HTTP %d)", remainingStatus), remainingHeader, remainingBody)})
return
}
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(remainingBody)
if recorder != nil {
recorder.addStep("sora2_remaining", "failed", remainingStatus, upstreamCode, upstreamMessage)
}
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check returned %d", remainingStatus)})
return
}
if recorder != nil {
recorder.addStep("sora2_remaining", "success", remainingStatus, "", "remaining endpoint ok")
}
if summary := parseSoraRemainingSummary(remainingBody); summary != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: summary})
} else {
s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 remaining check OK"})
}
}
func (s *AccountTestService) fetchSoraTestEndpoint(
ctx context.Context,
account *Account,
authToken string,
url string,
proxyURL string,
tlsProfile *tlsfingerprint.Profile,
) (int, http.Header, []byte, error) {
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return 0, nil, nil, err
}
req.Header.Set("Authorization", "Bearer "+authToken)
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
req.Header.Set("Accept", "application/json")
req.Header.Set("Accept-Language", "en-US,en;q=0.9")
req.Header.Set("Origin", "https://sora.chatgpt.com")
req.Header.Set("Referer", "https://sora.chatgpt.com/")
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, tlsProfile)
if err != nil {
return 0, nil, nil, err
}
defer func() { _ = resp.Body.Close() }()
body, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return resp.StatusCode, resp.Header, nil, readErr
}
return resp.StatusCode, resp.Header, body, nil
}
func parseSoraSubscriptionSummary(body []byte) string {
var subResp struct {
Data []struct {
Plan struct {
ID string `json:"id"`
Title string `json:"title"`
} `json:"plan"`
EndTS string `json:"end_ts"`
} `json:"data"`
}
if err := json.Unmarshal(body, &subResp); err != nil {
return ""
}
if len(subResp.Data) == 0 {
return ""
}
first := subResp.Data[0]
parts := make([]string, 0, 3)
if first.Plan.Title != "" {
parts = append(parts, first.Plan.Title)
}
if first.Plan.ID != "" {
parts = append(parts, first.Plan.ID)
}
if first.EndTS != "" {
parts = append(parts, "end="+first.EndTS)
}
if len(parts) == 0 {
return ""
}
return "Subscription: " + strings.Join(parts, " | ")
}
func parseSoraInviteSummary(body []byte) string {
var inviteResp struct {
InviteCode string `json:"invite_code"`
RedeemedCount int64 `json:"redeemed_count"`
TotalCount int64 `json:"total_count"`
}
if err := json.Unmarshal(body, &inviteResp); err != nil {
return ""
}
parts := []string{"Sora2: supported"}
if inviteResp.InviteCode != "" {
parts = append(parts, "invite="+inviteResp.InviteCode)
}
if inviteResp.TotalCount > 0 {
parts = append(parts, fmt.Sprintf("used=%d/%d", inviteResp.RedeemedCount, inviteResp.TotalCount))
}
return strings.Join(parts, " | ")
}
func parseSoraRemainingSummary(body []byte) string {
var remainingResp struct {
RateLimitAndCreditBalance struct {
EstimatedNumVideosRemaining int64 `json:"estimated_num_videos_remaining"`
RateLimitReached bool `json:"rate_limit_reached"`
AccessResetsInSeconds int64 `json:"access_resets_in_seconds"`
} `json:"rate_limit_and_credit_balance"`
}
if err := json.Unmarshal(body, &remainingResp); err != nil {
return ""
}
info := remainingResp.RateLimitAndCreditBalance
parts := []string{fmt.Sprintf("Sora2 remaining: %d", info.EstimatedNumVideosRemaining)}
if info.RateLimitReached {
parts = append(parts, "rate_limited=true")
}
if info.AccessResetsInSeconds > 0 {
parts = append(parts, fmt.Sprintf("reset_in=%ds", info.AccessResetsInSeconds))
}
return strings.Join(parts, " | ")
}
func (s *AccountTestService) resolveSoraTLSProfile() *tlsfingerprint.Profile {
if s == nil || s.cfg == nil || !s.cfg.Sora.Client.DisableTLSFingerprint {
// Sora TLS fingerprint enabled — use built-in default profile
return &tlsfingerprint.Profile{Name: "Built-in Default (Sora)"}
}
return nil // disabled
}
func isCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body)
}
func formatCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
return soraerror.FormatCloudflareChallengeMessage(base, headers, body)
}
func extractCloudflareRayID(headers http.Header, body []byte) string {
return soraerror.ExtractCloudflareRayID(headers, body)
}
func extractSoraEgressIPHint(headers http.Header) string {
if headers == nil {
return "unknown"
}
candidates := []string{
"x-openai-public-ip",
"x-envoy-external-address",
"cf-connecting-ip",
"x-forwarded-for",
}
for _, key := range candidates {
if value := strings.TrimSpace(headers.Get(key)); value != "" {
return value
}
}
return "unknown"
}
func sanitizeProxyURLForLog(raw string) string {
raw = strings.TrimSpace(raw)
if raw == "" {
return ""
}
u, err := url.Parse(raw)
if err != nil {
return "<invalid_proxy_url>"
}
if u.User != nil {
u.User = nil
}
return u.String()
}
func endpointPathForLog(endpoint string) string {
parsed, err := url.Parse(strings.TrimSpace(endpoint))
if err != nil || parsed.Path == "" {
return endpoint
}
return parsed.Path
}
func (s *AccountTestService) logSoraCloudflareChallenge(account *Account, proxyURL, endpoint string, headers http.Header, body []byte) {
accountID := int64(0)
platform := ""
proxyID := "none"
if account != nil {
accountID = account.ID
platform = account.Platform
if account.ProxyID != nil {
proxyID = fmt.Sprintf("%d", *account.ProxyID)
}
}
cfRay := extractCloudflareRayID(headers, body)
if cfRay == "" {
cfRay = "unknown"
}
log.Printf(
"[SoraCFChallenge] account_id=%d platform=%s endpoint=%s path=%s proxy_id=%s proxy_url=%s cf_ray=%s egress_ip_hint=%s",
accountID,
platform,
endpoint,
endpointPathForLog(endpoint),
proxyID,
sanitizeProxyURLForLog(proxyURL),
cfRay,
extractSoraEgressIPHint(headers),
)
}
func truncateSoraErrorBody(body []byte, max int) string {
return soraerror.TruncateBody(body, max)
}
// routeAntigravityTest 路由 Antigravity 账号的测试请求。 // routeAntigravityTest 路由 Antigravity 账号的测试请求。
// APIKey 类型走原生协议(与 gateway_handler 路由一致),OAuth/Upstream 走 CRS 中转。 // APIKey 类型走原生协议(与 gateway_handler 路由一致),OAuth/Upstream 走 CRS 中转。
func (s *AccountTestService) routeAntigravityTest(c *gin.Context, account *Account, modelID string, prompt string) error { func (s *AccountTestService) routeAntigravityTest(c *gin.Context, account *Account, modelID string, prompt string) error {
......
...@@ -42,7 +42,7 @@ func TestProcessGeminiStream_EmitsImageEvent(t *testing.T) { ...@@ -42,7 +42,7 @@ func TestProcessGeminiStream_EmitsImageEvent(t *testing.T) {
t.Parallel() t.Parallel()
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
ctx, recorder := newSoraTestContext() ctx, recorder := newTestContext()
svc := &AccountTestService{} svc := &AccountTestService{}
stream := strings.NewReader("data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"},{\"inlineData\":{\"mimeType\":\"image/png\",\"data\":\"QUJD\"}}]}}]}\n\ndata: [DONE]\n\n") stream := strings.NewReader("data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"},{\"inlineData\":{\"mimeType\":\"image/png\",\"data\":\"QUJD\"}}]}}]}\n\ndata: [DONE]\n\n")
......
...@@ -4,16 +4,61 @@ package service ...@@ -4,16 +4,61 @@ package service
import ( import (
"context" "context"
"fmt"
"io" "io"
"net/http" "net/http"
"net/http/httptest"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
) )
// --- shared test helpers ---
type queuedHTTPUpstream struct {
responses []*http.Response
requests []*http.Request
tlsFlags []bool
}
func (u *queuedHTTPUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
return nil, fmt.Errorf("unexpected Do call")
}
func (u *queuedHTTPUpstream) DoWithTLS(req *http.Request, _ string, _ int64, _ int, profile *tlsfingerprint.Profile) (*http.Response, error) {
u.requests = append(u.requests, req)
u.tlsFlags = append(u.tlsFlags, profile != nil)
if len(u.responses) == 0 {
return nil, fmt.Errorf("no mocked response")
}
resp := u.responses[0]
u.responses = u.responses[1:]
return resp, nil
}
func newJSONResponse(status int, body string) *http.Response {
return &http.Response{
StatusCode: status,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(body)),
}
}
// --- test functions ---
func newTestContext() (*gin.Context, *httptest.ResponseRecorder) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil)
return c, rec
}
type openAIAccountTestRepo struct { type openAIAccountTestRepo struct {
mockAccountRepoForGemini mockAccountRepoForGemini
updatedExtra map[string]any updatedExtra map[string]any
...@@ -34,7 +79,7 @@ func (r *openAIAccountTestRepo) SetRateLimited(_ context.Context, id int64, rese ...@@ -34,7 +79,7 @@ func (r *openAIAccountTestRepo) SetRateLimited(_ context.Context, id int64, rese
func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.T) { func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
ctx, recorder := newSoraTestContext() ctx, recorder := newTestContext()
resp := newJSONResponse(http.StatusOK, "") resp := newJSONResponse(http.StatusOK, "")
resp.Body = io.NopCloser(strings.NewReader(`data: {"type":"response.completed"} resp.Body = io.NopCloser(strings.NewReader(`data: {"type":"response.completed"}
...@@ -68,7 +113,7 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing. ...@@ -68,7 +113,7 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.
func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimit(t *testing.T) { func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimit(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
ctx, _ := newSoraTestContext() ctx, _ := newTestContext()
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`) resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`)
resp.Header.Set("x-codex-primary-used-percent", "100") resp.Header.Set("x-codex-primary-used-percent", "100")
......
package service
import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type queuedHTTPUpstream struct {
responses []*http.Response
requests []*http.Request
tlsFlags []bool
}
func (u *queuedHTTPUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
return nil, fmt.Errorf("unexpected Do call")
}
func (u *queuedHTTPUpstream) DoWithTLS(req *http.Request, _ string, _ int64, _ int, profile *tlsfingerprint.Profile) (*http.Response, error) {
u.requests = append(u.requests, req)
u.tlsFlags = append(u.tlsFlags, profile != nil)
if len(u.responses) == 0 {
return nil, fmt.Errorf("no mocked response")
}
resp := u.responses[0]
u.responses = u.responses[1:]
return resp, nil
}
func newJSONResponse(status int, body string) *http.Response {
return &http.Response{
StatusCode: status,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(body)),
}
}
func newJSONResponseWithHeader(status int, body, key, value string) *http.Response {
resp := newJSONResponse(status, body)
resp.Header.Set(key, value)
return resp
}
func newSoraTestContext() (*gin.Context, *httptest.ResponseRecorder) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil)
return c, rec
}
func TestAccountTestService_testSoraAccountConnection_WithSubscription(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`),
newJSONResponse(http.StatusOK, `{"data":[{"plan":{"id":"chatgpt_plus","title":"ChatGPT Plus"},"end_ts":"2026-12-31T00:00:00Z"}]}`),
newJSONResponse(http.StatusOK, `{"invite_code":"inv_abc","redeemed_count":3,"total_count":50}`),
newJSONResponse(http.StatusOK, `{"rate_limit_and_credit_balance":{"estimated_num_videos_remaining":27,"rate_limit_reached":false,"access_resets_in_seconds":46833}}`),
},
}
svc := &AccountTestService{
httpUpstream: upstream,
cfg: &config.Config{
Gateway: config.GatewayConfig{
TLSFingerprint: config.TLSFingerprintConfig{
Enabled: true,
},
},
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
DisableTLSFingerprint: false,
},
},
},
}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.NoError(t, err)
require.Len(t, upstream.requests, 4)
require.Equal(t, soraMeAPIURL, upstream.requests[0].URL.String())
require.Equal(t, soraBillingAPIURL, upstream.requests[1].URL.String())
require.Equal(t, soraInviteMineURL, upstream.requests[2].URL.String())
require.Equal(t, soraRemainingURL, upstream.requests[3].URL.String())
require.Equal(t, "Bearer test_token", upstream.requests[0].Header.Get("Authorization"))
require.Equal(t, "Bearer test_token", upstream.requests[1].Header.Get("Authorization"))
require.Equal(t, []bool{true, true, true, true}, upstream.tlsFlags)
body := rec.Body.String()
require.Contains(t, body, `"type":"test_start"`)
require.Contains(t, body, "Sora connection OK - Email: demo@example.com")
require.Contains(t, body, "Subscription: ChatGPT Plus | chatgpt_plus | end=2026-12-31T00:00:00Z")
require.Contains(t, body, "Sora2: supported | invite=inv_abc | used=3/50")
require.Contains(t, body, "Sora2 remaining: 27 | reset_in=46833s")
require.Contains(t, body, `"type":"sora_test_result"`)
require.Contains(t, body, `"status":"success"`)
require.Contains(t, body, `"type":"test_complete","success":true`)
}
func TestAccountTestService_testSoraAccountConnection_SubscriptionFailedStillSuccess(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusOK, `{"name":"demo-user"}`),
newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`),
newJSONResponse(http.StatusUnauthorized, `{"error":{"message":"Unauthorized"}}`),
newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.NoError(t, err)
require.Len(t, upstream.requests, 4)
body := rec.Body.String()
require.Contains(t, body, "Sora connection OK - User: demo-user")
require.Contains(t, body, "Subscription check returned 403")
require.Contains(t, body, "Sora2 invite check returned 401")
require.Contains(t, body, `"type":"sora_test_result"`)
require.Contains(t, body, `"status":"partial_success"`)
require.Contains(t, body, `"type":"test_complete","success":true`)
}
func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponseWithHeader(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`, "cf-ray", "9cff2d62d83bb98d"),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.Error(t, err)
require.Contains(t, err.Error(), "Cloudflare challenge")
require.Contains(t, err.Error(), "cf-ray: 9cff2d62d83bb98d")
body := rec.Body.String()
require.Contains(t, body, `"type":"error"`)
require.Contains(t, body, "Cloudflare challenge")
require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
}
func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge429WithHeader(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponseWithHeader(http.StatusTooManyRequests, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body></body></html>`, "cf-mitigated", "challenge"),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.Error(t, err)
require.Contains(t, err.Error(), "Cloudflare challenge")
require.Contains(t, err.Error(), "HTTP 429")
body := rec.Body.String()
require.Contains(t, body, "Cloudflare challenge")
}
func TestAccountTestService_testSoraAccountConnection_TokenInvalidated(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusUnauthorized, `{"error":{"code":"token_invalidated","message":"Token invalid"}}`),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.Error(t, err)
require.Contains(t, err.Error(), "token_invalidated")
body := rec.Body.String()
require.Contains(t, body, `"type":"sora_test_result"`)
require.Contains(t, body, `"status":"failed"`)
require.Contains(t, body, "token_invalidated")
require.NotContains(t, body, `"type":"test_complete","success":true`)
}
func TestAccountTestService_testSoraAccountConnection_RateLimited(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`),
},
}
svc := &AccountTestService{
httpUpstream: upstream,
soraTestCooldown: time.Hour,
}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c1, _ := newSoraTestContext()
err := svc.testSoraAccountConnection(c1, account)
require.NoError(t, err)
c2, rec2 := newSoraTestContext()
err = svc.testSoraAccountConnection(c2, account)
require.Error(t, err)
require.Contains(t, err.Error(), "测试过于频繁")
body := rec2.Body.String()
require.Contains(t, body, `"type":"sora_test_result"`)
require.Contains(t, body, `"code":"test_rate_limited"`)
require.Contains(t, body, `"status":"failed"`)
require.NotContains(t, body, `"type":"test_complete","success":true`)
}
func TestAccountTestService_testSoraAccountConnection_SubscriptionCloudflareChallengeWithRay(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusOK, `{"name":"demo-user"}`),
newJSONResponse(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`),
newJSONResponse(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.NoError(t, err)
body := rec.Body.String()
require.Contains(t, body, "Subscription check blocked by Cloudflare challenge (HTTP 403)")
require.Contains(t, body, "Sora2 invite check blocked by Cloudflare challenge (HTTP 403)")
require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
require.Contains(t, body, `"type":"test_complete","success":true`)
}
func TestSanitizeProxyURLForLog(t *testing.T) {
require.Equal(t, "http://proxy.example.com:8080", sanitizeProxyURLForLog("http://user:pass@proxy.example.com:8080"))
require.Equal(t, "", sanitizeProxyURLForLog(""))
require.Equal(t, "<invalid_proxy_url>", sanitizeProxyURLForLog("://invalid"))
}
func TestExtractSoraEgressIPHint(t *testing.T) {
h := make(http.Header)
h.Set("x-openai-public-ip", "203.0.113.10")
require.Equal(t, "203.0.113.10", extractSoraEgressIPHint(h))
h2 := make(http.Header)
h2.Set("x-envoy-external-address", "198.51.100.9")
require.Equal(t, "198.51.100.9", extractSoraEgressIPHint(h2))
require.Equal(t, "unknown", extractSoraEgressIPHint(nil))
require.Equal(t, "unknown", extractSoraEgressIPHint(http.Header{}))
}
...@@ -15,7 +15,7 @@ import ( ...@@ -15,7 +15,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/util/soraerror" "github.com/Wei-Shaw/sub2api/internal/util/httputil"
) )
// AdminService interface defines admin management operations // AdminService interface defines admin management operations
...@@ -104,14 +104,13 @@ type AdminService interface { ...@@ -104,14 +104,13 @@ type AdminService interface {
// CreateUserInput represents input for creating a new user via admin operations. // CreateUserInput represents input for creating a new user via admin operations.
type CreateUserInput struct { type CreateUserInput struct {
Email string Email string
Password string Password string
Username string Username string
Notes string Notes string
Balance float64 Balance float64
Concurrency int Concurrency int
AllowedGroups []int64 AllowedGroups []int64
SoraStorageQuotaBytes int64
} }
type UpdateUserInput struct { type UpdateUserInput struct {
...@@ -125,8 +124,7 @@ type UpdateUserInput struct { ...@@ -125,8 +124,7 @@ type UpdateUserInput struct {
AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组" AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组"
// GroupRates 用户专属分组倍率配置 // GroupRates 用户专属分组倍率配置
// map[groupID]*rate,nil 表示删除该分组的专属倍率 // map[groupID]*rate,nil 表示删除该分组的专属倍率
GroupRates map[int64]*float64 GroupRates map[int64]*float64
SoraStorageQuotaBytes *int64
} }
type CreateGroupInput struct { type CreateGroupInput struct {
...@@ -140,16 +138,11 @@ type CreateGroupInput struct { ...@@ -140,16 +138,11 @@ type CreateGroupInput struct {
WeeklyLimitUSD *float64 // 周限额 (USD) WeeklyLimitUSD *float64 // 周限额 (USD)
MonthlyLimitUSD *float64 // 月限额 (USD) MonthlyLimitUSD *float64 // 月限额 (USD)
// 图片生成计费配置(仅 antigravity 平台使用) // 图片生成计费配置(仅 antigravity 平台使用)
ImagePrice1K *float64 ImagePrice1K *float64
ImagePrice2K *float64 ImagePrice2K *float64
ImagePrice4K *float64 ImagePrice4K *float64
// Sora 按次计费配置 ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
SoraImagePrice360 *float64 FallbackGroupID *int64 // 降级分组 ID
SoraImagePrice540 *float64
SoraVideoPricePerRequest *float64
SoraVideoPricePerRequestHD *float64
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID
// 无效请求兜底分组 ID(仅 anthropic 平台使用) // 无效请求兜底分组 ID(仅 anthropic 平台使用)
FallbackGroupIDOnInvalidRequest *int64 FallbackGroupIDOnInvalidRequest *int64
// 模型路由配置(仅 anthropic 平台使用) // 模型路由配置(仅 anthropic 平台使用)
...@@ -158,8 +151,6 @@ type CreateGroupInput struct { ...@@ -158,8 +151,6 @@ type CreateGroupInput struct {
MCPXMLInject *bool MCPXMLInject *bool
// 支持的模型系列(仅 antigravity 平台使用) // 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string SupportedModelScopes []string
// Sora 存储配额
SoraStorageQuotaBytes int64
// OpenAI Messages 调度配置(仅 openai 平台使用) // OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch bool AllowMessagesDispatch bool
DefaultMappedModel string DefaultMappedModel string
...@@ -181,16 +172,11 @@ type UpdateGroupInput struct { ...@@ -181,16 +172,11 @@ type UpdateGroupInput struct {
WeeklyLimitUSD *float64 // 周限额 (USD) WeeklyLimitUSD *float64 // 周限额 (USD)
MonthlyLimitUSD *float64 // 月限额 (USD) MonthlyLimitUSD *float64 // 月限额 (USD)
// 图片生成计费配置(仅 antigravity 平台使用) // 图片生成计费配置(仅 antigravity 平台使用)
ImagePrice1K *float64 ImagePrice1K *float64
ImagePrice2K *float64 ImagePrice2K *float64
ImagePrice4K *float64 ImagePrice4K *float64
// Sora 按次计费配置 ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
SoraImagePrice360 *float64 FallbackGroupID *int64 // 降级分组 ID
SoraImagePrice540 *float64
SoraVideoPricePerRequest *float64
SoraVideoPricePerRequestHD *float64
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID
// 无效请求兜底分组 ID(仅 anthropic 平台使用) // 无效请求兜底分组 ID(仅 anthropic 平台使用)
FallbackGroupIDOnInvalidRequest *int64 FallbackGroupIDOnInvalidRequest *int64
// 模型路由配置(仅 anthropic 平台使用) // 模型路由配置(仅 anthropic 平台使用)
...@@ -199,8 +185,6 @@ type UpdateGroupInput struct { ...@@ -199,8 +185,6 @@ type UpdateGroupInput struct {
MCPXMLInject *bool MCPXMLInject *bool
// 支持的模型系列(仅 antigravity 平台使用) // 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string SupportedModelScopes *[]string
// Sora 存储配额
SoraStorageQuotaBytes *int64
// OpenAI Messages 调度配置(仅 openai 平台使用) // OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch *bool AllowMessagesDispatch *bool
DefaultMappedModel *string DefaultMappedModel *string
...@@ -426,14 +410,6 @@ var proxyQualityTargets = []proxyQualityTarget{ ...@@ -426,14 +410,6 @@ var proxyQualityTargets = []proxyQualityTarget{
http.StatusOK: {}, http.StatusOK: {},
}, },
}, },
{
Target: "sora",
URL: "https://sora.chatgpt.com/backend/me",
Method: http.MethodGet,
AllowedStatuses: map[int]struct{}{
http.StatusUnauthorized: {},
},
},
} }
const ( const (
...@@ -448,7 +424,6 @@ type adminServiceImpl struct { ...@@ -448,7 +424,6 @@ type adminServiceImpl struct {
userRepo UserRepository userRepo UserRepository
groupRepo GroupRepository groupRepo GroupRepository
accountRepo AccountRepository accountRepo AccountRepository
soraAccountRepo SoraAccountRepository // Sora 账号扩展表仓储
proxyRepo ProxyRepository proxyRepo ProxyRepository
apiKeyRepo APIKeyRepository apiKeyRepo APIKeyRepository
redeemCodeRepo RedeemCodeRepository redeemCodeRepo RedeemCodeRepository
...@@ -473,7 +448,6 @@ func NewAdminService( ...@@ -473,7 +448,6 @@ func NewAdminService(
userRepo UserRepository, userRepo UserRepository,
groupRepo GroupRepository, groupRepo GroupRepository,
accountRepo AccountRepository, accountRepo AccountRepository,
soraAccountRepo SoraAccountRepository,
proxyRepo ProxyRepository, proxyRepo ProxyRepository,
apiKeyRepo APIKeyRepository, apiKeyRepo APIKeyRepository,
redeemCodeRepo RedeemCodeRepository, redeemCodeRepo RedeemCodeRepository,
...@@ -492,7 +466,6 @@ func NewAdminService( ...@@ -492,7 +466,6 @@ func NewAdminService(
userRepo: userRepo, userRepo: userRepo,
groupRepo: groupRepo, groupRepo: groupRepo,
accountRepo: accountRepo, accountRepo: accountRepo,
soraAccountRepo: soraAccountRepo,
proxyRepo: proxyRepo, proxyRepo: proxyRepo,
apiKeyRepo: apiKeyRepo, apiKeyRepo: apiKeyRepo,
redeemCodeRepo: redeemCodeRepo, redeemCodeRepo: redeemCodeRepo,
...@@ -574,15 +547,14 @@ func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) ...@@ -574,15 +547,14 @@ func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error)
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) { func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) {
user := &User{ user := &User{
Email: input.Email, Email: input.Email,
Username: input.Username, Username: input.Username,
Notes: input.Notes, Notes: input.Notes,
Role: RoleUser, // Always create as regular user, never admin Role: RoleUser, // Always create as regular user, never admin
Balance: input.Balance, Balance: input.Balance,
Concurrency: input.Concurrency, Concurrency: input.Concurrency,
Status: StatusActive, Status: StatusActive,
AllowedGroups: input.AllowedGroups, AllowedGroups: input.AllowedGroups,
SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
} }
if err := user.SetPassword(input.Password); err != nil { if err := user.SetPassword(input.Password); err != nil {
return nil, err return nil, err
...@@ -654,10 +626,6 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda ...@@ -654,10 +626,6 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
user.AllowedGroups = *input.AllowedGroups user.AllowedGroups = *input.AllowedGroups
} }
if input.SoraStorageQuotaBytes != nil {
user.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes
}
if err := s.userRepo.Update(ctx, user); err != nil { if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err return nil, err
} }
...@@ -860,10 +828,6 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn ...@@ -860,10 +828,6 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
imagePrice1K := normalizePrice(input.ImagePrice1K) imagePrice1K := normalizePrice(input.ImagePrice1K)
imagePrice2K := normalizePrice(input.ImagePrice2K) imagePrice2K := normalizePrice(input.ImagePrice2K)
imagePrice4K := normalizePrice(input.ImagePrice4K) imagePrice4K := normalizePrice(input.ImagePrice4K)
soraImagePrice360 := normalizePrice(input.SoraImagePrice360)
soraImagePrice540 := normalizePrice(input.SoraImagePrice540)
soraVideoPrice := normalizePrice(input.SoraVideoPricePerRequest)
soraVideoPriceHD := normalizePrice(input.SoraVideoPricePerRequestHD)
// 校验降级分组 // 校验降级分组
if input.FallbackGroupID != nil { if input.FallbackGroupID != nil {
...@@ -934,17 +898,12 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn ...@@ -934,17 +898,12 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
ImagePrice1K: imagePrice1K, ImagePrice1K: imagePrice1K,
ImagePrice2K: imagePrice2K, ImagePrice2K: imagePrice2K,
ImagePrice4K: imagePrice4K, ImagePrice4K: imagePrice4K,
SoraImagePrice360: soraImagePrice360,
SoraImagePrice540: soraImagePrice540,
SoraVideoPricePerRequest: soraVideoPrice,
SoraVideoPricePerRequestHD: soraVideoPriceHD,
ClaudeCodeOnly: input.ClaudeCodeOnly, ClaudeCodeOnly: input.ClaudeCodeOnly,
FallbackGroupID: input.FallbackGroupID, FallbackGroupID: input.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest, FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest,
ModelRouting: input.ModelRouting, ModelRouting: input.ModelRouting,
MCPXMLInject: mcpXMLInject, MCPXMLInject: mcpXMLInject,
SupportedModelScopes: input.SupportedModelScopes, SupportedModelScopes: input.SupportedModelScopes,
SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
AllowMessagesDispatch: input.AllowMessagesDispatch, AllowMessagesDispatch: input.AllowMessagesDispatch,
RequireOAuthOnly: input.RequireOAuthOnly, RequireOAuthOnly: input.RequireOAuthOnly,
RequirePrivacySet: input.RequirePrivacySet, RequirePrivacySet: input.RequirePrivacySet,
...@@ -1115,21 +1074,6 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd ...@@ -1115,21 +1074,6 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.ImagePrice4K != nil { if input.ImagePrice4K != nil {
group.ImagePrice4K = normalizePrice(input.ImagePrice4K) group.ImagePrice4K = normalizePrice(input.ImagePrice4K)
} }
if input.SoraImagePrice360 != nil {
group.SoraImagePrice360 = normalizePrice(input.SoraImagePrice360)
}
if input.SoraImagePrice540 != nil {
group.SoraImagePrice540 = normalizePrice(input.SoraImagePrice540)
}
if input.SoraVideoPricePerRequest != nil {
group.SoraVideoPricePerRequest = normalizePrice(input.SoraVideoPricePerRequest)
}
if input.SoraVideoPricePerRequestHD != nil {
group.SoraVideoPricePerRequestHD = normalizePrice(input.SoraVideoPricePerRequestHD)
}
if input.SoraStorageQuotaBytes != nil {
group.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes
}
// Claude Code 客户端限制 // Claude Code 客户端限制
if input.ClaudeCodeOnly != nil { if input.ClaudeCodeOnly != nil {
...@@ -1566,18 +1510,6 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou ...@@ -1566,18 +1510,6 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
} }
} }
// Sora apikey 账号的 base_url 必填校验
if input.Platform == PlatformSora && input.Type == AccountTypeAPIKey {
baseURL, _ := input.Credentials["base_url"].(string)
baseURL = strings.TrimSpace(baseURL)
if baseURL == "" {
return nil, errors.New("sora apikey 账号必须设置 base_url")
}
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
return nil, errors.New("base_url 必须以 http:// 或 https:// 开头")
}
}
account := &Account{ account := &Account{
Name: input.Name, Name: input.Name,
Notes: normalizeAccountNotes(input.Notes), Notes: normalizeAccountNotes(input.Notes),
...@@ -1623,18 +1555,6 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou ...@@ -1623,18 +1555,6 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
return nil, err return nil, err
} }
// 如果是 Sora 平台账号,自动创建 sora_accounts 扩展表记录
if account.Platform == PlatformSora && s.soraAccountRepo != nil {
soraUpdates := map[string]any{
"access_token": account.GetCredential("access_token"),
"refresh_token": account.GetCredential("refresh_token"),
}
if err := s.soraAccountRepo.Upsert(ctx, account.ID, soraUpdates); err != nil {
// 只记录警告日志,不阻塞账号创建
logger.LegacyPrintf("service.admin", "[AdminService] 创建 sora_accounts 记录失败: account_id=%d err=%v", account.ID, err)
}
}
// 绑定分组 // 绑定分组
if len(groupIDs) > 0 { if len(groupIDs) > 0 {
if err := s.accountRepo.BindGroups(ctx, account.ID, groupIDs); err != nil { if err := s.accountRepo.BindGroups(ctx, account.ID, groupIDs); err != nil {
...@@ -1763,18 +1683,6 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U ...@@ -1763,18 +1683,6 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
account.AutoPauseOnExpired = *input.AutoPauseOnExpired account.AutoPauseOnExpired = *input.AutoPauseOnExpired
} }
// Sora apikey 账号的 base_url 必填校验
if account.Platform == PlatformSora && account.Type == AccountTypeAPIKey {
baseURL, _ := account.Credentials["base_url"].(string)
baseURL = strings.TrimSpace(baseURL)
if baseURL == "" {
return nil, errors.New("sora apikey 账号必须设置 base_url")
}
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
return nil, errors.New("base_url 必须以 http:// 或 https:// 开头")
}
}
// 先验证分组是否存在(在任何写操作之前) // 先验证分组是否存在(在任何写操作之前)
if input.GroupIDs != nil { if input.GroupIDs != nil {
if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil { if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil {
...@@ -2377,10 +2285,11 @@ func runProxyQualityTarget(ctx context.Context, client *http.Client, target prox ...@@ -2377,10 +2285,11 @@ func runProxyQualityTarget(ctx context.Context, client *http.Client, target prox
body = body[:proxyQualityMaxBodyBytes] body = body[:proxyQualityMaxBodyBytes]
} }
if target.Target == "sora" && soraerror.IsCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) { // Cloudflare challenge 检测
if httputil.IsCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) {
item.Status = "challenge" item.Status = "challenge"
item.CFRay = soraerror.ExtractCloudflareRayID(resp.Header, body) item.CFRay = httputil.ExtractCloudflareRayID(resp.Header, body)
item.Message = "Sora 命中 Cloudflare challenge" item.Message = "命中 Cloudflare challenge"
return item return item
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment