Commit 399dd78b authored by yangjianbo's avatar yangjianbo
Browse files

feat(Sora): 直连生成并移除sora2api依赖

实现直连 Sora 客户端、媒体落地与清理策略\n更新网关与前端配置以支持 Sora 平台\n补齐单元测试与契约测试,新增 curl 测试脚本\n\n测试: go test ./... -tags=unit
parent 78d0ca37
......@@ -67,6 +67,7 @@ func provideCleanup(
opsAlertEvaluator *service.OpsAlertEvaluatorService,
opsCleanup *service.OpsCleanupService,
opsScheduledReport *service.OpsScheduledReportService,
soraMediaCleanup *service.SoraMediaCleanupService,
schedulerSnapshot *service.SchedulerSnapshotService,
tokenRefresh *service.TokenRefreshService,
accountExpiry *service.AccountExpiryService,
......@@ -100,6 +101,12 @@ func provideCleanup(
}
return nil
}},
{"SoraMediaCleanupService", func() error {
if soraMediaCleanup != nil {
soraMediaCleanup.Stop()
}
return nil
}},
{"OpsAlertEvaluatorService", func() error {
if opsAlertEvaluator != nil {
opsAlertEvaluator.Stop()
......
......@@ -87,12 +87,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
schedulerCache := repository.NewSchedulerCache(redisClient)
accountRepository := repository.NewAccountRepository(client, db, schedulerCache)
soraAccountRepository := repository.NewSoraAccountRepository(db)
sora2APIService := service.NewSora2APIService(configConfig)
sora2APISyncService := service.NewSora2APISyncService(sora2APIService, accountRepository)
proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, sora2APISyncService, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator)
adminUserHandler := admin.NewUserHandler(adminService)
groupHandler := admin.NewGroupHandler(adminService)
claudeOAuthClient := repository.NewClaudeOAuthClient()
......@@ -164,11 +162,12 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
userAttributeValueRepository := repository.NewUserAttributeValueRepository(client)
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService)
modelHandler := admin.NewModelHandler(sora2APIService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, modelHandler)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, sora2APIService, concurrencyService, billingCacheService, configConfig)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, configConfig)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, configConfig)
soraGatewayService := service.NewSoraGatewayService(sora2APIService, httpUpstream, rateLimitService, configConfig)
soraDirectClient := service.NewSoraDirectClient(configConfig, httpUpstream, openAITokenProvider)
soraMediaStorage := service.ProvideSoraMediaStorage(configConfig)
soraGatewayService := service.NewSoraGatewayService(soraDirectClient, soraMediaStorage, rateLimitService, configConfig)
soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, configConfig)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, handlerSettingHandler)
......@@ -182,9 +181,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, sora2APISyncService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig)
soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
application := &Application{
Server: httpServer,
Cleanup: v,
......@@ -214,6 +214,7 @@ func provideCleanup(
opsAlertEvaluator *service.OpsAlertEvaluatorService,
opsCleanup *service.OpsCleanupService,
opsScheduledReport *service.OpsScheduledReportService,
soraMediaCleanup *service.SoraMediaCleanupService,
schedulerSnapshot *service.SchedulerSnapshotService,
tokenRefresh *service.TokenRefreshService,
accountExpiry *service.AccountExpiryService,
......@@ -246,6 +247,12 @@ func provideCleanup(
}
return nil
}},
{"SoraMediaCleanupService", func() error {
if soraMediaCleanup != nil {
soraMediaCleanup.Stop()
}
return nil
}},
{"OpsAlertEvaluatorService", func() error {
if opsAlertEvaluator != nil {
opsAlertEvaluator.Stop()
......
......@@ -58,7 +58,7 @@ type Config struct {
UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"`
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
Sora2API Sora2APIConfig `mapstructure:"sora2api"`
Sora SoraConfig `mapstructure:"sora"`
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
Gemini GeminiConfig `mapstructure:"gemini"`
......@@ -205,22 +205,40 @@ type ConcurrencyConfig struct {
PingInterval int `mapstructure:"ping_interval"`
}
// Sora2APIConfig Sora2API 服务配置
type Sora2APIConfig struct {
// BaseURL Sora2API 服务地址(例如 http://localhost:8000)
BaseURL string `mapstructure:"base_url"`
// APIKey Sora2API OpenAI 兼容接口的 API Key
APIKey string `mapstructure:"api_key"`
// AdminUsername 管理员用户名(用于 token 同步)
AdminUsername string `mapstructure:"admin_username"`
// AdminPassword 管理员密码(用于 token 同步)
AdminPassword string `mapstructure:"admin_password"`
// AdminTokenTTLSeconds 管理员 Token 缓存时长(秒)
AdminTokenTTLSeconds int `mapstructure:"admin_token_ttl_seconds"`
// AdminTimeoutSeconds 管理接口请求超时(秒)
AdminTimeoutSeconds int `mapstructure:"admin_timeout_seconds"`
// TokenImportMode token 导入模式:at/offline
TokenImportMode string `mapstructure:"token_import_mode"`
// SoraConfig 直连 Sora 配置
type SoraConfig struct {
Client SoraClientConfig `mapstructure:"client"`
Storage SoraStorageConfig `mapstructure:"storage"`
}
// SoraClientConfig 直连 Sora 客户端配置
type SoraClientConfig struct {
BaseURL string `mapstructure:"base_url"`
TimeoutSeconds int `mapstructure:"timeout_seconds"`
MaxRetries int `mapstructure:"max_retries"`
PollIntervalSeconds int `mapstructure:"poll_interval_seconds"`
MaxPollAttempts int `mapstructure:"max_poll_attempts"`
Debug bool `mapstructure:"debug"`
Headers map[string]string `mapstructure:"headers"`
UserAgent string `mapstructure:"user_agent"`
DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"`
}
// SoraStorageConfig 媒体存储配置
type SoraStorageConfig struct {
Type string `mapstructure:"type"`
LocalPath string `mapstructure:"local_path"`
FallbackToUpstream bool `mapstructure:"fallback_to_upstream"`
MaxConcurrentDownloads int `mapstructure:"max_concurrent_downloads"`
Debug bool `mapstructure:"debug"`
Cleanup SoraStorageCleanupConfig `mapstructure:"cleanup"`
}
// SoraStorageCleanupConfig 媒体清理配置
type SoraStorageCleanupConfig struct {
Enabled bool `mapstructure:"enabled"`
Schedule string `mapstructure:"schedule"`
RetentionDays int `mapstructure:"retention_days"`
}
// GatewayConfig API网关相关配置
......@@ -905,6 +923,26 @@ func setDefaults() {
viper.SetDefault("gateway.tls_fingerprint.enabled", true)
viper.SetDefault("concurrency.ping_interval", 10)
// Sora 直连配置
viper.SetDefault("sora.client.base_url", "https://sora.chatgpt.com/backend")
viper.SetDefault("sora.client.timeout_seconds", 120)
viper.SetDefault("sora.client.max_retries", 3)
viper.SetDefault("sora.client.poll_interval_seconds", 2)
viper.SetDefault("sora.client.max_poll_attempts", 600)
viper.SetDefault("sora.client.debug", false)
viper.SetDefault("sora.client.headers", map[string]string{})
viper.SetDefault("sora.client.user_agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
viper.SetDefault("sora.client.disable_tls_fingerprint", false)
viper.SetDefault("sora.storage.type", "local")
viper.SetDefault("sora.storage.local_path", "")
viper.SetDefault("sora.storage.fallback_to_upstream", true)
viper.SetDefault("sora.storage.max_concurrent_downloads", 4)
viper.SetDefault("sora.storage.debug", false)
viper.SetDefault("sora.storage.cleanup.enabled", true)
viper.SetDefault("sora.storage.cleanup.retention_days", 7)
viper.SetDefault("sora.storage.cleanup.schedule", "0 3 * * *")
// TokenRefresh
viper.SetDefault("token_refresh.enabled", true)
viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次
......@@ -920,15 +958,6 @@ func setDefaults() {
viper.SetDefault("gemini.oauth.scopes", "")
viper.SetDefault("gemini.quota.policy", "")
// Sora2API
viper.SetDefault("sora2api.base_url", "")
viper.SetDefault("sora2api.api_key", "")
viper.SetDefault("sora2api.admin_username", "")
viper.SetDefault("sora2api.admin_password", "")
viper.SetDefault("sora2api.admin_token_ttl_seconds", 900)
viper.SetDefault("sora2api.admin_timeout_seconds", 10)
viper.SetDefault("sora2api.token_import_mode", "at")
}
func (c *Config) Validate() error {
......@@ -1164,6 +1193,36 @@ func (c *Config) Validate() error {
return fmt.Errorf("gateway.sora_stream_mode must be one of: force/error")
}
}
if c.Sora.Client.TimeoutSeconds < 0 {
return fmt.Errorf("sora.client.timeout_seconds must be non-negative")
}
if c.Sora.Client.MaxRetries < 0 {
return fmt.Errorf("sora.client.max_retries must be non-negative")
}
if c.Sora.Client.PollIntervalSeconds < 0 {
return fmt.Errorf("sora.client.poll_interval_seconds must be non-negative")
}
if c.Sora.Client.MaxPollAttempts < 0 {
return fmt.Errorf("sora.client.max_poll_attempts must be non-negative")
}
if c.Sora.Storage.MaxConcurrentDownloads < 0 {
return fmt.Errorf("sora.storage.max_concurrent_downloads must be non-negative")
}
if c.Sora.Storage.Cleanup.Enabled {
if c.Sora.Storage.Cleanup.RetentionDays <= 0 {
return fmt.Errorf("sora.storage.cleanup.retention_days must be positive")
}
if strings.TrimSpace(c.Sora.Storage.Cleanup.Schedule) == "" {
return fmt.Errorf("sora.storage.cleanup.schedule is required when cleanup is enabled")
}
} else {
if c.Sora.Storage.Cleanup.RetentionDays < 0 {
return fmt.Errorf("sora.storage.cleanup.retention_days must be non-negative")
}
}
if storageType := strings.TrimSpace(strings.ToLower(c.Sora.Storage.Type)); storageType != "" && storageType != "local" {
return fmt.Errorf("sora.storage.type must be 'local'")
}
if strings.TrimSpace(c.Gateway.ConnectionPoolIsolation) != "" {
switch c.Gateway.ConnectionPoolIsolation {
case ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy:
......@@ -1260,11 +1319,6 @@ func (c *Config) Validate() error {
c.Gateway.Scheduling.OutboxLagRebuildSeconds < c.Gateway.Scheduling.OutboxLagWarnSeconds {
return fmt.Errorf("gateway.scheduling.outbox_lag_rebuild_seconds must be >= outbox_lag_warn_seconds")
}
if strings.TrimSpace(c.Sora2API.BaseURL) != "" {
if err := ValidateAbsoluteHTTPURL(c.Sora2API.BaseURL); err != nil {
return fmt.Errorf("sora2api.base_url invalid: %w", err)
}
}
if c.Ops.MetricsCollectorCache.TTL < 0 {
return fmt.Errorf("ops.metrics_collector_cache.ttl must be non-negative")
}
......
package admin
import (
"net/http"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// ModelHandler handles admin model listing requests.
type ModelHandler struct {
sora2apiService *service.Sora2APIService
}
// NewModelHandler creates a new ModelHandler.
func NewModelHandler(sora2apiService *service.Sora2APIService) *ModelHandler {
return &ModelHandler{
sora2apiService: sora2apiService,
}
}
// List handles listing models for a specific platform
// GET /api/v1/admin/models?platform=sora
func (h *ModelHandler) List(c *gin.Context) {
platform := strings.TrimSpace(strings.ToLower(c.Query("platform")))
if platform == "" {
response.BadRequest(c, "platform is required")
return
}
switch platform {
case service.PlatformSora:
if h.sora2apiService == nil || !h.sora2apiService.Enabled() {
response.Error(c, http.StatusServiceUnavailable, "sora2api not configured")
return
}
models, err := h.sora2apiService.ListModels(c.Request.Context())
if err != nil {
response.Error(c, http.StatusServiceUnavailable, "failed to fetch sora models")
return
}
ids := make([]string, 0, len(models))
for _, m := range models {
if strings.TrimSpace(m.ID) != "" {
ids = append(ids, m.ID)
}
}
response.Success(c, ids)
default:
response.BadRequest(c, "unsupported platform")
}
}
package admin
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
func TestModelHandlerListSoraSuccess(t *testing.T) {
gin.SetMode(gin.TestMode)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"object":"list","data":[{"id":"m1"},{"id":"m2"}]}`))
}))
t.Cleanup(upstream.Close)
cfg := &config.Config{}
cfg.Sora2API.BaseURL = upstream.URL
cfg.Sora2API.APIKey = "test-key"
soraService := service.NewSora2APIService(cfg)
h := NewModelHandler(soraService)
router := gin.New()
router.GET("/admin/models", h.List)
req := httptest.NewRequest(http.MethodGet, "/admin/models?platform=sora", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
if recorder.Code != http.StatusOK {
t.Fatalf("status=%d body=%s", recorder.Code, recorder.Body.String())
}
var resp response.Response
if err := json.Unmarshal(recorder.Body.Bytes(), &resp); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
if resp.Code != 0 {
t.Fatalf("响应 code=%d", resp.Code)
}
data, ok := resp.Data.([]any)
if !ok {
t.Fatalf("响应 data 类型错误")
}
if len(data) != 2 {
t.Fatalf("模型数量不符: %d", len(data))
}
}
func TestModelHandlerListSoraNotConfigured(t *testing.T) {
gin.SetMode(gin.TestMode)
h := NewModelHandler(&service.Sora2APIService{})
router := gin.New()
router.GET("/admin/models", h.List)
req := httptest.NewRequest(http.MethodGet, "/admin/models?platform=sora", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
if recorder.Code != http.StatusServiceUnavailable {
t.Fatalf("status=%d body=%s", recorder.Code, recorder.Body.String())
}
}
func TestModelHandlerListInvalidPlatform(t *testing.T) {
gin.SetMode(gin.TestMode)
h := NewModelHandler(&service.Sora2APIService{})
router := gin.New()
router.GET("/admin/models", h.List)
req := httptest.NewRequest(http.MethodGet, "/admin/models?platform=unknown", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
if recorder.Code != http.StatusBadRequest {
t.Fatalf("status=%d body=%s", recorder.Code, recorder.Body.String())
}
}
......@@ -29,11 +29,11 @@ type GatewayHandler struct {
geminiCompatService *service.GeminiMessagesCompatService
antigravityGatewayService *service.AntigravityGatewayService
userService *service.UserService
sora2apiService *service.Sora2APIService
billingCacheService *service.BillingCacheService
concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
maxAccountSwitchesGemini int
cfg *config.Config
}
// NewGatewayHandler creates a new GatewayHandler
......@@ -42,7 +42,6 @@ func NewGatewayHandler(
geminiCompatService *service.GeminiMessagesCompatService,
antigravityGatewayService *service.AntigravityGatewayService,
userService *service.UserService,
sora2apiService *service.Sora2APIService,
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
cfg *config.Config,
......@@ -64,11 +63,11 @@ func NewGatewayHandler(
geminiCompatService: geminiCompatService,
antigravityGatewayService: antigravityGatewayService,
userService: userService,
sora2apiService: sora2apiService,
billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
maxAccountSwitches: maxAccountSwitches,
maxAccountSwitchesGemini: maxAccountSwitchesGemini,
cfg: cfg,
}
}
......@@ -486,18 +485,9 @@ func (h *GatewayHandler) Models(c *gin.Context) {
}
if platform == service.PlatformSora {
if h.sora2apiService == nil || !h.sora2apiService.Enabled() {
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "sora2api not configured")
return
}
models, err := h.sora2apiService.ListModels(c.Request.Context())
if err != nil {
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Failed to fetch Sora models")
return
}
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": models,
"data": service.DefaultSoraModels(h.cfg),
})
return
}
......
......@@ -23,7 +23,6 @@ type AdminHandlers struct {
Subscription *admin.SubscriptionHandler
Usage *admin.UsageHandler
UserAttribute *admin.UserAttributeHandler
Model *admin.ModelHandler
}
// Handlers contains all HTTP handlers
......
......@@ -10,7 +10,9 @@ import (
"io"
"log"
"net/http"
"os"
"path"
"path/filepath"
"strconv"
"strings"
"time"
......@@ -31,9 +33,8 @@ type SoraGatewayHandler struct {
concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
streamMode string
sora2apiBaseURL string
soraMediaSigningKey string
mediaClient *http.Client
soraMediaRoot string
}
// NewSoraGatewayHandler creates a new SoraGatewayHandler
......@@ -48,6 +49,7 @@ func NewSoraGatewayHandler(
maxAccountSwitches := 3
streamMode := "force"
signKey := ""
mediaRoot := "/app/data/sora"
if cfg != nil {
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
if cfg.Gateway.MaxAccountSwitches > 0 {
......@@ -57,14 +59,9 @@ func NewSoraGatewayHandler(
streamMode = mode
}
signKey = strings.TrimSpace(cfg.Gateway.SoraMediaSigningKey)
}
baseURL := ""
if cfg != nil {
baseURL = strings.TrimRight(strings.TrimSpace(cfg.Sora2API.BaseURL), "/")
}
mediaTimeout := 180 * time.Second
if cfg != nil && cfg.Gateway.SoraRequestTimeoutSeconds > 0 {
mediaTimeout = time.Duration(cfg.Gateway.SoraRequestTimeoutSeconds) * time.Second
if root := strings.TrimSpace(cfg.Sora.Storage.LocalPath); root != "" {
mediaRoot = root
}
}
return &SoraGatewayHandler{
gatewayService: gatewayService,
......@@ -73,9 +70,8 @@ func NewSoraGatewayHandler(
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
maxAccountSwitches: maxAccountSwitches,
streamMode: strings.ToLower(streamMode),
sora2apiBaseURL: baseURL,
soraMediaSigningKey: signKey,
mediaClient: &http.Client{Timeout: mediaTimeout},
soraMediaRoot: mediaRoot,
}
}
......@@ -377,34 +373,24 @@ func (h *SoraGatewayHandler) errorResponse(c *gin.Context, status int, errType,
})
}
// MediaProxy proxies /tmp or /static media files from sora2api
// MediaProxy serves local Sora media files.
func (h *SoraGatewayHandler) MediaProxy(c *gin.Context) {
h.proxySoraMedia(c, false)
}
// MediaProxySigned proxies /tmp or /static media files with signature verification
// MediaProxySigned serves local Sora media files with signature verification.
func (h *SoraGatewayHandler) MediaProxySigned(c *gin.Context) {
h.proxySoraMedia(c, true)
}
func (h *SoraGatewayHandler) proxySoraMedia(c *gin.Context, requireSignature bool) {
if h.sora2apiBaseURL == "" {
c.JSON(http.StatusServiceUnavailable, gin.H{
"error": gin.H{
"type": "api_error",
"message": "sora2api 未配置",
},
})
return
}
rawPath := c.Param("filepath")
if rawPath == "" {
c.Status(http.StatusNotFound)
return
}
cleaned := path.Clean(rawPath)
if !strings.HasPrefix(cleaned, "/tmp/") && !strings.HasPrefix(cleaned, "/static/") {
if !strings.HasPrefix(cleaned, "/image/") && !strings.HasPrefix(cleaned, "/video/") {
c.Status(http.StatusNotFound)
return
}
......@@ -445,40 +431,25 @@ func (h *SoraGatewayHandler) proxySoraMedia(c *gin.Context, requireSignature boo
return
}
}
targetURL := h.sora2apiBaseURL + cleaned
if rawQuery := query.Encode(); rawQuery != "" {
targetURL += "?" + rawQuery
}
req, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, targetURL, nil)
if err != nil {
c.Status(http.StatusBadGateway)
return
}
copyHeaders := []string{"Range", "If-Range", "If-Modified-Since", "If-None-Match", "Accept", "User-Agent"}
for _, key := range copyHeaders {
if val := c.GetHeader(key); val != "" {
req.Header.Set(key, val)
}
}
client := h.mediaClient
if client == nil {
client = http.DefaultClient
}
resp, err := client.Do(req)
if err != nil {
c.Status(http.StatusBadGateway)
if strings.TrimSpace(h.soraMediaRoot) == "" {
c.JSON(http.StatusServiceUnavailable, gin.H{
"error": gin.H{
"type": "api_error",
"message": "Sora 媒体目录未配置",
},
})
return
}
defer func() { _ = resp.Body.Close() }()
for _, key := range []string{"Content-Type", "Content-Length", "Accept-Ranges", "Content-Range", "Cache-Control", "Last-Modified", "ETag"} {
if val := resp.Header.Get(key); val != "" {
c.Header(key, val)
relative := strings.TrimPrefix(cleaned, "/")
localPath := filepath.Join(h.soraMediaRoot, filepath.FromSlash(relative))
if _, err := os.Stat(localPath); err != nil {
if os.IsNotExist(err) {
c.Status(http.StatusNotFound)
return
}
c.Status(http.StatusInternalServerError)
return
}
c.Status(resp.StatusCode)
_, _ = io.Copy(c.Writer, resp.Body)
c.File(localPath)
}
//go:build unit
package handler
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type stubSoraClient struct {
imageURLs []string
}
func (s *stubSoraClient) Enabled() bool { return true }
func (s *stubSoraClient) UploadImage(ctx context.Context, account *service.Account, data []byte, filename string) (string, error) {
return "upload", nil
}
func (s *stubSoraClient) CreateImageTask(ctx context.Context, account *service.Account, req service.SoraImageRequest) (string, error) {
return "task-image", nil
}
func (s *stubSoraClient) CreateVideoTask(ctx context.Context, account *service.Account, req service.SoraVideoRequest) (string, error) {
return "task-video", nil
}
func (s *stubSoraClient) GetImageTask(ctx context.Context, account *service.Account, taskID string) (*service.SoraImageTaskStatus, error) {
return &service.SoraImageTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil
}
func (s *stubSoraClient) GetVideoTask(ctx context.Context, account *service.Account, taskID string) (*service.SoraVideoTaskStatus, error) {
return &service.SoraVideoTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil
}
type stubConcurrencyCache struct{}
func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
return true, nil
}
func (c stubConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
return nil
}
func (c stubConcurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
return 0, nil
}
func (c stubConcurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
return true, nil
}
func (c stubConcurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
return nil
}
func (c stubConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
return 0, nil
}
func (c stubConcurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
return true, nil
}
func (c stubConcurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
return nil
}
func (c stubConcurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
return 0, nil
}
func (c stubConcurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
return true, nil
}
func (c stubConcurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
return nil
}
func (c stubConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
result := make(map[int64]*service.AccountLoadInfo, len(accounts))
for _, acc := range accounts {
result[acc.ID] = &service.AccountLoadInfo{AccountID: acc.ID, LoadRate: 0}
}
return result, nil
}
func (c stubConcurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
return nil
}
type stubAccountRepo struct {
accounts map[int64]*service.Account
}
func (r *stubAccountRepo) Create(ctx context.Context, account *service.Account) error { return nil }
func (r *stubAccountRepo) GetByID(ctx context.Context, id int64) (*service.Account, error) {
if acc, ok := r.accounts[id]; ok {
return acc, nil
}
return nil, service.ErrAccountNotFound
}
func (r *stubAccountRepo) GetByIDs(ctx context.Context, ids []int64) ([]*service.Account, error) {
var result []*service.Account
for _, id := range ids {
if acc, ok := r.accounts[id]; ok {
result = append(result, acc)
}
}
return result, nil
}
func (r *stubAccountRepo) ExistsByID(ctx context.Context, id int64) (bool, error) {
_, ok := r.accounts[id]
return ok, nil
}
func (r *stubAccountRepo) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*service.Account, error) {
return nil, nil
}
func (r *stubAccountRepo) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) {
return nil, nil
}
func (r *stubAccountRepo) Update(ctx context.Context, account *service.Account) error { return nil }
func (r *stubAccountRepo) Delete(ctx context.Context, id int64) error { return nil }
func (r *stubAccountRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubAccountRepo) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) {
return nil, nil
}
func (r *stubAccountRepo) ListActive(ctx context.Context) ([]service.Account, error) { return nil, nil }
func (r *stubAccountRepo) ListByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
return r.listSchedulableByPlatform(platform), nil
}
func (r *stubAccountRepo) UpdateLastUsed(ctx context.Context, id int64) error { return nil }
func (r *stubAccountRepo) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
return nil
}
func (r *stubAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error { return nil }
func (r *stubAccountRepo) ClearError(ctx context.Context, id int64) error { return nil }
func (r *stubAccountRepo) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return nil
}
func (r *stubAccountRepo) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
return 0, nil
}
func (r *stubAccountRepo) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
return nil
}
func (r *stubAccountRepo) ListSchedulable(ctx context.Context) ([]service.Account, error) {
return r.listSchedulable(), nil
}
func (r *stubAccountRepo) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) {
return r.listSchedulable(), nil
}
func (r *stubAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
return r.listSchedulableByPlatform(platform), nil
}
func (r *stubAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
return r.listSchedulableByPlatform(platform), nil
}
func (r *stubAccountRepo) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
var result []service.Account
for _, acc := range r.accounts {
for _, platform := range platforms {
if acc.Platform == platform && acc.IsSchedulable() {
result = append(result, *acc)
break
}
}
}
return result, nil
}
func (r *stubAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) {
return r.ListSchedulableByPlatforms(ctx, platforms)
}
func (r *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return nil
}
func (r *stubAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error {
return nil
}
func (r *stubAccountRepo) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
return nil
}
func (r *stubAccountRepo) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return nil
}
func (r *stubAccountRepo) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
return nil
}
func (r *stubAccountRepo) ClearTempUnschedulable(ctx context.Context, id int64) error { return nil }
func (r *stubAccountRepo) ClearRateLimit(ctx context.Context, id int64) error { return nil }
func (r *stubAccountRepo) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
return nil
}
func (r *stubAccountRepo) ClearModelRateLimits(ctx context.Context, id int64) error { return nil }
func (r *stubAccountRepo) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
return nil
}
func (r *stubAccountRepo) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
return nil
}
func (r *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
return 0, nil
}
func (r *stubAccountRepo) listSchedulable() []service.Account {
var result []service.Account
for _, acc := range r.accounts {
if acc.IsSchedulable() {
result = append(result, *acc)
}
}
return result
}
func (r *stubAccountRepo) listSchedulableByPlatform(platform string) []service.Account {
var result []service.Account
for _, acc := range r.accounts {
if acc.Platform == platform && acc.IsSchedulable() {
result = append(result, *acc)
}
}
return result
}
type stubGroupRepo struct {
group *service.Group
}
func (r *stubGroupRepo) Create(ctx context.Context, group *service.Group) error { return nil }
func (r *stubGroupRepo) GetByID(ctx context.Context, id int64) (*service.Group, error) {
return r.group, nil
}
func (r *stubGroupRepo) GetByIDLite(ctx context.Context, id int64) (*service.Group, error) {
return r.group, nil
}
func (r *stubGroupRepo) Update(ctx context.Context, group *service.Group) error { return nil }
func (r *stubGroupRepo) Delete(ctx context.Context, id int64) error { return nil }
func (r *stubGroupRepo) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
return nil, nil
}
func (r *stubGroupRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubGroupRepo) ListActive(ctx context.Context) ([]service.Group, error) { return nil, nil }
func (r *stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
return nil, nil
}
func (r *stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) {
return false, nil
}
func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
return 0, nil
}
func (r *stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, nil
}
type stubUsageLogRepo struct{}
func (s *stubUsageLogRepo) Create(ctx context.Context, log *service.UsageLog) (bool, error) {
return true, nil
}
func (s *stubUsageLogRepo) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) {
return nil, nil
}
func (s *stubUsageLogRepo) Delete(ctx context.Context, id int64) error { return nil }
func (s *stubUsageLogRepo) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubUsageLogRepo) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubUsageLogRepo) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubUsageLogRepo) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubUsageLogRepo) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubUsageLogRepo) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubUsageLogRepo) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubUsageLogRepo) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error) {
return nil, nil
}
func (s *stubUsageLogRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubUsageLogRepo) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error) {
return nil, nil
}
func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
RunMode: config.RunModeSimple,
Gateway: config.GatewayConfig{
SoraStreamMode: "force",
MaxAccountSwitches: 1,
Scheduling: config.GatewaySchedulingConfig{
LoadBatchEnabled: false,
},
},
Concurrency: config.ConcurrencyConfig{PingInterval: 0},
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.test",
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
account := &service.Account{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, Concurrency: 1, Priority: 1}
accountRepo := &stubAccountRepo{accounts: map[int64]*service.Account{account.ID: account}}
group := &service.Group{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Hydrated: true}
groupRepo := &stubGroupRepo{group: group}
usageLogRepo := &stubUsageLogRepo{}
deferredService := service.NewDeferredService(accountRepo, nil, 0)
billingService := service.NewBillingService(cfg, nil)
concurrencyService := service.NewConcurrencyService(stubConcurrencyCache{})
billingCacheService := service.NewBillingCacheService(nil, nil, nil, cfg)
t.Cleanup(func() {
billingCacheService.Stop()
})
gatewayService := service.NewGatewayService(
accountRepo,
groupRepo,
usageLogRepo,
nil,
nil,
nil,
cfg,
nil,
concurrencyService,
billingService,
nil,
billingCacheService,
nil,
nil,
deferredService,
nil,
nil,
)
soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}}
soraGatewayService := service.NewSoraGatewayService(soraClient, nil, nil, cfg)
handler := NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, cfg)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := `{"model":"gpt-image","messages":[{"role":"user","content":"hello"}]}`
c.Request = httptest.NewRequest(http.MethodPost, "/sora/v1/chat/completions", strings.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
apiKey := &service.APIKey{
ID: 1,
UserID: 1,
Status: service.StatusActive,
GroupID: &group.ID,
User: &service.User{ID: 1, Concurrency: 1, Status: service.StatusActive},
Group: group,
}
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.UserID, Concurrency: apiKey.User.Concurrency})
handler.ChatCompletions(c)
require.Equal(t, http.StatusOK, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.NotEmpty(t, resp["media_url"])
}
......@@ -26,7 +26,6 @@ func ProvideAdminHandlers(
subscriptionHandler *admin.SubscriptionHandler,
usageHandler *admin.UsageHandler,
userAttributeHandler *admin.UserAttributeHandler,
modelHandler *admin.ModelHandler,
) *AdminHandlers {
return &AdminHandlers{
Dashboard: dashboardHandler,
......@@ -46,7 +45,6 @@ func ProvideAdminHandlers(
Subscription: subscriptionHandler,
Usage: usageHandler,
UserAttribute: userAttributeHandler,
Model: modelHandler,
}
}
......@@ -121,7 +119,6 @@ var ProviderSet = wire.NewSet(
admin.NewSubscriptionHandler,
admin.NewUsageHandler,
admin.NewUserAttributeHandler,
admin.NewModelHandler,
// AdminHandlers and Handlers constructors
ProvideAdminHandlers,
......
......@@ -178,6 +178,10 @@ func TestAPIContracts(t *testing.T) {
"image_price_1k": null,
"image_price_2k": null,
"image_price_4k": null,
"sora_image_price_360": null,
"sora_image_price_540": null,
"sora_video_price_per_request": null,
"sora_video_price_per_request_hd": null,
"claude_code_only": false,
"fallback_group_id": null,
"created_at": "2025-01-02T03:04:05Z",
......@@ -394,6 +398,7 @@ func TestAPIContracts(t *testing.T) {
"first_token_ms": 50,
"image_count": 0,
"image_size": null,
"media_type": null,
"created_at": "2025-01-02T03:04:05Z",
"user_agent": null
}
......@@ -887,6 +892,10 @@ func (s *stubAccountRepo) GetByCRSAccountID(ctx context.Context, crsAccountID st
return nil, errors.New("not implemented")
}
func (s *stubAccountRepo) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) {
return nil, errors.New("not implemented")
}
func (s *stubAccountRepo) Update(ctx context.Context, account *service.Account) error {
return errors.New("not implemented")
}
......
......@@ -64,9 +64,6 @@ func RegisterAdminRoutes(
// 用户属性管理
registerUserAttributeRoutes(admin, h)
// 模型列表
registerModelRoutes(admin, h)
}
}
......@@ -374,7 +371,3 @@ func registerUserAttributeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
attrs.DELETE("/:id", h.Admin.UserAttribute.DeleteDefinition)
}
}
func registerModelRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
admin.GET("/models", h.Admin.Model.List)
}
......@@ -491,7 +491,7 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
return s.sendErrorAndEnd(c, "Failed to create request")
}
// 使用 Sora 客户端标准请求头(参考 sora2api)
// 使用 Sora 客户端标准请求头
req.Header.Set("Authorization", "Bearer "+authToken)
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
req.Header.Set("Accept", "application/json")
......
......@@ -283,7 +283,6 @@ type adminServiceImpl struct {
groupRepo GroupRepository
accountRepo AccountRepository
soraAccountRepo SoraAccountRepository // Sora 账号扩展表仓储
soraSyncService *Sora2APISyncService // Sora2API 同步服务
proxyRepo ProxyRepository
apiKeyRepo APIKeyRepository
redeemCodeRepo RedeemCodeRepository
......@@ -299,7 +298,6 @@ func NewAdminService(
groupRepo GroupRepository,
accountRepo AccountRepository,
soraAccountRepo SoraAccountRepository,
soraSyncService *Sora2APISyncService,
proxyRepo ProxyRepository,
apiKeyRepo APIKeyRepository,
redeemCodeRepo RedeemCodeRepository,
......@@ -313,7 +311,6 @@ func NewAdminService(
groupRepo: groupRepo,
accountRepo: accountRepo,
soraAccountRepo: soraAccountRepo,
soraSyncService: soraSyncService,
proxyRepo: proxyRepo,
apiKeyRepo: apiKeyRepo,
redeemCodeRepo: redeemCodeRepo,
......@@ -917,9 +914,6 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
}
}
// 同步到 sora2api(异步,不阻塞创建)
s.syncSoraAccountAsync(account)
return account, nil
}
......@@ -1014,7 +1008,6 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
if err != nil {
return nil, err
}
s.syncSoraAccountAsync(updated)
return updated, nil
}
......@@ -1032,17 +1025,15 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
}
needMixedChannelCheck := input.GroupIDs != nil && !input.SkipMixedChannelCheck
needSoraSync := s != nil && s.soraSyncService != nil
// 预加载账号平台信息(混合渠道检查或 Sora 同步需要)。
platformByID := map[int64]string{}
if needMixedChannelCheck || needSoraSync {
if needMixedChannelCheck {
accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs)
if err != nil {
if needMixedChannelCheck {
return nil, err
}
log.Printf("[AdminService] 预加载账号平台信息失败,将逐个降级同步: err=%v", err)
} else {
for _, account := range accounts {
if account != nil {
......@@ -1134,45 +1125,15 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
result.Success++
result.SuccessIDs = append(result.SuccessIDs, accountID)
result.Results = append(result.Results, entry)
// 批量更新后同步 sora2api
if needSoraSync {
platform := platformByID[accountID]
if platform == "" {
updated, err := s.accountRepo.GetByID(ctx, accountID)
if err != nil {
log.Printf("[AdminService] 批量更新后获取账号失败,无法同步 sora2api: account_id=%d err=%v", accountID, err)
continue
}
if updated.Platform == PlatformSora {
s.syncSoraAccountAsync(updated)
}
continue
}
if platform == PlatformSora {
updated, err := s.accountRepo.GetByID(ctx, accountID)
if err != nil {
log.Printf("[AdminService] 批量更新后获取账号失败,无法同步 sora2api: account_id=%d err=%v", accountID, err)
continue
}
s.syncSoraAccountAsync(updated)
}
}
}
return result, nil
}
func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return err
}
if err := s.accountRepo.Delete(ctx, id); err != nil {
return err
}
s.deleteSoraAccountAsync(account)
return nil
}
......@@ -1210,44 +1171,9 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64,
if err != nil {
return nil, err
}
s.syncSoraAccountAsync(updated)
return updated, nil
}
func (s *adminServiceImpl) syncSoraAccountAsync(account *Account) {
if s == nil || s.soraSyncService == nil || account == nil {
return
}
if account.Platform != PlatformSora {
return
}
syncAccount := *account
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := s.soraSyncService.SyncAccount(ctx, &syncAccount); err != nil {
log.Printf("[AdminService] 同步 sora2api 失败: account_id=%d err=%v", syncAccount.ID, err)
}
}()
}
func (s *adminServiceImpl) deleteSoraAccountAsync(account *Account) {
if s == nil || s.soraSyncService == nil || account == nil {
return
}
if account.Platform != PlatformSora {
return
}
syncAccount := *account
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := s.soraSyncService.DeleteAccount(ctx, &syncAccount); err != nil {
log.Printf("[AdminService] 删除 sora2api token 失败: account_id=%d err=%v", syncAccount.ID, err)
}
}()
}
// Proxy management implementations
func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
......
......@@ -105,31 +105,3 @@ func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) {
require.ElementsMatch(t, []int64{2}, result.FailedIDs)
require.Len(t, result.Results, 3)
}
// TestAdminService_BulkUpdateAccounts_SoraSyncWithoutGroupIDs 验证无分组更新时仍会触发 Sora 同步。
func TestAdminService_BulkUpdateAccounts_SoraSyncWithoutGroupIDs(t *testing.T) {
repo := &accountRepoStubForBulkUpdate{
getByIDsAccounts: []*Account{
{ID: 1, Platform: PlatformSora},
},
getByIDAccounts: map[int64]*Account{
1: {ID: 1, Platform: PlatformSora},
},
}
svc := &adminServiceImpl{
accountRepo: repo,
soraSyncService: &Sora2APISyncService{},
}
schedulable := true
input := &BulkUpdateAccountsInput{
AccountIDs: []int64{1},
Schedulable: &schedulable,
}
result, err := svc.BulkUpdateAccounts(context.Background(), input)
require.NoError(t, err)
require.Equal(t, 1, result.Success)
require.True(t, repo.getByIDsCalled)
require.ElementsMatch(t, []int64{1}, repo.getByIDCalled)
}
package service
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
)
// Sora2APIModel represents a model entry returned by sora2api.
type Sora2APIModel struct {
ID string `json:"id"`
Object string `json:"object"`
OwnedBy string `json:"owned_by,omitempty"`
Description string `json:"description,omitempty"`
}
// Sora2APIModelList represents /v1/models response.
type Sora2APIModelList struct {
Object string `json:"object"`
Data []Sora2APIModel `json:"data"`
}
// Sora2APIImportTokenItem mirrors sora2api ImportTokenItem.
type Sora2APIImportTokenItem struct {
Email string `json:"email"`
AccessToken string `json:"access_token,omitempty"`
SessionToken string `json:"session_token,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
ClientID string `json:"client_id,omitempty"`
ProxyURL string `json:"proxy_url,omitempty"`
Remark string `json:"remark,omitempty"`
IsActive bool `json:"is_active"`
ImageEnabled bool `json:"image_enabled"`
VideoEnabled bool `json:"video_enabled"`
ImageConcurrency int `json:"image_concurrency"`
VideoConcurrency int `json:"video_concurrency"`
}
// Sora2APIToken represents minimal fields for admin list.
type Sora2APIToken struct {
ID int64 `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
Remark string `json:"remark"`
}
// Sora2APIService provides access to sora2api endpoints.
type Sora2APIService struct {
cfg *config.Config
baseURL string
apiKey string
adminUsername string
adminPassword string
adminTokenTTL time.Duration
tokenImportMode string
client *http.Client
adminClient *http.Client
adminToken string
adminTokenAt time.Time
adminMu sync.Mutex
modelCache []Sora2APIModel
modelMu sync.RWMutex
}
func NewSora2APIService(cfg *config.Config) *Sora2APIService {
if cfg == nil {
return &Sora2APIService{}
}
adminTTL := time.Duration(cfg.Sora2API.AdminTokenTTLSeconds) * time.Second
if adminTTL <= 0 {
adminTTL = 15 * time.Minute
}
adminTimeout := time.Duration(cfg.Sora2API.AdminTimeoutSeconds) * time.Second
if adminTimeout <= 0 {
adminTimeout = 10 * time.Second
}
return &Sora2APIService{
cfg: cfg,
baseURL: strings.TrimRight(strings.TrimSpace(cfg.Sora2API.BaseURL), "/"),
apiKey: strings.TrimSpace(cfg.Sora2API.APIKey),
adminUsername: strings.TrimSpace(cfg.Sora2API.AdminUsername),
adminPassword: strings.TrimSpace(cfg.Sora2API.AdminPassword),
adminTokenTTL: adminTTL,
tokenImportMode: strings.ToLower(strings.TrimSpace(cfg.Sora2API.TokenImportMode)),
client: &http.Client{},
adminClient: &http.Client{Timeout: adminTimeout},
}
}
func (s *Sora2APIService) Enabled() bool {
return s != nil && s.baseURL != "" && s.apiKey != ""
}
func (s *Sora2APIService) AdminEnabled() bool {
return s != nil && s.baseURL != "" && s.adminUsername != "" && s.adminPassword != ""
}
func (s *Sora2APIService) buildURL(path string) string {
if s.baseURL == "" {
return path
}
if strings.HasPrefix(path, "/") {
return s.baseURL + path
}
return s.baseURL + "/" + path
}
// BuildURL 返回完整的 sora2api URL(用于代理媒体)
func (s *Sora2APIService) BuildURL(path string) string {
return s.buildURL(path)
}
func (s *Sora2APIService) NewAPIRequest(ctx context.Context, method string, path string, body []byte) (*http.Request, error) {
if !s.Enabled() {
return nil, errors.New("sora2api not configured")
}
req, err := http.NewRequestWithContext(ctx, method, s.buildURL(path), bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+s.apiKey)
req.Header.Set("Content-Type", "application/json")
return req, nil
}
func (s *Sora2APIService) ListModels(ctx context.Context) ([]Sora2APIModel, error) {
if !s.Enabled() {
return nil, errors.New("sora2api not configured")
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, s.buildURL("/v1/models"), nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+s.apiKey)
resp, err := s.client.Do(req)
if err != nil {
return s.cachedModelsOnError(err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return s.cachedModelsOnError(fmt.Errorf("sora2api models status: %d", resp.StatusCode))
}
var payload Sora2APIModelList
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
return s.cachedModelsOnError(err)
}
models := payload.Data
if s.cfg != nil && s.cfg.Gateway.SoraModelFilters.HidePromptEnhance {
filtered := make([]Sora2APIModel, 0, len(models))
for _, m := range models {
if strings.HasPrefix(strings.ToLower(m.ID), "prompt-enhance") {
continue
}
filtered = append(filtered, m)
}
models = filtered
}
s.modelMu.Lock()
s.modelCache = models
s.modelMu.Unlock()
return models, nil
}
func (s *Sora2APIService) cachedModelsOnError(err error) ([]Sora2APIModel, error) {
s.modelMu.RLock()
cached := append([]Sora2APIModel(nil), s.modelCache...)
s.modelMu.RUnlock()
if len(cached) > 0 {
log.Printf("[Sora2API] 模型列表拉取失败,回退缓存: %v", err)
return cached, nil
}
return nil, err
}
func (s *Sora2APIService) ImportTokens(ctx context.Context, items []Sora2APIImportTokenItem) error {
if !s.AdminEnabled() {
return errors.New("sora2api admin not configured")
}
mode := s.tokenImportMode
if mode == "" {
mode = "at"
}
payload := map[string]any{
"tokens": items,
"mode": mode,
}
_, err := s.doAdminRequest(ctx, http.MethodPost, "/api/tokens/import", payload, nil)
return err
}
func (s *Sora2APIService) ListTokens(ctx context.Context) ([]Sora2APIToken, error) {
if !s.AdminEnabled() {
return nil, errors.New("sora2api admin not configured")
}
var tokens []Sora2APIToken
_, err := s.doAdminRequest(ctx, http.MethodGet, "/api/tokens", nil, &tokens)
return tokens, err
}
func (s *Sora2APIService) DisableToken(ctx context.Context, tokenID int64) error {
if !s.AdminEnabled() {
return errors.New("sora2api admin not configured")
}
path := fmt.Sprintf("/api/tokens/%d/disable", tokenID)
_, err := s.doAdminRequest(ctx, http.MethodPost, path, nil, nil)
return err
}
func (s *Sora2APIService) DeleteToken(ctx context.Context, tokenID int64) error {
if !s.AdminEnabled() {
return errors.New("sora2api admin not configured")
}
path := fmt.Sprintf("/api/tokens/%d", tokenID)
_, err := s.doAdminRequest(ctx, http.MethodDelete, path, nil, nil)
return err
}
func (s *Sora2APIService) doAdminRequest(ctx context.Context, method string, path string, body any, out any) (*http.Response, error) {
if !s.AdminEnabled() {
return nil, errors.New("sora2api admin not configured")
}
token, err := s.getAdminToken(ctx)
if err != nil {
return nil, err
}
resp, err := s.doAdminRequestWithToken(ctx, method, path, token, body, out)
if err == nil && resp != nil && resp.StatusCode != http.StatusUnauthorized {
return resp, nil
}
if resp != nil && resp.StatusCode == http.StatusUnauthorized {
s.invalidateAdminToken()
token, err = s.getAdminToken(ctx)
if err != nil {
return resp, err
}
return s.doAdminRequestWithToken(ctx, method, path, token, body, out)
}
return resp, err
}
func (s *Sora2APIService) doAdminRequestWithToken(ctx context.Context, method string, path string, token string, body any, out any) (*http.Response, error) {
var reader *bytes.Reader
if body != nil {
buf, err := json.Marshal(body)
if err != nil {
return nil, err
}
reader = bytes.NewReader(buf)
} else {
reader = bytes.NewReader(nil)
}
req, err := http.NewRequestWithContext(ctx, method, s.buildURL(path), reader)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+token)
if body != nil {
req.Header.Set("Content-Type", "application/json")
}
resp, err := s.adminClient.Do(req)
if err != nil {
return resp, err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return resp, fmt.Errorf("sora2api admin status: %d", resp.StatusCode)
}
if out != nil {
if err := json.NewDecoder(resp.Body).Decode(out); err != nil {
return resp, err
}
}
return resp, nil
}
func (s *Sora2APIService) getAdminToken(ctx context.Context) (string, error) {
s.adminMu.Lock()
defer s.adminMu.Unlock()
if s.adminToken != "" && time.Since(s.adminTokenAt) < s.adminTokenTTL {
return s.adminToken, nil
}
if !s.AdminEnabled() {
return "", errors.New("sora2api admin not configured")
}
payload := map[string]string{
"username": s.adminUsername,
"password": s.adminPassword,
}
buf, err := json.Marshal(payload)
if err != nil {
return "", err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.buildURL("/api/login"), bytes.NewReader(buf))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
resp, err := s.adminClient.Do(req)
if err != nil {
return "", err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("sora2api login failed: %d", resp.StatusCode)
}
var result struct {
Success bool `json:"success"`
Token string `json:"token"`
Message string `json:"message"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", err
}
if !result.Success || result.Token == "" {
if result.Message == "" {
result.Message = "sora2api login failed"
}
return "", errors.New(result.Message)
}
s.adminToken = result.Token
s.adminTokenAt = time.Now()
return result.Token, nil
}
func (s *Sora2APIService) invalidateAdminToken() {
s.adminMu.Lock()
defer s.adminMu.Unlock()
s.adminToken = ""
s.adminTokenAt = time.Time{}
}
package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
)
// Sora2APISyncService 用于同步 Sora 账号到 sora2api token 池
type Sora2APISyncService struct {
sora2api *Sora2APIService
accountRepo AccountRepository
httpClient *http.Client
}
func NewSora2APISyncService(sora2api *Sora2APIService, accountRepo AccountRepository) *Sora2APISyncService {
return &Sora2APISyncService{
sora2api: sora2api,
accountRepo: accountRepo,
httpClient: &http.Client{Timeout: 10 * time.Second},
}
}
func (s *Sora2APISyncService) Enabled() bool {
return s != nil && s.sora2api != nil && s.sora2api.AdminEnabled()
}
// SyncAccount 将 Sora 账号同步到 sora2api(导入或更新)
func (s *Sora2APISyncService) SyncAccount(ctx context.Context, account *Account) error {
if !s.Enabled() {
return nil
}
if account == nil || account.Platform != PlatformSora {
return nil
}
accessToken := strings.TrimSpace(account.GetCredential("access_token"))
if accessToken == "" {
return errors.New("sora 账号缺少 access_token")
}
email, updated := s.resolveAccountEmail(ctx, account)
if email == "" {
return errors.New("无法解析 Sora 账号邮箱")
}
if updated && s.accountRepo != nil {
if err := s.accountRepo.Update(ctx, account); err != nil {
log.Printf("[SoraSync] 更新账号邮箱失败: account_id=%d err=%v", account.ID, err)
}
}
item := Sora2APIImportTokenItem{
Email: email,
AccessToken: accessToken,
SessionToken: strings.TrimSpace(account.GetCredential("session_token")),
RefreshToken: strings.TrimSpace(account.GetCredential("refresh_token")),
ClientID: strings.TrimSpace(account.GetCredential("client_id")),
Remark: account.Name,
IsActive: account.IsActive() && account.Schedulable,
ImageEnabled: true,
VideoEnabled: true,
ImageConcurrency: normalizeSoraConcurrency(account.Concurrency),
VideoConcurrency: normalizeSoraConcurrency(account.Concurrency),
}
if err := s.sora2api.ImportTokens(ctx, []Sora2APIImportTokenItem{item}); err != nil {
return err
}
return nil
}
// DisableAccount 禁用 sora2api 中的 token
func (s *Sora2APISyncService) DisableAccount(ctx context.Context, account *Account) error {
if !s.Enabled() {
return nil
}
if account == nil || account.Platform != PlatformSora {
return nil
}
tokenID, err := s.resolveTokenID(ctx, account)
if err != nil {
return err
}
return s.sora2api.DisableToken(ctx, tokenID)
}
// DeleteAccount 删除 sora2api 中的 token
func (s *Sora2APISyncService) DeleteAccount(ctx context.Context, account *Account) error {
if !s.Enabled() {
return nil
}
if account == nil || account.Platform != PlatformSora {
return nil
}
tokenID, err := s.resolveTokenID(ctx, account)
if err != nil {
return err
}
return s.sora2api.DeleteToken(ctx, tokenID)
}
func normalizeSoraConcurrency(value int) int {
if value <= 0 {
return -1
}
return value
}
func (s *Sora2APISyncService) resolveAccountEmail(ctx context.Context, account *Account) (string, bool) {
if account == nil {
return "", false
}
if email := strings.TrimSpace(account.GetCredential("email")); email != "" {
return email, false
}
if email := strings.TrimSpace(account.GetExtraString("email")); email != "" {
if account.Credentials == nil {
account.Credentials = map[string]any{}
}
account.Credentials["email"] = email
return email, true
}
if email := strings.TrimSpace(account.GetExtraString("sora_email")); email != "" {
if account.Credentials == nil {
account.Credentials = map[string]any{}
}
account.Credentials["email"] = email
return email, true
}
accessToken := strings.TrimSpace(account.GetCredential("access_token"))
if accessToken != "" {
if email := extractEmailFromAccessToken(accessToken); email != "" {
if account.Credentials == nil {
account.Credentials = map[string]any{}
}
account.Credentials["email"] = email
return email, true
}
if email := s.fetchEmailFromSora(ctx, accessToken); email != "" {
if account.Credentials == nil {
account.Credentials = map[string]any{}
}
account.Credentials["email"] = email
return email, true
}
}
return "", false
}
func (s *Sora2APISyncService) resolveTokenID(ctx context.Context, account *Account) (int64, error) {
if account == nil {
return 0, errors.New("account is nil")
}
if account.Extra != nil {
if v, ok := account.Extra["sora2api_token_id"]; ok {
if id, ok := v.(float64); ok && id > 0 {
return int64(id), nil
}
if id, ok := v.(int64); ok && id > 0 {
return id, nil
}
if id, ok := v.(int); ok && id > 0 {
return int64(id), nil
}
}
}
email := strings.TrimSpace(account.GetCredential("email"))
if email == "" {
email, _ = s.resolveAccountEmail(ctx, account)
}
if email == "" {
return 0, errors.New("sora2api token email missing")
}
tokenID, err := s.findTokenIDByEmail(ctx, email)
if err != nil {
return 0, err
}
return tokenID, nil
}
func (s *Sora2APISyncService) findTokenIDByEmail(ctx context.Context, email string) (int64, error) {
if !s.Enabled() {
return 0, errors.New("sora2api admin not configured")
}
tokens, err := s.sora2api.ListTokens(ctx)
if err != nil {
return 0, err
}
for _, token := range tokens {
if strings.EqualFold(strings.TrimSpace(token.Email), strings.TrimSpace(email)) {
return token.ID, nil
}
}
return 0, fmt.Errorf("sora2api token not found for email: %s", email)
}
func extractEmailFromAccessToken(accessToken string) string {
parser := jwt.NewParser(jwt.WithoutClaimsValidation())
claims := jwt.MapClaims{}
_, _, err := parser.ParseUnverified(accessToken, claims)
if err != nil {
return ""
}
if email, ok := claims["email"].(string); ok && strings.TrimSpace(email) != "" {
return email
}
if profile, ok := claims["https://api.openai.com/profile"].(map[string]any); ok {
if email, ok := profile["email"].(string); ok && strings.TrimSpace(email) != "" {
return email
}
}
return ""
}
func (s *Sora2APISyncService) fetchEmailFromSora(ctx context.Context, accessToken string) string {
if s.httpClient == nil {
return ""
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, soraMeAPIURL, nil)
if err != nil {
return ""
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
req.Header.Set("Accept", "application/json")
resp, err := s.httpClient.Do(req)
if err != nil {
return ""
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return ""
}
var payload map[string]any
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
return ""
}
if email, ok := payload["email"].(string); ok && strings.TrimSpace(email) != "" {
return email
}
return ""
}
This diff is collapsed.
//go:build unit
package service
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestSoraDirectClient_DoRequestSuccess(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"ok":true}`))
}))
defer server.Close()
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{BaseURL: server.URL},
},
}
client := NewSoraDirectClient(cfg, nil, nil)
body, _, err := client.doRequest(context.Background(), &Account{ID: 1}, http.MethodGet, server.URL, http.Header{}, nil, false)
require.NoError(t, err)
require.Contains(t, string(body), "ok")
}
func TestSoraDirectClient_BuildBaseHeaders(t *testing.T) {
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
Headers: map[string]string{
"X-Test": "yes",
"Authorization": "should-ignore",
"openai-sentinel-token": "skip",
},
},
},
}
client := NewSoraDirectClient(cfg, nil, nil)
headers := client.buildBaseHeaders("token-123", "UA")
require.Equal(t, "Bearer token-123", headers.Get("Authorization"))
require.Equal(t, "UA", headers.Get("User-Agent"))
require.Equal(t, "yes", headers.Get("X-Test"))
require.Empty(t, headers.Get("openai-sentinel-token"))
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment