Commit eb2dce92 authored by 陈曦's avatar 陈曦
Browse files

升级v1.0.8 解决冲突

parents 7b83d6e7 339d906e
......@@ -158,30 +158,6 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_DefaultsToOpenAIClientID() {
require.Equal(s.T(), []string{openai.ClientID}, seenClientIDs)
}
// TestRefreshToken_UseSoraClientID 验证显式传入 Sora ClientID 时直接使用,不回退。
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseSoraClientID() {
var seenClientIDs []string
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
clientID := r.PostForm.Get("client_id")
seenClientIDs = append(seenClientIDs, clientID)
if clientID == openai.SoraClientID {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at-sora","refresh_token":"rt-sora","token_type":"bearer","expires_in":3600}`)
return
}
w.WriteHeader(http.StatusBadRequest)
}))
resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", openai.SoraClientID)
require.NoError(s.T(), err, "RefreshTokenWithClientID")
require.Equal(s.T(), "at-sora", resp.AccessToken)
require.Equal(s.T(), []string{openai.SoraClientID}, seenClientIDs)
}
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseProvidedClientID() {
const customClientID = "custom-client-id"
var seenClientIDs []string
......@@ -276,7 +252,7 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UsesProvidedRedirectURI() {
}
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UseProvidedClientID() {
wantClientID := openai.SoraClientID
wantClientID := "custom-exchange-client-id"
errCh := make(chan string, 1)
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = r.ParseForm()
......
package repository
import (
"context"
"database/sql"
"errors"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// soraAccountRepository 实现 service.SoraAccountRepository 接口。
// 使用原生 SQL 操作 sora_accounts 表,因为该表不在 Ent ORM 管理范围内。
//
// 设计说明:
// - sora_accounts 表是独立迁移创建的,不通过 Ent Schema 管理
// - 使用 ON CONFLICT (account_id) DO UPDATE 实现 Upsert 语义
// - 与 accounts 主表通过外键关联,ON DELETE CASCADE 确保级联删除
type soraAccountRepository struct {
sql *sql.DB
}
// NewSoraAccountRepository 创建 Sora 账号扩展表仓储实例
func NewSoraAccountRepository(sqlDB *sql.DB) service.SoraAccountRepository {
return &soraAccountRepository{sql: sqlDB}
}
// Upsert 创建或更新 Sora 账号扩展信息
// 使用 PostgreSQL ON CONFLICT ... DO UPDATE 实现原子性 upsert
func (r *soraAccountRepository) Upsert(ctx context.Context, accountID int64, updates map[string]any) error {
accessToken, accessOK := updates["access_token"].(string)
refreshToken, refreshOK := updates["refresh_token"].(string)
sessionToken, sessionOK := updates["session_token"].(string)
if !accessOK || accessToken == "" || !refreshOK || refreshToken == "" {
if !sessionOK {
return errors.New("缺少 access_token/refresh_token,且未提供可更新字段")
}
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_accounts
SET session_token = CASE WHEN $2 = '' THEN session_token ELSE $2 END,
updated_at = NOW()
WHERE account_id = $1
`, accountID, sessionToken)
if err != nil {
return err
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return errors.New("sora_accounts 记录不存在,无法仅更新 session_token")
}
return nil
}
_, err := r.sql.ExecContext(ctx, `
INSERT INTO sora_accounts (account_id, access_token, refresh_token, session_token, created_at, updated_at)
VALUES ($1, $2, $3, $4, NOW(), NOW())
ON CONFLICT (account_id) DO UPDATE SET
access_token = EXCLUDED.access_token,
refresh_token = EXCLUDED.refresh_token,
session_token = CASE WHEN EXCLUDED.session_token = '' THEN sora_accounts.session_token ELSE EXCLUDED.session_token END,
updated_at = NOW()
`, accountID, accessToken, refreshToken, sessionToken)
return err
}
// GetByAccountID 根据账号 ID 获取 Sora 扩展信息
func (r *soraAccountRepository) GetByAccountID(ctx context.Context, accountID int64) (*service.SoraAccount, error) {
rows, err := r.sql.QueryContext(ctx, `
SELECT account_id, access_token, refresh_token, COALESCE(session_token, '')
FROM sora_accounts
WHERE account_id = $1
`, accountID)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
return nil, nil // 记录不存在
}
var sa service.SoraAccount
if err := rows.Scan(&sa.AccountID, &sa.AccessToken, &sa.RefreshToken, &sa.SessionToken); err != nil {
return nil, err
}
return &sa, nil
}
// Delete 删除 Sora 账号扩展信息
func (r *soraAccountRepository) Delete(ctx context.Context, accountID int64) error {
_, err := r.sql.ExecContext(ctx, `
DELETE FROM sora_accounts WHERE account_id = $1
`, accountID)
return err
}
package repository
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// soraGenerationRepository 实现 service.SoraGenerationRepository 接口。
// 使用原生 SQL 操作 sora_generations 表。
type soraGenerationRepository struct {
sql *sql.DB
}
// NewSoraGenerationRepository 创建 Sora 生成记录仓储实例。
func NewSoraGenerationRepository(sqlDB *sql.DB) service.SoraGenerationRepository {
return &soraGenerationRepository{sql: sqlDB}
}
func (r *soraGenerationRepository) Create(ctx context.Context, gen *service.SoraGeneration) error {
mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
err := r.sql.QueryRowContext(ctx, `
INSERT INTO sora_generations (
user_id, api_key_id, model, prompt, media_type,
status, media_url, media_urls, file_size_bytes,
storage_type, s3_object_keys, upstream_task_id, error_message
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
RETURNING id, created_at
`,
gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType,
gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage,
).Scan(&gen.ID, &gen.CreatedAt)
return err
}
// CreatePendingWithLimit 在单事务内执行“并发上限检查 + 创建”,避免 count+create 竞态。
func (r *soraGenerationRepository) CreatePendingWithLimit(
ctx context.Context,
gen *service.SoraGeneration,
activeStatuses []string,
maxActive int64,
) error {
if gen == nil {
return fmt.Errorf("generation is nil")
}
if maxActive <= 0 {
return r.Create(ctx, gen)
}
if len(activeStatuses) == 0 {
activeStatuses = []string{service.SoraGenStatusPending, service.SoraGenStatusGenerating}
}
tx, err := r.sql.BeginTx(ctx, nil)
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
// 使用用户级 advisory lock 串行化并发创建,避免超限竞态。
if _, err := tx.ExecContext(ctx, `SELECT pg_advisory_xact_lock($1)`, gen.UserID); err != nil {
return err
}
placeholders := make([]string, len(activeStatuses))
args := make([]any, 0, 1+len(activeStatuses))
args = append(args, gen.UserID)
for i, s := range activeStatuses {
placeholders[i] = fmt.Sprintf("$%d", i+2)
args = append(args, s)
}
countQuery := fmt.Sprintf(
`SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)`,
strings.Join(placeholders, ","),
)
var activeCount int64
if err := tx.QueryRowContext(ctx, countQuery, args...).Scan(&activeCount); err != nil {
return err
}
if activeCount >= maxActive {
return service.ErrSoraGenerationConcurrencyLimit
}
mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
if err := tx.QueryRowContext(ctx, `
INSERT INTO sora_generations (
user_id, api_key_id, model, prompt, media_type,
status, media_url, media_urls, file_size_bytes,
storage_type, s3_object_keys, upstream_task_id, error_message
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
RETURNING id, created_at
`,
gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType,
gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage,
).Scan(&gen.ID, &gen.CreatedAt); err != nil {
return err
}
return tx.Commit()
}
func (r *soraGenerationRepository) GetByID(ctx context.Context, id int64) (*service.SoraGeneration, error) {
gen := &service.SoraGeneration{}
var mediaURLsJSON, s3KeysJSON []byte
var completedAt sql.NullTime
var apiKeyID sql.NullInt64
err := r.sql.QueryRowContext(ctx, `
SELECT id, user_id, api_key_id, model, prompt, media_type,
status, media_url, media_urls, file_size_bytes,
storage_type, s3_object_keys, upstream_task_id, error_message,
created_at, completed_at
FROM sora_generations WHERE id = $1
`, id).Scan(
&gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType,
&gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes,
&gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage,
&gen.CreatedAt, &completedAt,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("生成记录不存在")
}
return nil, err
}
if apiKeyID.Valid {
gen.APIKeyID = &apiKeyID.Int64
}
if completedAt.Valid {
gen.CompletedAt = &completedAt.Time
}
_ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs)
_ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys)
return gen, nil
}
func (r *soraGenerationRepository) Update(ctx context.Context, gen *service.SoraGeneration) error {
mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
var completedAt *time.Time
if gen.CompletedAt != nil {
completedAt = gen.CompletedAt
}
_, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations SET
status = $2, media_url = $3, media_urls = $4, file_size_bytes = $5,
storage_type = $6, s3_object_keys = $7, upstream_task_id = $8,
error_message = $9, completed_at = $10
WHERE id = $1
`,
gen.ID, gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
gen.StorageType, s3KeysJSON, gen.UpstreamTaskID,
gen.ErrorMessage, completedAt,
)
return err
}
// UpdateGeneratingIfPending 仅当状态为 pending 时更新为 generating。
func (r *soraGenerationRepository) UpdateGeneratingIfPending(ctx context.Context, id int64, upstreamTaskID string) (bool, error) {
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations
SET status = $2, upstream_task_id = $3
WHERE id = $1 AND status = $4
`,
id, service.SoraGenStatusGenerating, upstreamTaskID, service.SoraGenStatusPending,
)
if err != nil {
return false, err
}
affected, err := result.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
// UpdateCompletedIfActive 仅当状态为 pending/generating 时更新为 completed。
func (r *soraGenerationRepository) UpdateCompletedIfActive(
ctx context.Context,
id int64,
mediaURL string,
mediaURLs []string,
storageType string,
s3Keys []string,
fileSizeBytes int64,
completedAt time.Time,
) (bool, error) {
mediaURLsJSON, _ := json.Marshal(mediaURLs)
s3KeysJSON, _ := json.Marshal(s3Keys)
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations
SET status = $2,
media_url = $3,
media_urls = $4,
file_size_bytes = $5,
storage_type = $6,
s3_object_keys = $7,
error_message = '',
completed_at = $8
WHERE id = $1 AND status IN ($9, $10)
`,
id, service.SoraGenStatusCompleted, mediaURL, mediaURLsJSON, fileSizeBytes,
storageType, s3KeysJSON, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
)
if err != nil {
return false, err
}
affected, err := result.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
// UpdateFailedIfActive 仅当状态为 pending/generating 时更新为 failed。
func (r *soraGenerationRepository) UpdateFailedIfActive(
ctx context.Context,
id int64,
errMsg string,
completedAt time.Time,
) (bool, error) {
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations
SET status = $2,
error_message = $3,
completed_at = $4
WHERE id = $1 AND status IN ($5, $6)
`,
id, service.SoraGenStatusFailed, errMsg, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
)
if err != nil {
return false, err
}
affected, err := result.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
// UpdateCancelledIfActive 仅当状态为 pending/generating 时更新为 cancelled。
func (r *soraGenerationRepository) UpdateCancelledIfActive(ctx context.Context, id int64, completedAt time.Time) (bool, error) {
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations
SET status = $2, completed_at = $3
WHERE id = $1 AND status IN ($4, $5)
`,
id, service.SoraGenStatusCancelled, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
)
if err != nil {
return false, err
}
affected, err := result.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
// UpdateStorageIfCompleted 更新已完成记录的存储信息(用于手动保存,不重置 completed_at)。
func (r *soraGenerationRepository) UpdateStorageIfCompleted(
ctx context.Context,
id int64,
mediaURL string,
mediaURLs []string,
storageType string,
s3Keys []string,
fileSizeBytes int64,
) (bool, error) {
mediaURLsJSON, _ := json.Marshal(mediaURLs)
s3KeysJSON, _ := json.Marshal(s3Keys)
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations
SET media_url = $2,
media_urls = $3,
file_size_bytes = $4,
storage_type = $5,
s3_object_keys = $6
WHERE id = $1 AND status = $7
`,
id, mediaURL, mediaURLsJSON, fileSizeBytes, storageType, s3KeysJSON, service.SoraGenStatusCompleted,
)
if err != nil {
return false, err
}
affected, err := result.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
func (r *soraGenerationRepository) Delete(ctx context.Context, id int64) error {
_, err := r.sql.ExecContext(ctx, `DELETE FROM sora_generations WHERE id = $1`, id)
return err
}
func (r *soraGenerationRepository) List(ctx context.Context, params service.SoraGenerationListParams) ([]*service.SoraGeneration, int64, error) {
// 构建 WHERE 条件
conditions := []string{"user_id = $1"}
args := []any{params.UserID}
argIdx := 2
if params.Status != "" {
// 支持逗号分隔的多状态
statuses := strings.Split(params.Status, ",")
placeholders := make([]string, len(statuses))
for i, s := range statuses {
placeholders[i] = fmt.Sprintf("$%d", argIdx)
args = append(args, strings.TrimSpace(s))
argIdx++
}
conditions = append(conditions, fmt.Sprintf("status IN (%s)", strings.Join(placeholders, ",")))
}
if params.StorageType != "" {
storageTypes := strings.Split(params.StorageType, ",")
placeholders := make([]string, len(storageTypes))
for i, s := range storageTypes {
placeholders[i] = fmt.Sprintf("$%d", argIdx)
args = append(args, strings.TrimSpace(s))
argIdx++
}
conditions = append(conditions, fmt.Sprintf("storage_type IN (%s)", strings.Join(placeholders, ",")))
}
if params.MediaType != "" {
conditions = append(conditions, fmt.Sprintf("media_type = $%d", argIdx))
args = append(args, params.MediaType)
argIdx++
}
whereClause := "WHERE " + strings.Join(conditions, " AND ")
// 计数
var total int64
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations %s", whereClause)
if err := r.sql.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
return nil, 0, err
}
// 分页查询
offset := (params.Page - 1) * params.PageSize
listQuery := fmt.Sprintf(`
SELECT id, user_id, api_key_id, model, prompt, media_type,
status, media_url, media_urls, file_size_bytes,
storage_type, s3_object_keys, upstream_task_id, error_message,
created_at, completed_at
FROM sora_generations %s
ORDER BY created_at DESC
LIMIT $%d OFFSET $%d
`, whereClause, argIdx, argIdx+1)
args = append(args, params.PageSize, offset)
rows, err := r.sql.QueryContext(ctx, listQuery, args...)
if err != nil {
return nil, 0, err
}
defer func() {
_ = rows.Close()
}()
var results []*service.SoraGeneration
for rows.Next() {
gen := &service.SoraGeneration{}
var mediaURLsJSON, s3KeysJSON []byte
var completedAt sql.NullTime
var apiKeyID sql.NullInt64
if err := rows.Scan(
&gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType,
&gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes,
&gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage,
&gen.CreatedAt, &completedAt,
); err != nil {
return nil, 0, err
}
if apiKeyID.Valid {
gen.APIKeyID = &apiKeyID.Int64
}
if completedAt.Valid {
gen.CompletedAt = &completedAt.Time
}
_ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs)
_ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys)
results = append(results, gen)
}
return results, total, rows.Err()
}
func (r *soraGenerationRepository) CountByUserAndStatus(ctx context.Context, userID int64, statuses []string) (int64, error) {
if len(statuses) == 0 {
return 0, nil
}
placeholders := make([]string, len(statuses))
args := []any{userID}
for i, s := range statuses {
placeholders[i] = fmt.Sprintf("$%d", i+2)
args = append(args, s)
}
var count int64
query := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)", strings.Join(placeholders, ","))
err := r.sql.QueryRowContext(ctx, query, args...).Scan(&count)
return count, err
}
......@@ -28,7 +28,7 @@ import (
gocache "github.com/patrickmn/go-cache"
)
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, created_at"
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, created_at"
// usageLogInsertArgTypes must stay in the same order as:
// 1. prepareUsageLogInsert().args
......@@ -73,7 +73,6 @@ var usageLogInsertArgTypes = [...]string{
"text", // ip_address
"integer", // image_count
"text", // image_size
"text", // media_type
"text", // service_tier
"text", // reasoning_effort
"text", // inbound_endpoint
......@@ -352,7 +351,6 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
ip_address,
image_count,
image_size,
media_type,
service_tier,
reasoning_effort,
inbound_endpoint,
......@@ -369,7 +367,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
$10, $11, $12, $13,
$14, $15, $16, $17,
$18, $19, $20, $21, $22, $23,
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
......@@ -790,7 +788,6 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
ip_address,
image_count,
image_size,
media_type,
service_tier,
reasoning_effort,
inbound_endpoint,
......@@ -803,7 +800,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
created_at
) AS (VALUES `)
args := make([]any, 0, len(keys)*47)
args := make([]any, 0, len(keys)*46)
argPos := 1
for idx, key := range keys {
if idx > 0 {
......@@ -867,7 +864,6 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
ip_address,
image_count,
image_size,
media_type,
service_tier,
reasoning_effort,
inbound_endpoint,
......@@ -915,7 +911,6 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
ip_address,
image_count,
image_size,
media_type,
service_tier,
reasoning_effort,
inbound_endpoint,
......@@ -1003,7 +998,6 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
ip_address,
image_count,
image_size,
media_type,
service_tier,
reasoning_effort,
inbound_endpoint,
......@@ -1016,7 +1010,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
created_at
) AS (VALUES `)
args := make([]any, 0, len(preparedList)*46)
args := make([]any, 0, len(preparedList)*45)
argPos := 1
for idx, prepared := range preparedList {
if idx > 0 {
......@@ -1077,7 +1071,6 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
ip_address,
image_count,
image_size,
media_type,
service_tier,
reasoning_effort,
inbound_endpoint,
......@@ -1125,7 +1118,6 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
ip_address,
image_count,
image_size,
media_type,
service_tier,
reasoning_effort,
inbound_endpoint,
......@@ -1181,7 +1173,6 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
ip_address,
image_count,
image_size,
media_type,
service_tier,
reasoning_effort,
inbound_endpoint,
......@@ -1198,7 +1189,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
$10, $11, $12, $13,
$14, $15, $16, $17,
$18, $19, $20, $21, $22, $23,
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
`, prepared.args...)
......@@ -1225,7 +1216,6 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
userAgent := nullString(log.UserAgent)
ipAddress := nullString(log.IPAddress)
imageSize := nullString(log.ImageSize)
mediaType := nullString(log.MediaType)
serviceTier := nullString(log.ServiceTier)
reasoningEffort := nullString(log.ReasoningEffort)
inboundEndpoint := nullString(log.InboundEndpoint)
......@@ -1286,7 +1276,6 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
ipAddress,
log.ImageCount,
imageSize,
mediaType,
serviceTier,
reasoningEffort,
inboundEndpoint,
......@@ -4051,7 +4040,6 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
ipAddress sql.NullString
imageCount int
imageSize sql.NullString
mediaType sql.NullString
serviceTier sql.NullString
reasoningEffort sql.NullString
inboundEndpoint sql.NullString
......@@ -4101,7 +4089,6 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&ipAddress,
&imageCount,
&imageSize,
&mediaType,
&serviceTier,
&reasoningEffort,
&inboundEndpoint,
......@@ -4179,9 +4166,6 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if imageSize.Valid {
log.ImageSize = &imageSize.String
}
if mediaType.Valid {
log.MediaType = &mediaType.String
}
if serviceTier.Valid {
log.ServiceTier = &serviceTier.String
}
......
......@@ -76,7 +76,6 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
sqlmock.AnyArg(), // ip_address
log.ImageCount,
sqlmock.AnyArg(), // image_size
sqlmock.AnyArg(), // media_type
sqlmock.AnyArg(), // service_tier
sqlmock.AnyArg(), // reasoning_effort
sqlmock.AnyArg(), // inbound_endpoint
......@@ -155,7 +154,6 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
sqlmock.AnyArg(),
log.ImageCount,
sqlmock.AnyArg(),
sqlmock.AnyArg(),
serviceTier,
sqlmock.AnyArg(),
sqlmock.AnyArg(),
......@@ -471,7 +469,6 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
0,
sql.NullString{},
sql.NullString{},
sql.NullString{Valid: true, String: "priority"},
sql.NullString{},
sql.NullString{},
......@@ -519,7 +516,6 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
0,
sql.NullString{},
sql.NullString{},
sql.NullString{Valid: true, String: "flex"},
sql.NullString{},
sql.NullString{},
......@@ -567,7 +563,6 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
0,
sql.NullString{},
sql.NullString{},
sql.NullString{Valid: true, String: "priority"},
sql.NullString{},
sql.NullString{},
......
......@@ -62,7 +62,6 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
SetBalance(userIn.Balance).
SetConcurrency(userIn.Concurrency).
SetStatus(userIn.Status).
SetSoraStorageQuotaBytes(userIn.SoraStorageQuotaBytes).
Save(ctx)
if err != nil {
return translatePersistenceError(err, nil, service.ErrEmailExists)
......@@ -145,8 +144,6 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
SetBalance(userIn.Balance).
SetConcurrency(userIn.Concurrency).
SetStatus(userIn.Status).
SetSoraStorageQuotaBytes(userIn.SoraStorageQuotaBytes).
SetSoraStorageUsedBytes(userIn.SoraStorageUsedBytes).
Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
......@@ -376,65 +373,6 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount
return nil
}
// AddSoraStorageUsageWithQuota 原子累加 Sora 存储用量,并在有配额时校验不超额。
func (r *userRepository) AddSoraStorageUsageWithQuota(ctx context.Context, userID int64, deltaBytes int64, effectiveQuota int64) (int64, error) {
if deltaBytes <= 0 {
user, err := r.GetByID(ctx, userID)
if err != nil {
return 0, err
}
return user.SoraStorageUsedBytes, nil
}
var newUsed int64
err := scanSingleRow(ctx, r.sql, `
UPDATE users
SET sora_storage_used_bytes = sora_storage_used_bytes + $2
WHERE id = $1
AND ($3 = 0 OR sora_storage_used_bytes + $2 <= $3)
RETURNING sora_storage_used_bytes
`, []any{userID, deltaBytes, effectiveQuota}, &newUsed)
if err == nil {
return newUsed, nil
}
if errors.Is(err, sql.ErrNoRows) {
// 区分用户不存在和配额冲突
exists, existsErr := r.client.User.Query().Where(dbuser.IDEQ(userID)).Exist(ctx)
if existsErr != nil {
return 0, existsErr
}
if !exists {
return 0, service.ErrUserNotFound
}
return 0, service.ErrSoraStorageQuotaExceeded
}
return 0, err
}
// ReleaseSoraStorageUsageAtomic 原子释放 Sora 存储用量,并保证不低于 0。
func (r *userRepository) ReleaseSoraStorageUsageAtomic(ctx context.Context, userID int64, deltaBytes int64) (int64, error) {
if deltaBytes <= 0 {
user, err := r.GetByID(ctx, userID)
if err != nil {
return 0, err
}
return user.SoraStorageUsedBytes, nil
}
var newUsed int64
err := scanSingleRow(ctx, r.sql, `
UPDATE users
SET sora_storage_used_bytes = GREATEST(sora_storage_used_bytes - $2, 0)
WHERE id = $1
RETURNING sora_storage_used_bytes
`, []any{userID, deltaBytes}, &newUsed)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return 0, service.ErrUserNotFound
}
return 0, err
}
return newUsed, nil
}
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx)
}
......
......@@ -53,7 +53,6 @@ var ProviderSet = wire.NewSet(
NewAPIKeyRepository,
NewGroupRepository,
NewAccountRepository,
NewSoraAccountRepository, // Sora 账号扩展表仓储
NewScheduledTestPlanRepository, // 定时测试计划仓储
NewScheduledTestResultRepository, // 定时测试结果仓储
NewProxyRepository,
......
......@@ -204,11 +204,6 @@ func TestAPIContracts(t *testing.T) {
"image_price_1k": null,
"image_price_2k": null,
"image_price_4k": null,
"sora_image_price_360": null,
"sora_image_price_540": null,
"sora_storage_quota_bytes": 0,
"sora_video_price_per_request": null,
"sora_video_price_per_request_hd": null,
"claude_code_only": false,
"allow_messages_dispatch": false,
"fallback_group_id": null,
......@@ -532,7 +527,6 @@ func TestAPIContracts(t *testing.T) {
"fallback_model_openai": "gpt-4o",
"enable_identity_patch": true,
"identity_patch_prompt": "",
"sora_client_enabled": false,
"invitation_code_enabled": false,
"home_content": "",
"hide_ccs_import_button": false,
......@@ -653,11 +647,11 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil, nil)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil)
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
jwtAuth := func(c *gin.Context) {
......
......@@ -94,7 +94,6 @@ func isAPIRoutePath(c *gin.Context) bool {
return strings.HasPrefix(path, "/v1/") ||
strings.HasPrefix(path, "/v1beta/") ||
strings.HasPrefix(path, "/antigravity/") ||
strings.HasPrefix(path, "/sora/") ||
strings.HasPrefix(path, "/responses")
}
......
......@@ -109,7 +109,6 @@ func registerRoutes(
// 注册各模块路由
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient, settingService)
routes.RegisterUserRoutes(v1, h, jwtAuth, settingService)
routes.RegisterSoraClientRoutes(v1, h, jwtAuth, settingService)
routes.RegisterAdminRoutes(v1, h, adminAuth)
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg)
}
......@@ -34,8 +34,6 @@ func RegisterAdminRoutes(
// OpenAI OAuth
registerOpenAIOAuthRoutes(admin, h)
// Sora OAuth(实现复用 OpenAI OAuth 服务,入口独立)
registerSoraOAuthRoutes(admin, h)
// Gemini OAuth
registerGeminiOAuthRoutes(admin, h)
......@@ -321,19 +319,6 @@ func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
}
}
func registerSoraOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
sora := admin.Group("/sora")
{
sora.POST("/generate-auth-url", h.Admin.OpenAIOAuth.GenerateAuthURL)
sora.POST("/exchange-code", h.Admin.OpenAIOAuth.ExchangeCode)
sora.POST("/refresh-token", h.Admin.OpenAIOAuth.RefreshToken)
sora.POST("/st2at", h.Admin.OpenAIOAuth.ExchangeSoraSessionToken)
sora.POST("/rt2at", h.Admin.OpenAIOAuth.RefreshToken)
sora.POST("/accounts/:id/refresh", h.Admin.OpenAIOAuth.RefreshAccountToken)
sora.POST("/create-from-oauth", h.Admin.OpenAIOAuth.CreateAccountFromOAuth)
}
}
func registerGeminiOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
gemini := admin.Group("/gemini")
{
......@@ -422,15 +407,6 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
// Beta 策略配置
adminSettings.GET("/beta-policy", h.Admin.Setting.GetBetaPolicySettings)
adminSettings.PUT("/beta-policy", h.Admin.Setting.UpdateBetaPolicySettings)
// Sora S3 存储配置
adminSettings.GET("/sora-s3", h.Admin.Setting.GetSoraS3Settings)
adminSettings.PUT("/sora-s3", h.Admin.Setting.UpdateSoraS3Settings)
adminSettings.POST("/sora-s3/test", h.Admin.Setting.TestSoraS3Connection)
adminSettings.GET("/sora-s3/profiles", h.Admin.Setting.ListSoraS3Profiles)
adminSettings.POST("/sora-s3/profiles", h.Admin.Setting.CreateSoraS3Profile)
adminSettings.PUT("/sora-s3/profiles/:profile_id", h.Admin.Setting.UpdateSoraS3Profile)
adminSettings.DELETE("/sora-s3/profiles/:profile_id", h.Admin.Setting.DeleteSoraS3Profile)
adminSettings.POST("/sora-s3/profiles/:profile_id/activate", h.Admin.Setting.SetActiveSoraS3Profile)
}
}
......
......@@ -23,11 +23,6 @@ func RegisterGatewayRoutes(
cfg *config.Config,
) {
bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize)
soraMaxBodySize := cfg.Gateway.SoraMaxBodySize
if soraMaxBodySize <= 0 {
soraMaxBodySize = cfg.Gateway.MaxBodySize
}
soraBodyLimit := middleware.RequestBodyLimit(soraMaxBodySize)
clientRequestID := middleware.ClientRequestID()
opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService)
endpointNorm := handler.InboundEndpointMiddleware()
......@@ -163,28 +158,6 @@ func RegisterGatewayRoutes(
antigravityV1Beta.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels)
}
// Sora 专用路由(强制使用 sora 平台)
soraV1 := r.Group("/sora/v1")
soraV1.Use(soraBodyLimit)
soraV1.Use(clientRequestID)
soraV1.Use(opsErrorLogger)
soraV1.Use(endpointNorm)
soraV1.Use(middleware.ForcePlatform(service.PlatformSora))
soraV1.Use(gin.HandlerFunc(apiKeyAuth))
soraV1.Use(requireGroupAnthropic)
{
soraV1.POST("/chat/completions", h.SoraGateway.ChatCompletions)
soraV1.GET("/models", h.Gateway.Models)
}
// Sora 媒体代理(可选 API Key 验证)
if cfg.Gateway.SoraMediaRequireAPIKey {
r.GET("/sora/media/*filepath", gin.HandlerFunc(apiKeyAuth), h.SoraGateway.MediaProxy)
} else {
r.GET("/sora/media/*filepath", h.SoraGateway.MediaProxy)
}
// Sora 媒体代理(签名 URL,无需 API Key)
r.GET("/sora/media-signed/*filepath", h.SoraGateway.MediaProxySigned)
}
// getGroupPlatform extracts the group platform from the API Key stored in context.
......
......@@ -22,7 +22,6 @@ func newGatewayRoutesTestRouter() *gin.Engine {
&handler.Handlers{
Gateway: &handler.GatewayHandler{},
OpenAIGateway: &handler.OpenAIGatewayHandler{},
SoraGateway: &handler.SoraGatewayHandler{},
},
servermiddleware.APIKeyAuthMiddleware(func(c *gin.Context) {
c.Next()
......
package routes
import (
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// RegisterSoraClientRoutes 注册 Sora 客户端 API 路由(需要用户认证)。
func RegisterSoraClientRoutes(
v1 *gin.RouterGroup,
h *handler.Handlers,
jwtAuth middleware.JWTAuthMiddleware,
settingService *service.SettingService,
) {
if h.SoraClient == nil {
return
}
authenticated := v1.Group("/sora")
authenticated.Use(gin.HandlerFunc(jwtAuth))
authenticated.Use(middleware.BackendModeUserGuard(settingService))
{
authenticated.POST("/generate", h.SoraClient.Generate)
authenticated.GET("/generations", h.SoraClient.ListGenerations)
authenticated.GET("/generations/:id", h.SoraClient.GetGeneration)
authenticated.DELETE("/generations/:id", h.SoraClient.DeleteGeneration)
authenticated.POST("/generations/:id/cancel", h.SoraClient.CancelGeneration)
authenticated.POST("/generations/:id/save", h.SoraClient.SaveToStorage)
authenticated.GET("/quota", h.SoraClient.GetQuota)
authenticated.GET("/models", h.SoraClient.GetModels)
authenticated.GET("/storage-status", h.SoraClient.GetStorageStatus)
}
}
......@@ -28,8 +28,7 @@ type AccountRepository interface {
// GetByCRSAccountID finds an account previously synced from CRS.
// Returns (nil, nil) if not found.
GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error)
// FindByExtraField 根据 extra 字段中的键值对查找账号(限定 platform='sora')
// 用于查找通过 linked_openai_account_id 关联的 Sora 账号
// FindByExtraField 根据 extra 字段中的键值对查找账号
FindByExtraField(ctx context.Context, key string, value any) ([]Account, error)
// ListCRSAccountIDs returns a map of crs_account_id -> local account ID
// for all accounts that have been synced from CRS.
......
......@@ -13,18 +13,14 @@ import (
"log"
"net/http"
"net/http/httptest"
"net/url"
"regexp"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
......@@ -37,11 +33,6 @@ var sseDataPrefix = regexp.MustCompile(`^data:\s*`)
const (
testClaudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses"
soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接
soraBillingAPIURL = "https://sora.chatgpt.com/backend/billing/subscriptions"
soraInviteMineURL = "https://sora.chatgpt.com/backend/project_y/invite/mine"
soraBootstrapURL = "https://sora.chatgpt.com/backend/m/bootstrap"
soraRemainingURL = "https://sora.chatgpt.com/backend/nf/check"
)
// TestEvent represents a SSE event for account testing
......@@ -71,13 +62,8 @@ type AccountTestService struct {
httpUpstream HTTPUpstream
cfg *config.Config
tlsFPProfileService *TLSFingerprintProfileService
soraTestGuardMu sync.Mutex
soraTestLastRun map[int64]time.Time
soraTestCooldown time.Duration
}
const defaultSoraTestCooldown = 10 * time.Second
// NewAccountTestService creates a new AccountTestService
func NewAccountTestService(
accountRepo AccountRepository,
......@@ -94,8 +80,6 @@ func NewAccountTestService(
httpUpstream: httpUpstream,
cfg: cfg,
tlsFPProfileService: tlsFPProfileService,
soraTestLastRun: make(map[int64]time.Time),
soraTestCooldown: defaultSoraTestCooldown,
}
}
......@@ -197,10 +181,6 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
return s.routeAntigravityTest(c, account, modelID, prompt)
}
if account.Platform == PlatformSora {
return s.testSoraAccountConnection(c, account)
}
return s.testClaudeAccountConnection(c, account, modelID)
}
......@@ -634,698 +614,6 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
return s.processGeminiStream(c, resp.Body)
}
type soraProbeStep struct {
Name string `json:"name"`
Status string `json:"status"`
HTTPStatus int `json:"http_status,omitempty"`
ErrorCode string `json:"error_code,omitempty"`
Message string `json:"message,omitempty"`
}
type soraProbeSummary struct {
Status string `json:"status"`
Steps []soraProbeStep `json:"steps"`
}
type soraProbeRecorder struct {
steps []soraProbeStep
}
func (r *soraProbeRecorder) addStep(name, status string, httpStatus int, errorCode, message string) {
r.steps = append(r.steps, soraProbeStep{
Name: name,
Status: status,
HTTPStatus: httpStatus,
ErrorCode: strings.TrimSpace(errorCode),
Message: strings.TrimSpace(message),
})
}
func (r *soraProbeRecorder) finalize() soraProbeSummary {
meSuccess := false
partial := false
for _, step := range r.steps {
if step.Name == "me" {
meSuccess = strings.EqualFold(step.Status, "success")
continue
}
if strings.EqualFold(step.Status, "failed") {
partial = true
}
}
status := "success"
if !meSuccess {
status = "failed"
} else if partial {
status = "partial_success"
}
return soraProbeSummary{
Status: status,
Steps: append([]soraProbeStep(nil), r.steps...),
}
}
func (s *AccountTestService) emitSoraProbeSummary(c *gin.Context, rec *soraProbeRecorder) {
if rec == nil {
return
}
summary := rec.finalize()
code := ""
for _, step := range summary.Steps {
if strings.EqualFold(step.Status, "failed") && strings.TrimSpace(step.ErrorCode) != "" {
code = step.ErrorCode
break
}
}
s.sendEvent(c, TestEvent{
Type: "sora_test_result",
Status: summary.Status,
Code: code,
Data: summary,
})
}
func (s *AccountTestService) acquireSoraTestPermit(accountID int64) (time.Duration, bool) {
if accountID <= 0 {
return 0, true
}
s.soraTestGuardMu.Lock()
defer s.soraTestGuardMu.Unlock()
if s.soraTestLastRun == nil {
s.soraTestLastRun = make(map[int64]time.Time)
}
cooldown := s.soraTestCooldown
if cooldown <= 0 {
cooldown = defaultSoraTestCooldown
}
now := time.Now()
if lastRun, ok := s.soraTestLastRun[accountID]; ok {
elapsed := now.Sub(lastRun)
if elapsed < cooldown {
return cooldown - elapsed, false
}
}
s.soraTestLastRun[accountID] = now
return 0, true
}
func ceilSeconds(d time.Duration) int {
if d <= 0 {
return 1
}
sec := int(d / time.Second)
if d%time.Second != 0 {
sec++
}
if sec < 1 {
sec = 1
}
return sec
}
// testSoraAPIKeyAccountConnection 测试 Sora apikey 类型账号的连通性。
// 向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性和 API Key 有效性。
func (s *AccountTestService) testSoraAPIKeyAccountConnection(c *gin.Context, account *Account) error {
ctx := c.Request.Context()
apiKey := account.GetCredential("api_key")
if apiKey == "" {
return s.sendErrorAndEnd(c, "Sora apikey 账号缺少 api_key 凭证")
}
baseURL := account.GetBaseURL()
if baseURL == "" {
return s.sendErrorAndEnd(c, "Sora apikey 账号缺少 base_url")
}
// 验证 base_url 格式
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("base_url 无效: %s", err.Error()))
}
upstreamURL := strings.TrimSuffix(normalizedBaseURL, "/") + "/sora/v1/chat/completions"
// 设置 SSE 头
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Flush()
if wait, ok := s.acquireSoraTestPermit(account.ID); !ok {
msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait))
return s.sendErrorAndEnd(c, msg)
}
s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora-upstream"})
// 构建轻量级 prompt-enhance 请求作为连通性测试
testPayload := map[string]any{
"model": "prompt-enhance-short-10s",
"messages": []map[string]string{{"role": "user", "content": "test"}},
"stream": false,
}
payloadBytes, _ := json.Marshal(testPayload)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(payloadBytes))
if err != nil {
return s.sendErrorAndEnd(c, "构建测试请求失败")
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
// 获取代理 URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("上游连接失败: %s", err.Error()))
}
defer func() { _ = resp.Body.Close() }()
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 64*1024))
if resp.StatusCode == http.StatusOK {
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("上游连接成功 (%s)", upstreamURL)})
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("API Key 有效 (HTTP %d)", resp.StatusCode)})
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
return s.sendErrorAndEnd(c, fmt.Sprintf("上游认证失败 (HTTP %d),请检查 API Key 是否正确", resp.StatusCode))
}
// 其他错误但能连通(如 400 参数错误)也算连通性测试通过
if resp.StatusCode == http.StatusBadRequest {
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("上游连接成功 (%s)", upstreamURL)})
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("API Key 有效(上游返回 %d,参数校验错误属正常)", resp.StatusCode)})
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
return s.sendErrorAndEnd(c, fmt.Sprintf("上游返回异常 HTTP %d: %s", resp.StatusCode, truncateSoraErrorBody(respBody, 256)))
}
// testSoraAccountConnection 测试 Sora 账号的连接
// OAuth 类型:调用 /backend/me 接口验证 access_token 有效性
// APIKey 类型:向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性
func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *Account) error {
// apikey 类型走独立测试流程
if account.Type == AccountTypeAPIKey {
return s.testSoraAPIKeyAccountConnection(c, account)
}
ctx := c.Request.Context()
recorder := &soraProbeRecorder{}
authToken := account.GetCredential("access_token")
if authToken == "" {
recorder.addStep("me", "failed", http.StatusUnauthorized, "missing_access_token", "No access token available")
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, "No access token available")
}
// Set SSE headers
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Flush()
if wait, ok := s.acquireSoraTestPermit(account.ID); !ok {
msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait))
recorder.addStep("rate_limit", "failed", http.StatusTooManyRequests, "test_rate_limited", msg)
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, msg)
}
// Send test_start event
s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora"})
req, err := http.NewRequestWithContext(ctx, "GET", soraMeAPIURL, nil)
if err != nil {
recorder.addStep("me", "failed", 0, "request_build_failed", err.Error())
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, "Failed to create request")
}
// 使用 Sora 客户端标准请求头
req.Header.Set("Authorization", "Bearer "+authToken)
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
req.Header.Set("Accept", "application/json")
req.Header.Set("Accept-Language", "en-US,en;q=0.9")
req.Header.Set("Origin", "https://sora.chatgpt.com")
req.Header.Set("Referer", "https://sora.chatgpt.com/")
// Get proxy URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
soraTLSProfile := s.resolveSoraTLSProfile()
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, soraTLSProfile)
if err != nil {
recorder.addStep("me", "failed", 0, "network_error", err.Error())
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
if isCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) {
recorder.addStep("me", "failed", resp.StatusCode, "cf_challenge", "Cloudflare challenge detected")
s.emitSoraProbeSummary(c, recorder)
s.logSoraCloudflareChallenge(account, proxyURL, soraMeAPIURL, resp.Header, body)
return s.sendErrorAndEnd(c, formatCloudflareChallengeMessage(fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", resp.StatusCode), resp.Header, body))
}
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(body)
switch {
case resp.StatusCode == http.StatusUnauthorized && strings.EqualFold(upstreamCode, "token_invalidated"):
recorder.addStep("me", "failed", resp.StatusCode, "token_invalidated", "Sora token invalidated")
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, "Sora token 已失效(token_invalidated),请重新授权账号")
case strings.EqualFold(upstreamCode, "unsupported_country_code"):
recorder.addStep("me", "failed", resp.StatusCode, "unsupported_country_code", "Sora is unavailable in current egress region")
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, "Sora 在当前网络出口地区不可用(unsupported_country_code),请切换到支持地区后重试")
case strings.TrimSpace(upstreamMessage) != "":
recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, upstreamMessage)
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, upstreamMessage))
default:
recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, "Sora me endpoint failed")
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, truncateSoraErrorBody(body, 512)))
}
}
recorder.addStep("me", "success", resp.StatusCode, "", "me endpoint ok")
// 解析 /me 响应,提取用户信息
var meResp map[string]any
if err := json.Unmarshal(body, &meResp); err != nil {
// 能收到 200 就说明 token 有效
s.sendEvent(c, TestEvent{Type: "content", Text: "Sora connection OK (token valid)"})
} else {
// 尝试提取用户名或邮箱信息
info := "Sora connection OK"
if name, ok := meResp["name"].(string); ok && name != "" {
info = fmt.Sprintf("Sora connection OK - User: %s", name)
} else if email, ok := meResp["email"].(string); ok && email != "" {
info = fmt.Sprintf("Sora connection OK - Email: %s", email)
}
s.sendEvent(c, TestEvent{Type: "content", Text: info})
}
// 追加轻量能力检查:订阅信息查询(失败仅告警,不中断连接测试)
subReq, err := http.NewRequestWithContext(ctx, "GET", soraBillingAPIURL, nil)
if err == nil {
subReq.Header.Set("Authorization", "Bearer "+authToken)
subReq.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
subReq.Header.Set("Accept", "application/json")
subReq.Header.Set("Accept-Language", "en-US,en;q=0.9")
subReq.Header.Set("Origin", "https://sora.chatgpt.com")
subReq.Header.Set("Referer", "https://sora.chatgpt.com/")
subResp, subErr := s.httpUpstream.DoWithTLS(subReq, proxyURL, account.ID, account.Concurrency, soraTLSProfile)
if subErr != nil {
recorder.addStep("subscription", "failed", 0, "network_error", subErr.Error())
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check skipped: %s", subErr.Error())})
} else {
subBody, _ := io.ReadAll(subResp.Body)
_ = subResp.Body.Close()
if subResp.StatusCode == http.StatusOK {
recorder.addStep("subscription", "success", subResp.StatusCode, "", "subscription endpoint ok")
if summary := parseSoraSubscriptionSummary(subBody); summary != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: summary})
} else {
s.sendEvent(c, TestEvent{Type: "content", Text: "Subscription check OK"})
}
} else {
if isCloudflareChallengeResponse(subResp.StatusCode, subResp.Header, subBody) {
recorder.addStep("subscription", "failed", subResp.StatusCode, "cf_challenge", "Cloudflare challenge detected")
s.logSoraCloudflareChallenge(account, proxyURL, soraBillingAPIURL, subResp.Header, subBody)
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Subscription check blocked by Cloudflare challenge (HTTP %d)", subResp.StatusCode), subResp.Header, subBody)})
} else {
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(subBody)
recorder.addStep("subscription", "failed", subResp.StatusCode, upstreamCode, upstreamMessage)
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check returned %d", subResp.StatusCode)})
}
}
}
}
// 追加 Sora2 能力探测(对齐 sora2api 的测试思路):邀请码 + 剩余额度。
s.testSora2Capabilities(c, ctx, account, authToken, proxyURL, soraTLSProfile, recorder)
s.emitSoraProbeSummary(c, recorder)
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
func (s *AccountTestService) testSora2Capabilities(
c *gin.Context,
ctx context.Context,
account *Account,
authToken string,
proxyURL string,
tlsProfile *tlsfingerprint.Profile,
recorder *soraProbeRecorder,
) {
inviteStatus, inviteHeader, inviteBody, err := s.fetchSoraTestEndpoint(
ctx,
account,
authToken,
soraInviteMineURL,
proxyURL,
tlsProfile,
)
if err != nil {
if recorder != nil {
recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error())
}
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check skipped: %s", err.Error())})
return
}
if inviteStatus == http.StatusUnauthorized {
bootstrapStatus, _, _, bootstrapErr := s.fetchSoraTestEndpoint(
ctx,
account,
authToken,
soraBootstrapURL,
proxyURL,
tlsProfile,
)
if bootstrapErr == nil && bootstrapStatus == http.StatusOK {
if recorder != nil {
recorder.addStep("sora2_bootstrap", "success", bootstrapStatus, "", "bootstrap endpoint ok")
}
s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 bootstrap OK, retry invite check"})
inviteStatus, inviteHeader, inviteBody, err = s.fetchSoraTestEndpoint(
ctx,
account,
authToken,
soraInviteMineURL,
proxyURL,
tlsProfile,
)
if err != nil {
if recorder != nil {
recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error())
}
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite retry failed: %s", err.Error())})
return
}
} else if recorder != nil {
code := ""
msg := ""
if bootstrapErr != nil {
code = "network_error"
msg = bootstrapErr.Error()
}
recorder.addStep("sora2_bootstrap", "failed", bootstrapStatus, code, msg)
}
}
if inviteStatus != http.StatusOK {
if isCloudflareChallengeResponse(inviteStatus, inviteHeader, inviteBody) {
if recorder != nil {
recorder.addStep("sora2_invite", "failed", inviteStatus, "cf_challenge", "Cloudflare challenge detected")
}
s.logSoraCloudflareChallenge(account, proxyURL, soraInviteMineURL, inviteHeader, inviteBody)
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 invite check blocked by Cloudflare challenge (HTTP %d)", inviteStatus), inviteHeader, inviteBody)})
return
}
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(inviteBody)
if recorder != nil {
recorder.addStep("sora2_invite", "failed", inviteStatus, upstreamCode, upstreamMessage)
}
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check returned %d", inviteStatus)})
return
}
if recorder != nil {
recorder.addStep("sora2_invite", "success", inviteStatus, "", "invite endpoint ok")
}
if summary := parseSoraInviteSummary(inviteBody); summary != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: summary})
} else {
s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 invite check OK"})
}
remainingStatus, remainingHeader, remainingBody, remainingErr := s.fetchSoraTestEndpoint(
ctx,
account,
authToken,
soraRemainingURL,
proxyURL,
tlsProfile,
)
if remainingErr != nil {
if recorder != nil {
recorder.addStep("sora2_remaining", "failed", 0, "network_error", remainingErr.Error())
}
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check skipped: %s", remainingErr.Error())})
return
}
if remainingStatus != http.StatusOK {
if isCloudflareChallengeResponse(remainingStatus, remainingHeader, remainingBody) {
if recorder != nil {
recorder.addStep("sora2_remaining", "failed", remainingStatus, "cf_challenge", "Cloudflare challenge detected")
}
s.logSoraCloudflareChallenge(account, proxyURL, soraRemainingURL, remainingHeader, remainingBody)
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 remaining check blocked by Cloudflare challenge (HTTP %d)", remainingStatus), remainingHeader, remainingBody)})
return
}
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(remainingBody)
if recorder != nil {
recorder.addStep("sora2_remaining", "failed", remainingStatus, upstreamCode, upstreamMessage)
}
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check returned %d", remainingStatus)})
return
}
if recorder != nil {
recorder.addStep("sora2_remaining", "success", remainingStatus, "", "remaining endpoint ok")
}
if summary := parseSoraRemainingSummary(remainingBody); summary != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: summary})
} else {
s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 remaining check OK"})
}
}
func (s *AccountTestService) fetchSoraTestEndpoint(
ctx context.Context,
account *Account,
authToken string,
url string,
proxyURL string,
tlsProfile *tlsfingerprint.Profile,
) (int, http.Header, []byte, error) {
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return 0, nil, nil, err
}
req.Header.Set("Authorization", "Bearer "+authToken)
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
req.Header.Set("Accept", "application/json")
req.Header.Set("Accept-Language", "en-US,en;q=0.9")
req.Header.Set("Origin", "https://sora.chatgpt.com")
req.Header.Set("Referer", "https://sora.chatgpt.com/")
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, tlsProfile)
if err != nil {
return 0, nil, nil, err
}
defer func() { _ = resp.Body.Close() }()
body, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return resp.StatusCode, resp.Header, nil, readErr
}
return resp.StatusCode, resp.Header, body, nil
}
func parseSoraSubscriptionSummary(body []byte) string {
var subResp struct {
Data []struct {
Plan struct {
ID string `json:"id"`
Title string `json:"title"`
} `json:"plan"`
EndTS string `json:"end_ts"`
} `json:"data"`
}
if err := json.Unmarshal(body, &subResp); err != nil {
return ""
}
if len(subResp.Data) == 0 {
return ""
}
first := subResp.Data[0]
parts := make([]string, 0, 3)
if first.Plan.Title != "" {
parts = append(parts, first.Plan.Title)
}
if first.Plan.ID != "" {
parts = append(parts, first.Plan.ID)
}
if first.EndTS != "" {
parts = append(parts, "end="+first.EndTS)
}
if len(parts) == 0 {
return ""
}
return "Subscription: " + strings.Join(parts, " | ")
}
func parseSoraInviteSummary(body []byte) string {
var inviteResp struct {
InviteCode string `json:"invite_code"`
RedeemedCount int64 `json:"redeemed_count"`
TotalCount int64 `json:"total_count"`
}
if err := json.Unmarshal(body, &inviteResp); err != nil {
return ""
}
parts := []string{"Sora2: supported"}
if inviteResp.InviteCode != "" {
parts = append(parts, "invite="+inviteResp.InviteCode)
}
if inviteResp.TotalCount > 0 {
parts = append(parts, fmt.Sprintf("used=%d/%d", inviteResp.RedeemedCount, inviteResp.TotalCount))
}
return strings.Join(parts, " | ")
}
func parseSoraRemainingSummary(body []byte) string {
var remainingResp struct {
RateLimitAndCreditBalance struct {
EstimatedNumVideosRemaining int64 `json:"estimated_num_videos_remaining"`
RateLimitReached bool `json:"rate_limit_reached"`
AccessResetsInSeconds int64 `json:"access_resets_in_seconds"`
} `json:"rate_limit_and_credit_balance"`
}
if err := json.Unmarshal(body, &remainingResp); err != nil {
return ""
}
info := remainingResp.RateLimitAndCreditBalance
parts := []string{fmt.Sprintf("Sora2 remaining: %d", info.EstimatedNumVideosRemaining)}
if info.RateLimitReached {
parts = append(parts, "rate_limited=true")
}
if info.AccessResetsInSeconds > 0 {
parts = append(parts, fmt.Sprintf("reset_in=%ds", info.AccessResetsInSeconds))
}
return strings.Join(parts, " | ")
}
func (s *AccountTestService) resolveSoraTLSProfile() *tlsfingerprint.Profile {
if s == nil || s.cfg == nil || !s.cfg.Sora.Client.DisableTLSFingerprint {
// Sora TLS fingerprint enabled — use built-in default profile
return &tlsfingerprint.Profile{Name: "Built-in Default (Sora)"}
}
return nil // disabled
}
func isCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body)
}
func formatCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
return soraerror.FormatCloudflareChallengeMessage(base, headers, body)
}
func extractCloudflareRayID(headers http.Header, body []byte) string {
return soraerror.ExtractCloudflareRayID(headers, body)
}
func extractSoraEgressIPHint(headers http.Header) string {
if headers == nil {
return "unknown"
}
candidates := []string{
"x-openai-public-ip",
"x-envoy-external-address",
"cf-connecting-ip",
"x-forwarded-for",
}
for _, key := range candidates {
if value := strings.TrimSpace(headers.Get(key)); value != "" {
return value
}
}
return "unknown"
}
func sanitizeProxyURLForLog(raw string) string {
raw = strings.TrimSpace(raw)
if raw == "" {
return ""
}
u, err := url.Parse(raw)
if err != nil {
return "<invalid_proxy_url>"
}
if u.User != nil {
u.User = nil
}
return u.String()
}
func endpointPathForLog(endpoint string) string {
parsed, err := url.Parse(strings.TrimSpace(endpoint))
if err != nil || parsed.Path == "" {
return endpoint
}
return parsed.Path
}
func (s *AccountTestService) logSoraCloudflareChallenge(account *Account, proxyURL, endpoint string, headers http.Header, body []byte) {
accountID := int64(0)
platform := ""
proxyID := "none"
if account != nil {
accountID = account.ID
platform = account.Platform
if account.ProxyID != nil {
proxyID = fmt.Sprintf("%d", *account.ProxyID)
}
}
cfRay := extractCloudflareRayID(headers, body)
if cfRay == "" {
cfRay = "unknown"
}
log.Printf(
"[SoraCFChallenge] account_id=%d platform=%s endpoint=%s path=%s proxy_id=%s proxy_url=%s cf_ray=%s egress_ip_hint=%s",
accountID,
platform,
endpoint,
endpointPathForLog(endpoint),
proxyID,
sanitizeProxyURLForLog(proxyURL),
cfRay,
extractSoraEgressIPHint(headers),
)
}
func truncateSoraErrorBody(body []byte, max int) string {
return soraerror.TruncateBody(body, max)
}
// routeAntigravityTest 路由 Antigravity 账号的测试请求。
// APIKey 类型走原生协议(与 gateway_handler 路由一致),OAuth/Upstream 走 CRS 中转。
func (s *AccountTestService) routeAntigravityTest(c *gin.Context, account *Account, modelID string, prompt string) error {
......
......@@ -42,7 +42,7 @@ func TestProcessGeminiStream_EmitsImageEvent(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
ctx, recorder := newSoraTestContext()
ctx, recorder := newTestContext()
svc := &AccountTestService{}
stream := strings.NewReader("data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"},{\"inlineData\":{\"mimeType\":\"image/png\",\"data\":\"QUJD\"}}]}}]}\n\ndata: [DONE]\n\n")
......
......@@ -4,16 +4,61 @@ package service
import (
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
)
// --- shared test helpers ---
type queuedHTTPUpstream struct {
responses []*http.Response
requests []*http.Request
tlsFlags []bool
}
func (u *queuedHTTPUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
return nil, fmt.Errorf("unexpected Do call")
}
func (u *queuedHTTPUpstream) DoWithTLS(req *http.Request, _ string, _ int64, _ int, profile *tlsfingerprint.Profile) (*http.Response, error) {
u.requests = append(u.requests, req)
u.tlsFlags = append(u.tlsFlags, profile != nil)
if len(u.responses) == 0 {
return nil, fmt.Errorf("no mocked response")
}
resp := u.responses[0]
u.responses = u.responses[1:]
return resp, nil
}
func newJSONResponse(status int, body string) *http.Response {
return &http.Response{
StatusCode: status,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(body)),
}
}
// --- test functions ---
func newTestContext() (*gin.Context, *httptest.ResponseRecorder) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil)
return c, rec
}
type openAIAccountTestRepo struct {
mockAccountRepoForGemini
updatedExtra map[string]any
......@@ -34,7 +79,7 @@ func (r *openAIAccountTestRepo) SetRateLimited(_ context.Context, id int64, rese
func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, recorder := newSoraTestContext()
ctx, recorder := newTestContext()
resp := newJSONResponse(http.StatusOK, "")
resp.Body = io.NopCloser(strings.NewReader(`data: {"type":"response.completed"}
......@@ -68,7 +113,7 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.
func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimit(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newSoraTestContext()
ctx, _ := newTestContext()
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`)
resp.Header.Set("x-codex-primary-used-percent", "100")
......
package service
import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type queuedHTTPUpstream struct {
responses []*http.Response
requests []*http.Request
tlsFlags []bool
}
func (u *queuedHTTPUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
return nil, fmt.Errorf("unexpected Do call")
}
func (u *queuedHTTPUpstream) DoWithTLS(req *http.Request, _ string, _ int64, _ int, profile *tlsfingerprint.Profile) (*http.Response, error) {
u.requests = append(u.requests, req)
u.tlsFlags = append(u.tlsFlags, profile != nil)
if len(u.responses) == 0 {
return nil, fmt.Errorf("no mocked response")
}
resp := u.responses[0]
u.responses = u.responses[1:]
return resp, nil
}
func newJSONResponse(status int, body string) *http.Response {
return &http.Response{
StatusCode: status,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(body)),
}
}
func newJSONResponseWithHeader(status int, body, key, value string) *http.Response {
resp := newJSONResponse(status, body)
resp.Header.Set(key, value)
return resp
}
func newSoraTestContext() (*gin.Context, *httptest.ResponseRecorder) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil)
return c, rec
}
func TestAccountTestService_testSoraAccountConnection_WithSubscription(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`),
newJSONResponse(http.StatusOK, `{"data":[{"plan":{"id":"chatgpt_plus","title":"ChatGPT Plus"},"end_ts":"2026-12-31T00:00:00Z"}]}`),
newJSONResponse(http.StatusOK, `{"invite_code":"inv_abc","redeemed_count":3,"total_count":50}`),
newJSONResponse(http.StatusOK, `{"rate_limit_and_credit_balance":{"estimated_num_videos_remaining":27,"rate_limit_reached":false,"access_resets_in_seconds":46833}}`),
},
}
svc := &AccountTestService{
httpUpstream: upstream,
cfg: &config.Config{
Gateway: config.GatewayConfig{
TLSFingerprint: config.TLSFingerprintConfig{
Enabled: true,
},
},
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
DisableTLSFingerprint: false,
},
},
},
}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.NoError(t, err)
require.Len(t, upstream.requests, 4)
require.Equal(t, soraMeAPIURL, upstream.requests[0].URL.String())
require.Equal(t, soraBillingAPIURL, upstream.requests[1].URL.String())
require.Equal(t, soraInviteMineURL, upstream.requests[2].URL.String())
require.Equal(t, soraRemainingURL, upstream.requests[3].URL.String())
require.Equal(t, "Bearer test_token", upstream.requests[0].Header.Get("Authorization"))
require.Equal(t, "Bearer test_token", upstream.requests[1].Header.Get("Authorization"))
require.Equal(t, []bool{true, true, true, true}, upstream.tlsFlags)
body := rec.Body.String()
require.Contains(t, body, `"type":"test_start"`)
require.Contains(t, body, "Sora connection OK - Email: demo@example.com")
require.Contains(t, body, "Subscription: ChatGPT Plus | chatgpt_plus | end=2026-12-31T00:00:00Z")
require.Contains(t, body, "Sora2: supported | invite=inv_abc | used=3/50")
require.Contains(t, body, "Sora2 remaining: 27 | reset_in=46833s")
require.Contains(t, body, `"type":"sora_test_result"`)
require.Contains(t, body, `"status":"success"`)
require.Contains(t, body, `"type":"test_complete","success":true`)
}
func TestAccountTestService_testSoraAccountConnection_SubscriptionFailedStillSuccess(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusOK, `{"name":"demo-user"}`),
newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`),
newJSONResponse(http.StatusUnauthorized, `{"error":{"message":"Unauthorized"}}`),
newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.NoError(t, err)
require.Len(t, upstream.requests, 4)
body := rec.Body.String()
require.Contains(t, body, "Sora connection OK - User: demo-user")
require.Contains(t, body, "Subscription check returned 403")
require.Contains(t, body, "Sora2 invite check returned 401")
require.Contains(t, body, `"type":"sora_test_result"`)
require.Contains(t, body, `"status":"partial_success"`)
require.Contains(t, body, `"type":"test_complete","success":true`)
}
func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponseWithHeader(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`, "cf-ray", "9cff2d62d83bb98d"),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.Error(t, err)
require.Contains(t, err.Error(), "Cloudflare challenge")
require.Contains(t, err.Error(), "cf-ray: 9cff2d62d83bb98d")
body := rec.Body.String()
require.Contains(t, body, `"type":"error"`)
require.Contains(t, body, "Cloudflare challenge")
require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
}
func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge429WithHeader(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponseWithHeader(http.StatusTooManyRequests, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body></body></html>`, "cf-mitigated", "challenge"),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.Error(t, err)
require.Contains(t, err.Error(), "Cloudflare challenge")
require.Contains(t, err.Error(), "HTTP 429")
body := rec.Body.String()
require.Contains(t, body, "Cloudflare challenge")
}
func TestAccountTestService_testSoraAccountConnection_TokenInvalidated(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusUnauthorized, `{"error":{"code":"token_invalidated","message":"Token invalid"}}`),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.Error(t, err)
require.Contains(t, err.Error(), "token_invalidated")
body := rec.Body.String()
require.Contains(t, body, `"type":"sora_test_result"`)
require.Contains(t, body, `"status":"failed"`)
require.Contains(t, body, "token_invalidated")
require.NotContains(t, body, `"type":"test_complete","success":true`)
}
func TestAccountTestService_testSoraAccountConnection_RateLimited(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`),
},
}
svc := &AccountTestService{
httpUpstream: upstream,
soraTestCooldown: time.Hour,
}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c1, _ := newSoraTestContext()
err := svc.testSoraAccountConnection(c1, account)
require.NoError(t, err)
c2, rec2 := newSoraTestContext()
err = svc.testSoraAccountConnection(c2, account)
require.Error(t, err)
require.Contains(t, err.Error(), "测试过于频繁")
body := rec2.Body.String()
require.Contains(t, body, `"type":"sora_test_result"`)
require.Contains(t, body, `"code":"test_rate_limited"`)
require.Contains(t, body, `"status":"failed"`)
require.NotContains(t, body, `"type":"test_complete","success":true`)
}
func TestAccountTestService_testSoraAccountConnection_SubscriptionCloudflareChallengeWithRay(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusOK, `{"name":"demo-user"}`),
newJSONResponse(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`),
newJSONResponse(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.NoError(t, err)
body := rec.Body.String()
require.Contains(t, body, "Subscription check blocked by Cloudflare challenge (HTTP 403)")
require.Contains(t, body, "Sora2 invite check blocked by Cloudflare challenge (HTTP 403)")
require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
require.Contains(t, body, `"type":"test_complete","success":true`)
}
func TestSanitizeProxyURLForLog(t *testing.T) {
require.Equal(t, "http://proxy.example.com:8080", sanitizeProxyURLForLog("http://user:pass@proxy.example.com:8080"))
require.Equal(t, "", sanitizeProxyURLForLog(""))
require.Equal(t, "<invalid_proxy_url>", sanitizeProxyURLForLog("://invalid"))
}
func TestExtractSoraEgressIPHint(t *testing.T) {
h := make(http.Header)
h.Set("x-openai-public-ip", "203.0.113.10")
require.Equal(t, "203.0.113.10", extractSoraEgressIPHint(h))
h2 := make(http.Header)
h2.Set("x-envoy-external-address", "198.51.100.9")
require.Equal(t, "198.51.100.9", extractSoraEgressIPHint(h2))
require.Equal(t, "unknown", extractSoraEgressIPHint(nil))
require.Equal(t, "unknown", extractSoraEgressIPHint(http.Header{}))
}
......@@ -15,7 +15,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
"github.com/Wei-Shaw/sub2api/internal/util/httputil"
)
// AdminService interface defines admin management operations
......@@ -111,7 +111,6 @@ type CreateUserInput struct {
Balance float64
Concurrency int
AllowedGroups []int64
SoraStorageQuotaBytes int64
}
type UpdateUserInput struct {
......@@ -126,7 +125,6 @@ type UpdateUserInput struct {
// GroupRates 用户专属分组倍率配置
// map[groupID]*rate,nil 表示删除该分组的专属倍率
GroupRates map[int64]*float64
SoraStorageQuotaBytes *int64
}
type CreateGroupInput struct {
......@@ -143,11 +141,6 @@ type CreateGroupInput struct {
ImagePrice1K *float64
ImagePrice2K *float64
ImagePrice4K *float64
// Sora 按次计费配置
SoraImagePrice360 *float64
SoraImagePrice540 *float64
SoraVideoPricePerRequest *float64
SoraVideoPricePerRequestHD *float64
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
......@@ -158,8 +151,6 @@ type CreateGroupInput struct {
MCPXMLInject *bool
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string
// Sora 存储配额
SoraStorageQuotaBytes int64
// OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch bool
DefaultMappedModel string
......@@ -184,11 +175,6 @@ type UpdateGroupInput struct {
ImagePrice1K *float64
ImagePrice2K *float64
ImagePrice4K *float64
// Sora 按次计费配置
SoraImagePrice360 *float64
SoraImagePrice540 *float64
SoraVideoPricePerRequest *float64
SoraVideoPricePerRequestHD *float64
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
......@@ -199,8 +185,6 @@ type UpdateGroupInput struct {
MCPXMLInject *bool
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string
// Sora 存储配额
SoraStorageQuotaBytes *int64
// OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch *bool
DefaultMappedModel *string
......@@ -426,14 +410,6 @@ var proxyQualityTargets = []proxyQualityTarget{
http.StatusOK: {},
},
},
{
Target: "sora",
URL: "https://sora.chatgpt.com/backend/me",
Method: http.MethodGet,
AllowedStatuses: map[int]struct{}{
http.StatusUnauthorized: {},
},
},
}
const (
......@@ -448,7 +424,6 @@ type adminServiceImpl struct {
userRepo UserRepository
groupRepo GroupRepository
accountRepo AccountRepository
soraAccountRepo SoraAccountRepository // Sora 账号扩展表仓储
proxyRepo ProxyRepository
apiKeyRepo APIKeyRepository
redeemCodeRepo RedeemCodeRepository
......@@ -473,7 +448,6 @@ func NewAdminService(
userRepo UserRepository,
groupRepo GroupRepository,
accountRepo AccountRepository,
soraAccountRepo SoraAccountRepository,
proxyRepo ProxyRepository,
apiKeyRepo APIKeyRepository,
redeemCodeRepo RedeemCodeRepository,
......@@ -492,7 +466,6 @@ func NewAdminService(
userRepo: userRepo,
groupRepo: groupRepo,
accountRepo: accountRepo,
soraAccountRepo: soraAccountRepo,
proxyRepo: proxyRepo,
apiKeyRepo: apiKeyRepo,
redeemCodeRepo: redeemCodeRepo,
......@@ -582,7 +555,6 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu
Concurrency: input.Concurrency,
Status: StatusActive,
AllowedGroups: input.AllowedGroups,
SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
}
if err := user.SetPassword(input.Password); err != nil {
return nil, err
......@@ -654,10 +626,6 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
user.AllowedGroups = *input.AllowedGroups
}
if input.SoraStorageQuotaBytes != nil {
user.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes
}
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err
}
......@@ -860,10 +828,6 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
imagePrice1K := normalizePrice(input.ImagePrice1K)
imagePrice2K := normalizePrice(input.ImagePrice2K)
imagePrice4K := normalizePrice(input.ImagePrice4K)
soraImagePrice360 := normalizePrice(input.SoraImagePrice360)
soraImagePrice540 := normalizePrice(input.SoraImagePrice540)
soraVideoPrice := normalizePrice(input.SoraVideoPricePerRequest)
soraVideoPriceHD := normalizePrice(input.SoraVideoPricePerRequestHD)
// 校验降级分组
if input.FallbackGroupID != nil {
......@@ -934,17 +898,12 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
ImagePrice1K: imagePrice1K,
ImagePrice2K: imagePrice2K,
ImagePrice4K: imagePrice4K,
SoraImagePrice360: soraImagePrice360,
SoraImagePrice540: soraImagePrice540,
SoraVideoPricePerRequest: soraVideoPrice,
SoraVideoPricePerRequestHD: soraVideoPriceHD,
ClaudeCodeOnly: input.ClaudeCodeOnly,
FallbackGroupID: input.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest,
ModelRouting: input.ModelRouting,
MCPXMLInject: mcpXMLInject,
SupportedModelScopes: input.SupportedModelScopes,
SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
AllowMessagesDispatch: input.AllowMessagesDispatch,
RequireOAuthOnly: input.RequireOAuthOnly,
RequirePrivacySet: input.RequirePrivacySet,
......@@ -1115,21 +1074,6 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.ImagePrice4K != nil {
group.ImagePrice4K = normalizePrice(input.ImagePrice4K)
}
if input.SoraImagePrice360 != nil {
group.SoraImagePrice360 = normalizePrice(input.SoraImagePrice360)
}
if input.SoraImagePrice540 != nil {
group.SoraImagePrice540 = normalizePrice(input.SoraImagePrice540)
}
if input.SoraVideoPricePerRequest != nil {
group.SoraVideoPricePerRequest = normalizePrice(input.SoraVideoPricePerRequest)
}
if input.SoraVideoPricePerRequestHD != nil {
group.SoraVideoPricePerRequestHD = normalizePrice(input.SoraVideoPricePerRequestHD)
}
if input.SoraStorageQuotaBytes != nil {
group.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes
}
// Claude Code 客户端限制
if input.ClaudeCodeOnly != nil {
......@@ -1566,18 +1510,6 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
}
}
// Sora apikey 账号的 base_url 必填校验
if input.Platform == PlatformSora && input.Type == AccountTypeAPIKey {
baseURL, _ := input.Credentials["base_url"].(string)
baseURL = strings.TrimSpace(baseURL)
if baseURL == "" {
return nil, errors.New("sora apikey 账号必须设置 base_url")
}
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
return nil, errors.New("base_url 必须以 http:// 或 https:// 开头")
}
}
account := &Account{
Name: input.Name,
Notes: normalizeAccountNotes(input.Notes),
......@@ -1623,18 +1555,6 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
return nil, err
}
// 如果是 Sora 平台账号,自动创建 sora_accounts 扩展表记录
if account.Platform == PlatformSora && s.soraAccountRepo != nil {
soraUpdates := map[string]any{
"access_token": account.GetCredential("access_token"),
"refresh_token": account.GetCredential("refresh_token"),
}
if err := s.soraAccountRepo.Upsert(ctx, account.ID, soraUpdates); err != nil {
// 只记录警告日志,不阻塞账号创建
logger.LegacyPrintf("service.admin", "[AdminService] 创建 sora_accounts 记录失败: account_id=%d err=%v", account.ID, err)
}
}
// 绑定分组
if len(groupIDs) > 0 {
if err := s.accountRepo.BindGroups(ctx, account.ID, groupIDs); err != nil {
......@@ -1763,18 +1683,6 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
account.AutoPauseOnExpired = *input.AutoPauseOnExpired
}
// Sora apikey 账号的 base_url 必填校验
if account.Platform == PlatformSora && account.Type == AccountTypeAPIKey {
baseURL, _ := account.Credentials["base_url"].(string)
baseURL = strings.TrimSpace(baseURL)
if baseURL == "" {
return nil, errors.New("sora apikey 账号必须设置 base_url")
}
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
return nil, errors.New("base_url 必须以 http:// 或 https:// 开头")
}
}
// 先验证分组是否存在(在任何写操作之前)
if input.GroupIDs != nil {
if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil {
......@@ -2377,10 +2285,11 @@ func runProxyQualityTarget(ctx context.Context, client *http.Client, target prox
body = body[:proxyQualityMaxBodyBytes]
}
if target.Target == "sora" && soraerror.IsCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) {
// Cloudflare challenge 检测
if httputil.IsCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) {
item.Status = "challenge"
item.CFRay = soraerror.ExtractCloudflareRayID(resp.Header, body)
item.Message = "Sora 命中 Cloudflare challenge"
item.CFRay = httputil.ExtractCloudflareRayID(resp.Header, body)
item.Message = "命中 Cloudflare challenge"
return item
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment