Commit 618a614c authored by yangjianbo's avatar yangjianbo
Browse files

feat(Sora): 完成Sora网关接入与媒体能力

新增 Sora 网关路由、账号调度与同步服务\n补充媒体代理与签名 URL、模型列表动态拉取\n完善计费配置、前端支持与相关测试
parent 99dc3b59
package admin
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
func TestModelHandlerListSoraSuccess(t *testing.T) {
gin.SetMode(gin.TestMode)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"object":"list","data":[{"id":"m1"},{"id":"m2"}]}`))
}))
t.Cleanup(upstream.Close)
cfg := &config.Config{}
cfg.Sora2API.BaseURL = upstream.URL
cfg.Sora2API.APIKey = "test-key"
soraService := service.NewSora2APIService(cfg)
h := NewModelHandler(soraService)
router := gin.New()
router.GET("/admin/models", h.List)
req := httptest.NewRequest(http.MethodGet, "/admin/models?platform=sora", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
if recorder.Code != http.StatusOK {
t.Fatalf("status=%d body=%s", recorder.Code, recorder.Body.String())
}
var resp response.Response
if err := json.Unmarshal(recorder.Body.Bytes(), &resp); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
if resp.Code != 0 {
t.Fatalf("响应 code=%d", resp.Code)
}
data, ok := resp.Data.([]any)
if !ok {
t.Fatalf("响应 data 类型错误")
}
if len(data) != 2 {
t.Fatalf("模型数量不符: %d", len(data))
}
}
func TestModelHandlerListSoraNotConfigured(t *testing.T) {
gin.SetMode(gin.TestMode)
h := NewModelHandler(&service.Sora2APIService{})
router := gin.New()
router.GET("/admin/models", h.List)
req := httptest.NewRequest(http.MethodGet, "/admin/models?platform=sora", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
if recorder.Code != http.StatusServiceUnavailable {
t.Fatalf("status=%d body=%s", recorder.Code, recorder.Body.String())
}
}
func TestModelHandlerListInvalidPlatform(t *testing.T) {
gin.SetMode(gin.TestMode)
h := NewModelHandler(&service.Sora2APIService{})
router := gin.New()
router.GET("/admin/models", h.List)
req := httptest.NewRequest(http.MethodGet, "/admin/models?platform=unknown", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
if recorder.Code != http.StatusBadRequest {
t.Fatalf("status=%d body=%s", recorder.Code, recorder.Body.String())
}
}
...@@ -136,6 +136,10 @@ func groupFromServiceBase(g *service.Group) Group { ...@@ -136,6 +136,10 @@ func groupFromServiceBase(g *service.Group) Group {
ImagePrice1K: g.ImagePrice1K, ImagePrice1K: g.ImagePrice1K,
ImagePrice2K: g.ImagePrice2K, ImagePrice2K: g.ImagePrice2K,
ImagePrice4K: g.ImagePrice4K, ImagePrice4K: g.ImagePrice4K,
SoraImagePrice360: g.SoraImagePrice360,
SoraImagePrice540: g.SoraImagePrice540,
SoraVideoPricePerRequest: g.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: g.ClaudeCodeOnly, ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID, FallbackGroupID: g.FallbackGroupID,
CreatedAt: g.CreatedAt, CreatedAt: g.CreatedAt,
...@@ -379,6 +383,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog { ...@@ -379,6 +383,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
FirstTokenMs: l.FirstTokenMs, FirstTokenMs: l.FirstTokenMs,
ImageCount: l.ImageCount, ImageCount: l.ImageCount,
ImageSize: l.ImageSize, ImageSize: l.ImageSize,
MediaType: l.MediaType,
UserAgent: l.UserAgent, UserAgent: l.UserAgent,
CreatedAt: l.CreatedAt, CreatedAt: l.CreatedAt,
User: UserFromServiceShallow(l.User), User: UserFromServiceShallow(l.User),
......
...@@ -61,6 +61,12 @@ type Group struct { ...@@ -61,6 +61,12 @@ type Group struct {
ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"` ImagePrice4K *float64 `json:"image_price_4k"`
// Sora 按次计费配置
SoraImagePrice360 *float64 `json:"sora_image_price_360"`
SoraImagePrice540 *float64 `json:"sora_image_price_540"`
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
// Claude Code 客户端限制 // Claude Code 客户端限制
ClaudeCodeOnly bool `json:"claude_code_only"` ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"` FallbackGroupID *int64 `json:"fallback_group_id"`
...@@ -246,6 +252,7 @@ type UsageLog struct { ...@@ -246,6 +252,7 @@ type UsageLog struct {
// 图片生成字段 // 图片生成字段
ImageCount int `json:"image_count"` ImageCount int `json:"image_count"`
ImageSize *string `json:"image_size"` ImageSize *string `json:"image_size"`
MediaType *string `json:"media_type"`
// User-Agent // User-Agent
UserAgent *string `json:"user_agent"` UserAgent *string `json:"user_agent"`
......
...@@ -29,6 +29,7 @@ type GatewayHandler struct { ...@@ -29,6 +29,7 @@ type GatewayHandler struct {
geminiCompatService *service.GeminiMessagesCompatService geminiCompatService *service.GeminiMessagesCompatService
antigravityGatewayService *service.AntigravityGatewayService antigravityGatewayService *service.AntigravityGatewayService
userService *service.UserService userService *service.UserService
sora2apiService *service.Sora2APIService
billingCacheService *service.BillingCacheService billingCacheService *service.BillingCacheService
concurrencyHelper *ConcurrencyHelper concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int maxAccountSwitches int
...@@ -41,6 +42,7 @@ func NewGatewayHandler( ...@@ -41,6 +42,7 @@ func NewGatewayHandler(
geminiCompatService *service.GeminiMessagesCompatService, geminiCompatService *service.GeminiMessagesCompatService,
antigravityGatewayService *service.AntigravityGatewayService, antigravityGatewayService *service.AntigravityGatewayService,
userService *service.UserService, userService *service.UserService,
sora2apiService *service.Sora2APIService,
concurrencyService *service.ConcurrencyService, concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService, billingCacheService *service.BillingCacheService,
cfg *config.Config, cfg *config.Config,
...@@ -62,6 +64,7 @@ func NewGatewayHandler( ...@@ -62,6 +64,7 @@ func NewGatewayHandler(
geminiCompatService: geminiCompatService, geminiCompatService: geminiCompatService,
antigravityGatewayService: antigravityGatewayService, antigravityGatewayService: antigravityGatewayService,
userService: userService, userService: userService,
sora2apiService: sora2apiService,
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval), concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
maxAccountSwitches: maxAccountSwitches, maxAccountSwitches: maxAccountSwitches,
...@@ -478,6 +481,26 @@ func (h *GatewayHandler) Models(c *gin.Context) { ...@@ -478,6 +481,26 @@ func (h *GatewayHandler) Models(c *gin.Context) {
groupID = &apiKey.Group.ID groupID = &apiKey.Group.ID
platform = apiKey.Group.Platform platform = apiKey.Group.Platform
} }
if forcedPlatform, ok := middleware2.GetForcePlatformFromContext(c); ok && strings.TrimSpace(forcedPlatform) != "" {
platform = forcedPlatform
}
if platform == service.PlatformSora {
if h.sora2apiService == nil || !h.sora2apiService.Enabled() {
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "sora2api not configured")
return
}
models, err := h.sora2apiService.ListModels(c.Request.Context())
if err != nil {
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Failed to fetch Sora models")
return
}
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": models,
})
return
}
// Get available models from account configurations (without platform filter) // Get available models from account configurations (without platform filter)
availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, "") availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, "")
......
...@@ -23,6 +23,7 @@ type AdminHandlers struct { ...@@ -23,6 +23,7 @@ type AdminHandlers struct {
Subscription *admin.SubscriptionHandler Subscription *admin.SubscriptionHandler
Usage *admin.UsageHandler Usage *admin.UsageHandler
UserAttribute *admin.UserAttributeHandler UserAttribute *admin.UserAttributeHandler
Model *admin.ModelHandler
} }
// Handlers contains all HTTP handlers // Handlers contains all HTTP handlers
...@@ -36,6 +37,7 @@ type Handlers struct { ...@@ -36,6 +37,7 @@ type Handlers struct {
Admin *AdminHandlers Admin *AdminHandlers
Gateway *GatewayHandler Gateway *GatewayHandler
OpenAIGateway *OpenAIGatewayHandler OpenAIGateway *OpenAIGatewayHandler
SoraGateway *SoraGatewayHandler
Setting *SettingHandler Setting *SettingHandler
} }
......
package handler
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"path"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// SoraGatewayHandler handles Sora chat completions requests
type SoraGatewayHandler struct {
gatewayService *service.GatewayService
soraGatewayService *service.SoraGatewayService
billingCacheService *service.BillingCacheService
concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
streamMode string
sora2apiBaseURL string
soraMediaSigningKey string
}
// NewSoraGatewayHandler creates a new SoraGatewayHandler
func NewSoraGatewayHandler(
gatewayService *service.GatewayService,
soraGatewayService *service.SoraGatewayService,
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
cfg *config.Config,
) *SoraGatewayHandler {
pingInterval := time.Duration(0)
maxAccountSwitches := 3
streamMode := "force"
signKey := ""
if cfg != nil {
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
if cfg.Gateway.MaxAccountSwitches > 0 {
maxAccountSwitches = cfg.Gateway.MaxAccountSwitches
}
if mode := strings.TrimSpace(cfg.Gateway.SoraStreamMode); mode != "" {
streamMode = mode
}
signKey = strings.TrimSpace(cfg.Gateway.SoraMediaSigningKey)
}
baseURL := ""
if cfg != nil {
baseURL = strings.TrimRight(strings.TrimSpace(cfg.Sora2API.BaseURL), "/")
}
return &SoraGatewayHandler{
gatewayService: gatewayService,
soraGatewayService: soraGatewayService,
billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
maxAccountSwitches: maxAccountSwitches,
streamMode: strings.ToLower(streamMode),
sora2apiBaseURL: baseURL,
soraMediaSigningKey: signKey,
}
}
// ChatCompletions handles Sora /v1/chat/completions endpoint
func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
body, err := io.ReadAll(c.Request.Body)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(body) == 0 {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
setOpsRequestContext(c, "", false, body)
var reqBody map[string]any
if err := json.Unmarshal(body, &reqBody); err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
reqModel, _ := reqBody["model"].(string)
if reqModel == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
reqMessages, _ := reqBody["messages"].([]any)
if len(reqMessages) == 0 {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "messages is required")
return
}
clientStream, _ := reqBody["stream"].(bool)
if !clientStream {
if h.streamMode == "error" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Sora requires stream=true")
return
}
reqBody["stream"] = true
updated, err := json.Marshal(reqBody)
if err != nil {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
return
}
body = updated
}
setOpsRequestContext(c, reqModel, clientStream, body)
platform := ""
if forced, ok := middleware2.GetForcePlatformFromContext(c); ok {
platform = forced
} else if apiKey.Group != nil {
platform = apiKey.Group.Platform
}
if platform != service.PlatformSora {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "This endpoint only supports Sora platform")
return
}
streamStarted := false
subscription, _ := middleware2.GetSubscriptionFromContext(c)
maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
waitCounted := false
if err != nil {
log.Printf("Increment wait count failed: %v", err)
} else if !canWait {
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return
}
if err == nil && canWait {
waitCounted = true
}
defer func() {
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
}
}()
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, clientStream, &streamStarted)
if err != nil {
log.Printf("User concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
waitCounted = false
}
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil {
defer userReleaseFunc()
}
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err)
status, code, message := billingErrorDetails(err)
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
sessionHash := generateOpenAISessionHash(c, reqBody)
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "")
if err != nil {
log.Printf("[Sora Handler] SelectAccount failed: %v", err)
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
account := selection.Account
setOpsSelectedAccount(c, account.ID)
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
accountWaitCounted := false
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
}
if err == nil && canWait {
accountWaitCounted = true
}
defer func() {
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}()
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
clientStream,
&streamStarted,
)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
}
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
result, err := h.soraGatewayService.Forward(c.Request.Context(), c, account, body, clientStream)
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches {
lastFailoverStatus = failoverErr.StatusCode
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
lastFailoverStatus = failoverErr.StatusCode
switchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
continue
}
log.Printf("Account %d: Forward request failed: %v", account.ID, err)
return
}
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
UserAgent: ua,
IPAddress: ip,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}(result, account, userAgent, clientIP)
return
}
}
func generateOpenAISessionHash(c *gin.Context, reqBody map[string]any) string {
if c == nil {
return ""
}
sessionID := strings.TrimSpace(c.GetHeader("session_id"))
if sessionID == "" {
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
}
if sessionID == "" && reqBody != nil {
if v, ok := reqBody["prompt_cache_key"].(string); ok {
sessionID = strings.TrimSpace(v)
}
}
if sessionID == "" {
return ""
}
hash := sha256.Sum256([]byte(sessionID))
return hex.EncodeToString(hash[:])
}
func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
}
func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) {
status, errType, errMsg := h.mapUpstreamError(statusCode)
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
}
func (h *SoraGatewayHandler) mapUpstreamError(statusCode int) (int, string, string) {
switch statusCode {
case 401:
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
case 403:
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
case 429:
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
case 529:
return http.StatusServiceUnavailable, "upstream_error", "Upstream service overloaded, please retry later"
case 500, 502, 503, 504:
return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable"
default:
return http.StatusBadGateway, "upstream_error", "Upstream request failed"
}
}
func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
if streamStarted {
flusher, ok := c.Writer.(http.Flusher)
if ok {
errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
_ = c.Error(err)
}
flusher.Flush()
}
return
}
h.errorResponse(c, status, errType, message)
}
func (h *SoraGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{
"error": gin.H{
"type": errType,
"message": message,
},
})
}
// MediaProxy proxies /tmp or /static media files from sora2api
func (h *SoraGatewayHandler) MediaProxy(c *gin.Context) {
h.proxySoraMedia(c, false)
}
// MediaProxySigned proxies /tmp or /static media files with signature verification
func (h *SoraGatewayHandler) MediaProxySigned(c *gin.Context) {
h.proxySoraMedia(c, true)
}
func (h *SoraGatewayHandler) proxySoraMedia(c *gin.Context, requireSignature bool) {
if h.sora2apiBaseURL == "" {
c.JSON(http.StatusServiceUnavailable, gin.H{
"error": gin.H{
"type": "api_error",
"message": "sora2api 未配置",
},
})
return
}
rawPath := c.Param("filepath")
if rawPath == "" {
c.Status(http.StatusNotFound)
return
}
cleaned := path.Clean(rawPath)
if !strings.HasPrefix(cleaned, "/tmp/") && !strings.HasPrefix(cleaned, "/static/") {
c.Status(http.StatusNotFound)
return
}
query := c.Request.URL.Query()
if requireSignature {
if h.soraMediaSigningKey == "" {
c.JSON(http.StatusServiceUnavailable, gin.H{
"error": gin.H{
"type": "api_error",
"message": "Sora 媒体签名未配置",
},
})
return
}
expiresStr := strings.TrimSpace(query.Get("expires"))
signature := strings.TrimSpace(query.Get("sig"))
expires, err := strconv.ParseInt(expiresStr, 10, 64)
if err != nil || expires <= time.Now().Unix() {
c.JSON(http.StatusUnauthorized, gin.H{
"error": gin.H{
"type": "authentication_error",
"message": "Sora 媒体签名已过期",
},
})
return
}
query.Del("sig")
query.Del("expires")
signingQuery := query.Encode()
if !service.VerifySoraMediaURL(cleaned, signingQuery, expires, signature, h.soraMediaSigningKey) {
c.JSON(http.StatusUnauthorized, gin.H{
"error": gin.H{
"type": "authentication_error",
"message": "Sora 媒体签名无效",
},
})
return
}
}
targetURL := h.sora2apiBaseURL + cleaned
if rawQuery := query.Encode(); rawQuery != "" {
targetURL += "?" + rawQuery
}
req, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, targetURL, nil)
if err != nil {
c.Status(http.StatusBadGateway)
return
}
copyHeaders := []string{"Range", "If-Range", "If-Modified-Since", "If-None-Match", "Accept", "User-Agent"}
for _, key := range copyHeaders {
if val := c.GetHeader(key); val != "" {
req.Header.Set(key, val)
}
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
c.Status(http.StatusBadGateway)
return
}
defer func() { _ = resp.Body.Close() }()
for _, key := range []string{"Content-Type", "Content-Length", "Accept-Ranges", "Content-Range", "Cache-Control", "Last-Modified", "ETag"} {
if val := resp.Header.Get(key); val != "" {
c.Header(key, val)
}
}
c.Status(resp.StatusCode)
_, _ = io.Copy(c.Writer, resp.Body)
}
...@@ -26,6 +26,7 @@ func ProvideAdminHandlers( ...@@ -26,6 +26,7 @@ func ProvideAdminHandlers(
subscriptionHandler *admin.SubscriptionHandler, subscriptionHandler *admin.SubscriptionHandler,
usageHandler *admin.UsageHandler, usageHandler *admin.UsageHandler,
userAttributeHandler *admin.UserAttributeHandler, userAttributeHandler *admin.UserAttributeHandler,
modelHandler *admin.ModelHandler,
) *AdminHandlers { ) *AdminHandlers {
return &AdminHandlers{ return &AdminHandlers{
Dashboard: dashboardHandler, Dashboard: dashboardHandler,
...@@ -45,6 +46,7 @@ func ProvideAdminHandlers( ...@@ -45,6 +46,7 @@ func ProvideAdminHandlers(
Subscription: subscriptionHandler, Subscription: subscriptionHandler,
Usage: usageHandler, Usage: usageHandler,
UserAttribute: userAttributeHandler, UserAttribute: userAttributeHandler,
Model: modelHandler,
} }
} }
...@@ -69,6 +71,7 @@ func ProvideHandlers( ...@@ -69,6 +71,7 @@ func ProvideHandlers(
adminHandlers *AdminHandlers, adminHandlers *AdminHandlers,
gatewayHandler *GatewayHandler, gatewayHandler *GatewayHandler,
openaiGatewayHandler *OpenAIGatewayHandler, openaiGatewayHandler *OpenAIGatewayHandler,
soraGatewayHandler *SoraGatewayHandler,
settingHandler *SettingHandler, settingHandler *SettingHandler,
) *Handlers { ) *Handlers {
return &Handlers{ return &Handlers{
...@@ -81,6 +84,7 @@ func ProvideHandlers( ...@@ -81,6 +84,7 @@ func ProvideHandlers(
Admin: adminHandlers, Admin: adminHandlers,
Gateway: gatewayHandler, Gateway: gatewayHandler,
OpenAIGateway: openaiGatewayHandler, OpenAIGateway: openaiGatewayHandler,
SoraGateway: soraGatewayHandler,
Setting: settingHandler, Setting: settingHandler,
} }
} }
...@@ -96,6 +100,7 @@ var ProviderSet = wire.NewSet( ...@@ -96,6 +100,7 @@ var ProviderSet = wire.NewSet(
NewSubscriptionHandler, NewSubscriptionHandler,
NewGatewayHandler, NewGatewayHandler,
NewOpenAIGatewayHandler, NewOpenAIGatewayHandler,
NewSoraGatewayHandler,
ProvideSettingHandler, ProvideSettingHandler,
// Admin handlers // Admin handlers
...@@ -116,6 +121,7 @@ var ProviderSet = wire.NewSet( ...@@ -116,6 +121,7 @@ var ProviderSet = wire.NewSet(
admin.NewSubscriptionHandler, admin.NewSubscriptionHandler,
admin.NewUsageHandler, admin.NewUsageHandler,
admin.NewUserAttributeHandler, admin.NewUserAttributeHandler,
admin.NewModelHandler,
// AdminHandlers and Handlers constructors // AdminHandlers and Handlers constructors
ProvideAdminHandlers, ProvideAdminHandlers,
......
...@@ -13,6 +13,7 @@ import ( ...@@ -13,6 +13,7 @@ import (
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
"os"
"strings" "strings"
"testing" "testing"
"time" "time"
...@@ -38,9 +39,7 @@ type TLSInfo struct { ...@@ -38,9 +39,7 @@ type TLSInfo struct {
// TestDialerBasicConnection tests that the dialer can establish TLS connections. // TestDialerBasicConnection tests that the dialer can establish TLS connections.
func TestDialerBasicConnection(t *testing.T) { func TestDialerBasicConnection(t *testing.T) {
if testing.Short() { skipNetworkTest(t)
t.Skip("skipping network test in short mode")
}
// Create a dialer with default profile // Create a dialer with default profile
profile := &Profile{ profile := &Profile{
...@@ -74,10 +73,7 @@ func TestDialerBasicConnection(t *testing.T) { ...@@ -74,10 +73,7 @@ func TestDialerBasicConnection(t *testing.T) {
// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x) // Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x)
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP) // Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP)
func TestJA3Fingerprint(t *testing.T) { func TestJA3Fingerprint(t *testing.T) {
// Skip if network is unavailable or if running in short mode skipNetworkTest(t)
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
profile := &Profile{ profile := &Profile{
Name: "Claude CLI Test", Name: "Claude CLI Test",
...@@ -178,6 +174,15 @@ func TestJA3Fingerprint(t *testing.T) { ...@@ -178,6 +174,15 @@ func TestJA3Fingerprint(t *testing.T) {
} }
} }
func skipNetworkTest(t *testing.T) {
if testing.Short() {
t.Skip("跳过网络测试(short 模式)")
}
if os.Getenv("TLSFINGERPRINT_NETWORK_TESTS") != "1" {
t.Skip("跳过网络测试(需要设置 TLSFINGERPRINT_NETWORK_TESTS=1)")
}
}
// TestDialerWithProfile tests that different profiles produce different fingerprints. // TestDialerWithProfile tests that different profiles produce different fingerprints.
func TestDialerWithProfile(t *testing.T) { func TestDialerWithProfile(t *testing.T) {
// Create two dialers with different profiles // Create two dialers with different profiles
...@@ -317,9 +322,7 @@ type TestProfileExpectation struct { ...@@ -317,9 +322,7 @@ type TestProfileExpectation struct {
// TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws. // TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws.
// Run with: go test -v -run TestAllProfiles ./internal/pkg/tlsfingerprint/... // Run with: go test -v -run TestAllProfiles ./internal/pkg/tlsfingerprint/...
func TestAllProfiles(t *testing.T) { func TestAllProfiles(t *testing.T) {
if testing.Short() { skipNetworkTest(t)
t.Skip("skipping integration test in short mode")
}
// Define all profiles to test with their expected fingerprints // Define all profiles to test with their expected fingerprints
// These profiles are from config.yaml gateway.tls_fingerprint.profiles // These profiles are from config.yaml gateway.tls_fingerprint.profiles
......
...@@ -134,6 +134,10 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se ...@@ -134,6 +134,10 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
group.FieldImagePrice1k, group.FieldImagePrice1k,
group.FieldImagePrice2k, group.FieldImagePrice2k,
group.FieldImagePrice4k, group.FieldImagePrice4k,
group.FieldSoraImagePrice360,
group.FieldSoraImagePrice540,
group.FieldSoraVideoPricePerRequest,
group.FieldSoraVideoPricePerRequestHd,
group.FieldClaudeCodeOnly, group.FieldClaudeCodeOnly,
group.FieldFallbackGroupID, group.FieldFallbackGroupID,
group.FieldModelRoutingEnabled, group.FieldModelRoutingEnabled,
...@@ -421,6 +425,10 @@ func groupEntityToService(g *dbent.Group) *service.Group { ...@@ -421,6 +425,10 @@ func groupEntityToService(g *dbent.Group) *service.Group {
ImagePrice1K: g.ImagePrice1k, ImagePrice1K: g.ImagePrice1k,
ImagePrice2K: g.ImagePrice2k, ImagePrice2K: g.ImagePrice2k,
ImagePrice4K: g.ImagePrice4k, ImagePrice4K: g.ImagePrice4k,
SoraImagePrice360: g.SoraImagePrice360,
SoraImagePrice540: g.SoraImagePrice540,
SoraVideoPricePerRequest: g.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHd,
DefaultValidityDays: g.DefaultValidityDays, DefaultValidityDays: g.DefaultValidityDays,
ClaudeCodeOnly: g.ClaudeCodeOnly, ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID, FallbackGroupID: g.FallbackGroupID,
......
...@@ -47,6 +47,10 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er ...@@ -47,6 +47,10 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice1k(groupIn.ImagePrice1K). SetNillableImagePrice1k(groupIn.ImagePrice1K).
SetNillableImagePrice2k(groupIn.ImagePrice2K). SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K). SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetNillableSoraImagePrice360(groupIn.SoraImagePrice360).
SetNillableSoraImagePrice540(groupIn.SoraImagePrice540).
SetNillableSoraVideoPricePerRequest(groupIn.SoraVideoPricePerRequest).
SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD).
SetDefaultValidityDays(groupIn.DefaultValidityDays). SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetNillableFallbackGroupID(groupIn.FallbackGroupID). SetNillableFallbackGroupID(groupIn.FallbackGroupID).
...@@ -106,6 +110,10 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er ...@@ -106,6 +110,10 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice1k(groupIn.ImagePrice1K). SetNillableImagePrice1k(groupIn.ImagePrice1K).
SetNillableImagePrice2k(groupIn.ImagePrice2K). SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K). SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetNillableSoraImagePrice360(groupIn.SoraImagePrice360).
SetNillableSoraImagePrice540(groupIn.SoraImagePrice540).
SetNillableSoraVideoPricePerRequest(groupIn.SoraVideoPricePerRequest).
SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD).
SetDefaultValidityDays(groupIn.DefaultValidityDays). SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled) SetModelRoutingEnabled(groupIn.ModelRoutingEnabled)
......
...@@ -22,7 +22,7 @@ import ( ...@@ -22,7 +22,7 @@ import (
"github.com/lib/pq" "github.com/lib/pq"
) )
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, created_at" const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, created_at"
type usageLogRepository struct { type usageLogRepository struct {
client *dbent.Client client *dbent.Client
...@@ -114,6 +114,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ...@@ -114,6 +114,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
ip_address, ip_address,
image_count, image_count,
image_size, image_size,
media_type,
created_at created_at
) VALUES ( ) VALUES (
$1, $2, $3, $4, $5, $1, $2, $3, $4, $5,
...@@ -121,7 +122,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ...@@ -121,7 +122,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
$8, $9, $10, $11, $8, $9, $10, $11,
$12, $13, $12, $13,
$14, $15, $16, $17, $18, $19, $14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30 $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31
) )
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at RETURNING id, created_at
...@@ -134,6 +135,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ...@@ -134,6 +135,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
userAgent := nullString(log.UserAgent) userAgent := nullString(log.UserAgent)
ipAddress := nullString(log.IPAddress) ipAddress := nullString(log.IPAddress)
imageSize := nullString(log.ImageSize) imageSize := nullString(log.ImageSize)
mediaType := nullString(log.MediaType)
var requestIDArg any var requestIDArg any
if requestID != "" { if requestID != "" {
...@@ -170,6 +172,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ...@@ -170,6 +172,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
ipAddress, ipAddress,
log.ImageCount, log.ImageCount,
imageSize, imageSize,
mediaType,
createdAt, createdAt,
} }
if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil { if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil {
...@@ -2090,6 +2093,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -2090,6 +2093,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
ipAddress sql.NullString ipAddress sql.NullString
imageCount int imageCount int
imageSize sql.NullString imageSize sql.NullString
mediaType sql.NullString
createdAt time.Time createdAt time.Time
) )
...@@ -2124,6 +2128,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -2124,6 +2128,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&ipAddress, &ipAddress,
&imageCount, &imageCount,
&imageSize, &imageSize,
&mediaType,
&createdAt, &createdAt,
); err != nil { ); err != nil {
return nil, err return nil, err
...@@ -2183,6 +2188,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -2183,6 +2188,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if imageSize.Valid { if imageSize.Valid {
log.ImageSize = &imageSize.String log.ImageSize = &imageSize.String
} }
if mediaType.Valid {
log.MediaType = &mediaType.String
}
return log, nil return log, nil
} }
......
...@@ -64,6 +64,9 @@ func RegisterAdminRoutes( ...@@ -64,6 +64,9 @@ func RegisterAdminRoutes(
// 用户属性管理 // 用户属性管理
registerUserAttributeRoutes(admin, h) registerUserAttributeRoutes(admin, h)
// 模型列表
registerModelRoutes(admin, h)
} }
} }
...@@ -371,3 +374,7 @@ func registerUserAttributeRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -371,3 +374,7 @@ func registerUserAttributeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
attrs.DELETE("/:id", h.Admin.UserAttribute.DeleteDefinition) attrs.DELETE("/:id", h.Admin.UserAttribute.DeleteDefinition)
} }
} }
func registerModelRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
admin.GET("/models", h.Admin.Model.List)
}
...@@ -20,6 +20,11 @@ func RegisterGatewayRoutes( ...@@ -20,6 +20,11 @@ func RegisterGatewayRoutes(
cfg *config.Config, cfg *config.Config,
) { ) {
bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize) bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize)
soraMaxBodySize := cfg.Gateway.SoraMaxBodySize
if soraMaxBodySize <= 0 {
soraMaxBodySize = cfg.Gateway.MaxBodySize
}
soraBodyLimit := middleware.RequestBodyLimit(soraMaxBodySize)
clientRequestID := middleware.ClientRequestID() clientRequestID := middleware.ClientRequestID()
opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService) opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService)
...@@ -38,6 +43,16 @@ func RegisterGatewayRoutes( ...@@ -38,6 +43,16 @@ func RegisterGatewayRoutes(
gateway.POST("/responses", h.OpenAIGateway.Responses) gateway.POST("/responses", h.OpenAIGateway.Responses)
} }
// Sora Chat Completions
soraGateway := r.Group("/v1")
soraGateway.Use(soraBodyLimit)
soraGateway.Use(clientRequestID)
soraGateway.Use(opsErrorLogger)
soraGateway.Use(gin.HandlerFunc(apiKeyAuth))
{
soraGateway.POST("/chat/completions", h.SoraGateway.ChatCompletions)
}
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连) // Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
gemini := r.Group("/v1beta") gemini := r.Group("/v1beta")
gemini.Use(bodyLimit) gemini.Use(bodyLimit)
...@@ -82,4 +97,25 @@ func RegisterGatewayRoutes( ...@@ -82,4 +97,25 @@ func RegisterGatewayRoutes(
antigravityV1Beta.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel) antigravityV1Beta.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
antigravityV1Beta.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels) antigravityV1Beta.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels)
} }
// Sora 专用路由(强制使用 sora 平台)
soraV1 := r.Group("/sora/v1")
soraV1.Use(soraBodyLimit)
soraV1.Use(clientRequestID)
soraV1.Use(opsErrorLogger)
soraV1.Use(middleware.ForcePlatform(service.PlatformSora))
soraV1.Use(gin.HandlerFunc(apiKeyAuth))
{
soraV1.POST("/chat/completions", h.SoraGateway.ChatCompletions)
soraV1.GET("/models", h.Gateway.Models)
}
// Sora 媒体代理(可选 API Key 验证)
if cfg.Gateway.SoraMediaRequireAPIKey {
r.GET("/sora/media/*filepath", gin.HandlerFunc(apiKeyAuth), h.SoraGateway.MediaProxy)
} else {
r.GET("/sora/media/*filepath", h.SoraGateway.MediaProxy)
}
// Sora 媒体代理(签名 URL,无需 API Key)
r.GET("/sora/media-signed/*filepath", h.SoraGateway.MediaProxySigned)
} }
...@@ -102,11 +102,16 @@ type CreateGroupInput struct { ...@@ -102,11 +102,16 @@ type CreateGroupInput struct {
WeeklyLimitUSD *float64 // 周限额 (USD) WeeklyLimitUSD *float64 // 周限额 (USD)
MonthlyLimitUSD *float64 // 月限额 (USD) MonthlyLimitUSD *float64 // 月限额 (USD)
// 图片生成计费配置(仅 antigravity 平台使用) // 图片生成计费配置(仅 antigravity 平台使用)
ImagePrice1K *float64 ImagePrice1K *float64
ImagePrice2K *float64 ImagePrice2K *float64
ImagePrice4K *float64 ImagePrice4K *float64
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端 // Sora 按次计费配置
FallbackGroupID *int64 // 降级分组 ID SoraImagePrice360 *float64
SoraImagePrice540 *float64
SoraVideoPricePerRequest *float64
SoraVideoPricePerRequestHD *float64
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID
// 模型路由配置(仅 anthropic 平台使用) // 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 ModelRouting map[string][]int64
ModelRoutingEnabled bool // 是否启用模型路由 ModelRoutingEnabled bool // 是否启用模型路由
...@@ -124,11 +129,16 @@ type UpdateGroupInput struct { ...@@ -124,11 +129,16 @@ type UpdateGroupInput struct {
WeeklyLimitUSD *float64 // 周限额 (USD) WeeklyLimitUSD *float64 // 周限额 (USD)
MonthlyLimitUSD *float64 // 月限额 (USD) MonthlyLimitUSD *float64 // 月限额 (USD)
// 图片生成计费配置(仅 antigravity 平台使用) // 图片生成计费配置(仅 antigravity 平台使用)
ImagePrice1K *float64 ImagePrice1K *float64
ImagePrice2K *float64 ImagePrice2K *float64
ImagePrice4K *float64 ImagePrice4K *float64
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端 // Sora 按次计费配置
FallbackGroupID *int64 // 降级分组 ID SoraImagePrice360 *float64
SoraImagePrice540 *float64
SoraVideoPricePerRequest *float64
SoraVideoPricePerRequestHD *float64
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID
// 模型路由配置(仅 anthropic 平台使用) // 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 ModelRouting map[string][]int64
ModelRoutingEnabled *bool // 是否启用模型路由 ModelRoutingEnabled *bool // 是否启用模型路由
...@@ -273,6 +283,7 @@ type adminServiceImpl struct { ...@@ -273,6 +283,7 @@ type adminServiceImpl struct {
groupRepo GroupRepository groupRepo GroupRepository
accountRepo AccountRepository accountRepo AccountRepository
soraAccountRepo SoraAccountRepository // Sora 账号扩展表仓储 soraAccountRepo SoraAccountRepository // Sora 账号扩展表仓储
soraSyncService *Sora2APISyncService // Sora2API 同步服务
proxyRepo ProxyRepository proxyRepo ProxyRepository
apiKeyRepo APIKeyRepository apiKeyRepo APIKeyRepository
redeemCodeRepo RedeemCodeRepository redeemCodeRepo RedeemCodeRepository
...@@ -288,6 +299,7 @@ func NewAdminService( ...@@ -288,6 +299,7 @@ func NewAdminService(
groupRepo GroupRepository, groupRepo GroupRepository,
accountRepo AccountRepository, accountRepo AccountRepository,
soraAccountRepo SoraAccountRepository, soraAccountRepo SoraAccountRepository,
soraSyncService *Sora2APISyncService,
proxyRepo ProxyRepository, proxyRepo ProxyRepository,
apiKeyRepo APIKeyRepository, apiKeyRepo APIKeyRepository,
redeemCodeRepo RedeemCodeRepository, redeemCodeRepo RedeemCodeRepository,
...@@ -301,6 +313,7 @@ func NewAdminService( ...@@ -301,6 +313,7 @@ func NewAdminService(
groupRepo: groupRepo, groupRepo: groupRepo,
accountRepo: accountRepo, accountRepo: accountRepo,
soraAccountRepo: soraAccountRepo, soraAccountRepo: soraAccountRepo,
soraSyncService: soraSyncService,
proxyRepo: proxyRepo, proxyRepo: proxyRepo,
apiKeyRepo: apiKeyRepo, apiKeyRepo: apiKeyRepo,
redeemCodeRepo: redeemCodeRepo, redeemCodeRepo: redeemCodeRepo,
...@@ -567,6 +580,10 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn ...@@ -567,6 +580,10 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
imagePrice1K := normalizePrice(input.ImagePrice1K) imagePrice1K := normalizePrice(input.ImagePrice1K)
imagePrice2K := normalizePrice(input.ImagePrice2K) imagePrice2K := normalizePrice(input.ImagePrice2K)
imagePrice4K := normalizePrice(input.ImagePrice4K) imagePrice4K := normalizePrice(input.ImagePrice4K)
soraImagePrice360 := normalizePrice(input.SoraImagePrice360)
soraImagePrice540 := normalizePrice(input.SoraImagePrice540)
soraVideoPrice := normalizePrice(input.SoraVideoPricePerRequest)
soraVideoPriceHD := normalizePrice(input.SoraVideoPricePerRequestHD)
// 校验降级分组 // 校验降级分组
if input.FallbackGroupID != nil { if input.FallbackGroupID != nil {
...@@ -576,22 +593,26 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn ...@@ -576,22 +593,26 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
} }
group := &Group{ group := &Group{
Name: input.Name, Name: input.Name,
Description: input.Description, Description: input.Description,
Platform: platform, Platform: platform,
RateMultiplier: input.RateMultiplier, RateMultiplier: input.RateMultiplier,
IsExclusive: input.IsExclusive, IsExclusive: input.IsExclusive,
Status: StatusActive, Status: StatusActive,
SubscriptionType: subscriptionType, SubscriptionType: subscriptionType,
DailyLimitUSD: dailyLimit, DailyLimitUSD: dailyLimit,
WeeklyLimitUSD: weeklyLimit, WeeklyLimitUSD: weeklyLimit,
MonthlyLimitUSD: monthlyLimit, MonthlyLimitUSD: monthlyLimit,
ImagePrice1K: imagePrice1K, ImagePrice1K: imagePrice1K,
ImagePrice2K: imagePrice2K, ImagePrice2K: imagePrice2K,
ImagePrice4K: imagePrice4K, ImagePrice4K: imagePrice4K,
ClaudeCodeOnly: input.ClaudeCodeOnly, SoraImagePrice360: soraImagePrice360,
FallbackGroupID: input.FallbackGroupID, SoraImagePrice540: soraImagePrice540,
ModelRouting: input.ModelRouting, SoraVideoPricePerRequest: soraVideoPrice,
SoraVideoPricePerRequestHD: soraVideoPriceHD,
ClaudeCodeOnly: input.ClaudeCodeOnly,
FallbackGroupID: input.FallbackGroupID,
ModelRouting: input.ModelRouting,
} }
if err := s.groupRepo.Create(ctx, group); err != nil { if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err return nil, err
...@@ -702,6 +723,18 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd ...@@ -702,6 +723,18 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.ImagePrice4K != nil { if input.ImagePrice4K != nil {
group.ImagePrice4K = normalizePrice(input.ImagePrice4K) group.ImagePrice4K = normalizePrice(input.ImagePrice4K)
} }
if input.SoraImagePrice360 != nil {
group.SoraImagePrice360 = normalizePrice(input.SoraImagePrice360)
}
if input.SoraImagePrice540 != nil {
group.SoraImagePrice540 = normalizePrice(input.SoraImagePrice540)
}
if input.SoraVideoPricePerRequest != nil {
group.SoraVideoPricePerRequest = normalizePrice(input.SoraVideoPricePerRequest)
}
if input.SoraVideoPricePerRequestHD != nil {
group.SoraVideoPricePerRequestHD = normalizePrice(input.SoraVideoPricePerRequestHD)
}
// Claude Code 客户端限制 // Claude Code 客户端限制
if input.ClaudeCodeOnly != nil { if input.ClaudeCodeOnly != nil {
...@@ -884,6 +917,9 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou ...@@ -884,6 +917,9 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
} }
} }
// 同步到 sora2api(异步,不阻塞创建)
s.syncSoraAccountAsync(account)
return account, nil return account, nil
} }
...@@ -974,7 +1010,12 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U ...@@ -974,7 +1010,12 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
} }
// 重新查询以确保返回完整数据(包括正确的 Proxy 关联对象) // 重新查询以确保返回完整数据(包括正确的 Proxy 关联对象)
return s.accountRepo.GetByID(ctx, id) updated, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
s.syncSoraAccountAsync(updated)
return updated, nil
} }
// BulkUpdateAccounts updates multiple accounts in one request. // BulkUpdateAccounts updates multiple accounts in one request.
...@@ -990,16 +1031,23 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp ...@@ -990,16 +1031,23 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
return result, nil return result, nil
} }
// Preload account platforms for mixed channel risk checks if group bindings are requested. needMixedChannelCheck := input.GroupIDs != nil && !input.SkipMixedChannelCheck
needSoraSync := s != nil && s.soraSyncService != nil
// 预加载账号平台信息(混合渠道检查或 Sora 同步需要)。
platformByID := map[int64]string{} platformByID := map[int64]string{}
if input.GroupIDs != nil && !input.SkipMixedChannelCheck { if needMixedChannelCheck || needSoraSync {
accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs) accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs)
if err != nil { if err != nil {
return nil, err if needMixedChannelCheck {
} return nil, err
for _, account := range accounts { }
if account != nil { log.Printf("[AdminService] 预加载账号平台信息失败,将逐个降级同步: err=%v", err)
platformByID[account.ID] = account.Platform } else {
for _, account := range accounts {
if account != nil {
platformByID[account.ID] = account.Platform
}
} }
} }
} }
...@@ -1086,13 +1134,46 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp ...@@ -1086,13 +1134,46 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
result.Success++ result.Success++
result.SuccessIDs = append(result.SuccessIDs, accountID) result.SuccessIDs = append(result.SuccessIDs, accountID)
result.Results = append(result.Results, entry) result.Results = append(result.Results, entry)
// 批量更新后同步 sora2api
if needSoraSync {
platform := platformByID[accountID]
if platform == "" {
updated, err := s.accountRepo.GetByID(ctx, accountID)
if err != nil {
log.Printf("[AdminService] 批量更新后获取账号失败,无法同步 sora2api: account_id=%d err=%v", accountID, err)
continue
}
if updated.Platform == PlatformSora {
s.syncSoraAccountAsync(updated)
}
continue
}
if platform == PlatformSora {
updated, err := s.accountRepo.GetByID(ctx, accountID)
if err != nil {
log.Printf("[AdminService] 批量更新后获取账号失败,无法同步 sora2api: account_id=%d err=%v", accountID, err)
continue
}
s.syncSoraAccountAsync(updated)
}
}
} }
return result, nil return result, nil
} }
func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error { func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error {
return s.accountRepo.Delete(ctx, id) account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return err
}
if err := s.accountRepo.Delete(ctx, id); err != nil {
return err
}
s.deleteSoraAccountAsync(account)
return nil
} }
func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error) { func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error) {
...@@ -1125,7 +1206,46 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, ...@@ -1125,7 +1206,46 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64,
if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil { if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil {
return nil, err return nil, err
} }
return s.accountRepo.GetByID(ctx, id) updated, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
s.syncSoraAccountAsync(updated)
return updated, nil
}
func (s *adminServiceImpl) syncSoraAccountAsync(account *Account) {
if s == nil || s.soraSyncService == nil || account == nil {
return
}
if account.Platform != PlatformSora {
return
}
syncAccount := *account
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := s.soraSyncService.SyncAccount(ctx, &syncAccount); err != nil {
log.Printf("[AdminService] 同步 sora2api 失败: account_id=%d err=%v", syncAccount.ID, err)
}
}()
}
func (s *adminServiceImpl) deleteSoraAccountAsync(account *Account) {
if s == nil || s.soraSyncService == nil || account == nil {
return
}
if account.Platform != PlatformSora {
return
}
syncAccount := *account
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := s.soraSyncService.DeleteAccount(ctx, &syncAccount); err != nil {
log.Printf("[AdminService] 删除 sora2api token 失败: account_id=%d err=%v", syncAccount.ID, err)
}
}()
} }
// Proxy management implementations // Proxy management implementations
......
...@@ -15,6 +15,13 @@ type accountRepoStubForBulkUpdate struct { ...@@ -15,6 +15,13 @@ type accountRepoStubForBulkUpdate struct {
bulkUpdateErr error bulkUpdateErr error
bulkUpdateIDs []int64 bulkUpdateIDs []int64
bindGroupErrByID map[int64]error bindGroupErrByID map[int64]error
getByIDsAccounts []*Account
getByIDsErr error
getByIDsCalled bool
getByIDsIDs []int64
getByIDAccounts map[int64]*Account
getByIDErrByID map[int64]error
getByIDCalled []int64
} }
func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) { func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) {
...@@ -32,6 +39,26 @@ func (s *accountRepoStubForBulkUpdate) BindGroups(_ context.Context, accountID i ...@@ -32,6 +39,26 @@ func (s *accountRepoStubForBulkUpdate) BindGroups(_ context.Context, accountID i
return nil return nil
} }
func (s *accountRepoStubForBulkUpdate) GetByIDs(_ context.Context, ids []int64) ([]*Account, error) {
s.getByIDsCalled = true
s.getByIDsIDs = append([]int64{}, ids...)
if s.getByIDsErr != nil {
return nil, s.getByIDsErr
}
return s.getByIDsAccounts, nil
}
func (s *accountRepoStubForBulkUpdate) GetByID(_ context.Context, id int64) (*Account, error) {
s.getByIDCalled = append(s.getByIDCalled, id)
if err, ok := s.getByIDErrByID[id]; ok {
return nil, err
}
if account, ok := s.getByIDAccounts[id]; ok {
return account, nil
}
return nil, errors.New("account not found")
}
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。 // TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) { func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) {
repo := &accountRepoStubForBulkUpdate{} repo := &accountRepoStubForBulkUpdate{}
...@@ -78,3 +105,31 @@ func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) { ...@@ -78,3 +105,31 @@ func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) {
require.ElementsMatch(t, []int64{2}, result.FailedIDs) require.ElementsMatch(t, []int64{2}, result.FailedIDs)
require.Len(t, result.Results, 3) require.Len(t, result.Results, 3)
} }
// TestAdminService_BulkUpdateAccounts_SoraSyncWithoutGroupIDs 验证无分组更新时仍会触发 Sora 同步。
func TestAdminService_BulkUpdateAccounts_SoraSyncWithoutGroupIDs(t *testing.T) {
repo := &accountRepoStubForBulkUpdate{
getByIDsAccounts: []*Account{
{ID: 1, Platform: PlatformSora},
},
getByIDAccounts: map[int64]*Account{
1: {ID: 1, Platform: PlatformSora},
},
}
svc := &adminServiceImpl{
accountRepo: repo,
soraSyncService: &Sora2APISyncService{},
}
schedulable := true
input := &BulkUpdateAccountsInput{
AccountIDs: []int64{1},
Schedulable: &schedulable,
}
result, err := svc.BulkUpdateAccounts(context.Background(), input)
require.NoError(t, err)
require.Equal(t, 1, result.Success)
require.True(t, repo.getByIDsCalled)
require.ElementsMatch(t, []int64{1}, repo.getByIDCalled)
}
...@@ -35,6 +35,10 @@ type APIKeyAuthGroupSnapshot struct { ...@@ -35,6 +35,10 @@ type APIKeyAuthGroupSnapshot struct {
ImagePrice1K *float64 `json:"image_price_1k,omitempty"` ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
ImagePrice2K *float64 `json:"image_price_2k,omitempty"` ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
ImagePrice4K *float64 `json:"image_price_4k,omitempty"` ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"`
SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"`
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"`
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd,omitempty"`
ClaudeCodeOnly bool `json:"claude_code_only"` ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
......
...@@ -235,6 +235,10 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { ...@@ -235,6 +235,10 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
ImagePrice1K: apiKey.Group.ImagePrice1K, ImagePrice1K: apiKey.Group.ImagePrice1K,
ImagePrice2K: apiKey.Group.ImagePrice2K, ImagePrice2K: apiKey.Group.ImagePrice2K,
ImagePrice4K: apiKey.Group.ImagePrice4K, ImagePrice4K: apiKey.Group.ImagePrice4K,
SoraImagePrice360: apiKey.Group.SoraImagePrice360,
SoraImagePrice540: apiKey.Group.SoraImagePrice540,
SoraVideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly, ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
FallbackGroupID: apiKey.Group.FallbackGroupID, FallbackGroupID: apiKey.Group.FallbackGroupID,
ModelRouting: apiKey.Group.ModelRouting, ModelRouting: apiKey.Group.ModelRouting,
...@@ -279,6 +283,10 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho ...@@ -279,6 +283,10 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
ImagePrice1K: snapshot.Group.ImagePrice1K, ImagePrice1K: snapshot.Group.ImagePrice1K,
ImagePrice2K: snapshot.Group.ImagePrice2K, ImagePrice2K: snapshot.Group.ImagePrice2K,
ImagePrice4K: snapshot.Group.ImagePrice4K, ImagePrice4K: snapshot.Group.ImagePrice4K,
SoraImagePrice360: snapshot.Group.SoraImagePrice360,
SoraImagePrice540: snapshot.Group.SoraImagePrice540,
SoraVideoPricePerRequest: snapshot.Group.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: snapshot.Group.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly, ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
FallbackGroupID: snapshot.Group.FallbackGroupID, FallbackGroupID: snapshot.Group.FallbackGroupID,
ModelRouting: snapshot.Group.ModelRouting, ModelRouting: snapshot.Group.ModelRouting,
......
...@@ -303,6 +303,14 @@ type ImagePriceConfig struct { ...@@ -303,6 +303,14 @@ type ImagePriceConfig struct {
Price4K *float64 // 4K 尺寸价格(nil 表示使用默认值) Price4K *float64 // 4K 尺寸价格(nil 表示使用默认值)
} }
// SoraPriceConfig Sora 按次计费配置
type SoraPriceConfig struct {
ImagePrice360 *float64
ImagePrice540 *float64
VideoPricePerRequest *float64
VideoPricePerRequestHD *float64
}
// CalculateImageCost 计算图片生成费用 // CalculateImageCost 计算图片生成费用
// model: 请求的模型名称(用于获取 LiteLLM 默认价格) // model: 请求的模型名称(用于获取 LiteLLM 默认价格)
// imageSize: 图片尺寸 "1K", "2K", "4K" // imageSize: 图片尺寸 "1K", "2K", "4K"
...@@ -332,6 +340,65 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag ...@@ -332,6 +340,65 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag
} }
} }
// CalculateSoraImageCost 计算 Sora 图片按次费用
func (s *BillingService) CalculateSoraImageCost(imageSize string, imageCount int, groupConfig *SoraPriceConfig, rateMultiplier float64) *CostBreakdown {
if imageCount <= 0 {
return &CostBreakdown{}
}
unitPrice := 0.0
if groupConfig != nil {
switch imageSize {
case "540":
if groupConfig.ImagePrice540 != nil {
unitPrice = *groupConfig.ImagePrice540
}
default:
if groupConfig.ImagePrice360 != nil {
unitPrice = *groupConfig.ImagePrice360
}
}
}
totalCost := unitPrice * float64(imageCount)
if rateMultiplier <= 0 {
rateMultiplier = 1.0
}
actualCost := totalCost * rateMultiplier
return &CostBreakdown{
TotalCost: totalCost,
ActualCost: actualCost,
}
}
// CalculateSoraVideoCost 计算 Sora 视频按次费用
func (s *BillingService) CalculateSoraVideoCost(model string, groupConfig *SoraPriceConfig, rateMultiplier float64) *CostBreakdown {
unitPrice := 0.0
if groupConfig != nil {
modelLower := strings.ToLower(model)
if strings.Contains(modelLower, "sora2pro-hd") {
if groupConfig.VideoPricePerRequestHD != nil {
unitPrice = *groupConfig.VideoPricePerRequestHD
}
}
if unitPrice <= 0 && groupConfig.VideoPricePerRequest != nil {
unitPrice = *groupConfig.VideoPricePerRequest
}
}
totalCost := unitPrice
if rateMultiplier <= 0 {
rateMultiplier = 1.0
}
actualCost := totalCost * rateMultiplier
return &CostBreakdown{
TotalCost: totalCost,
ActualCost: actualCost,
}
}
// getImageUnitPrice 获取图片单价 // getImageUnitPrice 获取图片单价
func (s *BillingService) getImageUnitPrice(model string, imageSize string, groupConfig *ImagePriceConfig) float64 { func (s *BillingService) getImageUnitPrice(model string, imageSize string, groupConfig *ImagePriceConfig) float64 {
// 优先使用分组配置的价格 // 优先使用分组配置的价格
......
...@@ -184,6 +184,10 @@ type ForwardResult struct { ...@@ -184,6 +184,10 @@ type ForwardResult struct {
// 图片生成计费字段(仅 gemini-3-pro-image 使用) // 图片生成计费字段(仅 gemini-3-pro-image 使用)
ImageCount int // 生成的图片数量 ImageCount int // 生成的图片数量
ImageSize string // 图片尺寸 "1K", "2K", "4K" ImageSize string // 图片尺寸 "1K", "2K", "4K"
// Sora 媒体字段
MediaType string // image / video / prompt
MediaURL string // 生成后的媒体地址(可选)
} }
// UpstreamFailoverError indicates an upstream error that should trigger account failover. // UpstreamFailoverError indicates an upstream error that should trigger account failover.
...@@ -3461,7 +3465,22 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -3461,7 +3465,22 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
var cost *CostBreakdown var cost *CostBreakdown
// 根据请求类型选择计费方式 // 根据请求类型选择计费方式
if result.ImageCount > 0 { if result.MediaType == "image" || result.MediaType == "video" || result.MediaType == "prompt" {
var soraConfig *SoraPriceConfig
if apiKey.Group != nil {
soraConfig = &SoraPriceConfig{
ImagePrice360: apiKey.Group.SoraImagePrice360,
ImagePrice540: apiKey.Group.SoraImagePrice540,
VideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest,
VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
}
}
if result.MediaType == "image" {
cost = s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier)
} else {
cost = s.billingService.CalculateSoraVideoCost(result.Model, soraConfig, multiplier)
}
} else if result.ImageCount > 0 {
// 图片生成计费 // 图片生成计费
var groupConfig *ImagePriceConfig var groupConfig *ImagePriceConfig
if apiKey.Group != nil { if apiKey.Group != nil {
...@@ -3501,6 +3520,10 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -3501,6 +3520,10 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
if result.ImageSize != "" { if result.ImageSize != "" {
imageSize = &result.ImageSize imageSize = &result.ImageSize
} }
var mediaType *string
if strings.TrimSpace(result.MediaType) != "" {
mediaType = &result.MediaType
}
accountRateMultiplier := account.BillingRateMultiplier() accountRateMultiplier := account.BillingRateMultiplier()
usageLog := &UsageLog{ usageLog := &UsageLog{
UserID: user.ID, UserID: user.ID,
...@@ -3526,6 +3549,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -3526,6 +3549,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
FirstTokenMs: result.FirstTokenMs, FirstTokenMs: result.FirstTokenMs,
ImageCount: result.ImageCount, ImageCount: result.ImageCount,
ImageSize: imageSize, ImageSize: imageSize,
MediaType: mediaType,
CreatedAt: time.Now(), CreatedAt: time.Now(),
} }
......
...@@ -26,6 +26,12 @@ type Group struct { ...@@ -26,6 +26,12 @@ type Group struct {
ImagePrice2K *float64 ImagePrice2K *float64
ImagePrice4K *float64 ImagePrice4K *float64
// Sora 按次计费配置(阶段 1)
SoraImagePrice360 *float64
SoraImagePrice540 *float64
SoraVideoPricePerRequest *float64
SoraVideoPricePerRequestHD *float64
// Claude Code 客户端限制 // Claude Code 客户端限制
ClaudeCodeOnly bool ClaudeCodeOnly bool
FallbackGroupID *int64 FallbackGroupID *int64
...@@ -83,6 +89,18 @@ func (g *Group) GetImagePrice(imageSize string) *float64 { ...@@ -83,6 +89,18 @@ func (g *Group) GetImagePrice(imageSize string) *float64 {
} }
} }
// GetSoraImagePrice 根据 Sora 图片尺寸返回价格(360/540)
func (g *Group) GetSoraImagePrice(imageSize string) *float64 {
switch imageSize {
case "360":
return g.SoraImagePrice360
case "540":
return g.SoraImagePrice540
default:
return g.SoraImagePrice360
}
}
// IsGroupContextValid reports whether a group from context has the fields required for routing decisions. // IsGroupContextValid reports whether a group from context has the fields required for routing decisions.
func IsGroupContextValid(group *Group) bool { func IsGroupContextValid(group *Group) bool {
if group == nil { if group == nil {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment