Commit 62e80c60 authored by erio's avatar erio
Browse files

revert: completely remove all Sora functionality

parent dbb248df
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -129,56 +129,3 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovere
})
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
}
func TestSoraGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
pool := newUsageRecordTestPool(t)
h := &SoraGatewayHandler{usageRecordWorkerPool: pool}
done := make(chan struct{})
h.submitUsageRecordTask(func(ctx context.Context) {
close(done)
})
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("task not executed")
}
}
func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) {
h := &SoraGatewayHandler{}
var called atomic.Bool
h.submitUsageRecordTask(func(ctx context.Context) {
if _, ok := ctx.Deadline(); !ok {
t.Fatal("expected deadline in fallback context")
}
called.Store(true)
})
require.True(t, called.Load())
}
func TestSoraGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
h := &SoraGatewayHandler{}
require.NotPanics(t, func() {
h.submitUsageRecordTask(nil)
})
}
func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) {
h := &SoraGatewayHandler{}
var called atomic.Bool
require.NotPanics(t, func() {
h.submitUsageRecordTask(func(ctx context.Context) {
panic("usage task panic")
})
})
h.submitUsageRecordTask(func(ctx context.Context) {
called.Store(true)
})
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
}
......@@ -86,8 +86,6 @@ func ProvideHandlers(
adminHandlers *AdminHandlers,
gatewayHandler *GatewayHandler,
openaiGatewayHandler *OpenAIGatewayHandler,
soraGatewayHandler *SoraGatewayHandler,
soraClientHandler *SoraClientHandler,
settingHandler *SettingHandler,
totpHandler *TotpHandler,
_ *service.IdempotencyCoordinator,
......@@ -104,8 +102,6 @@ func ProvideHandlers(
Admin: adminHandlers,
Gateway: gatewayHandler,
OpenAIGateway: openaiGatewayHandler,
SoraGateway: soraGatewayHandler,
SoraClient: soraClientHandler,
Setting: settingHandler,
Totp: totpHandler,
}
......@@ -123,7 +119,6 @@ var ProviderSet = wire.NewSet(
NewAnnouncementHandler,
NewGatewayHandler,
NewOpenAIGatewayHandler,
NewSoraGatewayHandler,
NewTotpHandler,
ProvideSettingHandler,
......
......@@ -17,8 +17,6 @@ import (
const (
// OAuth Client ID for OpenAI (Codex CLI official)
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
// OAuth Client ID for Sora mobile flow (aligned with sora2api)
SoraClientID = "app_LlGpXReQgckcGGUo2JrYvtJK"
// OAuth endpoints
AuthorizeURL = "https://auth.openai.com/oauth/authorize"
......@@ -39,8 +37,6 @@ const (
const (
// OAuthPlatformOpenAI uses OpenAI Codex-compatible OAuth client.
OAuthPlatformOpenAI = "openai"
// OAuthPlatformSora uses Sora OAuth client.
OAuthPlatformSora = "sora"
)
// OAuthSession stores OAuth flow state for OpenAI
......@@ -211,15 +207,8 @@ func BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, platfor
}
// OAuthClientConfigByPlatform returns oauth client_id and whether codex simplified flow should be enabled.
// Sora 授权流程复用 Codex CLI 的 client_id(支持 localhost redirect_uri),
// 但不启用 codex_cli_simplified_flow;拿到的 access_token 绑定同一 OpenAI 账号,对 Sora API 同样可用。
func OAuthClientConfigByPlatform(platform string) (clientID string, codexFlow bool) {
switch strings.ToLower(strings.TrimSpace(platform)) {
case OAuthPlatformSora:
return ClientID, false
default:
return ClientID, true
}
}
// TokenRequest represents the token exchange request body
......
......@@ -60,23 +60,3 @@ func TestBuildAuthorizationURLForPlatform_OpenAI(t *testing.T) {
t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got)
}
}
// TestBuildAuthorizationURLForPlatform_Sora 验证 Sora 平台复用 Codex CLI 的 client_id,
// 但不启用 codex_cli_simplified_flow。
func TestBuildAuthorizationURLForPlatform_Sora(t *testing.T) {
authURL := BuildAuthorizationURLForPlatform("state-2", "challenge-2", DefaultRedirectURI, OAuthPlatformSora)
parsed, err := url.Parse(authURL)
if err != nil {
t.Fatalf("Parse URL failed: %v", err)
}
q := parsed.Query()
if got := q.Get("client_id"); got != ClientID {
t.Fatalf("client_id mismatch: got=%q want=%q (Sora should reuse Codex CLI client_id)", got, ClientID)
}
if got := q.Get("codex_cli_simplified_flow"); got != "" {
t.Fatalf("codex flow should be empty for sora, got=%q", got)
}
if got := q.Get("id_token_add_organizations"); got != "true" {
t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got)
}
}
......@@ -1692,20 +1692,13 @@ func itoa(v int) string {
}
// FindByExtraField 根据 extra 字段中的键值对查找账号。
// 该方法限定 platform='sora',避免误查询其他平台的账号。
// 使用 PostgreSQL JSONB @> 操作符进行高效查询(需要 GIN 索引支持)。
//
// 应用场景:查找通过 linked_openai_account_id 关联的 Sora 账号。
//
// FindByExtraField finds accounts by key-value pairs in the extra field.
// Limited to platform='sora' to avoid querying accounts from other platforms.
// Uses PostgreSQL JSONB @> operator for efficient queries (requires GIN index).
//
// Use case: Finding Sora accounts linked via linked_openai_account_id.
func (r *accountRepository) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) {
accounts, err := r.client.Account.Query().
Where(
dbaccount.PlatformEQ("sora"), // 限定平台为 sora
dbaccount.DeletedAtIsNil(),
func(s *entsql.Selector) {
path := sqljson.Path(key)
......
......@@ -155,10 +155,6 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
group.FieldImagePrice1k,
group.FieldImagePrice2k,
group.FieldImagePrice4k,
group.FieldSoraImagePrice360,
group.FieldSoraImagePrice540,
group.FieldSoraVideoPricePerRequest,
group.FieldSoraVideoPricePerRequestHd,
group.FieldClaudeCodeOnly,
group.FieldFallbackGroupID,
group.FieldFallbackGroupIDOnInvalidRequest,
......@@ -617,8 +613,6 @@ func userEntityToService(u *dbent.User) *service.User {
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
SoraStorageQuotaBytes: u.SoraStorageQuotaBytes,
SoraStorageUsedBytes: u.SoraStorageUsedBytes,
TotpSecretEncrypted: u.TotpSecretEncrypted,
TotpEnabled: u.TotpEnabled,
TotpEnabledAt: u.TotpEnabledAt,
......@@ -647,11 +641,6 @@ func groupEntityToService(g *dbent.Group) *service.Group {
ImagePrice1K: g.ImagePrice1k,
ImagePrice2K: g.ImagePrice2k,
ImagePrice4K: g.ImagePrice4k,
SoraImagePrice360: g.SoraImagePrice360,
SoraImagePrice540: g.SoraImagePrice540,
SoraVideoPricePerRequest: g.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHd,
SoraStorageQuotaBytes: g.SoraStorageQuotaBytes,
DefaultValidityDays: g.DefaultValidityDays,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
......
......@@ -49,17 +49,12 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice1k(groupIn.ImagePrice1K).
SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetNillableSoraImagePrice360(groupIn.SoraImagePrice360).
SetNillableSoraImagePrice540(groupIn.SoraImagePrice540).
SetNillableSoraVideoPricePerRequest(groupIn.SoraVideoPricePerRequest).
SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD).
SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetNillableFallbackGroupID(groupIn.FallbackGroupID).
SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
SetMcpXMLInject(groupIn.MCPXMLInject).
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
SetRequireOauthOnly(groupIn.RequireOAuthOnly).
SetRequirePrivacySet(groupIn.RequirePrivacySet).
......@@ -122,15 +117,10 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice1k(groupIn.ImagePrice1K).
SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetNillableSoraImagePrice360(groupIn.SoraImagePrice360).
SetNillableSoraImagePrice540(groupIn.SoraImagePrice540).
SetNillableSoraVideoPricePerRequest(groupIn.SoraVideoPricePerRequest).
SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD).
SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
SetMcpXMLInject(groupIn.MCPXMLInject).
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
SetRequireOauthOnly(groupIn.RequireOAuthOnly).
SetRequirePrivacySet(groupIn.RequirePrivacySet).
......
......@@ -158,30 +158,6 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_DefaultsToOpenAIClientID() {
require.Equal(s.T(), []string{openai.ClientID}, seenClientIDs)
}
// TestRefreshToken_UseSoraClientID 验证显式传入 Sora ClientID 时直接使用,不回退。
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseSoraClientID() {
var seenClientIDs []string
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
clientID := r.PostForm.Get("client_id")
seenClientIDs = append(seenClientIDs, clientID)
if clientID == openai.SoraClientID {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at-sora","refresh_token":"rt-sora","token_type":"bearer","expires_in":3600}`)
return
}
w.WriteHeader(http.StatusBadRequest)
}))
resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", openai.SoraClientID)
require.NoError(s.T(), err, "RefreshTokenWithClientID")
require.Equal(s.T(), "at-sora", resp.AccessToken)
require.Equal(s.T(), []string{openai.SoraClientID}, seenClientIDs)
}
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseProvidedClientID() {
const customClientID = "custom-client-id"
var seenClientIDs []string
......@@ -276,7 +252,7 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UsesProvidedRedirectURI() {
}
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UseProvidedClientID() {
wantClientID := openai.SoraClientID
wantClientID := "custom-exchange-client-id"
errCh := make(chan string, 1)
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = r.ParseForm()
......
package repository
import (
"context"
"database/sql"
"errors"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// soraAccountRepository 实现 service.SoraAccountRepository 接口。
// 使用原生 SQL 操作 sora_accounts 表,因为该表不在 Ent ORM 管理范围内。
//
// 设计说明:
// - sora_accounts 表是独立迁移创建的,不通过 Ent Schema 管理
// - 使用 ON CONFLICT (account_id) DO UPDATE 实现 Upsert 语义
// - 与 accounts 主表通过外键关联,ON DELETE CASCADE 确保级联删除
type soraAccountRepository struct {
sql *sql.DB
}
// NewSoraAccountRepository 创建 Sora 账号扩展表仓储实例
func NewSoraAccountRepository(sqlDB *sql.DB) service.SoraAccountRepository {
return &soraAccountRepository{sql: sqlDB}
}
// Upsert 创建或更新 Sora 账号扩展信息
// 使用 PostgreSQL ON CONFLICT ... DO UPDATE 实现原子性 upsert
func (r *soraAccountRepository) Upsert(ctx context.Context, accountID int64, updates map[string]any) error {
accessToken, accessOK := updates["access_token"].(string)
refreshToken, refreshOK := updates["refresh_token"].(string)
sessionToken, sessionOK := updates["session_token"].(string)
if !accessOK || accessToken == "" || !refreshOK || refreshToken == "" {
if !sessionOK {
return errors.New("缺少 access_token/refresh_token,且未提供可更新字段")
}
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_accounts
SET session_token = CASE WHEN $2 = '' THEN session_token ELSE $2 END,
updated_at = NOW()
WHERE account_id = $1
`, accountID, sessionToken)
if err != nil {
return err
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return errors.New("sora_accounts 记录不存在,无法仅更新 session_token")
}
return nil
}
_, err := r.sql.ExecContext(ctx, `
INSERT INTO sora_accounts (account_id, access_token, refresh_token, session_token, created_at, updated_at)
VALUES ($1, $2, $3, $4, NOW(), NOW())
ON CONFLICT (account_id) DO UPDATE SET
access_token = EXCLUDED.access_token,
refresh_token = EXCLUDED.refresh_token,
session_token = CASE WHEN EXCLUDED.session_token = '' THEN sora_accounts.session_token ELSE EXCLUDED.session_token END,
updated_at = NOW()
`, accountID, accessToken, refreshToken, sessionToken)
return err
}
// GetByAccountID 根据账号 ID 获取 Sora 扩展信息
func (r *soraAccountRepository) GetByAccountID(ctx context.Context, accountID int64) (*service.SoraAccount, error) {
rows, err := r.sql.QueryContext(ctx, `
SELECT account_id, access_token, refresh_token, COALESCE(session_token, '')
FROM sora_accounts
WHERE account_id = $1
`, accountID)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
return nil, nil // 记录不存在
}
var sa service.SoraAccount
if err := rows.Scan(&sa.AccountID, &sa.AccessToken, &sa.RefreshToken, &sa.SessionToken); err != nil {
return nil, err
}
return &sa, nil
}
// Delete 删除 Sora 账号扩展信息
func (r *soraAccountRepository) Delete(ctx context.Context, accountID int64) error {
_, err := r.sql.ExecContext(ctx, `
DELETE FROM sora_accounts WHERE account_id = $1
`, accountID)
return err
}
package repository
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// soraGenerationRepository 实现 service.SoraGenerationRepository 接口。
// 使用原生 SQL 操作 sora_generations 表。
type soraGenerationRepository struct {
sql *sql.DB
}
// NewSoraGenerationRepository 创建 Sora 生成记录仓储实例。
func NewSoraGenerationRepository(sqlDB *sql.DB) service.SoraGenerationRepository {
return &soraGenerationRepository{sql: sqlDB}
}
func (r *soraGenerationRepository) Create(ctx context.Context, gen *service.SoraGeneration) error {
mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
err := r.sql.QueryRowContext(ctx, `
INSERT INTO sora_generations (
user_id, api_key_id, model, prompt, media_type,
status, media_url, media_urls, file_size_bytes,
storage_type, s3_object_keys, upstream_task_id, error_message
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
RETURNING id, created_at
`,
gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType,
gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage,
).Scan(&gen.ID, &gen.CreatedAt)
return err
}
// CreatePendingWithLimit 在单事务内执行“并发上限检查 + 创建”,避免 count+create 竞态。
func (r *soraGenerationRepository) CreatePendingWithLimit(
ctx context.Context,
gen *service.SoraGeneration,
activeStatuses []string,
maxActive int64,
) error {
if gen == nil {
return fmt.Errorf("generation is nil")
}
if maxActive <= 0 {
return r.Create(ctx, gen)
}
if len(activeStatuses) == 0 {
activeStatuses = []string{service.SoraGenStatusPending, service.SoraGenStatusGenerating}
}
tx, err := r.sql.BeginTx(ctx, nil)
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
// 使用用户级 advisory lock 串行化并发创建,避免超限竞态。
if _, err := tx.ExecContext(ctx, `SELECT pg_advisory_xact_lock($1)`, gen.UserID); err != nil {
return err
}
placeholders := make([]string, len(activeStatuses))
args := make([]any, 0, 1+len(activeStatuses))
args = append(args, gen.UserID)
for i, s := range activeStatuses {
placeholders[i] = fmt.Sprintf("$%d", i+2)
args = append(args, s)
}
countQuery := fmt.Sprintf(
`SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)`,
strings.Join(placeholders, ","),
)
var activeCount int64
if err := tx.QueryRowContext(ctx, countQuery, args...).Scan(&activeCount); err != nil {
return err
}
if activeCount >= maxActive {
return service.ErrSoraGenerationConcurrencyLimit
}
mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
if err := tx.QueryRowContext(ctx, `
INSERT INTO sora_generations (
user_id, api_key_id, model, prompt, media_type,
status, media_url, media_urls, file_size_bytes,
storage_type, s3_object_keys, upstream_task_id, error_message
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
RETURNING id, created_at
`,
gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType,
gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage,
).Scan(&gen.ID, &gen.CreatedAt); err != nil {
return err
}
return tx.Commit()
}
func (r *soraGenerationRepository) GetByID(ctx context.Context, id int64) (*service.SoraGeneration, error) {
gen := &service.SoraGeneration{}
var mediaURLsJSON, s3KeysJSON []byte
var completedAt sql.NullTime
var apiKeyID sql.NullInt64
err := r.sql.QueryRowContext(ctx, `
SELECT id, user_id, api_key_id, model, prompt, media_type,
status, media_url, media_urls, file_size_bytes,
storage_type, s3_object_keys, upstream_task_id, error_message,
created_at, completed_at
FROM sora_generations WHERE id = $1
`, id).Scan(
&gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType,
&gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes,
&gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage,
&gen.CreatedAt, &completedAt,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("生成记录不存在")
}
return nil, err
}
if apiKeyID.Valid {
gen.APIKeyID = &apiKeyID.Int64
}
if completedAt.Valid {
gen.CompletedAt = &completedAt.Time
}
_ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs)
_ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys)
return gen, nil
}
func (r *soraGenerationRepository) Update(ctx context.Context, gen *service.SoraGeneration) error {
mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
var completedAt *time.Time
if gen.CompletedAt != nil {
completedAt = gen.CompletedAt
}
_, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations SET
status = $2, media_url = $3, media_urls = $4, file_size_bytes = $5,
storage_type = $6, s3_object_keys = $7, upstream_task_id = $8,
error_message = $9, completed_at = $10
WHERE id = $1
`,
gen.ID, gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
gen.StorageType, s3KeysJSON, gen.UpstreamTaskID,
gen.ErrorMessage, completedAt,
)
return err
}
// UpdateGeneratingIfPending 仅当状态为 pending 时更新为 generating。
func (r *soraGenerationRepository) UpdateGeneratingIfPending(ctx context.Context, id int64, upstreamTaskID string) (bool, error) {
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations
SET status = $2, upstream_task_id = $3
WHERE id = $1 AND status = $4
`,
id, service.SoraGenStatusGenerating, upstreamTaskID, service.SoraGenStatusPending,
)
if err != nil {
return false, err
}
affected, err := result.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
// UpdateCompletedIfActive 仅当状态为 pending/generating 时更新为 completed。
func (r *soraGenerationRepository) UpdateCompletedIfActive(
ctx context.Context,
id int64,
mediaURL string,
mediaURLs []string,
storageType string,
s3Keys []string,
fileSizeBytes int64,
completedAt time.Time,
) (bool, error) {
mediaURLsJSON, _ := json.Marshal(mediaURLs)
s3KeysJSON, _ := json.Marshal(s3Keys)
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations
SET status = $2,
media_url = $3,
media_urls = $4,
file_size_bytes = $5,
storage_type = $6,
s3_object_keys = $7,
error_message = '',
completed_at = $8
WHERE id = $1 AND status IN ($9, $10)
`,
id, service.SoraGenStatusCompleted, mediaURL, mediaURLsJSON, fileSizeBytes,
storageType, s3KeysJSON, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
)
if err != nil {
return false, err
}
affected, err := result.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
// UpdateFailedIfActive 仅当状态为 pending/generating 时更新为 failed。
func (r *soraGenerationRepository) UpdateFailedIfActive(
ctx context.Context,
id int64,
errMsg string,
completedAt time.Time,
) (bool, error) {
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations
SET status = $2,
error_message = $3,
completed_at = $4
WHERE id = $1 AND status IN ($5, $6)
`,
id, service.SoraGenStatusFailed, errMsg, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
)
if err != nil {
return false, err
}
affected, err := result.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
// UpdateCancelledIfActive 仅当状态为 pending/generating 时更新为 cancelled。
func (r *soraGenerationRepository) UpdateCancelledIfActive(ctx context.Context, id int64, completedAt time.Time) (bool, error) {
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations
SET status = $2, completed_at = $3
WHERE id = $1 AND status IN ($4, $5)
`,
id, service.SoraGenStatusCancelled, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
)
if err != nil {
return false, err
}
affected, err := result.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
// UpdateStorageIfCompleted 更新已完成记录的存储信息(用于手动保存,不重置 completed_at)。
func (r *soraGenerationRepository) UpdateStorageIfCompleted(
ctx context.Context,
id int64,
mediaURL string,
mediaURLs []string,
storageType string,
s3Keys []string,
fileSizeBytes int64,
) (bool, error) {
mediaURLsJSON, _ := json.Marshal(mediaURLs)
s3KeysJSON, _ := json.Marshal(s3Keys)
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations
SET media_url = $2,
media_urls = $3,
file_size_bytes = $4,
storage_type = $5,
s3_object_keys = $6
WHERE id = $1 AND status = $7
`,
id, mediaURL, mediaURLsJSON, fileSizeBytes, storageType, s3KeysJSON, service.SoraGenStatusCompleted,
)
if err != nil {
return false, err
}
affected, err := result.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
func (r *soraGenerationRepository) Delete(ctx context.Context, id int64) error {
_, err := r.sql.ExecContext(ctx, `DELETE FROM sora_generations WHERE id = $1`, id)
return err
}
func (r *soraGenerationRepository) List(ctx context.Context, params service.SoraGenerationListParams) ([]*service.SoraGeneration, int64, error) {
// 构建 WHERE 条件
conditions := []string{"user_id = $1"}
args := []any{params.UserID}
argIdx := 2
if params.Status != "" {
// 支持逗号分隔的多状态
statuses := strings.Split(params.Status, ",")
placeholders := make([]string, len(statuses))
for i, s := range statuses {
placeholders[i] = fmt.Sprintf("$%d", argIdx)
args = append(args, strings.TrimSpace(s))
argIdx++
}
conditions = append(conditions, fmt.Sprintf("status IN (%s)", strings.Join(placeholders, ",")))
}
if params.StorageType != "" {
storageTypes := strings.Split(params.StorageType, ",")
placeholders := make([]string, len(storageTypes))
for i, s := range storageTypes {
placeholders[i] = fmt.Sprintf("$%d", argIdx)
args = append(args, strings.TrimSpace(s))
argIdx++
}
conditions = append(conditions, fmt.Sprintf("storage_type IN (%s)", strings.Join(placeholders, ",")))
}
if params.MediaType != "" {
conditions = append(conditions, fmt.Sprintf("media_type = $%d", argIdx))
args = append(args, params.MediaType)
argIdx++
}
whereClause := "WHERE " + strings.Join(conditions, " AND ")
// 计数
var total int64
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations %s", whereClause)
if err := r.sql.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
return nil, 0, err
}
// 分页查询
offset := (params.Page - 1) * params.PageSize
listQuery := fmt.Sprintf(`
SELECT id, user_id, api_key_id, model, prompt, media_type,
status, media_url, media_urls, file_size_bytes,
storage_type, s3_object_keys, upstream_task_id, error_message,
created_at, completed_at
FROM sora_generations %s
ORDER BY created_at DESC
LIMIT $%d OFFSET $%d
`, whereClause, argIdx, argIdx+1)
args = append(args, params.PageSize, offset)
rows, err := r.sql.QueryContext(ctx, listQuery, args...)
if err != nil {
return nil, 0, err
}
defer func() {
_ = rows.Close()
}()
var results []*service.SoraGeneration
for rows.Next() {
gen := &service.SoraGeneration{}
var mediaURLsJSON, s3KeysJSON []byte
var completedAt sql.NullTime
var apiKeyID sql.NullInt64
if err := rows.Scan(
&gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType,
&gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes,
&gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage,
&gen.CreatedAt, &completedAt,
); err != nil {
return nil, 0, err
}
if apiKeyID.Valid {
gen.APIKeyID = &apiKeyID.Int64
}
if completedAt.Valid {
gen.CompletedAt = &completedAt.Time
}
_ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs)
_ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys)
results = append(results, gen)
}
return results, total, rows.Err()
}
func (r *soraGenerationRepository) CountByUserAndStatus(ctx context.Context, userID int64, statuses []string) (int64, error) {
if len(statuses) == 0 {
return 0, nil
}
placeholders := make([]string, len(statuses))
args := []any{userID}
for i, s := range statuses {
placeholders[i] = fmt.Sprintf("$%d", i+2)
args = append(args, s)
}
var count int64
query := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)", strings.Join(placeholders, ","))
err := r.sql.QueryRowContext(ctx, query, args...).Scan(&count)
return count, err
}
......@@ -62,7 +62,6 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
SetBalance(userIn.Balance).
SetConcurrency(userIn.Concurrency).
SetStatus(userIn.Status).
SetSoraStorageQuotaBytes(userIn.SoraStorageQuotaBytes).
Save(ctx)
if err != nil {
return translatePersistenceError(err, nil, service.ErrEmailExists)
......@@ -145,8 +144,6 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
SetBalance(userIn.Balance).
SetConcurrency(userIn.Concurrency).
SetStatus(userIn.Status).
SetSoraStorageQuotaBytes(userIn.SoraStorageQuotaBytes).
SetSoraStorageUsedBytes(userIn.SoraStorageUsedBytes).
Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
......@@ -376,65 +373,6 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount
return nil
}
// AddSoraStorageUsageWithQuota 原子累加 Sora 存储用量,并在有配额时校验不超额。
func (r *userRepository) AddSoraStorageUsageWithQuota(ctx context.Context, userID int64, deltaBytes int64, effectiveQuota int64) (int64, error) {
if deltaBytes <= 0 {
user, err := r.GetByID(ctx, userID)
if err != nil {
return 0, err
}
return user.SoraStorageUsedBytes, nil
}
var newUsed int64
err := scanSingleRow(ctx, r.sql, `
UPDATE users
SET sora_storage_used_bytes = sora_storage_used_bytes + $2
WHERE id = $1
AND ($3 = 0 OR sora_storage_used_bytes + $2 <= $3)
RETURNING sora_storage_used_bytes
`, []any{userID, deltaBytes, effectiveQuota}, &newUsed)
if err == nil {
return newUsed, nil
}
if errors.Is(err, sql.ErrNoRows) {
// 区分用户不存在和配额冲突
exists, existsErr := r.client.User.Query().Where(dbuser.IDEQ(userID)).Exist(ctx)
if existsErr != nil {
return 0, existsErr
}
if !exists {
return 0, service.ErrUserNotFound
}
return 0, service.ErrSoraStorageQuotaExceeded
}
return 0, err
}
// ReleaseSoraStorageUsageAtomic 原子释放 Sora 存储用量,并保证不低于 0。
func (r *userRepository) ReleaseSoraStorageUsageAtomic(ctx context.Context, userID int64, deltaBytes int64) (int64, error) {
if deltaBytes <= 0 {
user, err := r.GetByID(ctx, userID)
if err != nil {
return 0, err
}
return user.SoraStorageUsedBytes, nil
}
var newUsed int64
err := scanSingleRow(ctx, r.sql, `
UPDATE users
SET sora_storage_used_bytes = GREATEST(sora_storage_used_bytes - $2, 0)
WHERE id = $1
RETURNING sora_storage_used_bytes
`, []any{userID, deltaBytes}, &newUsed)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return 0, service.ErrUserNotFound
}
return 0, err
}
return newUsed, nil
}
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx)
}
......
......@@ -53,7 +53,6 @@ var ProviderSet = wire.NewSet(
NewAPIKeyRepository,
NewGroupRepository,
NewAccountRepository,
NewSoraAccountRepository, // Sora 账号扩展表仓储
NewScheduledTestPlanRepository, // 定时测试计划仓储
NewScheduledTestResultRepository, // 定时测试结果仓储
NewProxyRepository,
......
......@@ -94,7 +94,6 @@ func isAPIRoutePath(c *gin.Context) bool {
return strings.HasPrefix(path, "/v1/") ||
strings.HasPrefix(path, "/v1beta/") ||
strings.HasPrefix(path, "/antigravity/") ||
strings.HasPrefix(path, "/sora/") ||
strings.HasPrefix(path, "/responses")
}
......
......@@ -109,7 +109,6 @@ func registerRoutes(
// 注册各模块路由
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient, settingService)
routes.RegisterUserRoutes(v1, h, jwtAuth, settingService)
routes.RegisterSoraClientRoutes(v1, h, jwtAuth, settingService)
routes.RegisterAdminRoutes(v1, h, adminAuth)
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg)
}
......@@ -34,8 +34,6 @@ func RegisterAdminRoutes(
// OpenAI OAuth
registerOpenAIOAuthRoutes(admin, h)
// Sora OAuth(实现复用 OpenAI OAuth 服务,入口独立)
registerSoraOAuthRoutes(admin, h)
// Gemini OAuth
registerGeminiOAuthRoutes(admin, h)
......@@ -321,19 +319,6 @@ func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
}
}
func registerSoraOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
sora := admin.Group("/sora")
{
sora.POST("/generate-auth-url", h.Admin.OpenAIOAuth.GenerateAuthURL)
sora.POST("/exchange-code", h.Admin.OpenAIOAuth.ExchangeCode)
sora.POST("/refresh-token", h.Admin.OpenAIOAuth.RefreshToken)
sora.POST("/st2at", h.Admin.OpenAIOAuth.ExchangeSoraSessionToken)
sora.POST("/rt2at", h.Admin.OpenAIOAuth.RefreshToken)
sora.POST("/accounts/:id/refresh", h.Admin.OpenAIOAuth.RefreshAccountToken)
sora.POST("/create-from-oauth", h.Admin.OpenAIOAuth.CreateAccountFromOAuth)
}
}
func registerGeminiOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
gemini := admin.Group("/gemini")
{
......
......@@ -23,11 +23,6 @@ func RegisterGatewayRoutes(
cfg *config.Config,
) {
bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize)
soraMaxBodySize := cfg.Gateway.SoraMaxBodySize
if soraMaxBodySize <= 0 {
soraMaxBodySize = cfg.Gateway.MaxBodySize
}
soraBodyLimit := middleware.RequestBodyLimit(soraMaxBodySize)
clientRequestID := middleware.ClientRequestID()
opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService)
endpointNorm := handler.InboundEndpointMiddleware()
......@@ -163,28 +158,6 @@ func RegisterGatewayRoutes(
antigravityV1Beta.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels)
}
// Sora 专用路由(强制使用 sora 平台)
soraV1 := r.Group("/sora/v1")
soraV1.Use(soraBodyLimit)
soraV1.Use(clientRequestID)
soraV1.Use(opsErrorLogger)
soraV1.Use(endpointNorm)
soraV1.Use(middleware.ForcePlatform(service.PlatformSora))
soraV1.Use(gin.HandlerFunc(apiKeyAuth))
soraV1.Use(requireGroupAnthropic)
{
soraV1.POST("/chat/completions", h.SoraGateway.ChatCompletions)
soraV1.GET("/models", h.Gateway.Models)
}
// Sora 媒体代理(可选 API Key 验证)
if cfg.Gateway.SoraMediaRequireAPIKey {
r.GET("/sora/media/*filepath", gin.HandlerFunc(apiKeyAuth), h.SoraGateway.MediaProxy)
} else {
r.GET("/sora/media/*filepath", h.SoraGateway.MediaProxy)
}
// Sora 媒体代理(签名 URL,无需 API Key)
r.GET("/sora/media-signed/*filepath", h.SoraGateway.MediaProxySigned)
}
// getGroupPlatform extracts the group platform from the API Key stored in context.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment