Commit 13262a56 authored by yangjianbo's avatar yangjianbo
Browse files

feat(sora): 新增 Sora 平台支持并修复高危安全和性能问题



新增功能:
- 新增 Sora 账号管理和 OAuth 认证
- 新增 Sora 视频/图片生成 API 网关
- 新增 Sora 任务调度和缓存机制
- 新增 Sora 使用统计和计费支持
- 前端增加 Sora 平台配置界面

安全修复(代码审核):
- [SEC-001] 限制媒体下载响应体大小(图片 20MB、视频 200MB),防止 DoS 攻击
- [SEC-002] 限制 SDK API 响应大小(1MB),防止内存耗尽
- [SEC-003] 修复 SSRF 风险,添加 URL 验证并强制使用代理配置

BUG 修复(代码审核):
- [BUG-001] 修复 for 循环内 defer 累积导致的资源泄漏
- [BUG-002] 修复图片并发槽位获取失败时已持有锁未释放的永久泄漏

性能优化(代码审核):
- [PERF-001] 添加 Sentinel Token 缓存(3 分钟有效期),减少 PoW 计算开销

技术细节:
- 使用 io.LimitReader 限制所有外部输入的大小
- 添加 urlvalidator 验证防止 SSRF 攻击
- 使用 sync.Map 实现线程安全的包级缓存
- 优化并发槽位管理,添加 releaseAll 模式防止泄漏

影响范围:
- 后端:新增 Sora 相关数据模型、服务、网关和管理接口
- 前端:新增 Sora 平台配置、账号管理和监控界面
- 配置:新增 Sora 相关配置项和环境变量
Co-Authored-By: default avatarClaude Sonnet 4.5 <noreply@anthropic.com>
parent bece1b52
package repository
import (
"context"
"database/sql"
"errors"
"time"
"github.com/Wei-Shaw/sub2api/ent"
dbsoraaccount "github.com/Wei-Shaw/sub2api/ent/soraaccount"
dbsoracachefile "github.com/Wei-Shaw/sub2api/ent/soracachefile"
dbsoratask "github.com/Wei-Shaw/sub2api/ent/soratask"
dbsorausagestat "github.com/Wei-Shaw/sub2api/ent/sorausagestat"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
entsql "entgo.io/ent/dialect/sql"
)
// SoraAccount
type soraAccountRepository struct {
client *ent.Client
}
func NewSoraAccountRepository(client *ent.Client) service.SoraAccountRepository {
return &soraAccountRepository{client: client}
}
func (r *soraAccountRepository) GetByAccountID(ctx context.Context, accountID int64) (*service.SoraAccount, error) {
if accountID <= 0 {
return nil, nil
}
acc, err := r.client.SoraAccount.Query().Where(dbsoraaccount.AccountIDEQ(accountID)).Only(ctx)
if err != nil {
if ent.IsNotFound(err) {
return nil, nil
}
return nil, err
}
return mapSoraAccount(acc), nil
}
func (r *soraAccountRepository) GetByAccountIDs(ctx context.Context, accountIDs []int64) (map[int64]*service.SoraAccount, error) {
if len(accountIDs) == 0 {
return map[int64]*service.SoraAccount{}, nil
}
records, err := r.client.SoraAccount.Query().Where(dbsoraaccount.AccountIDIn(accountIDs...)).All(ctx)
if err != nil {
return nil, err
}
result := make(map[int64]*service.SoraAccount, len(records))
for _, record := range records {
if record == nil {
continue
}
result[record.AccountID] = mapSoraAccount(record)
}
return result, nil
}
func (r *soraAccountRepository) Upsert(ctx context.Context, accountID int64, updates map[string]any) error {
if accountID <= 0 {
return errors.New("invalid account_id")
}
acc, err := r.client.SoraAccount.Query().Where(dbsoraaccount.AccountIDEQ(accountID)).Only(ctx)
if err != nil && !ent.IsNotFound(err) {
return err
}
if acc == nil {
builder := r.client.SoraAccount.Create().SetAccountID(accountID)
applySoraAccountUpdates(builder.Mutation(), updates)
return builder.Exec(ctx)
}
updater := r.client.SoraAccount.UpdateOneID(acc.ID)
applySoraAccountUpdates(updater.Mutation(), updates)
return updater.Exec(ctx)
}
func applySoraAccountUpdates(m *ent.SoraAccountMutation, updates map[string]any) {
if updates == nil {
return
}
for key, val := range updates {
switch key {
case "access_token":
if v, ok := val.(string); ok {
m.SetAccessToken(v)
}
case "session_token":
if v, ok := val.(string); ok {
m.SetSessionToken(v)
}
case "refresh_token":
if v, ok := val.(string); ok {
m.SetRefreshToken(v)
}
case "client_id":
if v, ok := val.(string); ok {
m.SetClientID(v)
}
case "email":
if v, ok := val.(string); ok {
m.SetEmail(v)
}
case "username":
if v, ok := val.(string); ok {
m.SetUsername(v)
}
case "remark":
if v, ok := val.(string); ok {
m.SetRemark(v)
}
case "plan_type":
if v, ok := val.(string); ok {
m.SetPlanType(v)
}
case "plan_title":
if v, ok := val.(string); ok {
m.SetPlanTitle(v)
}
case "subscription_end":
if v, ok := val.(time.Time); ok {
m.SetSubscriptionEnd(v)
}
if v, ok := val.(*time.Time); ok && v != nil {
m.SetSubscriptionEnd(*v)
}
case "sora_supported":
if v, ok := val.(bool); ok {
m.SetSoraSupported(v)
}
case "sora_invite_code":
if v, ok := val.(string); ok {
m.SetSoraInviteCode(v)
}
case "sora_redeemed_count":
if v, ok := val.(int); ok {
m.SetSoraRedeemedCount(v)
}
case "sora_remaining_count":
if v, ok := val.(int); ok {
m.SetSoraRemainingCount(v)
}
case "sora_total_count":
if v, ok := val.(int); ok {
m.SetSoraTotalCount(v)
}
case "sora_cooldown_until":
if v, ok := val.(time.Time); ok {
m.SetSoraCooldownUntil(v)
}
if v, ok := val.(*time.Time); ok && v != nil {
m.SetSoraCooldownUntil(*v)
}
case "cooled_until":
if v, ok := val.(time.Time); ok {
m.SetCooledUntil(v)
}
if v, ok := val.(*time.Time); ok && v != nil {
m.SetCooledUntil(*v)
}
case "image_enabled":
if v, ok := val.(bool); ok {
m.SetImageEnabled(v)
}
case "video_enabled":
if v, ok := val.(bool); ok {
m.SetVideoEnabled(v)
}
case "image_concurrency":
if v, ok := val.(int); ok {
m.SetImageConcurrency(v)
}
case "video_concurrency":
if v, ok := val.(int); ok {
m.SetVideoConcurrency(v)
}
case "is_expired":
if v, ok := val.(bool); ok {
m.SetIsExpired(v)
}
}
}
}
func mapSoraAccount(acc *ent.SoraAccount) *service.SoraAccount {
if acc == nil {
return nil
}
return &service.SoraAccount{
AccountID: acc.AccountID,
AccessToken: derefString(acc.AccessToken),
SessionToken: derefString(acc.SessionToken),
RefreshToken: derefString(acc.RefreshToken),
ClientID: derefString(acc.ClientID),
Email: derefString(acc.Email),
Username: derefString(acc.Username),
Remark: derefString(acc.Remark),
UseCount: acc.UseCount,
PlanType: derefString(acc.PlanType),
PlanTitle: derefString(acc.PlanTitle),
SubscriptionEnd: acc.SubscriptionEnd,
SoraSupported: acc.SoraSupported,
SoraInviteCode: derefString(acc.SoraInviteCode),
SoraRedeemedCount: acc.SoraRedeemedCount,
SoraRemainingCount: acc.SoraRemainingCount,
SoraTotalCount: acc.SoraTotalCount,
SoraCooldownUntil: acc.SoraCooldownUntil,
CooledUntil: acc.CooledUntil,
ImageEnabled: acc.ImageEnabled,
VideoEnabled: acc.VideoEnabled,
ImageConcurrency: acc.ImageConcurrency,
VideoConcurrency: acc.VideoConcurrency,
IsExpired: acc.IsExpired,
CreatedAt: acc.CreatedAt,
UpdatedAt: acc.UpdatedAt,
}
}
func mapSoraUsageStat(stat *ent.SoraUsageStat) *service.SoraUsageStat {
if stat == nil {
return nil
}
return &service.SoraUsageStat{
AccountID: stat.AccountID,
ImageCount: stat.ImageCount,
VideoCount: stat.VideoCount,
ErrorCount: stat.ErrorCount,
LastErrorAt: stat.LastErrorAt,
TodayImageCount: stat.TodayImageCount,
TodayVideoCount: stat.TodayVideoCount,
TodayErrorCount: stat.TodayErrorCount,
TodayDate: stat.TodayDate,
ConsecutiveErrorCount: stat.ConsecutiveErrorCount,
CreatedAt: stat.CreatedAt,
UpdatedAt: stat.UpdatedAt,
}
}
func mapSoraCacheFile(file *ent.SoraCacheFile) *service.SoraCacheFile {
if file == nil {
return nil
}
return &service.SoraCacheFile{
ID: int64(file.ID),
TaskID: derefString(file.TaskID),
AccountID: file.AccountID,
UserID: file.UserID,
MediaType: file.MediaType,
OriginalURL: file.OriginalURL,
CachePath: file.CachePath,
CacheURL: file.CacheURL,
SizeBytes: file.SizeBytes,
CreatedAt: file.CreatedAt,
}
}
// SoraUsageStat
type soraUsageStatRepository struct {
client *ent.Client
sql sqlExecutor
}
func NewSoraUsageStatRepository(client *ent.Client, sqlDB *sql.DB) service.SoraUsageStatRepository {
return &soraUsageStatRepository{client: client, sql: sqlDB}
}
func (r *soraUsageStatRepository) RecordSuccess(ctx context.Context, accountID int64, isVideo bool) error {
if accountID <= 0 {
return nil
}
field := "image_count"
todayField := "today_image_count"
if isVideo {
field = "video_count"
todayField = "today_video_count"
}
today := time.Now().UTC().Truncate(24 * time.Hour)
query := "INSERT INTO sora_usage_stats (account_id, " + field + ", " + todayField + ", today_date, consecutive_error_count, created_at, updated_at) " +
"VALUES ($1, 1, 1, $2, 0, NOW(), NOW()) " +
"ON CONFLICT (account_id) DO UPDATE SET " +
field + " = sora_usage_stats." + field + " + 1, " +
todayField + " = CASE WHEN sora_usage_stats.today_date = $2 THEN sora_usage_stats." + todayField + " + 1 ELSE 1 END, " +
"today_date = $2, consecutive_error_count = 0, updated_at = NOW()"
_, err := r.sql.ExecContext(ctx, query, accountID, today)
return err
}
func (r *soraUsageStatRepository) RecordError(ctx context.Context, accountID int64) (int, error) {
if accountID <= 0 {
return 0, nil
}
today := time.Now().UTC().Truncate(24 * time.Hour)
query := "INSERT INTO sora_usage_stats (account_id, error_count, today_error_count, today_date, consecutive_error_count, last_error_at, created_at, updated_at) " +
"VALUES ($1, 1, 1, $2, 1, NOW(), NOW(), NOW()) " +
"ON CONFLICT (account_id) DO UPDATE SET " +
"error_count = sora_usage_stats.error_count + 1, " +
"today_error_count = CASE WHEN sora_usage_stats.today_date = $2 THEN sora_usage_stats.today_error_count + 1 ELSE 1 END, " +
"today_date = $2, consecutive_error_count = sora_usage_stats.consecutive_error_count + 1, last_error_at = NOW(), updated_at = NOW() " +
"RETURNING consecutive_error_count"
var consecutive int
err := scanSingleRow(ctx, r.sql, query, []any{accountID, today}, &consecutive)
if err != nil {
return 0, err
}
return consecutive, nil
}
func (r *soraUsageStatRepository) ResetConsecutiveErrors(ctx context.Context, accountID int64) error {
if accountID <= 0 {
return nil
}
err := r.client.SoraUsageStat.Update().Where(dbsorausagestat.AccountIDEQ(accountID)).
SetConsecutiveErrorCount(0).
Exec(ctx)
if ent.IsNotFound(err) {
return nil
}
return err
}
func (r *soraUsageStatRepository) GetByAccountID(ctx context.Context, accountID int64) (*service.SoraUsageStat, error) {
if accountID <= 0 {
return nil, nil
}
stat, err := r.client.SoraUsageStat.Query().Where(dbsorausagestat.AccountIDEQ(accountID)).Only(ctx)
if err != nil {
if ent.IsNotFound(err) {
return nil, nil
}
return nil, err
}
return mapSoraUsageStat(stat), nil
}
func (r *soraUsageStatRepository) GetByAccountIDs(ctx context.Context, accountIDs []int64) (map[int64]*service.SoraUsageStat, error) {
if len(accountIDs) == 0 {
return map[int64]*service.SoraUsageStat{}, nil
}
stats, err := r.client.SoraUsageStat.Query().Where(dbsorausagestat.AccountIDIn(accountIDs...)).All(ctx)
if err != nil {
return nil, err
}
result := make(map[int64]*service.SoraUsageStat, len(stats))
for _, stat := range stats {
if stat == nil {
continue
}
result[stat.AccountID] = mapSoraUsageStat(stat)
}
return result, nil
}
func (r *soraUsageStatRepository) List(ctx context.Context, params pagination.PaginationParams) ([]*service.SoraUsageStat, *pagination.PaginationResult, error) {
query := r.client.SoraUsageStat.Query()
total, err := query.Count(ctx)
if err != nil {
return nil, nil, err
}
stats, err := query.Order(ent.Desc(dbsorausagestat.FieldUpdatedAt)).
Limit(params.Limit()).
Offset(params.Offset()).
All(ctx)
if err != nil {
return nil, nil, err
}
result := make([]*service.SoraUsageStat, 0, len(stats))
for _, stat := range stats {
result = append(result, mapSoraUsageStat(stat))
}
return result, paginationResultFromTotal(int64(total), params), nil
}
// SoraTask
type soraTaskRepository struct {
client *ent.Client
}
func NewSoraTaskRepository(client *ent.Client) service.SoraTaskRepository {
return &soraTaskRepository{client: client}
}
func (r *soraTaskRepository) Create(ctx context.Context, task *service.SoraTask) error {
if task == nil {
return nil
}
builder := r.client.SoraTask.Create().
SetTaskID(task.TaskID).
SetAccountID(task.AccountID).
SetModel(task.Model).
SetPrompt(task.Prompt).
SetStatus(task.Status).
SetProgress(task.Progress).
SetRetryCount(task.RetryCount)
if task.ResultURLs != "" {
builder.SetResultUrls(task.ResultURLs)
}
if task.ErrorMessage != "" {
builder.SetErrorMessage(task.ErrorMessage)
}
if task.CreatedAt.IsZero() {
builder.SetCreatedAt(time.Now())
} else {
builder.SetCreatedAt(task.CreatedAt)
}
if task.CompletedAt != nil {
builder.SetCompletedAt(*task.CompletedAt)
}
return builder.Exec(ctx)
}
func (r *soraTaskRepository) UpdateStatus(ctx context.Context, taskID string, status string, progress float64, resultURLs string, errorMessage string, completedAt *time.Time) error {
if taskID == "" {
return nil
}
builder := r.client.SoraTask.Update().Where(dbsoratask.TaskIDEQ(taskID)).
SetStatus(status).
SetProgress(progress)
if resultURLs != "" {
builder.SetResultUrls(resultURLs)
}
if errorMessage != "" {
builder.SetErrorMessage(errorMessage)
}
if completedAt != nil {
builder.SetCompletedAt(*completedAt)
}
_, err := builder.Save(ctx)
if ent.IsNotFound(err) {
return nil
}
return err
}
// SoraCacheFile
type soraCacheFileRepository struct {
client *ent.Client
}
func NewSoraCacheFileRepository(client *ent.Client) service.SoraCacheFileRepository {
return &soraCacheFileRepository{client: client}
}
func (r *soraCacheFileRepository) Create(ctx context.Context, file *service.SoraCacheFile) error {
if file == nil {
return nil
}
builder := r.client.SoraCacheFile.Create().
SetAccountID(file.AccountID).
SetUserID(file.UserID).
SetMediaType(file.MediaType).
SetOriginalURL(file.OriginalURL).
SetCachePath(file.CachePath).
SetCacheURL(file.CacheURL).
SetSizeBytes(file.SizeBytes)
if file.TaskID != "" {
builder.SetTaskID(file.TaskID)
}
if file.CreatedAt.IsZero() {
builder.SetCreatedAt(time.Now())
} else {
builder.SetCreatedAt(file.CreatedAt)
}
return builder.Exec(ctx)
}
func (r *soraCacheFileRepository) ListOldest(ctx context.Context, limit int) ([]*service.SoraCacheFile, error) {
if limit <= 0 {
return []*service.SoraCacheFile{}, nil
}
records, err := r.client.SoraCacheFile.Query().
Order(dbsoracachefile.ByCreatedAt(entsql.OrderAsc())).
Limit(limit).
All(ctx)
if err != nil {
return nil, err
}
result := make([]*service.SoraCacheFile, 0, len(records))
for _, record := range records {
if record == nil {
continue
}
result = append(result, mapSoraCacheFile(record))
}
return result, nil
}
func (r *soraCacheFileRepository) DeleteByIDs(ctx context.Context, ids []int64) error {
if len(ids) == 0 {
return nil
}
_, err := r.client.SoraCacheFile.Delete().Where(dbsoracachefile.IDIn(ids...)).Exec(ctx)
return err
}
...@@ -64,6 +64,10 @@ var ProviderSet = wire.NewSet( ...@@ -64,6 +64,10 @@ var ProviderSet = wire.NewSet(
NewUserSubscriptionRepository, NewUserSubscriptionRepository,
NewUserAttributeDefinitionRepository, NewUserAttributeDefinitionRepository,
NewUserAttributeValueRepository, NewUserAttributeValueRepository,
NewSoraAccountRepository,
NewSoraUsageStatRepository,
NewSoraTaskRepository,
NewSoraCacheFileRepository,
// Cache implementations // Cache implementations
NewGatewayCache, NewGatewayCache,
......
package server package server
import ( import (
"context"
"log" "log"
"path/filepath"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/handler"
...@@ -46,6 +49,22 @@ func SetupRouter( ...@@ -46,6 +49,22 @@ func SetupRouter(
} }
} }
// Serve Sora cached videos when enabled
cacheVideoDir := ""
cacheEnabled := false
if settingService != nil {
soraCfg := settingService.GetSoraConfig(context.Background())
cacheEnabled = soraCfg.Cache.Enabled
cacheVideoDir = strings.TrimSpace(soraCfg.Cache.VideoDir)
} else if cfg != nil {
cacheEnabled = cfg.Sora.Cache.Enabled
cacheVideoDir = strings.TrimSpace(cfg.Sora.Cache.VideoDir)
}
if cacheEnabled && cacheVideoDir != "" {
videoDir := filepath.Clean(cacheVideoDir)
r.Static("/data/video", videoDir)
}
// 注册路由 // 注册路由
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg, redisClient) registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg, redisClient)
......
...@@ -29,6 +29,9 @@ func RegisterAdminRoutes( ...@@ -29,6 +29,9 @@ func RegisterAdminRoutes(
// 账号管理 // 账号管理
registerAccountRoutes(admin, h) registerAccountRoutes(admin, h)
// Sora 账号扩展
registerSoraRoutes(admin, h)
// OpenAI OAuth // OpenAI OAuth
registerOpenAIOAuthRoutes(admin, h) registerOpenAIOAuthRoutes(admin, h)
...@@ -229,6 +232,17 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -229,6 +232,17 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
} }
} }
func registerSoraRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
sora := admin.Group("/sora")
{
sora.GET("/accounts", h.Admin.SoraAccount.List)
sora.GET("/accounts/:id", h.Admin.SoraAccount.Get)
sora.PUT("/accounts/:id", h.Admin.SoraAccount.Upsert)
sora.POST("/accounts/import", h.Admin.SoraAccount.BatchUpsert)
sora.GET("/usage", h.Admin.SoraAccount.ListUsage)
}
}
func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) { func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
openai := admin.Group("/openai") openai := admin.Group("/openai")
{ {
......
...@@ -33,6 +33,7 @@ func RegisterGatewayRoutes( ...@@ -33,6 +33,7 @@ func RegisterGatewayRoutes(
gateway.POST("/messages", h.Gateway.Messages) gateway.POST("/messages", h.Gateway.Messages)
gateway.POST("/messages/count_tokens", h.Gateway.CountTokens) gateway.POST("/messages/count_tokens", h.Gateway.CountTokens)
gateway.GET("/models", h.Gateway.Models) gateway.GET("/models", h.Gateway.Models)
gateway.POST("/chat/completions", h.SoraGateway.ChatCompletions)
gateway.GET("/usage", h.Gateway.Usage) gateway.GET("/usage", h.Gateway.Usage)
// OpenAI Responses API // OpenAI Responses API
gateway.POST("/responses", h.OpenAIGateway.Responses) gateway.POST("/responses", h.OpenAIGateway.Responses)
......
...@@ -22,6 +22,7 @@ const ( ...@@ -22,6 +22,7 @@ const (
PlatformOpenAI = "openai" PlatformOpenAI = "openai"
PlatformGemini = "gemini" PlatformGemini = "gemini"
PlatformAntigravity = "antigravity" PlatformAntigravity = "antigravity"
PlatformSora = "sora"
) )
// Account type constants // Account type constants
...@@ -124,6 +125,28 @@ const ( ...@@ -124,6 +125,28 @@ const (
SettingKeyEnableIdentityPatch = "enable_identity_patch" SettingKeyEnableIdentityPatch = "enable_identity_patch"
SettingKeyIdentityPatchPrompt = "identity_patch_prompt" SettingKeyIdentityPatchPrompt = "identity_patch_prompt"
// =========================
// Sora Settings
// =========================
SettingKeySoraBaseURL = "sora_base_url"
SettingKeySoraTimeout = "sora_timeout"
SettingKeySoraMaxRetries = "sora_max_retries"
SettingKeySoraPollInterval = "sora_poll_interval"
SettingKeySoraCallLogicMode = "sora_call_logic_mode"
SettingKeySoraCacheEnabled = "sora_cache_enabled"
SettingKeySoraCacheBaseDir = "sora_cache_base_dir"
SettingKeySoraCacheVideoDir = "sora_cache_video_dir"
SettingKeySoraCacheMaxBytes = "sora_cache_max_bytes"
SettingKeySoraCacheAllowedHosts = "sora_cache_allowed_hosts"
SettingKeySoraCacheUserDirEnabled = "sora_cache_user_dir_enabled"
SettingKeySoraWatermarkFreeEnabled = "sora_watermark_free_enabled"
SettingKeySoraWatermarkFreeParseMethod = "sora_watermark_free_parse_method"
SettingKeySoraWatermarkFreeCustomParseURL = "sora_watermark_free_custom_parse_url"
SettingKeySoraWatermarkFreeCustomParseToken = "sora_watermark_free_custom_parse_token"
SettingKeySoraWatermarkFreeFallbackOnFailure = "sora_watermark_free_fallback_on_failure"
SettingKeySoraTokenRefreshEnabled = "sora_token_refresh_enabled"
// ========================= // =========================
// Ops Monitoring (vNext) // Ops Monitoring (vNext)
// ========================= // =========================
......
...@@ -378,7 +378,7 @@ func (s *SchedulerSnapshotService) rebuildByGroupIDs(ctx context.Context, groupI ...@@ -378,7 +378,7 @@ func (s *SchedulerSnapshotService) rebuildByGroupIDs(ctx context.Context, groupI
if len(groupIDs) == 0 { if len(groupIDs) == 0 {
return nil return nil
} }
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity} platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformSora, PlatformAntigravity}
var firstErr error var firstErr error
for _, platform := range platforms { for _, platform := range platforms {
if err := s.rebuildBucketsForPlatform(ctx, platform, groupIDs, reason); err != nil && firstErr == nil { if err := s.rebuildBucketsForPlatform(ctx, platform, groupIDs, reason); err != nil && firstErr == nil {
...@@ -661,7 +661,7 @@ func (s *SchedulerSnapshotService) fullRebuildInterval() time.Duration { ...@@ -661,7 +661,7 @@ func (s *SchedulerSnapshotService) fullRebuildInterval() time.Duration {
func (s *SchedulerSnapshotService) defaultBuckets(ctx context.Context) ([]SchedulerBucket, error) { func (s *SchedulerSnapshotService) defaultBuckets(ctx context.Context) ([]SchedulerBucket, error) {
buckets := make([]SchedulerBucket, 0) buckets := make([]SchedulerBucket, 0)
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity} platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformSora, PlatformAntigravity}
for _, platform := range platforms { for _, platform := range platforms {
buckets = append(buckets, SchedulerBucket{GroupID: 0, Platform: platform, Mode: SchedulerModeSingle}) buckets = append(buckets, SchedulerBucket{GroupID: 0, Platform: platform, Mode: SchedulerModeSingle})
buckets = append(buckets, SchedulerBucket{GroupID: 0, Platform: platform, Mode: SchedulerModeForced}) buckets = append(buckets, SchedulerBucket{GroupID: 0, Platform: platform, Mode: SchedulerModeForced})
......
...@@ -219,6 +219,29 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet ...@@ -219,6 +219,29 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyEnableIdentityPatch] = strconv.FormatBool(settings.EnableIdentityPatch) updates[SettingKeyEnableIdentityPatch] = strconv.FormatBool(settings.EnableIdentityPatch)
updates[SettingKeyIdentityPatchPrompt] = settings.IdentityPatchPrompt updates[SettingKeyIdentityPatchPrompt] = settings.IdentityPatchPrompt
// Sora settings
updates[SettingKeySoraBaseURL] = strings.TrimSpace(settings.SoraBaseURL)
updates[SettingKeySoraTimeout] = strconv.Itoa(settings.SoraTimeout)
updates[SettingKeySoraMaxRetries] = strconv.Itoa(settings.SoraMaxRetries)
updates[SettingKeySoraPollInterval] = strconv.FormatFloat(settings.SoraPollInterval, 'f', -1, 64)
updates[SettingKeySoraCallLogicMode] = settings.SoraCallLogicMode
updates[SettingKeySoraCacheEnabled] = strconv.FormatBool(settings.SoraCacheEnabled)
updates[SettingKeySoraCacheBaseDir] = settings.SoraCacheBaseDir
updates[SettingKeySoraCacheVideoDir] = settings.SoraCacheVideoDir
updates[SettingKeySoraCacheMaxBytes] = strconv.FormatInt(settings.SoraCacheMaxBytes, 10)
allowedHostsRaw, err := marshalStringSliceSetting(settings.SoraCacheAllowedHosts)
if err != nil {
return fmt.Errorf("marshal sora cache allowed hosts: %w", err)
}
updates[SettingKeySoraCacheAllowedHosts] = allowedHostsRaw
updates[SettingKeySoraCacheUserDirEnabled] = strconv.FormatBool(settings.SoraCacheUserDirEnabled)
updates[SettingKeySoraWatermarkFreeEnabled] = strconv.FormatBool(settings.SoraWatermarkFreeEnabled)
updates[SettingKeySoraWatermarkFreeParseMethod] = settings.SoraWatermarkFreeParseMethod
updates[SettingKeySoraWatermarkFreeCustomParseURL] = strings.TrimSpace(settings.SoraWatermarkFreeCustomParseURL)
updates[SettingKeySoraWatermarkFreeCustomParseToken] = settings.SoraWatermarkFreeCustomParseToken
updates[SettingKeySoraWatermarkFreeFallbackOnFailure] = strconv.FormatBool(settings.SoraWatermarkFreeFallbackOnFailure)
updates[SettingKeySoraTokenRefreshEnabled] = strconv.FormatBool(settings.SoraTokenRefreshEnabled)
// Ops monitoring (vNext) // Ops monitoring (vNext)
updates[SettingKeyOpsMonitoringEnabled] = strconv.FormatBool(settings.OpsMonitoringEnabled) updates[SettingKeyOpsMonitoringEnabled] = strconv.FormatBool(settings.OpsMonitoringEnabled)
updates[SettingKeyOpsRealtimeMonitoringEnabled] = strconv.FormatBool(settings.OpsRealtimeMonitoringEnabled) updates[SettingKeyOpsRealtimeMonitoringEnabled] = strconv.FormatBool(settings.OpsRealtimeMonitoringEnabled)
...@@ -227,7 +250,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet ...@@ -227,7 +250,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyOpsMetricsIntervalSeconds] = strconv.Itoa(settings.OpsMetricsIntervalSeconds) updates[SettingKeyOpsMetricsIntervalSeconds] = strconv.Itoa(settings.OpsMetricsIntervalSeconds)
} }
err := s.settingRepo.SetMultiple(ctx, updates) err = s.settingRepo.SetMultiple(ctx, updates)
if err == nil && s.onUpdate != nil { if err == nil && s.onUpdate != nil {
s.onUpdate() // Invalidate cache after settings update s.onUpdate() // Invalidate cache after settings update
} }
...@@ -295,6 +318,41 @@ func (s *SettingService) GetDefaultBalance(ctx context.Context) float64 { ...@@ -295,6 +318,41 @@ func (s *SettingService) GetDefaultBalance(ctx context.Context) float64 {
return s.cfg.Default.UserBalance return s.cfg.Default.UserBalance
} }
// GetSoraConfig 获取 Sora 配置(优先读取 DB 设置,回退 config.yaml)
func (s *SettingService) GetSoraConfig(ctx context.Context) config.SoraConfig {
base := config.SoraConfig{}
if s.cfg != nil {
base = s.cfg.Sora
}
if s.settingRepo == nil {
return base
}
keys := []string{
SettingKeySoraBaseURL,
SettingKeySoraTimeout,
SettingKeySoraMaxRetries,
SettingKeySoraPollInterval,
SettingKeySoraCallLogicMode,
SettingKeySoraCacheEnabled,
SettingKeySoraCacheBaseDir,
SettingKeySoraCacheVideoDir,
SettingKeySoraCacheMaxBytes,
SettingKeySoraCacheAllowedHosts,
SettingKeySoraCacheUserDirEnabled,
SettingKeySoraWatermarkFreeEnabled,
SettingKeySoraWatermarkFreeParseMethod,
SettingKeySoraWatermarkFreeCustomParseURL,
SettingKeySoraWatermarkFreeCustomParseToken,
SettingKeySoraWatermarkFreeFallbackOnFailure,
SettingKeySoraTokenRefreshEnabled,
}
values, err := s.settingRepo.GetMultiple(ctx, keys)
if err != nil {
return base
}
return mergeSoraConfig(base, values)
}
// InitializeDefaultSettings 初始化默认设置 // InitializeDefaultSettings 初始化默认设置
func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// 检查是否已有设置 // 检查是否已有设置
...@@ -308,6 +366,12 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { ...@@ -308,6 +366,12 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
} }
// 初始化默认设置 // 初始化默认设置
soraCfg := config.SoraConfig{}
if s.cfg != nil {
soraCfg = s.cfg.Sora
}
allowedHostsRaw, _ := marshalStringSliceSetting(soraCfg.Cache.AllowedHosts)
defaults := map[string]string{ defaults := map[string]string{
SettingKeyRegistrationEnabled: "true", SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "false", SettingKeyEmailVerifyEnabled: "false",
...@@ -328,6 +392,25 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { ...@@ -328,6 +392,25 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyEnableIdentityPatch: "true", SettingKeyEnableIdentityPatch: "true",
SettingKeyIdentityPatchPrompt: "", SettingKeyIdentityPatchPrompt: "",
// Sora defaults
SettingKeySoraBaseURL: soraCfg.BaseURL,
SettingKeySoraTimeout: strconv.Itoa(soraCfg.Timeout),
SettingKeySoraMaxRetries: strconv.Itoa(soraCfg.MaxRetries),
SettingKeySoraPollInterval: strconv.FormatFloat(soraCfg.PollInterval, 'f', -1, 64),
SettingKeySoraCallLogicMode: soraCfg.CallLogicMode,
SettingKeySoraCacheEnabled: strconv.FormatBool(soraCfg.Cache.Enabled),
SettingKeySoraCacheBaseDir: soraCfg.Cache.BaseDir,
SettingKeySoraCacheVideoDir: soraCfg.Cache.VideoDir,
SettingKeySoraCacheMaxBytes: strconv.FormatInt(soraCfg.Cache.MaxBytes, 10),
SettingKeySoraCacheAllowedHosts: allowedHostsRaw,
SettingKeySoraCacheUserDirEnabled: strconv.FormatBool(soraCfg.Cache.UserDirEnabled),
SettingKeySoraWatermarkFreeEnabled: strconv.FormatBool(soraCfg.WatermarkFree.Enabled),
SettingKeySoraWatermarkFreeParseMethod: soraCfg.WatermarkFree.ParseMethod,
SettingKeySoraWatermarkFreeCustomParseURL: soraCfg.WatermarkFree.CustomParseURL,
SettingKeySoraWatermarkFreeCustomParseToken: soraCfg.WatermarkFree.CustomParseToken,
SettingKeySoraWatermarkFreeFallbackOnFailure: strconv.FormatBool(soraCfg.WatermarkFree.FallbackOnFailure),
SettingKeySoraTokenRefreshEnabled: strconv.FormatBool(soraCfg.TokenRefresh.Enabled),
// Ops monitoring defaults (vNext) // Ops monitoring defaults (vNext)
SettingKeyOpsMonitoringEnabled: "true", SettingKeyOpsMonitoringEnabled: "true",
SettingKeyOpsRealtimeMonitoringEnabled: "true", SettingKeyOpsRealtimeMonitoringEnabled: "true",
...@@ -434,6 +517,26 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin ...@@ -434,6 +517,26 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
} }
result.IdentityPatchPrompt = settings[SettingKeyIdentityPatchPrompt] result.IdentityPatchPrompt = settings[SettingKeyIdentityPatchPrompt]
// Sora settings
soraCfg := s.parseSoraConfig(settings)
result.SoraBaseURL = soraCfg.BaseURL
result.SoraTimeout = soraCfg.Timeout
result.SoraMaxRetries = soraCfg.MaxRetries
result.SoraPollInterval = soraCfg.PollInterval
result.SoraCallLogicMode = soraCfg.CallLogicMode
result.SoraCacheEnabled = soraCfg.Cache.Enabled
result.SoraCacheBaseDir = soraCfg.Cache.BaseDir
result.SoraCacheVideoDir = soraCfg.Cache.VideoDir
result.SoraCacheMaxBytes = soraCfg.Cache.MaxBytes
result.SoraCacheAllowedHosts = soraCfg.Cache.AllowedHosts
result.SoraCacheUserDirEnabled = soraCfg.Cache.UserDirEnabled
result.SoraWatermarkFreeEnabled = soraCfg.WatermarkFree.Enabled
result.SoraWatermarkFreeParseMethod = soraCfg.WatermarkFree.ParseMethod
result.SoraWatermarkFreeCustomParseURL = soraCfg.WatermarkFree.CustomParseURL
result.SoraWatermarkFreeCustomParseToken = soraCfg.WatermarkFree.CustomParseToken
result.SoraWatermarkFreeFallbackOnFailure = soraCfg.WatermarkFree.FallbackOnFailure
result.SoraTokenRefreshEnabled = soraCfg.TokenRefresh.Enabled
// Ops monitoring settings (default: enabled, fail-open) // Ops monitoring settings (default: enabled, fail-open)
result.OpsMonitoringEnabled = !isFalseSettingValue(settings[SettingKeyOpsMonitoringEnabled]) result.OpsMonitoringEnabled = !isFalseSettingValue(settings[SettingKeyOpsMonitoringEnabled])
result.OpsRealtimeMonitoringEnabled = !isFalseSettingValue(settings[SettingKeyOpsRealtimeMonitoringEnabled]) result.OpsRealtimeMonitoringEnabled = !isFalseSettingValue(settings[SettingKeyOpsRealtimeMonitoringEnabled])
...@@ -471,6 +574,131 @@ func (s *SettingService) getStringOrDefault(settings map[string]string, key, def ...@@ -471,6 +574,131 @@ func (s *SettingService) getStringOrDefault(settings map[string]string, key, def
return defaultValue return defaultValue
} }
func (s *SettingService) parseSoraConfig(settings map[string]string) config.SoraConfig {
base := config.SoraConfig{}
if s.cfg != nil {
base = s.cfg.Sora
}
return mergeSoraConfig(base, settings)
}
func mergeSoraConfig(base config.SoraConfig, settings map[string]string) config.SoraConfig {
cfg := base
if settings == nil {
return cfg
}
if raw, ok := settings[SettingKeySoraBaseURL]; ok {
if trimmed := strings.TrimSpace(raw); trimmed != "" {
cfg.BaseURL = trimmed
}
}
if raw, ok := settings[SettingKeySoraTimeout]; ok {
if v, err := strconv.Atoi(strings.TrimSpace(raw)); err == nil && v > 0 {
cfg.Timeout = v
}
}
if raw, ok := settings[SettingKeySoraMaxRetries]; ok {
if v, err := strconv.Atoi(strings.TrimSpace(raw)); err == nil && v >= 0 {
cfg.MaxRetries = v
}
}
if raw, ok := settings[SettingKeySoraPollInterval]; ok {
if v, err := strconv.ParseFloat(strings.TrimSpace(raw), 64); err == nil && v > 0 {
cfg.PollInterval = v
}
}
if raw, ok := settings[SettingKeySoraCallLogicMode]; ok && strings.TrimSpace(raw) != "" {
cfg.CallLogicMode = strings.TrimSpace(raw)
}
if raw, ok := settings[SettingKeySoraCacheEnabled]; ok {
cfg.Cache.Enabled = parseBoolSetting(raw, cfg.Cache.Enabled)
}
if raw, ok := settings[SettingKeySoraCacheBaseDir]; ok && strings.TrimSpace(raw) != "" {
cfg.Cache.BaseDir = strings.TrimSpace(raw)
}
if raw, ok := settings[SettingKeySoraCacheVideoDir]; ok && strings.TrimSpace(raw) != "" {
cfg.Cache.VideoDir = strings.TrimSpace(raw)
}
if raw, ok := settings[SettingKeySoraCacheMaxBytes]; ok {
if v, err := strconv.ParseInt(strings.TrimSpace(raw), 10, 64); err == nil && v >= 0 {
cfg.Cache.MaxBytes = v
}
}
if raw, ok := settings[SettingKeySoraCacheAllowedHosts]; ok {
cfg.Cache.AllowedHosts = parseStringSliceSetting(raw)
}
if raw, ok := settings[SettingKeySoraCacheUserDirEnabled]; ok {
cfg.Cache.UserDirEnabled = parseBoolSetting(raw, cfg.Cache.UserDirEnabled)
}
if raw, ok := settings[SettingKeySoraWatermarkFreeEnabled]; ok {
cfg.WatermarkFree.Enabled = parseBoolSetting(raw, cfg.WatermarkFree.Enabled)
}
if raw, ok := settings[SettingKeySoraWatermarkFreeParseMethod]; ok && strings.TrimSpace(raw) != "" {
cfg.WatermarkFree.ParseMethod = strings.TrimSpace(raw)
}
if raw, ok := settings[SettingKeySoraWatermarkFreeCustomParseURL]; ok && strings.TrimSpace(raw) != "" {
cfg.WatermarkFree.CustomParseURL = strings.TrimSpace(raw)
}
if raw, ok := settings[SettingKeySoraWatermarkFreeCustomParseToken]; ok {
cfg.WatermarkFree.CustomParseToken = raw
}
if raw, ok := settings[SettingKeySoraWatermarkFreeFallbackOnFailure]; ok {
cfg.WatermarkFree.FallbackOnFailure = parseBoolSetting(raw, cfg.WatermarkFree.FallbackOnFailure)
}
if raw, ok := settings[SettingKeySoraTokenRefreshEnabled]; ok {
cfg.TokenRefresh.Enabled = parseBoolSetting(raw, cfg.TokenRefresh.Enabled)
}
return cfg
}
func parseBoolSetting(raw string, fallback bool) bool {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return fallback
}
if v, err := strconv.ParseBool(trimmed); err == nil {
return v
}
return fallback
}
func parseStringSliceSetting(raw string) []string {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return []string{}
}
var values []string
if err := json.Unmarshal([]byte(trimmed), &values); err == nil {
return normalizeStringSlice(values)
}
parts := strings.FieldsFunc(trimmed, func(r rune) bool {
return r == ',' || r == '\n' || r == ';'
})
return normalizeStringSlice(parts)
}
func marshalStringSliceSetting(values []string) (string, error) {
normalized := normalizeStringSlice(values)
data, err := json.Marshal(normalized)
if err != nil {
return "", err
}
return string(data), nil
}
func normalizeStringSlice(values []string) []string {
if len(values) == 0 {
return []string{}
}
normalized := make([]string, 0, len(values))
for _, value := range values {
if trimmed := strings.TrimSpace(value); trimmed != "" {
normalized = append(normalized, trimmed)
}
}
return normalized
}
// IsTurnstileEnabled 检查是否启用 Turnstile 验证 // IsTurnstileEnabled 检查是否启用 Turnstile 验证
func (s *SettingService) IsTurnstileEnabled(ctx context.Context) bool { func (s *SettingService) IsTurnstileEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, SettingKeyTurnstileEnabled) value, err := s.settingRepo.GetValue(ctx, SettingKeyTurnstileEnabled)
......
...@@ -49,6 +49,25 @@ type SystemSettings struct { ...@@ -49,6 +49,25 @@ type SystemSettings struct {
EnableIdentityPatch bool `json:"enable_identity_patch"` EnableIdentityPatch bool `json:"enable_identity_patch"`
IdentityPatchPrompt string `json:"identity_patch_prompt"` IdentityPatchPrompt string `json:"identity_patch_prompt"`
// Sora configuration
SoraBaseURL string
SoraTimeout int
SoraMaxRetries int
SoraPollInterval float64
SoraCallLogicMode string
SoraCacheEnabled bool
SoraCacheBaseDir string
SoraCacheVideoDir string
SoraCacheMaxBytes int64
SoraCacheAllowedHosts []string
SoraCacheUserDirEnabled bool
SoraWatermarkFreeEnabled bool
SoraWatermarkFreeParseMethod string
SoraWatermarkFreeCustomParseURL string
SoraWatermarkFreeCustomParseToken string
SoraWatermarkFreeFallbackOnFailure bool
SoraTokenRefreshEnabled bool
// Ops monitoring (vNext) // Ops monitoring (vNext)
OpsMonitoringEnabled bool OpsMonitoringEnabled bool
OpsRealtimeMonitoringEnabled bool OpsRealtimeMonitoringEnabled bool
......
package service
import (
"context"
"log"
"os"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
)
const (
soraCacheCleanupInterval = time.Hour
soraCacheCleanupBatch = 200
)
// SoraCacheCleanupService 负责清理 Sora 视频缓存文件。
type SoraCacheCleanupService struct {
cacheRepo SoraCacheFileRepository
settingService *SettingService
cfg *config.Config
stopCh chan struct{}
stopOnce sync.Once
}
func NewSoraCacheCleanupService(cacheRepo SoraCacheFileRepository, settingService *SettingService, cfg *config.Config) *SoraCacheCleanupService {
return &SoraCacheCleanupService{
cacheRepo: cacheRepo,
settingService: settingService,
cfg: cfg,
stopCh: make(chan struct{}),
}
}
func (s *SoraCacheCleanupService) Start() {
if s == nil || s.cacheRepo == nil {
return
}
go s.cleanupLoop()
}
func (s *SoraCacheCleanupService) Stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
close(s.stopCh)
})
}
func (s *SoraCacheCleanupService) cleanupLoop() {
ticker := time.NewTicker(soraCacheCleanupInterval)
defer ticker.Stop()
s.cleanupOnce()
for {
select {
case <-ticker.C:
s.cleanupOnce()
case <-s.stopCh:
return
}
}
}
func (s *SoraCacheCleanupService) cleanupOnce() {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute)
defer cancel()
if s.cacheRepo == nil {
return
}
cfg := s.getSoraConfig(ctx)
videoDir := strings.TrimSpace(cfg.Cache.VideoDir)
if videoDir == "" {
return
}
maxBytes := cfg.Cache.MaxBytes
if maxBytes <= 0 {
return
}
size, err := dirSize(videoDir)
if err != nil {
log.Printf("[SoraCacheCleanup] 计算目录大小失败: %v", err)
return
}
if size <= maxBytes {
return
}
for size > maxBytes {
entries, err := s.cacheRepo.ListOldest(ctx, soraCacheCleanupBatch)
if err != nil {
log.Printf("[SoraCacheCleanup] 读取缓存记录失败: %v", err)
return
}
if len(entries) == 0 {
log.Printf("[SoraCacheCleanup] 无缓存记录但目录仍超限: size=%d max=%d", size, maxBytes)
return
}
ids := make([]int64, 0, len(entries))
for _, entry := range entries {
if entry == nil {
continue
}
removedSize := entry.SizeBytes
if entry.CachePath != "" {
if info, err := os.Stat(entry.CachePath); err == nil {
if removedSize <= 0 {
removedSize = info.Size()
}
}
if err := os.Remove(entry.CachePath); err != nil && !os.IsNotExist(err) {
log.Printf("[SoraCacheCleanup] 删除缓存文件失败: path=%s err=%v", entry.CachePath, err)
}
}
if entry.ID > 0 {
ids = append(ids, entry.ID)
}
if removedSize > 0 {
size -= removedSize
if size < 0 {
size = 0
}
}
}
if len(ids) > 0 {
if err := s.cacheRepo.DeleteByIDs(ctx, ids); err != nil {
log.Printf("[SoraCacheCleanup] 删除缓存记录失败: %v", err)
}
}
if size > maxBytes {
if refreshed, err := dirSize(videoDir); err == nil {
size = refreshed
}
}
}
}
func (s *SoraCacheCleanupService) getSoraConfig(ctx context.Context) config.SoraConfig {
if s.settingService != nil {
return s.settingService.GetSoraConfig(ctx)
}
if s.cfg != nil {
return s.cfg.Sora
}
return config.SoraConfig{}
}
package service
import (
"context"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path"
"path/filepath"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/uuidv7"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
)
// SoraCacheService 提供 Sora 视频缓存能力。
type SoraCacheService struct {
cfg *config.Config
cacheRepo SoraCacheFileRepository
settingService *SettingService
accountRepo AccountRepository
httpUpstream HTTPUpstream
}
// NewSoraCacheService 创建 SoraCacheService。
func NewSoraCacheService(cfg *config.Config, cacheRepo SoraCacheFileRepository, settingService *SettingService, accountRepo AccountRepository, httpUpstream HTTPUpstream) *SoraCacheService {
return &SoraCacheService{
cfg: cfg,
cacheRepo: cacheRepo,
settingService: settingService,
accountRepo: accountRepo,
httpUpstream: httpUpstream,
}
}
func (s *SoraCacheService) CacheVideo(ctx context.Context, accountID, userID int64, taskID, mediaURL string) (*SoraCacheFile, error) {
cfg := s.getSoraConfig(ctx)
if !cfg.Cache.Enabled {
return nil, nil
}
trimmed := strings.TrimSpace(mediaURL)
if trimmed == "" {
return nil, nil
}
allowedHosts := cfg.Cache.AllowedHosts
useAllowlist := true
if len(allowedHosts) == 0 {
if s.cfg != nil {
allowedHosts = s.cfg.Security.URLAllowlist.UpstreamHosts
useAllowlist = s.cfg.Security.URLAllowlist.Enabled
} else {
useAllowlist = false
}
}
if useAllowlist {
if _, err := urlvalidator.ValidateHTTPSURL(trimmed, urlvalidator.ValidationOptions{
AllowedHosts: allowedHosts,
RequireAllowlist: true,
AllowPrivate: s.cfg != nil && s.cfg.Security.URLAllowlist.AllowPrivateHosts,
}); err != nil {
return nil, fmt.Errorf("缓存下载地址不合法: %w", err)
}
} else {
allowInsecure := false
if s.cfg != nil {
allowInsecure = s.cfg.Security.URLAllowlist.AllowInsecureHTTP
}
if _, err := urlvalidator.ValidateURLFormat(trimmed, allowInsecure); err != nil {
return nil, fmt.Errorf("缓存下载地址不合法: %w", err)
}
}
videoDir := strings.TrimSpace(cfg.Cache.VideoDir)
if videoDir == "" {
return nil, nil
}
if cfg.Cache.MaxBytes > 0 {
size, err := dirSize(videoDir)
if err != nil {
return nil, err
}
if size >= cfg.Cache.MaxBytes {
return nil, nil
}
}
relativeDir := ""
if cfg.Cache.UserDirEnabled && userID > 0 {
relativeDir = fmt.Sprintf("u_%d", userID)
}
targetDir := filepath.Join(videoDir, relativeDir)
if err := os.MkdirAll(targetDir, 0o755); err != nil {
return nil, err
}
uuid, err := uuidv7.New()
if err != nil {
return nil, err
}
name := deriveFileName(trimmed)
if name == "" {
name = "video.mp4"
}
name = sanitizeFileName(name)
filename := uuid + "_" + name
cachePath := filepath.Join(targetDir, filename)
resp, err := s.downloadMedia(ctx, accountID, trimmed, time.Duration(cfg.Timeout)*time.Second)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("缓存下载失败: %d", resp.StatusCode)
}
out, err := os.Create(cachePath)
if err != nil {
return nil, err
}
defer out.Close()
written, err := io.Copy(out, resp.Body)
if err != nil {
return nil, err
}
cacheURL := buildCacheURL(relativeDir, filename)
record := &SoraCacheFile{
TaskID: taskID,
AccountID: accountID,
UserID: userID,
MediaType: "video",
OriginalURL: trimmed,
CachePath: cachePath,
CacheURL: cacheURL,
SizeBytes: written,
CreatedAt: time.Now(),
}
if s.cacheRepo != nil {
if err := s.cacheRepo.Create(ctx, record); err != nil {
return nil, err
}
}
return record, nil
}
func buildCacheURL(relativeDir, filename string) string {
base := "/data/video"
if relativeDir != "" {
return path.Join(base, relativeDir, filename)
}
return path.Join(base, filename)
}
func (s *SoraCacheService) getSoraConfig(ctx context.Context) config.SoraConfig {
if s.settingService != nil {
return s.settingService.GetSoraConfig(ctx)
}
if s.cfg != nil {
return s.cfg.Sora
}
return config.SoraConfig{}
}
func (s *SoraCacheService) downloadMedia(ctx context.Context, accountID int64, mediaURL string, timeout time.Duration) (*http.Response, error) {
if timeout <= 0 {
timeout = 120 * time.Second
}
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", mediaURL, nil)
if err != nil {
return nil, err
}
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36")
if s.httpUpstream == nil {
client := &http.Client{Timeout: timeout}
return client.Do(req)
}
var accountConcurrency int
proxyURL := ""
if s.accountRepo != nil && accountID > 0 {
account, err := s.accountRepo.GetByID(ctx, accountID)
if err == nil && account != nil {
accountConcurrency = account.Concurrency
if account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
}
}
enableTLS := false
if s.cfg != nil {
enableTLS = s.cfg.Gateway.TLSFingerprint.Enabled
}
return s.httpUpstream.DoWithTLS(req, proxyURL, accountID, accountConcurrency, enableTLS)
}
func deriveFileName(rawURL string) string {
parsed, err := url.Parse(rawURL)
if err != nil {
return ""
}
name := path.Base(parsed.Path)
if name == "/" || name == "." {
return ""
}
return name
}
func sanitizeFileName(name string) string {
name = strings.TrimSpace(name)
if name == "" {
return ""
}
sanitized := strings.Map(func(r rune) rune {
switch {
case r >= 'a' && r <= 'z':
return r
case r >= 'A' && r <= 'Z':
return r
case r >= '0' && r <= '9':
return r
case r == '-' || r == '_' || r == '.':
return r
case r == ' ': // 空格替换为下划线
return '_'
default:
return -1
}
}, name)
return strings.TrimLeft(sanitized, ".")
}
package service
import (
"os"
"path/filepath"
)
func dirSize(root string) (int64, error) {
var size int64
err := filepath.WalkDir(root, func(path string, d os.DirEntry, err error) error {
if err != nil {
return err
}
if d.IsDir() {
return nil
}
info, err := d.Info()
if err != nil {
return err
}
size += info.Size()
return nil
})
if err != nil && os.IsNotExist(err) {
return 0, nil
}
return size, err
}
package service
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"math/rand"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/sora"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
)
const (
soraErrorDisableThreshold = 5
maxImageDownloadSize = 20 * 1024 * 1024 // 20MB
maxVideoDownloadSize = 200 * 1024 * 1024 // 200MB
)
var (
ErrSoraAccountMissingToken = errors.New("sora account missing access token")
ErrSoraAccountNotEligible = errors.New("sora account not eligible")
)
// SoraGenerationRequest 表示 Sora 生成请求。
type SoraGenerationRequest struct {
Model string
Prompt string
Image string
Video string
RemixTargetID string
Stream bool
UserID int64
}
// SoraGenerationResult 表示 Sora 生成结果。
type SoraGenerationResult struct {
Content string
MediaType string
ResultURLs []string
TaskID string
}
// SoraGatewayService 处理 Sora 生成流程。
type SoraGatewayService struct {
accountRepo AccountRepository
soraAccountRepo SoraAccountRepository
usageRepo SoraUsageStatRepository
taskRepo SoraTaskRepository
cacheService *SoraCacheService
settingService *SettingService
concurrency *ConcurrencyService
cfg *config.Config
httpUpstream HTTPUpstream
}
// NewSoraGatewayService 创建 SoraGatewayService。
func NewSoraGatewayService(
accountRepo AccountRepository,
soraAccountRepo SoraAccountRepository,
usageRepo SoraUsageStatRepository,
taskRepo SoraTaskRepository,
cacheService *SoraCacheService,
settingService *SettingService,
concurrencyService *ConcurrencyService,
cfg *config.Config,
httpUpstream HTTPUpstream,
) *SoraGatewayService {
return &SoraGatewayService{
accountRepo: accountRepo,
soraAccountRepo: soraAccountRepo,
usageRepo: usageRepo,
taskRepo: taskRepo,
cacheService: cacheService,
settingService: settingService,
concurrency: concurrencyService,
cfg: cfg,
httpUpstream: httpUpstream,
}
}
// ListModels 返回 Sora 模型列表。
func (s *SoraGatewayService) ListModels() []sora.ModelListItem {
return sora.ListModels()
}
// Generate 执行 Sora 生成流程。
func (s *SoraGatewayService) Generate(ctx context.Context, account *Account, req SoraGenerationRequest) (*SoraGenerationResult, error) {
client, cfg := s.getClient(ctx)
if client == nil {
return nil, errors.New("sora client is not configured")
}
modelCfg, ok := sora.ModelConfigs[req.Model]
if !ok {
return nil, fmt.Errorf("unsupported model: %s", req.Model)
}
accessToken, soraAcc, err := s.getAccessToken(ctx, account)
if err != nil {
return nil, err
}
if soraAcc != nil && soraAcc.SoraCooldownUntil != nil && time.Now().Before(*soraAcc.SoraCooldownUntil) {
return nil, ErrSoraAccountNotEligible
}
if modelCfg.RequirePro && !isSoraProAccount(soraAcc) {
return nil, ErrSoraAccountNotEligible
}
if modelCfg.Type == "video" && soraAcc != nil {
if !soraAcc.VideoEnabled || !soraAcc.SoraSupported || soraAcc.IsExpired {
return nil, ErrSoraAccountNotEligible
}
}
if modelCfg.Type == "image" && soraAcc != nil {
if !soraAcc.ImageEnabled || soraAcc.IsExpired {
return nil, ErrSoraAccountNotEligible
}
}
opts := sora.RequestOptions{
AccountID: account.ID,
AccountConcurrency: account.Concurrency,
AccessToken: accessToken,
}
if account.Proxy != nil {
opts.ProxyURL = account.Proxy.URL()
}
releaseFunc, err := s.acquireSoraSlots(ctx, account, soraAcc, modelCfg.Type == "video")
if err != nil {
return nil, err
}
if releaseFunc != nil {
defer releaseFunc()
}
if modelCfg.Type == "prompt_enhance" {
content, err := client.EnhancePrompt(ctx, opts, req.Prompt, modelCfg.ExpansionLevel, modelCfg.DurationS)
if err != nil {
return nil, err
}
return &SoraGenerationResult{Content: content, MediaType: "text"}, nil
}
var mediaID string
if req.Image != "" {
data, err := s.loadImageBytes(ctx, opts, req.Image)
if err != nil {
return nil, err
}
mediaID, err = client.UploadImage(ctx, opts, data, "image.png")
if err != nil {
return nil, err
}
}
if req.Video != "" && modelCfg.Type != "video" {
return nil, errors.New("视频输入仅支持视频模型")
}
if req.Video != "" && req.Image != "" {
return nil, errors.New("不能同时传入 image 与 video")
}
var cleanupCharacter func()
if req.Video != "" && req.RemixTargetID == "" {
username, characterID, err := s.createCharacter(ctx, client, opts, req.Video)
if err != nil {
return nil, err
}
if strings.TrimSpace(req.Prompt) == "" {
return &SoraGenerationResult{
Content: fmt.Sprintf("角色创建成功,角色名@%s", username),
MediaType: "text",
}, nil
}
if username != "" {
req.Prompt = fmt.Sprintf("@%s %s", username, strings.TrimSpace(req.Prompt))
}
if characterID != "" {
cleanupCharacter = func() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
_ = client.DeleteCharacter(ctx, opts, characterID)
}
}
}
if cleanupCharacter != nil {
defer cleanupCharacter()
}
var taskID string
if modelCfg.Type == "image" {
taskID, err = client.GenerateImage(ctx, opts, req.Prompt, modelCfg.Width, modelCfg.Height, mediaID)
} else {
orientation := modelCfg.Orientation
if orientation == "" {
orientation = "landscape"
}
modelName := modelCfg.Model
if modelName == "" {
modelName = "sy_8"
}
size := modelCfg.Size
if size == "" {
size = "small"
}
if req.RemixTargetID != "" {
taskID, err = client.RemixVideo(ctx, opts, req.RemixTargetID, req.Prompt, orientation, modelCfg.NFrames, "")
} else if sora.IsStoryboardPrompt(req.Prompt) {
formatted := sora.FormatStoryboardPrompt(req.Prompt)
taskID, err = client.GenerateStoryboard(ctx, opts, formatted, orientation, modelCfg.NFrames, mediaID, "")
} else {
taskID, err = client.GenerateVideo(ctx, opts, req.Prompt, orientation, modelCfg.NFrames, mediaID, "", modelName, size)
}
}
if err != nil {
return nil, err
}
if s.taskRepo != nil {
_ = s.taskRepo.Create(ctx, &SoraTask{
TaskID: taskID,
AccountID: account.ID,
Model: req.Model,
Prompt: req.Prompt,
Status: "processing",
Progress: 0,
CreatedAt: time.Now(),
})
}
result, err := s.pollResult(ctx, client, cfg, opts, taskID, modelCfg.Type == "video", req)
if err != nil {
if s.taskRepo != nil {
_ = s.taskRepo.UpdateStatus(ctx, taskID, "failed", 0, "", err.Error(), timePtr(time.Now()))
}
consecutive := 0
if s.usageRepo != nil {
consecutive, _ = s.usageRepo.RecordError(ctx, account.ID)
}
if consecutive >= soraErrorDisableThreshold {
_ = s.accountRepo.SetError(ctx, account.ID, "Sora 连续错误次数过多,已自动禁用")
}
return nil, err
}
if s.taskRepo != nil {
payload, _ := json.Marshal(result.ResultURLs)
_ = s.taskRepo.UpdateStatus(ctx, taskID, "completed", 100, string(payload), "", timePtr(time.Now()))
}
if s.usageRepo != nil {
_ = s.usageRepo.RecordSuccess(ctx, account.ID, modelCfg.Type == "video")
}
return result, nil
}
func (s *SoraGatewayService) pollResult(ctx context.Context, client *sora.Client, cfg config.SoraConfig, opts sora.RequestOptions, taskID string, isVideo bool, req SoraGenerationRequest) (*SoraGenerationResult, error) {
if taskID == "" {
return nil, errors.New("missing task id")
}
pollInterval := 2 * time.Second
if cfg.PollInterval > 0 {
pollInterval = time.Duration(cfg.PollInterval*1000) * time.Millisecond
}
timeout := 300 * time.Second
if cfg.Timeout > 0 {
timeout = time.Duration(cfg.Timeout) * time.Second
}
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
if isVideo {
pending, err := client.GetPendingTasks(ctx, opts)
if err == nil {
for _, task := range pending {
if stringFromMap(task, "id") == taskID {
continue
}
}
}
drafts, err := client.GetVideoDrafts(ctx, opts)
if err != nil {
return nil, err
}
items, _ := drafts["items"].([]any)
for _, item := range items {
entry, ok := item.(map[string]any)
if !ok {
continue
}
if stringFromMap(entry, "task_id") != taskID {
continue
}
url := firstNonEmpty(stringFromMap(entry, "downloadable_url"), stringFromMap(entry, "url"))
reason := stringFromMap(entry, "reason_str")
if url == "" {
if reason == "" {
reason = "视频生成失败"
}
return nil, errors.New(reason)
}
finalURL, err := s.handleWatermark(ctx, client, cfg, opts, url, entry, req, opts.AccountID, taskID)
if err != nil {
return nil, err
}
return &SoraGenerationResult{
Content: buildVideoMarkdown(finalURL),
MediaType: "video",
ResultURLs: []string{finalURL},
TaskID: taskID,
}, nil
}
} else {
resp, err := client.GetImageTasks(ctx, opts)
if err != nil {
return nil, err
}
tasks, _ := resp["task_responses"].([]any)
for _, item := range tasks {
entry, ok := item.(map[string]any)
if !ok {
continue
}
if stringFromMap(entry, "id") != taskID {
continue
}
status := stringFromMap(entry, "status")
switch status {
case "succeeded":
urls := extractImageURLs(entry)
if len(urls) == 0 {
return nil, errors.New("image urls empty")
}
content := buildImageMarkdown(urls)
return &SoraGenerationResult{
Content: content,
MediaType: "image",
ResultURLs: urls,
TaskID: taskID,
}, nil
case "failed":
message := stringFromMap(entry, "error_message")
if message == "" {
message = "image generation failed"
}
return nil, errors.New(message)
}
}
}
time.Sleep(pollInterval)
}
return nil, errors.New("generation timeout")
}
func (s *SoraGatewayService) handleWatermark(ctx context.Context, client *sora.Client, cfg config.SoraConfig, opts sora.RequestOptions, url string, entry map[string]any, req SoraGenerationRequest, accountID int64, taskID string) (string, error) {
if !cfg.WatermarkFree.Enabled {
return s.cacheVideo(ctx, url, req, accountID, taskID), nil
}
generationID := stringFromMap(entry, "id")
if generationID == "" {
return s.cacheVideo(ctx, url, req, accountID, taskID), nil
}
postID, err := client.PostVideoForWatermarkFree(ctx, opts, generationID)
if err != nil {
if cfg.WatermarkFree.FallbackOnFailure {
return s.cacheVideo(ctx, url, req, accountID, taskID), nil
}
return "", err
}
if postID == "" {
if cfg.WatermarkFree.FallbackOnFailure {
return s.cacheVideo(ctx, url, req, accountID, taskID), nil
}
return "", errors.New("watermark-free post id empty")
}
var parsedURL string
if cfg.WatermarkFree.ParseMethod == "custom" {
if cfg.WatermarkFree.CustomParseURL == "" || cfg.WatermarkFree.CustomParseToken == "" {
return "", errors.New("custom parse 未配置")
}
parsedURL, err = s.fetchCustomWatermarkURL(ctx, cfg.WatermarkFree.CustomParseURL, cfg.WatermarkFree.CustomParseToken, postID)
if err != nil {
if cfg.WatermarkFree.FallbackOnFailure {
return s.cacheVideo(ctx, url, req, accountID, taskID), nil
}
return "", err
}
} else {
parsedURL = fmt.Sprintf("https://oscdn2.dyysy.com/MP4/%s.mp4", postID)
}
cached := s.cacheVideo(ctx, parsedURL, req, accountID, taskID)
_ = client.DeletePost(ctx, opts, postID)
return cached, nil
}
func (s *SoraGatewayService) cacheVideo(ctx context.Context, url string, req SoraGenerationRequest, accountID int64, taskID string) string {
if s.cacheService == nil {
return url
}
file, err := s.cacheService.CacheVideo(ctx, accountID, req.UserID, taskID, url)
if err != nil || file == nil {
return url
}
return file.CacheURL
}
func (s *SoraGatewayService) getAccessToken(ctx context.Context, account *Account) (string, *SoraAccount, error) {
if account == nil {
return "", nil, errors.New("account is nil")
}
var soraAcc *SoraAccount
if s.soraAccountRepo != nil {
soraAcc, _ = s.soraAccountRepo.GetByAccountID(ctx, account.ID)
}
if soraAcc != nil && soraAcc.AccessToken != "" {
return soraAcc.AccessToken, soraAcc, nil
}
if account.Credentials != nil {
if v, ok := account.Credentials["access_token"].(string); ok && v != "" {
return v, soraAcc, nil
}
if v, ok := account.Credentials["token"].(string); ok && v != "" {
return v, soraAcc, nil
}
}
return "", soraAcc, ErrSoraAccountMissingToken
}
func (s *SoraGatewayService) getClient(ctx context.Context) (*sora.Client, config.SoraConfig) {
cfg := s.getSoraConfig(ctx)
if s.httpUpstream == nil {
return nil, cfg
}
baseURL := strings.TrimSpace(cfg.BaseURL)
if baseURL == "" {
return nil, cfg
}
timeout := time.Duration(cfg.Timeout) * time.Second
if cfg.Timeout <= 0 {
timeout = 120 * time.Second
}
enableTLS := false
if s.cfg != nil {
enableTLS = s.cfg.Gateway.TLSFingerprint.Enabled
}
return sora.NewClient(baseURL, timeout, s.httpUpstream, enableTLS), cfg
}
func decodeBase64(raw string) ([]byte, error) {
data := raw
if idx := strings.Index(raw, "base64,"); idx != -1 {
data = raw[idx+7:]
}
return base64.StdEncoding.DecodeString(data)
}
func extractImageURLs(entry map[string]any) []string {
generations, _ := entry["generations"].([]any)
urls := make([]string, 0, len(generations))
for _, gen := range generations {
m, ok := gen.(map[string]any)
if !ok {
continue
}
if url, ok := m["url"].(string); ok && url != "" {
urls = append(urls, url)
}
}
return urls
}
func buildImageMarkdown(urls []string) string {
parts := make([]string, 0, len(urls))
for _, u := range urls {
parts = append(parts, fmt.Sprintf("![Generated Image](%s)", u))
}
return strings.Join(parts, "\n")
}
func buildVideoMarkdown(url string) string {
return fmt.Sprintf("```html\n<video src='%s' controls></video>\n```", url)
}
func stringFromMap(m map[string]any, key string) string {
if m == nil {
return ""
}
if v, ok := m[key].(string); ok {
return v
}
return ""
}
func firstNonEmpty(values ...string) string {
for _, v := range values {
if strings.TrimSpace(v) != "" {
return v
}
}
return ""
}
func isSoraProAccount(acc *SoraAccount) bool {
if acc == nil {
return false
}
return strings.EqualFold(acc.PlanType, "chatgpt_pro")
}
func timePtr(t time.Time) *time.Time {
return &t
}
// fetchCustomWatermarkURL 使用自定义解析服务获取无水印视频 URL
func (s *SoraGatewayService) fetchCustomWatermarkURL(ctx context.Context, parseURL, parseToken, postID string) (string, error) {
// 使用项目的 URL 校验器验证 parseURL 格式,防止 SSRF 攻击
if _, err := urlvalidator.ValidateHTTPSURL(parseURL, urlvalidator.ValidationOptions{}); err != nil {
return "", fmt.Errorf("无效的解析服务地址: %w", err)
}
payload := map[string]any{
"url": fmt.Sprintf("https://sora.chatgpt.com/p/%s", postID),
"token": parseToken,
}
body, err := json.Marshal(payload)
if err != nil {
return "", err
}
req, err := http.NewRequestWithContext(ctx, "POST", strings.TrimRight(parseURL, "/")+"/get-sora-link", strings.NewReader(string(body)))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
// 复用 httpUpstream,遵守代理和 TLS 配置
enableTLS := false
if s.cfg != nil {
enableTLS = s.cfg.Gateway.TLSFingerprint.Enabled
}
resp, err := s.httpUpstream.DoWithTLS(req, "", 0, 1, enableTLS)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return "", fmt.Errorf("custom parse failed: %d", resp.StatusCode)
}
var parsed map[string]any
if err := json.NewDecoder(resp.Body).Decode(&parsed); err != nil {
return "", err
}
if errMsg, ok := parsed["error"].(string); ok && errMsg != "" {
return "", errors.New(errMsg)
}
if link, ok := parsed["download_link"].(string); ok {
return link, nil
}
return "", errors.New("custom parse response missing download_link")
}
const (
soraSlotImageLock int64 = 1
soraSlotImageLimit int64 = 2
soraSlotVideoLimit int64 = 3
soraDefaultUsername = "character"
)
func (s *SoraGatewayService) CallLogicMode(ctx context.Context) string {
return strings.TrimSpace(s.getSoraConfig(ctx).CallLogicMode)
}
func (s *SoraGatewayService) getSoraConfig(ctx context.Context) config.SoraConfig {
if s.settingService != nil {
return s.settingService.GetSoraConfig(ctx)
}
if s.cfg != nil {
return s.cfg.Sora
}
return config.SoraConfig{}
}
func (s *SoraGatewayService) acquireSoraSlots(ctx context.Context, account *Account, soraAcc *SoraAccount, isVideo bool) (func(), error) {
if s.concurrency == nil || account == nil || soraAcc == nil {
return nil, nil
}
releases := make([]func(), 0, 2)
appendRelease := func(release func()) {
if release != nil {
releases = append(releases, release)
}
}
// 错误时释放所有已获取的槽位
releaseAll := func() {
for _, r := range releases {
r()
}
}
if isVideo {
if soraAcc.VideoConcurrency > 0 {
release, err := s.acquireSoraSlot(ctx, account.ID, soraAcc.VideoConcurrency, soraSlotVideoLimit)
if err != nil {
releaseAll()
return nil, err
}
appendRelease(release)
}
} else {
release, err := s.acquireSoraSlot(ctx, account.ID, 1, soraSlotImageLock)
if err != nil {
releaseAll()
return nil, err
}
appendRelease(release)
if soraAcc.ImageConcurrency > 0 {
release, err := s.acquireSoraSlot(ctx, account.ID, soraAcc.ImageConcurrency, soraSlotImageLimit)
if err != nil {
releaseAll() // 释放已获取的 soraSlotImageLock
return nil, err
}
appendRelease(release)
}
}
if len(releases) == 0 {
return nil, nil
}
return func() {
for _, release := range releases {
release()
}
}, nil
}
func (s *SoraGatewayService) acquireSoraSlot(ctx context.Context, accountID int64, maxConcurrency int, slotType int64) (func(), error) {
if s.concurrency == nil || maxConcurrency <= 0 {
return nil, nil
}
derivedID := soraConcurrencyAccountID(accountID, slotType)
result, err := s.concurrency.AcquireAccountSlot(ctx, derivedID, maxConcurrency)
if err != nil {
return nil, err
}
if !result.Acquired {
return nil, ErrSoraAccountNotEligible
}
return result.ReleaseFunc, nil
}
func soraConcurrencyAccountID(accountID int64, slotType int64) int64 {
if accountID < 0 {
accountID = -accountID
}
return -(accountID*10 + slotType)
}
func (s *SoraGatewayService) createCharacter(ctx context.Context, client *sora.Client, opts sora.RequestOptions, rawVideo string) (string, string, error) {
videoBytes, err := s.loadVideoBytes(ctx, opts, rawVideo)
if err != nil {
return "", "", err
}
cameoID, err := client.UploadCharacterVideo(ctx, opts, videoBytes)
if err != nil {
return "", "", err
}
status, err := s.pollCameoStatus(ctx, client, opts, cameoID)
if err != nil {
return "", "", err
}
username := processCharacterUsername(stringFromMap(status, "username_hint"))
if username == "" {
username = soraDefaultUsername
}
displayName := stringFromMap(status, "display_name_hint")
if displayName == "" {
displayName = "Character"
}
profileURL := stringFromMap(status, "profile_asset_url")
if profileURL == "" {
return "", "", errors.New("profile asset url missing")
}
avatarData, err := client.DownloadCharacterImage(ctx, opts, profileURL)
if err != nil {
return "", "", err
}
assetPointer, err := client.UploadCharacterImage(ctx, opts, avatarData)
if err != nil {
return "", "", err
}
characterID, err := client.FinalizeCharacter(ctx, opts, cameoID, username, displayName, assetPointer)
if err != nil {
return "", "", err
}
if err := client.SetCharacterPublic(ctx, opts, cameoID); err != nil {
return "", "", err
}
return username, characterID, nil
}
func (s *SoraGatewayService) pollCameoStatus(ctx context.Context, client *sora.Client, opts sora.RequestOptions, cameoID string) (map[string]any, error) {
if cameoID == "" {
return nil, errors.New("cameo id empty")
}
timeout := 600 * time.Second
pollInterval := 5 * time.Second
deadline := time.Now().Add(timeout)
consecutiveErrors := 0
maxConsecutiveErrors := 3
for time.Now().Before(deadline) {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
time.Sleep(pollInterval)
status, err := client.GetCameoStatus(ctx, opts, cameoID)
if err != nil {
consecutiveErrors++
if consecutiveErrors >= maxConsecutiveErrors {
return nil, err
}
continue
}
consecutiveErrors = 0
statusValue := stringFromMap(status, "status")
statusMessage := stringFromMap(status, "status_message")
if statusValue == "failed" {
if statusMessage == "" {
statusMessage = "角色创建失败"
}
return nil, fmt.Errorf("角色创建失败: %s", statusMessage)
}
if strings.EqualFold(statusMessage, "Completed") || strings.EqualFold(statusValue, "finalized") {
return status, nil
}
}
return nil, errors.New("角色创建超时")
}
func (s *SoraGatewayService) loadVideoBytes(ctx context.Context, opts sora.RequestOptions, rawVideo string) ([]byte, error) {
trimmed := strings.TrimSpace(rawVideo)
if trimmed == "" {
return nil, errors.New("video data is empty")
}
if looksLikeURL(trimmed) {
if err := s.validateMediaURL(trimmed); err != nil {
return nil, err
}
return s.downloadMedia(ctx, opts, trimmed, maxVideoDownloadSize)
}
return decodeBase64(trimmed)
}
func (s *SoraGatewayService) loadImageBytes(ctx context.Context, opts sora.RequestOptions, rawImage string) ([]byte, error) {
trimmed := strings.TrimSpace(rawImage)
if trimmed == "" {
return nil, errors.New("image data is empty")
}
if looksLikeURL(trimmed) {
if err := s.validateMediaURL(trimmed); err != nil {
return nil, err
}
return s.downloadMedia(ctx, opts, trimmed, maxImageDownloadSize)
}
return decodeBase64(trimmed)
}
func (s *SoraGatewayService) validateMediaURL(rawURL string) error {
cfg := s.cfg
if cfg == nil {
return nil
}
if cfg.Security.URLAllowlist.Enabled {
_, err := urlvalidator.ValidateHTTPSURL(rawURL, urlvalidator.ValidationOptions{
AllowedHosts: cfg.Security.URLAllowlist.UpstreamHosts,
RequireAllowlist: true,
AllowPrivate: cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
return fmt.Errorf("媒体地址不合法: %w", err)
}
return nil
}
if _, err := urlvalidator.ValidateURLFormat(rawURL, cfg.Security.URLAllowlist.AllowInsecureHTTP); err != nil {
return fmt.Errorf("媒体地址不合法: %w", err)
}
return nil
}
func (s *SoraGatewayService) downloadMedia(ctx context.Context, opts sora.RequestOptions, mediaURL string, maxSize int64) ([]byte, error) {
if s.httpUpstream == nil {
return nil, errors.New("upstream is nil")
}
req, err := http.NewRequestWithContext(ctx, "GET", mediaURL, nil)
if err != nil {
return nil, err
}
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36")
enableTLS := false
if s.cfg != nil {
enableTLS = s.cfg.Gateway.TLSFingerprint.Enabled
}
resp, err := s.httpUpstream.DoWithTLS(req, opts.ProxyURL, opts.AccountID, opts.AccountConcurrency, enableTLS)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("下载失败: %d", resp.StatusCode)
}
// 使用 LimitReader 限制最大读取大小,防止 DoS 攻击
limitedReader := io.LimitReader(resp.Body, maxSize+1)
data, err := io.ReadAll(limitedReader)
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
// 检查是否超过大小限制
if int64(len(data)) > maxSize {
return nil, fmt.Errorf("媒体文件过大 (最大 %d 字节, 实际 %d 字节)", maxSize, len(data))
}
return data, nil
}
func processCharacterUsername(usernameHint string) string {
trimmed := strings.TrimSpace(usernameHint)
if trimmed == "" {
return ""
}
base := trimmed
if idx := strings.LastIndex(trimmed, "."); idx != -1 && idx+1 < len(trimmed) {
base = trimmed[idx+1:]
}
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
return fmt.Sprintf("%s%d", base, rng.Intn(900)+100)
}
func looksLikeURL(value string) bool {
trimmed := strings.ToLower(strings.TrimSpace(value))
return strings.HasPrefix(trimmed, "http://") || strings.HasPrefix(trimmed, "https://")
}
package service
import (
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
// SoraAccount 表示 Sora 账号扩展信息。
type SoraAccount struct {
AccountID int64
AccessToken string
SessionToken string
RefreshToken string
ClientID string
Email string
Username string
Remark string
UseCount int
PlanType string
PlanTitle string
SubscriptionEnd *time.Time
SoraSupported bool
SoraInviteCode string
SoraRedeemedCount int
SoraRemainingCount int
SoraTotalCount int
SoraCooldownUntil *time.Time
CooledUntil *time.Time
ImageEnabled bool
VideoEnabled bool
ImageConcurrency int
VideoConcurrency int
IsExpired bool
CreatedAt time.Time
UpdatedAt time.Time
}
// SoraUsageStat 表示 Sora 调用统计。
type SoraUsageStat struct {
AccountID int64
ImageCount int
VideoCount int
ErrorCount int
LastErrorAt *time.Time
TodayImageCount int
TodayVideoCount int
TodayErrorCount int
TodayDate *time.Time
ConsecutiveErrorCount int
CreatedAt time.Time
UpdatedAt time.Time
}
// SoraTask 表示 Sora 任务记录。
type SoraTask struct {
TaskID string
AccountID int64
Model string
Prompt string
Status string
Progress float64
ResultURLs string
ErrorMessage string
RetryCount int
CreatedAt time.Time
CompletedAt *time.Time
}
// SoraCacheFile 表示 Sora 缓存文件记录。
type SoraCacheFile struct {
ID int64
TaskID string
AccountID int64
UserID int64
MediaType string
OriginalURL string
CachePath string
CacheURL string
SizeBytes int64
CreatedAt time.Time
}
// SoraAccountRepository 定义 Sora 账号仓储接口。
type SoraAccountRepository interface {
GetByAccountID(ctx context.Context, accountID int64) (*SoraAccount, error)
GetByAccountIDs(ctx context.Context, accountIDs []int64) (map[int64]*SoraAccount, error)
Upsert(ctx context.Context, accountID int64, updates map[string]any) error
}
// SoraUsageStatRepository 定义 Sora 调用统计仓储接口。
type SoraUsageStatRepository interface {
RecordSuccess(ctx context.Context, accountID int64, isVideo bool) error
RecordError(ctx context.Context, accountID int64) (int, error)
ResetConsecutiveErrors(ctx context.Context, accountID int64) error
GetByAccountID(ctx context.Context, accountID int64) (*SoraUsageStat, error)
GetByAccountIDs(ctx context.Context, accountIDs []int64) (map[int64]*SoraUsageStat, error)
List(ctx context.Context, params pagination.PaginationParams) ([]*SoraUsageStat, *pagination.PaginationResult, error)
}
// SoraTaskRepository 定义 Sora 任务仓储接口。
type SoraTaskRepository interface {
Create(ctx context.Context, task *SoraTask) error
UpdateStatus(ctx context.Context, taskID string, status string, progress float64, resultURLs string, errorMessage string, completedAt *time.Time) error
}
// SoraCacheFileRepository 定义 Sora 缓存文件仓储接口。
type SoraCacheFileRepository interface {
Create(ctx context.Context, file *SoraCacheFile) error
ListOldest(ctx context.Context, limit int) ([]*SoraCacheFile, error)
DeleteByIDs(ctx context.Context, ids []int64) error
}
package service
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
)
const defaultSoraClientID = "app_LlGpXReQgckcGGUo2JrYvtJK"
// SoraTokenRefreshService handles Sora access token refresh.
type SoraTokenRefreshService struct {
accountRepo AccountRepository
soraAccountRepo SoraAccountRepository
settingService *SettingService
httpUpstream HTTPUpstream
cfg *config.Config
stopCh chan struct{}
stopOnce sync.Once
}
func NewSoraTokenRefreshService(
accountRepo AccountRepository,
soraAccountRepo SoraAccountRepository,
settingService *SettingService,
httpUpstream HTTPUpstream,
cfg *config.Config,
) *SoraTokenRefreshService {
return &SoraTokenRefreshService{
accountRepo: accountRepo,
soraAccountRepo: soraAccountRepo,
settingService: settingService,
httpUpstream: httpUpstream,
cfg: cfg,
stopCh: make(chan struct{}),
}
}
func (s *SoraTokenRefreshService) Start() {
if s == nil {
return
}
go s.refreshLoop()
}
func (s *SoraTokenRefreshService) Stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
close(s.stopCh)
})
}
func (s *SoraTokenRefreshService) refreshLoop() {
for {
wait := s.nextRunDelay()
timer := time.NewTimer(wait)
select {
case <-timer.C:
s.refreshOnce()
case <-s.stopCh:
timer.Stop()
return
}
}
}
func (s *SoraTokenRefreshService) refreshOnce() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
defer cancel()
if !s.isEnabled(ctx) {
log.Println("[SoraTokenRefresh] disabled by settings")
return
}
if s.accountRepo == nil || s.soraAccountRepo == nil {
log.Println("[SoraTokenRefresh] repository not configured")
return
}
accounts, err := s.accountRepo.ListByPlatform(ctx, PlatformSora)
if err != nil {
log.Printf("[SoraTokenRefresh] list accounts failed: %v", err)
return
}
if len(accounts) == 0 {
log.Println("[SoraTokenRefresh] no sora accounts")
return
}
ids := make([]int64, 0, len(accounts))
accountMap := make(map[int64]*Account, len(accounts))
for i := range accounts {
acc := accounts[i]
ids = append(ids, acc.ID)
accountMap[acc.ID] = &acc
}
accountExtras, err := s.soraAccountRepo.GetByAccountIDs(ctx, ids)
if err != nil {
log.Printf("[SoraTokenRefresh] load sora accounts failed: %v", err)
return
}
success := 0
failed := 0
skipped := 0
for accountID, account := range accountMap {
extra := accountExtras[accountID]
if extra == nil {
skipped++
continue
}
result, err := s.refreshForAccount(ctx, account, extra)
if err != nil {
failed++
log.Printf("[SoraTokenRefresh] account %d refresh failed: %v", accountID, err)
continue
}
if result == nil {
skipped++
continue
}
updates := map[string]any{
"access_token": result.AccessToken,
}
if result.RefreshToken != "" {
updates["refresh_token"] = result.RefreshToken
}
if result.Email != "" {
updates["email"] = result.Email
}
if err := s.soraAccountRepo.Upsert(ctx, accountID, updates); err != nil {
failed++
log.Printf("[SoraTokenRefresh] account %d update failed: %v", accountID, err)
continue
}
success++
}
log.Printf("[SoraTokenRefresh] done: success=%d failed=%d skipped=%d", success, failed, skipped)
}
func (s *SoraTokenRefreshService) refreshForAccount(ctx context.Context, account *Account, extra *SoraAccount) (*soraRefreshResult, error) {
if extra == nil {
return nil, nil
}
if strings.TrimSpace(extra.SessionToken) == "" && strings.TrimSpace(extra.RefreshToken) == "" {
return nil, nil
}
if extra.SessionToken != "" {
result, err := s.refreshWithSessionToken(ctx, account, extra.SessionToken)
if err == nil && result != nil && result.AccessToken != "" {
return result, nil
}
if strings.TrimSpace(extra.RefreshToken) == "" {
return nil, err
}
}
clientID := strings.TrimSpace(extra.ClientID)
if clientID == "" {
clientID = defaultSoraClientID
}
return s.refreshWithRefreshToken(ctx, account, extra.RefreshToken, clientID)
}
type soraRefreshResult struct {
AccessToken string
RefreshToken string
Email string
}
type soraSessionResponse struct {
AccessToken string `json:"accessToken"`
User struct {
Email string `json:"email"`
} `json:"user"`
}
type soraRefreshResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
}
func (s *SoraTokenRefreshService) refreshWithSessionToken(ctx context.Context, account *Account, sessionToken string) (*soraRefreshResult, error) {
if s.httpUpstream == nil {
return nil, fmt.Errorf("upstream not configured")
}
req, err := http.NewRequestWithContext(ctx, "GET", "https://sora.chatgpt.com/api/auth/session", nil)
if err != nil {
return nil, err
}
req.Header.Set("Cookie", "__Secure-next-auth.session-token="+sessionToken)
req.Header.Set("Accept", "application/json")
req.Header.Set("Origin", "https://sora.chatgpt.com")
req.Header.Set("Referer", "https://sora.chatgpt.com/")
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36")
enableTLS := false
if s.cfg != nil {
enableTLS = s.cfg.Gateway.TLSFingerprint.Enabled
}
proxyURL := ""
accountConcurrency := 0
accountID := int64(0)
if account != nil {
accountID = account.ID
accountConcurrency = account.Concurrency
if account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
}
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, accountID, accountConcurrency, enableTLS)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("session refresh failed: %d", resp.StatusCode)
}
var payload soraSessionResponse
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
return nil, err
}
if payload.AccessToken == "" {
return nil, errors.New("session refresh missing access token")
}
return &soraRefreshResult{AccessToken: payload.AccessToken, Email: payload.User.Email}, nil
}
func (s *SoraTokenRefreshService) refreshWithRefreshToken(ctx context.Context, account *Account, refreshToken, clientID string) (*soraRefreshResult, error) {
if s.httpUpstream == nil {
return nil, fmt.Errorf("upstream not configured")
}
payload := map[string]any{
"client_id": clientID,
"grant_type": "refresh_token",
"redirect_uri": "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback",
"refresh_token": refreshToken,
}
body, err := json.Marshal(payload)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, "POST", "https://auth.openai.com/oauth/token", bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36")
enableTLS := false
if s.cfg != nil {
enableTLS = s.cfg.Gateway.TLSFingerprint.Enabled
}
proxyURL := ""
accountConcurrency := 0
accountID := int64(0)
if account != nil {
accountID = account.ID
accountConcurrency = account.Concurrency
if account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
}
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, accountID, accountConcurrency, enableTLS)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("refresh token failed: %d", resp.StatusCode)
}
var payloadResp soraRefreshResponse
if err := json.NewDecoder(resp.Body).Decode(&payloadResp); err != nil {
return nil, err
}
if payloadResp.AccessToken == "" {
return nil, errors.New("refresh token missing access token")
}
return &soraRefreshResult{AccessToken: payloadResp.AccessToken, RefreshToken: payloadResp.RefreshToken}, nil
}
func (s *SoraTokenRefreshService) nextRunDelay() time.Duration {
location := time.Local
if s.cfg != nil && strings.TrimSpace(s.cfg.Timezone) != "" {
if tz, err := time.LoadLocation(strings.TrimSpace(s.cfg.Timezone)); err == nil {
location = tz
}
}
now := time.Now().In(location)
next := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, location).Add(24 * time.Hour)
return time.Until(next)
}
func (s *SoraTokenRefreshService) isEnabled(ctx context.Context) bool {
if s.settingService == nil {
return s.cfg != nil && s.cfg.Sora.TokenRefresh.Enabled
}
cfg := s.settingService.GetSoraConfig(ctx)
return cfg.TokenRefresh.Enabled
}
...@@ -51,6 +51,30 @@ func ProvideTokenRefreshService( ...@@ -51,6 +51,30 @@ func ProvideTokenRefreshService(
return svc return svc
} }
// ProvideSoraTokenRefreshService creates and starts SoraTokenRefreshService.
func ProvideSoraTokenRefreshService(
accountRepo AccountRepository,
soraAccountRepo SoraAccountRepository,
settingService *SettingService,
httpUpstream HTTPUpstream,
cfg *config.Config,
) *SoraTokenRefreshService {
svc := NewSoraTokenRefreshService(accountRepo, soraAccountRepo, settingService, httpUpstream, cfg)
svc.Start()
return svc
}
// ProvideSoraCacheCleanupService creates and starts SoraCacheCleanupService.
func ProvideSoraCacheCleanupService(
cacheRepo SoraCacheFileRepository,
settingService *SettingService,
cfg *config.Config,
) *SoraCacheCleanupService {
svc := NewSoraCacheCleanupService(cacheRepo, settingService, cfg)
svc.Start()
return svc
}
// ProvideDashboardAggregationService 创建并启动仪表盘聚合服务 // ProvideDashboardAggregationService 创建并启动仪表盘聚合服务
func ProvideDashboardAggregationService(repo DashboardAggregationRepository, timingWheel *TimingWheelService, cfg *config.Config) *DashboardAggregationService { func ProvideDashboardAggregationService(repo DashboardAggregationRepository, timingWheel *TimingWheelService, cfg *config.Config) *DashboardAggregationService {
svc := NewDashboardAggregationService(repo, timingWheel, cfg) svc := NewDashboardAggregationService(repo, timingWheel, cfg)
...@@ -222,6 +246,8 @@ var ProviderSet = wire.NewSet( ...@@ -222,6 +246,8 @@ var ProviderSet = wire.NewSet(
NewAdminService, NewAdminService,
NewGatewayService, NewGatewayService,
NewOpenAIGatewayService, NewOpenAIGatewayService,
NewSoraCacheService,
NewSoraGatewayService,
NewOAuthService, NewOAuthService,
NewOpenAIOAuthService, NewOpenAIOAuthService,
NewGeminiOAuthService, NewGeminiOAuthService,
...@@ -255,6 +281,8 @@ var ProviderSet = wire.NewSet( ...@@ -255,6 +281,8 @@ var ProviderSet = wire.NewSet(
NewCRSSyncService, NewCRSSyncService,
ProvideUpdateService, ProvideUpdateService,
ProvideTokenRefreshService, ProvideTokenRefreshService,
ProvideSoraTokenRefreshService,
ProvideSoraCacheCleanupService,
ProvideAccountExpiryService, ProvideAccountExpiryService,
ProvideTimingWheelService, ProvideTimingWheelService,
ProvideDashboardAggregationService, ProvideDashboardAggregationService,
......
-- Add Sora platform tables
CREATE TABLE IF NOT EXISTS sora_accounts (
id BIGSERIAL PRIMARY KEY,
account_id BIGINT NOT NULL UNIQUE,
access_token TEXT,
session_token TEXT,
refresh_token TEXT,
client_id TEXT,
email TEXT,
username TEXT,
remark TEXT,
use_count INT DEFAULT 0,
plan_type TEXT,
plan_title TEXT,
subscription_end TIMESTAMPTZ,
sora_supported BOOLEAN DEFAULT FALSE,
sora_invite_code TEXT,
sora_redeemed_count INT DEFAULT 0,
sora_remaining_count INT DEFAULT 0,
sora_total_count INT DEFAULT 0,
sora_cooldown_until TIMESTAMPTZ,
cooled_until TIMESTAMPTZ,
image_enabled BOOLEAN DEFAULT TRUE,
video_enabled BOOLEAN DEFAULT TRUE,
image_concurrency INT DEFAULT -1,
video_concurrency INT DEFAULT -1,
is_expired BOOLEAN DEFAULT FALSE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
FOREIGN KEY (account_id) REFERENCES accounts(id)
);
CREATE INDEX IF NOT EXISTS idx_sora_accounts_plan_type ON sora_accounts (plan_type);
CREATE INDEX IF NOT EXISTS idx_sora_accounts_sora_supported ON sora_accounts (sora_supported);
CREATE INDEX IF NOT EXISTS idx_sora_accounts_image_enabled ON sora_accounts (image_enabled);
CREATE INDEX IF NOT EXISTS idx_sora_accounts_video_enabled ON sora_accounts (video_enabled);
CREATE TABLE IF NOT EXISTS sora_usage_stats (
id BIGSERIAL PRIMARY KEY,
account_id BIGINT NOT NULL UNIQUE,
image_count INT DEFAULT 0,
video_count INT DEFAULT 0,
error_count INT DEFAULT 0,
last_error_at TIMESTAMPTZ,
today_image_count INT DEFAULT 0,
today_video_count INT DEFAULT 0,
today_error_count INT DEFAULT 0,
today_date DATE,
consecutive_error_count INT DEFAULT 0,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
FOREIGN KEY (account_id) REFERENCES accounts(id)
);
CREATE INDEX IF NOT EXISTS idx_sora_usage_stats_today_date ON sora_usage_stats (today_date);
CREATE TABLE IF NOT EXISTS sora_tasks (
id BIGSERIAL PRIMARY KEY,
task_id TEXT NOT NULL UNIQUE,
account_id BIGINT NOT NULL,
model TEXT NOT NULL,
prompt TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'processing',
progress DOUBLE PRECISION DEFAULT 0,
result_urls TEXT,
error_message TEXT,
retry_count INT DEFAULT 0,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
completed_at TIMESTAMPTZ,
FOREIGN KEY (account_id) REFERENCES accounts(id)
);
CREATE INDEX IF NOT EXISTS idx_sora_tasks_account_id ON sora_tasks (account_id);
CREATE INDEX IF NOT EXISTS idx_sora_tasks_status ON sora_tasks (status);
CREATE TABLE IF NOT EXISTS sora_cache_files (
id BIGSERIAL PRIMARY KEY,
task_id TEXT,
account_id BIGINT NOT NULL,
user_id BIGINT NOT NULL,
media_type TEXT NOT NULL,
original_url TEXT NOT NULL,
cache_path TEXT NOT NULL,
cache_url TEXT NOT NULL,
size_bytes BIGINT DEFAULT 0,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
FOREIGN KEY (account_id) REFERENCES accounts(id),
FOREIGN KEY (user_id) REFERENCES users(id)
);
CREATE INDEX IF NOT EXISTS idx_sora_cache_files_account_id ON sora_cache_files (account_id);
CREATE INDEX IF NOT EXISTS idx_sora_cache_files_user_id ON sora_cache_files (user_id);
CREATE INDEX IF NOT EXISTS idx_sora_cache_files_media_type ON sora_cache_files (media_type);
...@@ -525,3 +525,63 @@ gemini: ...@@ -525,3 +525,63 @@ gemini:
# Cooldown time (minutes) after hitting quota # Cooldown time (minutes) after hitting quota
# 达到配额后的冷却时间(分钟) # 达到配额后的冷却时间(分钟)
cooldown_minutes: 5 cooldown_minutes: 5
# =============================================================================
# Sora
# Sora 配置
# =============================================================================
sora:
# Sora Backend API base URL
# Sora 后端 API 基础地址
base_url: "https://sora.chatgpt.com/backend"
# Request timeout in seconds
# 请求超时时间(秒)
timeout: 120
# Max retry attempts for upstream requests
# 上游请求最大重试次数
max_retries: 3
# Poll interval in seconds for task status
# 任务状态轮询间隔(秒)
poll_interval: 2.5
# Call logic mode: default/native/proxy (default keeps current behavior)
# 调用模式:default/native/proxy(default 保持当前默认策略)
call_logic_mode: "default"
cache:
# Enable media caching
# 是否启用媒体缓存
enabled: false
# Base cache directory (temporary files, intermediate downloads)
# 缓存根目录(临时文件、中间下载)
base_dir: "tmp/sora"
# Video cache directory (separated from images)
# 视频缓存目录(与图片分离)
video_dir: "data/video"
# Max bytes for cache dir (0 = unlimited)
# 缓存目录最大字节数(0 = 不限制)
max_bytes: 0
# Allowed hosts for cache download (empty -> fallback to global allowlist)
# 缓存下载白名单域名(为空则回退全局 allowlist)
allowed_hosts: []
# Enable user directory isolation (data/video/u_{user_id})
# 是否按用户隔离目录(data/video/u_{user_id})
user_dir_enabled: true
watermark_free:
# Enable watermark-free flow
# 是否启用去水印流程
enabled: false
# Parse method: third_party/custom
# 解析方式:third_party/custom
parse_method: "third_party"
# Custom parse server URL
# 自定义解析服务 URL
custom_parse_url: ""
# Custom parse token
# 自定义解析 token
custom_parse_token: ""
# Fallback to watermark video when parse fails
# 去水印失败时是否回退原视频
fallback_on_failure: true
token_refresh:
# Enable periodic token refresh
# 是否启用定时刷新
enabled: false
...@@ -194,6 +194,28 @@ GEMINI_OAUTH_SCOPES= ...@@ -194,6 +194,28 @@ GEMINI_OAUTH_SCOPES=
# GEMINI_QUOTA_POLICY={"tiers":{"LEGACY":{"pro_rpd":50,"flash_rpd":1500,"cooldown_minutes":30},"PRO":{"pro_rpd":1500,"flash_rpd":4000,"cooldown_minutes":5},"ULTRA":{"pro_rpd":2000,"flash_rpd":0,"cooldown_minutes":5}}} # GEMINI_QUOTA_POLICY={"tiers":{"LEGACY":{"pro_rpd":50,"flash_rpd":1500,"cooldown_minutes":30},"PRO":{"pro_rpd":1500,"flash_rpd":4000,"cooldown_minutes":5},"ULTRA":{"pro_rpd":2000,"flash_rpd":0,"cooldown_minutes":5}}}
GEMINI_QUOTA_POLICY= GEMINI_QUOTA_POLICY=
# -----------------------------------------------------------------------------
# Sora Configuration (OPTIONAL)
# -----------------------------------------------------------------------------
SORA_BASE_URL=https://sora.chatgpt.com/backend
SORA_TIMEOUT=120
SORA_MAX_RETRIES=3
SORA_POLL_INTERVAL=2.5
SORA_CALL_LOGIC_MODE=default
SORA_CACHE_ENABLED=false
SORA_CACHE_BASE_DIR=tmp/sora
SORA_CACHE_VIDEO_DIR=data/video
SORA_CACHE_MAX_BYTES=0
# Comma-separated hosts (leave empty to use global allowlist)
SORA_CACHE_ALLOWED_HOSTS=
SORA_CACHE_USER_DIR_ENABLED=true
SORA_WATERMARK_FREE_ENABLED=false
SORA_WATERMARK_FREE_PARSE_METHOD=third_party
SORA_WATERMARK_FREE_CUSTOM_PARSE_URL=
SORA_WATERMARK_FREE_CUSTOM_PARSE_TOKEN=
SORA_WATERMARK_FREE_FALLBACK_ON_FAILURE=true
SORA_TOKEN_REFRESH_ENABLED=false
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Ops Monitoring Configuration (运维监控配置) # Ops Monitoring Configuration (运维监控配置)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
......
...@@ -583,6 +583,66 @@ gemini: ...@@ -583,6 +583,66 @@ gemini:
# 达到配额后的冷却时间(分钟) # 达到配额后的冷却时间(分钟)
cooldown_minutes: 5 cooldown_minutes: 5
# =============================================================================
# Sora
# Sora 配置
# =============================================================================
sora:
# Sora Backend API base URL
# Sora 后端 API 基础地址
base_url: "https://sora.chatgpt.com/backend"
# Request timeout in seconds
# 请求超时时间(秒)
timeout: 120
# Max retry attempts for upstream requests
# 上游请求最大重试次数
max_retries: 3
# Poll interval in seconds for task status
# 任务状态轮询间隔(秒)
poll_interval: 2.5
# Call logic mode: default/native/proxy (default keeps current behavior)
# 调用模式:default/native/proxy(default 保持当前默认策略)
call_logic_mode: "default"
cache:
# Enable media caching
# 是否启用媒体缓存
enabled: false
# Base cache directory (temporary files, intermediate downloads)
# 缓存根目录(临时文件、中间下载)
base_dir: "tmp/sora"
# Video cache directory (separated from images)
# 视频缓存目录(与图片分离)
video_dir: "data/video"
# Max bytes for cache dir (0 = unlimited)
# 缓存目录最大字节数(0 = 不限制)
max_bytes: 0
# Allowed hosts for cache download (empty -> fallback to global allowlist)
# 缓存下载白名单域名(为空则回退全局 allowlist)
allowed_hosts: []
# Enable user directory isolation (data/video/u_{user_id})
# 是否按用户隔离目录(data/video/u_{user_id})
user_dir_enabled: true
watermark_free:
# Enable watermark-free flow
# 是否启用去水印流程
enabled: false
# Parse method: third_party/custom
# 解析方式:third_party/custom
parse_method: "third_party"
# Custom parse server URL
# 自定义解析服务 URL
custom_parse_url: ""
# Custom parse token
# 自定义解析 token
custom_parse_token: ""
# Fallback to watermark video when parse fails
# 去水印失败时是否回退原视频
fallback_on_failure: true
token_refresh:
# Enable periodic token refresh
# 是否启用定时刷新
enabled: false
# ============================================================================= # =============================================================================
# Update Configuration (在线更新配置) # Update Configuration (在线更新配置)
# ============================================================================= # =============================================================================
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment