Commit bb664d9b authored by yangjianbo's avatar yangjianbo
Browse files

feat(sync): full code sync from release

parent bfc7b339
...@@ -97,7 +97,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti ...@@ -97,7 +97,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
// 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制 // 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制
if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 { if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 {
clientIP := ip.GetTrustedClientIP(c) clientIP := ip.GetTrustedClientIP(c)
allowed, _ := ip.CheckIPRestriction(clientIP, apiKey.IPWhitelist, apiKey.IPBlacklist) allowed, _ := ip.CheckIPRestrictionWithCompiledRules(clientIP, apiKey.CompiledIPWhitelist, apiKey.CompiledIPBlacklist)
if !allowed { if !allowed {
AbortWithError(c, 403, "ACCESS_DENIED", "Access denied") AbortWithError(c, 403, "ACCESS_DENIED", "Access denied")
return return
......
...@@ -80,17 +80,25 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs ...@@ -80,17 +80,25 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
abortWithGoogleError(c, 403, "No active subscription found for this group") abortWithGoogleError(c, 403, "No active subscription found for this group")
return return
} }
if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil {
abortWithGoogleError(c, 403, err.Error()) needsMaintenance, err := subscriptionService.ValidateAndCheckLimits(subscription, apiKey.Group)
return if err != nil {
} status := 403
_ = subscriptionService.CheckAndActivateWindow(c.Request.Context(), subscription) if errors.Is(err, service.ErrDailyLimitExceeded) ||
_ = subscriptionService.CheckAndResetWindows(c.Request.Context(), subscription) errors.Is(err, service.ErrWeeklyLimitExceeded) ||
if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil { errors.Is(err, service.ErrMonthlyLimitExceeded) {
abortWithGoogleError(c, 429, err.Error()) status = 429
}
abortWithGoogleError(c, status, err.Error())
return return
} }
c.Set(string(ContextKeySubscription), subscription) c.Set(string(ContextKeySubscription), subscription)
if needsMaintenance {
maintenanceCopy := *subscription
subscriptionService.DoWindowMaintenance(&maintenanceCopy)
}
} else { } else {
if apiKey.User.Balance <= 0 { if apiKey.User.Balance <= 0 {
abortWithGoogleError(c, 403, "Insufficient account balance") abortWithGoogleError(c, 403, "Insufficient account balance")
......
...@@ -23,6 +23,15 @@ type fakeAPIKeyRepo struct { ...@@ -23,6 +23,15 @@ type fakeAPIKeyRepo struct {
updateLastUsed func(ctx context.Context, id int64, usedAt time.Time) error updateLastUsed func(ctx context.Context, id int64, usedAt time.Time) error
} }
type fakeGoogleSubscriptionRepo struct {
getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error)
updateStatus func(ctx context.Context, subscriptionID int64, status string) error
activateWindow func(ctx context.Context, id int64, start time.Time) error
resetDaily func(ctx context.Context, id int64, start time.Time) error
resetWeekly func(ctx context.Context, id int64, start time.Time) error
resetMonthly func(ctx context.Context, id int64, start time.Time) error
}
func (f fakeAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error { func (f fakeAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
...@@ -87,6 +96,85 @@ func (f fakeAPIKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt tim ...@@ -87,6 +96,85 @@ func (f fakeAPIKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt tim
return nil return nil
} }
func (f fakeGoogleSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error {
return errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) {
return nil, errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
return nil, errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
if f.getActive != nil {
return f.getActive(ctx, userID, groupID)
}
return nil, errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) Update(ctx context.Context, sub *service.UserSubscription) error {
return errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
return nil, errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
return nil, errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
return false, errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error {
return errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error {
if f.updateStatus != nil {
return f.updateStatus(ctx, subscriptionID, status)
}
return errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error {
return errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) ActivateWindows(ctx context.Context, id int64, start time.Time) error {
if f.activateWindow != nil {
return f.activateWindow(ctx, id, start)
}
return errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) ResetDailyUsage(ctx context.Context, id int64, start time.Time) error {
if f.resetDaily != nil {
return f.resetDaily(ctx, id, start)
}
return errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) ResetWeeklyUsage(ctx context.Context, id int64, start time.Time) error {
if f.resetWeekly != nil {
return f.resetWeekly(ctx, id, start)
}
return errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) ResetMonthlyUsage(ctx context.Context, id int64, start time.Time) error {
if f.resetMonthly != nil {
return f.resetMonthly(ctx, id, start)
}
return errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
return errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
return 0, errors.New("not implemented")
}
type googleErrorResponse struct { type googleErrorResponse struct {
Error struct { Error struct {
Code int `json:"code"` Code int `json:"code"`
...@@ -505,3 +593,85 @@ func TestApiKeyAuthWithSubscriptionGoogle_TouchesLastUsedInStandardMode(t *testi ...@@ -505,3 +593,85 @@ func TestApiKeyAuthWithSubscriptionGoogle_TouchesLastUsedInStandardMode(t *testi
require.Equal(t, http.StatusOK, rec.Code) require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, 1, touchCalls) require.Equal(t, 1, touchCalls)
} }
func TestApiKeyAuthWithSubscriptionGoogle_SubscriptionLimitExceededReturns429(t *testing.T) {
gin.SetMode(gin.TestMode)
limit := 1.0
group := &service.Group{
ID: 77,
Name: "gemini-sub",
Status: service.StatusActive,
Platform: service.PlatformGemini,
Hydrated: true,
SubscriptionType: service.SubscriptionTypeSubscription,
DailyLimitUSD: &limit,
}
user := &service.User{
ID: 999,
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 10,
Concurrency: 3,
}
apiKey := &service.APIKey{
ID: 501,
UserID: user.ID,
Key: "google-sub-limit",
Status: service.StatusActive,
User: user,
Group: group,
}
apiKey.GroupID = &group.ID
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
if key != apiKey.Key {
return nil, service.ErrAPIKeyNotFound
}
clone := *apiKey
return &clone, nil
},
})
now := time.Now()
sub := &service.UserSubscription{
ID: 601,
UserID: user.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: now.Add(24 * time.Hour),
DailyWindowStart: &now,
DailyUsageUSD: 10,
}
subscriptionService := service.NewSubscriptionService(nil, fakeGoogleSubscriptionRepo{
getActive: func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
if userID != user.ID || groupID != group.ID {
return nil, service.ErrSubscriptionNotFound
}
clone := *sub
return &clone, nil
},
updateStatus: func(ctx context.Context, subscriptionID int64, status string) error { return nil },
activateWindow: func(ctx context.Context, id int64, start time.Time) error { return nil },
resetDaily: func(ctx context.Context, id int64, start time.Time) error { return nil },
resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil },
resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil },
}, nil, nil, &config.Config{RunMode: config.RunModeStandard})
r := gin.New()
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, &config.Config{RunMode: config.RunModeStandard}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
req.Header.Set("x-goog-api-key", apiKey.Key)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusTooManyRequests, rec.Code)
var resp googleErrorResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, http.StatusTooManyRequests, resp.Error.Code)
require.Equal(t, "RESOURCE_EXHAUSTED", resp.Error.Status)
require.Contains(t, resp.Error.Message, "daily usage limit exceeded")
}
...@@ -54,6 +54,10 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc { ...@@ -54,6 +54,10 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
c.Header("X-Content-Type-Options", "nosniff") c.Header("X-Content-Type-Options", "nosniff")
c.Header("X-Frame-Options", "DENY") c.Header("X-Frame-Options", "DENY")
c.Header("Referrer-Policy", "strict-origin-when-cross-origin") c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
if isAPIRoutePath(c) {
c.Next()
return
}
if cfg.Enabled { if cfg.Enabled {
// Generate nonce for this request // Generate nonce for this request
...@@ -73,6 +77,18 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc { ...@@ -73,6 +77,18 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
} }
} }
func isAPIRoutePath(c *gin.Context) bool {
if c == nil || c.Request == nil || c.Request.URL == nil {
return false
}
path := c.Request.URL.Path
return strings.HasPrefix(path, "/v1/") ||
strings.HasPrefix(path, "/v1beta/") ||
strings.HasPrefix(path, "/antigravity/") ||
strings.HasPrefix(path, "/sora/") ||
strings.HasPrefix(path, "/responses")
}
// enhanceCSPPolicy ensures the CSP policy includes nonce support and Cloudflare Insights domain. // enhanceCSPPolicy ensures the CSP policy includes nonce support and Cloudflare Insights domain.
// This allows the application to work correctly even if the config file has an older CSP policy. // This allows the application to work correctly even if the config file has an older CSP policy.
func enhanceCSPPolicy(policy string) string { func enhanceCSPPolicy(policy string) string {
......
...@@ -131,6 +131,26 @@ func TestSecurityHeaders(t *testing.T) { ...@@ -131,6 +131,26 @@ func TestSecurityHeaders(t *testing.T) {
assert.Contains(t, csp, CloudflareInsightsDomain) assert.Contains(t, csp, CloudflareInsightsDomain)
}) })
t.Run("api_route_skips_csp_nonce_generation", func(t *testing.T) {
cfg := config.CSPConfig{
Enabled: true,
Policy: "default-src 'self'; script-src 'self' __CSP_NONCE__",
}
middleware := SecurityHeaders(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
middleware(c)
assert.Equal(t, "nosniff", w.Header().Get("X-Content-Type-Options"))
assert.Equal(t, "DENY", w.Header().Get("X-Frame-Options"))
assert.Equal(t, "strict-origin-when-cross-origin", w.Header().Get("Referrer-Policy"))
assert.Empty(t, w.Header().Get("Content-Security-Policy"))
assert.Empty(t, GetNonceFromContext(c))
})
t.Run("csp_enabled_with_nonce_placeholder", func(t *testing.T) { t.Run("csp_enabled_with_nonce_placeholder", func(t *testing.T) {
cfg := config.CSPConfig{ cfg := config.CSPConfig{
Enabled: true, Enabled: true,
......
...@@ -75,6 +75,7 @@ func registerRoutes( ...@@ -75,6 +75,7 @@ func registerRoutes(
// 注册各模块路由 // 注册各模块路由
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient) routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient)
routes.RegisterUserRoutes(v1, h, jwtAuth) routes.RegisterUserRoutes(v1, h, jwtAuth)
routes.RegisterSoraClientRoutes(v1, h, jwtAuth)
routes.RegisterAdminRoutes(v1, h, adminAuth) routes.RegisterAdminRoutes(v1, h, adminAuth)
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg) routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg)
} }
...@@ -55,6 +55,9 @@ func RegisterAdminRoutes( ...@@ -55,6 +55,9 @@ func RegisterAdminRoutes(
// 系统设置 // 系统设置
registerSettingsRoutes(admin, h) registerSettingsRoutes(admin, h)
// 数据管理
registerDataManagementRoutes(admin, h)
// 运维监控(Ops) // 运维监控(Ops)
registerOpsRoutes(admin, h) registerOpsRoutes(admin, h)
...@@ -231,6 +234,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -231,6 +234,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts.POST("/:id/clear-error", h.Admin.Account.ClearError) accounts.POST("/:id/clear-error", h.Admin.Account.ClearError)
accounts.GET("/:id/usage", h.Admin.Account.GetUsage) accounts.GET("/:id/usage", h.Admin.Account.GetUsage)
accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats) accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats)
accounts.POST("/today-stats/batch", h.Admin.Account.GetBatchTodayStats)
accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit) accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit)
accounts.GET("/:id/temp-unschedulable", h.Admin.Account.GetTempUnschedulable) accounts.GET("/:id/temp-unschedulable", h.Admin.Account.GetTempUnschedulable)
accounts.DELETE("/:id/temp-unschedulable", h.Admin.Account.ClearTempUnschedulable) accounts.DELETE("/:id/temp-unschedulable", h.Admin.Account.ClearTempUnschedulable)
...@@ -370,6 +374,38 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -370,6 +374,38 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
// 流超时处理配置 // 流超时处理配置
adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings) adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings)
adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings) adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings)
// Sora S3 存储配置
adminSettings.GET("/sora-s3", h.Admin.Setting.GetSoraS3Settings)
adminSettings.PUT("/sora-s3", h.Admin.Setting.UpdateSoraS3Settings)
adminSettings.POST("/sora-s3/test", h.Admin.Setting.TestSoraS3Connection)
adminSettings.GET("/sora-s3/profiles", h.Admin.Setting.ListSoraS3Profiles)
adminSettings.POST("/sora-s3/profiles", h.Admin.Setting.CreateSoraS3Profile)
adminSettings.PUT("/sora-s3/profiles/:profile_id", h.Admin.Setting.UpdateSoraS3Profile)
adminSettings.DELETE("/sora-s3/profiles/:profile_id", h.Admin.Setting.DeleteSoraS3Profile)
adminSettings.POST("/sora-s3/profiles/:profile_id/activate", h.Admin.Setting.SetActiveSoraS3Profile)
}
}
func registerDataManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
dataManagement := admin.Group("/data-management")
{
dataManagement.GET("/agent/health", h.Admin.DataManagement.GetAgentHealth)
dataManagement.GET("/config", h.Admin.DataManagement.GetConfig)
dataManagement.PUT("/config", h.Admin.DataManagement.UpdateConfig)
dataManagement.GET("/sources/:source_type/profiles", h.Admin.DataManagement.ListSourceProfiles)
dataManagement.POST("/sources/:source_type/profiles", h.Admin.DataManagement.CreateSourceProfile)
dataManagement.PUT("/sources/:source_type/profiles/:profile_id", h.Admin.DataManagement.UpdateSourceProfile)
dataManagement.DELETE("/sources/:source_type/profiles/:profile_id", h.Admin.DataManagement.DeleteSourceProfile)
dataManagement.POST("/sources/:source_type/profiles/:profile_id/activate", h.Admin.DataManagement.SetActiveSourceProfile)
dataManagement.POST("/s3/test", h.Admin.DataManagement.TestS3)
dataManagement.GET("/s3/profiles", h.Admin.DataManagement.ListS3Profiles)
dataManagement.POST("/s3/profiles", h.Admin.DataManagement.CreateS3Profile)
dataManagement.PUT("/s3/profiles/:profile_id", h.Admin.DataManagement.UpdateS3Profile)
dataManagement.DELETE("/s3/profiles/:profile_id", h.Admin.DataManagement.DeleteS3Profile)
dataManagement.POST("/s3/profiles/:profile_id/activate", h.Admin.DataManagement.SetActiveS3Profile)
dataManagement.POST("/backups", h.Admin.DataManagement.CreateBackupJob)
dataManagement.GET("/backups", h.Admin.DataManagement.ListBackupJobs)
dataManagement.GET("/backups/:job_id", h.Admin.DataManagement.GetBackupJob)
} }
} }
......
...@@ -43,6 +43,7 @@ func RegisterGatewayRoutes( ...@@ -43,6 +43,7 @@ func RegisterGatewayRoutes(
gateway.GET("/usage", h.Gateway.Usage) gateway.GET("/usage", h.Gateway.Usage)
// OpenAI Responses API // OpenAI Responses API
gateway.POST("/responses", h.OpenAIGateway.Responses) gateway.POST("/responses", h.OpenAIGateway.Responses)
gateway.GET("/responses", h.OpenAIGateway.ResponsesWebSocket)
// 明确阻止旧协议入口:OpenAI 仅支持 Responses API,避免客户端误解为会自动路由到其它平台。 // 明确阻止旧协议入口:OpenAI 仅支持 Responses API,避免客户端误解为会自动路由到其它平台。
gateway.POST("/chat/completions", func(c *gin.Context) { gateway.POST("/chat/completions", func(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{ c.JSON(http.StatusBadRequest, gin.H{
...@@ -69,6 +70,7 @@ func RegisterGatewayRoutes( ...@@ -69,6 +70,7 @@ func RegisterGatewayRoutes(
// OpenAI Responses API(不带v1前缀的别名) // OpenAI Responses API(不带v1前缀的别名)
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses) r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses)
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.ResponsesWebSocket)
// Antigravity 模型列表 // Antigravity 模型列表
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), h.Gateway.AntigravityModels) r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), h.Gateway.AntigravityModels)
......
package routes
import (
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/gin-gonic/gin"
)
// RegisterSoraClientRoutes 注册 Sora 客户端 API 路由(需要用户认证)。
func RegisterSoraClientRoutes(
v1 *gin.RouterGroup,
h *handler.Handlers,
jwtAuth middleware.JWTAuthMiddleware,
) {
if h.SoraClient == nil {
return
}
authenticated := v1.Group("/sora")
authenticated.Use(gin.HandlerFunc(jwtAuth))
{
authenticated.POST("/generate", h.SoraClient.Generate)
authenticated.GET("/generations", h.SoraClient.ListGenerations)
authenticated.GET("/generations/:id", h.SoraClient.GetGeneration)
authenticated.DELETE("/generations/:id", h.SoraClient.DeleteGeneration)
authenticated.POST("/generations/:id/cancel", h.SoraClient.CancelGeneration)
authenticated.POST("/generations/:id/save", h.SoraClient.SaveToStorage)
authenticated.GET("/quota", h.SoraClient.GetQuota)
authenticated.GET("/models", h.SoraClient.GetModels)
authenticated.GET("/storage-status", h.SoraClient.GetStorageStatus)
}
}
...@@ -3,6 +3,8 @@ package service ...@@ -3,6 +3,8 @@ package service
import ( import (
"encoding/json" "encoding/json"
"hash/fnv"
"reflect"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
...@@ -50,6 +52,14 @@ type Account struct { ...@@ -50,6 +52,14 @@ type Account struct {
AccountGroups []AccountGroup AccountGroups []AccountGroup
GroupIDs []int64 GroupIDs []int64
Groups []*Group Groups []*Group
// model_mapping 热路径缓存(非持久化字段)
modelMappingCache map[string]string
modelMappingCacheReady bool
modelMappingCacheCredentialsPtr uintptr
modelMappingCacheRawPtr uintptr
modelMappingCacheRawLen int
modelMappingCacheRawSig uint64
} }
type TempUnschedulableRule struct { type TempUnschedulableRule struct {
...@@ -349,6 +359,39 @@ func parseTempUnschedInt(value any) int { ...@@ -349,6 +359,39 @@ func parseTempUnschedInt(value any) int {
} }
func (a *Account) GetModelMapping() map[string]string { func (a *Account) GetModelMapping() map[string]string {
credentialsPtr := mapPtr(a.Credentials)
rawMapping, _ := a.Credentials["model_mapping"].(map[string]any)
rawPtr := mapPtr(rawMapping)
rawLen := len(rawMapping)
rawSig := uint64(0)
rawSigReady := false
if a.modelMappingCacheReady &&
a.modelMappingCacheCredentialsPtr == credentialsPtr &&
a.modelMappingCacheRawPtr == rawPtr &&
a.modelMappingCacheRawLen == rawLen {
rawSig = modelMappingSignature(rawMapping)
rawSigReady = true
if a.modelMappingCacheRawSig == rawSig {
return a.modelMappingCache
}
}
mapping := a.resolveModelMapping(rawMapping)
if !rawSigReady {
rawSig = modelMappingSignature(rawMapping)
}
a.modelMappingCache = mapping
a.modelMappingCacheReady = true
a.modelMappingCacheCredentialsPtr = credentialsPtr
a.modelMappingCacheRawPtr = rawPtr
a.modelMappingCacheRawLen = rawLen
a.modelMappingCacheRawSig = rawSig
return mapping
}
func (a *Account) resolveModelMapping(rawMapping map[string]any) map[string]string {
if a.Credentials == nil { if a.Credentials == nil {
// Antigravity 平台使用默认映射 // Antigravity 平台使用默认映射
if a.Platform == domain.PlatformAntigravity { if a.Platform == domain.PlatformAntigravity {
...@@ -356,32 +399,31 @@ func (a *Account) GetModelMapping() map[string]string { ...@@ -356,32 +399,31 @@ func (a *Account) GetModelMapping() map[string]string {
} }
return nil return nil
} }
raw, ok := a.Credentials["model_mapping"] if len(rawMapping) == 0 {
if !ok || raw == nil {
// Antigravity 平台使用默认映射 // Antigravity 平台使用默认映射
if a.Platform == domain.PlatformAntigravity { if a.Platform == domain.PlatformAntigravity {
return domain.DefaultAntigravityModelMapping return domain.DefaultAntigravityModelMapping
} }
return nil return nil
} }
if m, ok := raw.(map[string]any); ok {
result := make(map[string]string) result := make(map[string]string)
for k, v := range m { for k, v := range rawMapping {
if s, ok := v.(string); ok { if s, ok := v.(string); ok {
result[k] = s result[k] = s
}
} }
if len(result) > 0 { }
if a.Platform == domain.PlatformAntigravity { if len(result) > 0 {
ensureAntigravityDefaultPassthroughs(result, []string{ if a.Platform == domain.PlatformAntigravity {
"gemini-3-flash", ensureAntigravityDefaultPassthroughs(result, []string{
"gemini-3.1-pro-high", "gemini-3-flash",
"gemini-3.1-pro-low", "gemini-3.1-pro-high",
}) "gemini-3.1-pro-low",
} })
return result
} }
return result
} }
// Antigravity 平台使用默认映射 // Antigravity 平台使用默认映射
if a.Platform == domain.PlatformAntigravity { if a.Platform == domain.PlatformAntigravity {
return domain.DefaultAntigravityModelMapping return domain.DefaultAntigravityModelMapping
...@@ -389,6 +431,37 @@ func (a *Account) GetModelMapping() map[string]string { ...@@ -389,6 +431,37 @@ func (a *Account) GetModelMapping() map[string]string {
return nil return nil
} }
func mapPtr(m map[string]any) uintptr {
if m == nil {
return 0
}
return reflect.ValueOf(m).Pointer()
}
func modelMappingSignature(rawMapping map[string]any) uint64 {
if len(rawMapping) == 0 {
return 0
}
keys := make([]string, 0, len(rawMapping))
for k := range rawMapping {
keys = append(keys, k)
}
sort.Strings(keys)
h := fnv.New64a()
for _, k := range keys {
_, _ = h.Write([]byte(k))
_, _ = h.Write([]byte{0})
if v, ok := rawMapping[k].(string); ok {
_, _ = h.Write([]byte(v))
} else {
_, _ = h.Write([]byte{1})
}
_, _ = h.Write([]byte{0xff})
}
return h.Sum64()
}
func ensureAntigravityDefaultPassthrough(mapping map[string]string, model string) { func ensureAntigravityDefaultPassthrough(mapping map[string]string, model string) {
if mapping == nil || model == "" { if mapping == nil || model == "" {
return return
...@@ -742,6 +815,159 @@ func (a *Account) IsOpenAIPassthroughEnabled() bool { ...@@ -742,6 +815,159 @@ func (a *Account) IsOpenAIPassthroughEnabled() bool {
return false return false
} }
// IsOpenAIResponsesWebSocketV2Enabled 返回 OpenAI 账号是否开启 Responses WebSocket v2。
//
// 分类型新字段:
// - OAuth 账号:accounts.extra.openai_oauth_responses_websockets_v2_enabled
// - API Key 账号:accounts.extra.openai_apikey_responses_websockets_v2_enabled
//
// 兼容字段:
// - accounts.extra.responses_websockets_v2_enabled
// - accounts.extra.openai_ws_enabled(历史开关)
//
// 优先级:
// 1. 按账号类型读取分类型字段
// 2. 分类型字段缺失时,回退兼容字段
func (a *Account) IsOpenAIResponsesWebSocketV2Enabled() bool {
if a == nil || !a.IsOpenAI() || a.Extra == nil {
return false
}
if a.IsOpenAIOAuth() {
if enabled, ok := a.Extra["openai_oauth_responses_websockets_v2_enabled"].(bool); ok {
return enabled
}
}
if a.IsOpenAIApiKey() {
if enabled, ok := a.Extra["openai_apikey_responses_websockets_v2_enabled"].(bool); ok {
return enabled
}
}
if enabled, ok := a.Extra["responses_websockets_v2_enabled"].(bool); ok {
return enabled
}
if enabled, ok := a.Extra["openai_ws_enabled"].(bool); ok {
return enabled
}
return false
}
const (
OpenAIWSIngressModeOff = "off"
OpenAIWSIngressModeShared = "shared"
OpenAIWSIngressModeDedicated = "dedicated"
)
func normalizeOpenAIWSIngressMode(mode string) string {
switch strings.ToLower(strings.TrimSpace(mode)) {
case OpenAIWSIngressModeOff:
return OpenAIWSIngressModeOff
case OpenAIWSIngressModeShared:
return OpenAIWSIngressModeShared
case OpenAIWSIngressModeDedicated:
return OpenAIWSIngressModeDedicated
default:
return ""
}
}
func normalizeOpenAIWSIngressDefaultMode(mode string) string {
if normalized := normalizeOpenAIWSIngressMode(mode); normalized != "" {
return normalized
}
return OpenAIWSIngressModeShared
}
// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/shared/dedicated)。
//
// 优先级:
// 1. 分类型 mode 新字段(string)
// 2. 分类型 enabled 旧字段(bool)
// 3. 兼容 enabled 旧字段(bool)
// 4. defaultMode(非法时回退 shared)
func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) string {
resolvedDefault := normalizeOpenAIWSIngressDefaultMode(defaultMode)
if a == nil || !a.IsOpenAI() {
return OpenAIWSIngressModeOff
}
if a.Extra == nil {
return resolvedDefault
}
resolveModeString := func(key string) (string, bool) {
raw, ok := a.Extra[key]
if !ok {
return "", false
}
mode, ok := raw.(string)
if !ok {
return "", false
}
normalized := normalizeOpenAIWSIngressMode(mode)
if normalized == "" {
return "", false
}
return normalized, true
}
resolveBoolMode := func(key string) (string, bool) {
raw, ok := a.Extra[key]
if !ok {
return "", false
}
enabled, ok := raw.(bool)
if !ok {
return "", false
}
if enabled {
return OpenAIWSIngressModeShared, true
}
return OpenAIWSIngressModeOff, true
}
if a.IsOpenAIOAuth() {
if mode, ok := resolveModeString("openai_oauth_responses_websockets_v2_mode"); ok {
return mode
}
if mode, ok := resolveBoolMode("openai_oauth_responses_websockets_v2_enabled"); ok {
return mode
}
}
if a.IsOpenAIApiKey() {
if mode, ok := resolveModeString("openai_apikey_responses_websockets_v2_mode"); ok {
return mode
}
if mode, ok := resolveBoolMode("openai_apikey_responses_websockets_v2_enabled"); ok {
return mode
}
}
if mode, ok := resolveBoolMode("responses_websockets_v2_enabled"); ok {
return mode
}
if mode, ok := resolveBoolMode("openai_ws_enabled"); ok {
return mode
}
return resolvedDefault
}
// IsOpenAIWSForceHTTPEnabled 返回账号级“强制 HTTP”开关。
// 字段:accounts.extra.openai_ws_force_http。
func (a *Account) IsOpenAIWSForceHTTPEnabled() bool {
if a == nil || !a.IsOpenAI() || a.Extra == nil {
return false
}
enabled, ok := a.Extra["openai_ws_force_http"].(bool)
return ok && enabled
}
// IsOpenAIWSAllowStoreRecoveryEnabled 返回账号级 store 恢复开关。
// 字段:accounts.extra.openai_ws_allow_store_recovery。
func (a *Account) IsOpenAIWSAllowStoreRecoveryEnabled() bool {
if a == nil || !a.IsOpenAI() || a.Extra == nil {
return false
}
enabled, ok := a.Extra["openai_ws_allow_store_recovery"].(bool)
return ok && enabled
}
// IsOpenAIOAuthPassthroughEnabled 兼容旧接口,等价于 OAuth 账号的 IsOpenAIPassthroughEnabled。 // IsOpenAIOAuthPassthroughEnabled 兼容旧接口,等价于 OAuth 账号的 IsOpenAIPassthroughEnabled。
func (a *Account) IsOpenAIOAuthPassthroughEnabled() bool { func (a *Account) IsOpenAIOAuthPassthroughEnabled() bool {
return a != nil && a.IsOpenAIOAuth() && a.IsOpenAIPassthroughEnabled() return a != nil && a.IsOpenAIOAuth() && a.IsOpenAIPassthroughEnabled()
......
...@@ -134,3 +134,161 @@ func TestAccount_IsCodexCLIOnlyEnabled(t *testing.T) { ...@@ -134,3 +134,161 @@ func TestAccount_IsCodexCLIOnlyEnabled(t *testing.T) {
require.False(t, otherPlatform.IsCodexCLIOnlyEnabled()) require.False(t, otherPlatform.IsCodexCLIOnlyEnabled())
}) })
} }
func TestAccount_IsOpenAIResponsesWebSocketV2Enabled(t *testing.T) {
t.Run("OAuth使用OAuth专用开关", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_enabled": true,
},
}
require.True(t, account.IsOpenAIResponsesWebSocketV2Enabled())
})
t.Run("API Key使用API Key专用开关", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
require.True(t, account.IsOpenAIResponsesWebSocketV2Enabled())
})
t.Run("OAuth账号不会读取API Key专用开关", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
require.False(t, account.IsOpenAIResponsesWebSocketV2Enabled())
})
t.Run("分类型新键优先于兼容键", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_enabled": false,
"responses_websockets_v2_enabled": true,
"openai_ws_enabled": true,
},
}
require.False(t, account.IsOpenAIResponsesWebSocketV2Enabled())
})
t.Run("分类型键缺失时回退兼容键", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
require.True(t, account.IsOpenAIResponsesWebSocketV2Enabled())
})
t.Run("非OpenAI账号默认关闭", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
require.False(t, account.IsOpenAIResponsesWebSocketV2Enabled())
})
}
func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
t.Run("default fallback to shared", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{},
}
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(""))
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid"))
})
t.Run("oauth mode field has highest priority", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
"openai_oauth_responses_websockets_v2_enabled": false,
"responses_websockets_v2_enabled": false,
},
}
require.Equal(t, OpenAIWSIngressModeDedicated, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared))
})
t.Run("legacy enabled maps to shared", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
})
t.Run("legacy disabled maps to off", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": false,
"responses_websockets_v2_enabled": true,
},
}
require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared))
})
t.Run("non openai always off", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
},
}
require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeDedicated))
})
}
func TestAccount_OpenAIWSExtraFlags(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_ws_force_http": true,
"openai_ws_allow_store_recovery": true,
},
}
require.True(t, account.IsOpenAIWSForceHTTPEnabled())
require.True(t, account.IsOpenAIWSAllowStoreRecoveryEnabled())
off := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{}}
require.False(t, off.IsOpenAIWSForceHTTPEnabled())
require.False(t, off.IsOpenAIWSAllowStoreRecoveryEnabled())
var nilAccount *Account
require.False(t, nilAccount.IsOpenAIWSAllowStoreRecoveryEnabled())
nonOpenAI := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_ws_allow_store_recovery": true,
},
}
require.False(t, nonOpenAI.IsOpenAIWSAllowStoreRecoveryEnabled())
}
...@@ -119,6 +119,10 @@ type AccountService struct { ...@@ -119,6 +119,10 @@ type AccountService struct {
groupRepo GroupRepository groupRepo GroupRepository
} }
type groupExistenceBatchChecker interface {
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
}
// NewAccountService 创建账号服务实例 // NewAccountService 创建账号服务实例
func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository) *AccountService { func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository) *AccountService {
return &AccountService{ return &AccountService{
...@@ -131,11 +135,8 @@ func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository) ...@@ -131,11 +135,8 @@ func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository)
func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (*Account, error) { func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (*Account, error) {
// 验证分组是否存在(如果指定了分组) // 验证分组是否存在(如果指定了分组)
if len(req.GroupIDs) > 0 { if len(req.GroupIDs) > 0 {
for _, groupID := range req.GroupIDs { if err := s.validateGroupIDsExist(ctx, req.GroupIDs); err != nil {
_, err := s.groupRepo.GetByID(ctx, groupID) return nil, err
if err != nil {
return nil, fmt.Errorf("get group: %w", err)
}
} }
} }
...@@ -256,11 +257,8 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount ...@@ -256,11 +257,8 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount
// 先验证分组是否存在(在任何写操作之前) // 先验证分组是否存在(在任何写操作之前)
if req.GroupIDs != nil { if req.GroupIDs != nil {
for _, groupID := range *req.GroupIDs { if err := s.validateGroupIDsExist(ctx, *req.GroupIDs); err != nil {
_, err := s.groupRepo.GetByID(ctx, groupID) return nil, err
if err != nil {
return nil, fmt.Errorf("get group: %w", err)
}
} }
} }
...@@ -300,6 +298,39 @@ func (s *AccountService) Delete(ctx context.Context, id int64) error { ...@@ -300,6 +298,39 @@ func (s *AccountService) Delete(ctx context.Context, id int64) error {
return nil return nil
} }
func (s *AccountService) validateGroupIDsExist(ctx context.Context, groupIDs []int64) error {
if len(groupIDs) == 0 {
return nil
}
if s.groupRepo == nil {
return fmt.Errorf("group repository not configured")
}
if batchChecker, ok := s.groupRepo.(groupExistenceBatchChecker); ok {
existsByID, err := batchChecker.ExistsByIDs(ctx, groupIDs)
if err != nil {
return fmt.Errorf("check groups exists: %w", err)
}
for _, groupID := range groupIDs {
if groupID <= 0 {
return fmt.Errorf("get group: %w", ErrGroupNotFound)
}
if !existsByID[groupID] {
return fmt.Errorf("get group: %w", ErrGroupNotFound)
}
}
return nil
}
for _, groupID := range groupIDs {
_, err := s.groupRepo.GetByID(ctx, groupID)
if err != nil {
return fmt.Errorf("get group: %w", err)
}
}
return nil
}
// UpdateStatus 更新账号状态 // UpdateStatus 更新账号状态
func (s *AccountService) UpdateStatus(ctx context.Context, id int64, status string, errorMessage string) error { func (s *AccountService) UpdateStatus(ctx context.Context, id int64, status string, errorMessage string) error {
account, err := s.accountRepo.GetByID(ctx, id) account, err := s.accountRepo.GetByID(ctx, id)
......
...@@ -598,9 +598,102 @@ func ceilSeconds(d time.Duration) int { ...@@ -598,9 +598,102 @@ func ceilSeconds(d time.Duration) int {
return sec return sec
} }
// testSoraAPIKeyAccountConnection 测试 Sora apikey 类型账号的连通性。
// 向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性和 API Key 有效性。
func (s *AccountTestService) testSoraAPIKeyAccountConnection(c *gin.Context, account *Account) error {
ctx := c.Request.Context()
apiKey := account.GetCredential("api_key")
if apiKey == "" {
return s.sendErrorAndEnd(c, "Sora apikey 账号缺少 api_key 凭证")
}
baseURL := account.GetBaseURL()
if baseURL == "" {
return s.sendErrorAndEnd(c, "Sora apikey 账号缺少 base_url")
}
// 验证 base_url 格式
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("base_url 无效: %s", err.Error()))
}
upstreamURL := strings.TrimSuffix(normalizedBaseURL, "/") + "/sora/v1/chat/completions"
// 设置 SSE 头
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Flush()
if wait, ok := s.acquireSoraTestPermit(account.ID); !ok {
msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait))
return s.sendErrorAndEnd(c, msg)
}
s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora-upstream"})
// 构建轻量级 prompt-enhance 请求作为连通性测试
testPayload := map[string]any{
"model": "prompt-enhance-short-10s",
"messages": []map[string]string{{"role": "user", "content": "test"}},
"stream": false,
}
payloadBytes, _ := json.Marshal(testPayload)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(payloadBytes))
if err != nil {
return s.sendErrorAndEnd(c, "构建测试请求失败")
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
// 获取代理 URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("上游连接失败: %s", err.Error()))
}
defer func() { _ = resp.Body.Close() }()
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 64*1024))
if resp.StatusCode == http.StatusOK {
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("上游连接成功 (%s)", upstreamURL)})
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("API Key 有效 (HTTP %d)", resp.StatusCode)})
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
return s.sendErrorAndEnd(c, fmt.Sprintf("上游认证失败 (HTTP %d),请检查 API Key 是否正确", resp.StatusCode))
}
// 其他错误但能连通(如 400 参数错误)也算连通性测试通过
if resp.StatusCode == http.StatusBadRequest {
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("上游连接成功 (%s)", upstreamURL)})
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("API Key 有效(上游返回 %d,参数校验错误属正常)", resp.StatusCode)})
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
return s.sendErrorAndEnd(c, fmt.Sprintf("上游返回异常 HTTP %d: %s", resp.StatusCode, truncateSoraErrorBody(respBody, 256)))
}
// testSoraAccountConnection 测试 Sora 账号的连接 // testSoraAccountConnection 测试 Sora 账号的连接
// 调用 /backend/me 接口验证 access_token 有效性(不需要 Sentinel Token) // OAuth 类型:调用 /backend/me 接口验证 access_token 有效性
// APIKey 类型:向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性
func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *Account) error { func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *Account) error {
// apikey 类型走独立测试流程
if account.Type == AccountTypeAPIKey {
return s.testSoraAPIKeyAccountConnection(c, account)
}
ctx := c.Request.Context() ctx := c.Request.Context()
recorder := &soraProbeRecorder{} recorder := &soraProbeRecorder{}
......
...@@ -9,7 +9,9 @@ import ( ...@@ -9,7 +9,9 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"golang.org/x/sync/errgroup"
) )
type UsageLogRepository interface { type UsageLogRepository interface {
...@@ -33,8 +35,8 @@ type UsageLogRepository interface { ...@@ -33,8 +35,8 @@ type UsageLogRepository interface {
// Admin dashboard stats // Admin dashboard stats
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error)
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error)
...@@ -62,6 +64,10 @@ type UsageLogRepository interface { ...@@ -62,6 +64,10 @@ type UsageLogRepository interface {
GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error)
} }
type accountWindowStatsBatchReader interface {
GetAccountWindowStatsBatch(ctx context.Context, accountIDs []int64, startTime time.Time) (map[int64]*usagestats.AccountStats, error)
}
// apiUsageCache 缓存从 Anthropic API 获取的使用率数据(utilization, resets_at) // apiUsageCache 缓存从 Anthropic API 获取的使用率数据(utilization, resets_at)
type apiUsageCache struct { type apiUsageCache struct {
response *ClaudeUsageResponse response *ClaudeUsageResponse
...@@ -297,7 +303,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou ...@@ -297,7 +303,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
} }
dayStart := geminiDailyWindowStart(now) dayStart := geminiDailyWindowStart(now)
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil, nil) stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil, nil, nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("get gemini usage stats failed: %w", err) return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
} }
...@@ -319,7 +325,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou ...@@ -319,7 +325,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
// Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m) // Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m)
minuteStart := now.Truncate(time.Minute) minuteStart := now.Truncate(time.Minute)
minuteResetAt := minuteStart.Add(time.Minute) minuteResetAt := minuteStart.Add(time.Minute)
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil, nil) minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil, nil, nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err) return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err)
} }
...@@ -440,6 +446,78 @@ func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64 ...@@ -440,6 +446,78 @@ func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64
}, nil }, nil
} }
// GetTodayStatsBatch 批量获取账号今日统计,优先走批量 SQL,失败时回退单账号查询。
func (s *AccountUsageService) GetTodayStatsBatch(ctx context.Context, accountIDs []int64) (map[int64]*WindowStats, error) {
uniqueIDs := make([]int64, 0, len(accountIDs))
seen := make(map[int64]struct{}, len(accountIDs))
for _, accountID := range accountIDs {
if accountID <= 0 {
continue
}
if _, exists := seen[accountID]; exists {
continue
}
seen[accountID] = struct{}{}
uniqueIDs = append(uniqueIDs, accountID)
}
result := make(map[int64]*WindowStats, len(uniqueIDs))
if len(uniqueIDs) == 0 {
return result, nil
}
startTime := timezone.Today()
if batchReader, ok := s.usageLogRepo.(accountWindowStatsBatchReader); ok {
statsByAccount, err := batchReader.GetAccountWindowStatsBatch(ctx, uniqueIDs, startTime)
if err == nil {
for _, accountID := range uniqueIDs {
result[accountID] = windowStatsFromAccountStats(statsByAccount[accountID])
}
return result, nil
}
}
var mu sync.Mutex
g, gctx := errgroup.WithContext(ctx)
g.SetLimit(8)
for _, accountID := range uniqueIDs {
id := accountID
g.Go(func() error {
stats, err := s.usageLogRepo.GetAccountWindowStats(gctx, id, startTime)
if err != nil {
return nil
}
mu.Lock()
result[id] = windowStatsFromAccountStats(stats)
mu.Unlock()
return nil
})
}
_ = g.Wait()
for _, accountID := range uniqueIDs {
if _, ok := result[accountID]; !ok {
result[accountID] = &WindowStats{}
}
}
return result, nil
}
func windowStatsFromAccountStats(stats *usagestats.AccountStats) *WindowStats {
if stats == nil {
return &WindowStats{}
}
return &WindowStats{
Requests: stats.Requests,
Tokens: stats.Tokens,
Cost: stats.Cost,
StandardCost: stats.StandardCost,
UserCost: stats.UserCost,
}
}
func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) { func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) {
stats, err := s.usageLogRepo.GetAccountUsageStats(ctx, accountID, startTime, endTime) stats, err := s.usageLogRepo.GetAccountUsageStats(ctx, accountID, startTime, endTime)
if err != nil { if err != nil {
......
...@@ -314,3 +314,72 @@ func TestAccountGetModelMapping_AntigravityRespectsWildcardOverride(t *testing.T ...@@ -314,3 +314,72 @@ func TestAccountGetModelMapping_AntigravityRespectsWildcardOverride(t *testing.T
t.Fatalf("expected wildcard mapping to stay effective, got: %q", mapped) t.Fatalf("expected wildcard mapping to stay effective, got: %q", mapped)
} }
} }
func TestAccountGetModelMapping_CacheInvalidatesOnCredentialsReplace(t *testing.T) {
account := &Account{
Credentials: map[string]any{
"model_mapping": map[string]any{
"claude-3-5-sonnet": "upstream-a",
},
},
}
first := account.GetModelMapping()
if first["claude-3-5-sonnet"] != "upstream-a" {
t.Fatalf("unexpected first mapping: %v", first)
}
account.Credentials = map[string]any{
"model_mapping": map[string]any{
"claude-3-5-sonnet": "upstream-b",
},
}
second := account.GetModelMapping()
if second["claude-3-5-sonnet"] != "upstream-b" {
t.Fatalf("expected cache invalidated after credentials replace, got: %v", second)
}
}
func TestAccountGetModelMapping_CacheInvalidatesOnMappingLenChange(t *testing.T) {
rawMapping := map[string]any{
"claude-sonnet": "sonnet-a",
}
account := &Account{
Credentials: map[string]any{
"model_mapping": rawMapping,
},
}
first := account.GetModelMapping()
if len(first) != 1 {
t.Fatalf("unexpected first mapping length: %d", len(first))
}
rawMapping["claude-opus"] = "opus-b"
second := account.GetModelMapping()
if second["claude-opus"] != "opus-b" {
t.Fatalf("expected cache invalidated after mapping len change, got: %v", second)
}
}
func TestAccountGetModelMapping_CacheInvalidatesOnInPlaceValueChange(t *testing.T) {
rawMapping := map[string]any{
"claude-sonnet": "sonnet-a",
}
account := &Account{
Credentials: map[string]any{
"model_mapping": rawMapping,
},
}
first := account.GetModelMapping()
if first["claude-sonnet"] != "sonnet-a" {
t.Fatalf("unexpected first mapping: %v", first)
}
rawMapping["claude-sonnet"] = "sonnet-b"
second := account.GetModelMapping()
if second["claude-sonnet"] != "sonnet-b" {
t.Fatalf("expected cache invalidated after in-place value change, got: %v", second)
}
}
...@@ -83,13 +83,14 @@ type AdminService interface { ...@@ -83,13 +83,14 @@ type AdminService interface {
// CreateUserInput represents input for creating a new user via admin operations. // CreateUserInput represents input for creating a new user via admin operations.
type CreateUserInput struct { type CreateUserInput struct {
Email string Email string
Password string Password string
Username string Username string
Notes string Notes string
Balance float64 Balance float64
Concurrency int Concurrency int
AllowedGroups []int64 AllowedGroups []int64
SoraStorageQuotaBytes int64
} }
type UpdateUserInput struct { type UpdateUserInput struct {
...@@ -103,7 +104,8 @@ type UpdateUserInput struct { ...@@ -103,7 +104,8 @@ type UpdateUserInput struct {
AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组" AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组"
// GroupRates 用户专属分组倍率配置 // GroupRates 用户专属分组倍率配置
// map[groupID]*rate,nil 表示删除该分组的专属倍率 // map[groupID]*rate,nil 表示删除该分组的专属倍率
GroupRates map[int64]*float64 GroupRates map[int64]*float64
SoraStorageQuotaBytes *int64
} }
type CreateGroupInput struct { type CreateGroupInput struct {
...@@ -135,6 +137,8 @@ type CreateGroupInput struct { ...@@ -135,6 +137,8 @@ type CreateGroupInput struct {
MCPXMLInject *bool MCPXMLInject *bool
// 支持的模型系列(仅 antigravity 平台使用) // 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string SupportedModelScopes []string
// Sora 存储配额
SoraStorageQuotaBytes int64
// 从指定分组复制账号(创建分组后在同一事务内绑定) // 从指定分组复制账号(创建分组后在同一事务内绑定)
CopyAccountsFromGroupIDs []int64 CopyAccountsFromGroupIDs []int64
} }
...@@ -169,6 +173,8 @@ type UpdateGroupInput struct { ...@@ -169,6 +173,8 @@ type UpdateGroupInput struct {
MCPXMLInject *bool MCPXMLInject *bool
// 支持的模型系列(仅 antigravity 平台使用) // 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string SupportedModelScopes *[]string
// Sora 存储配额
SoraStorageQuotaBytes *int64
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64 CopyAccountsFromGroupIDs []int64
} }
...@@ -402,6 +408,14 @@ type adminServiceImpl struct { ...@@ -402,6 +408,14 @@ type adminServiceImpl struct {
authCacheInvalidator APIKeyAuthCacheInvalidator authCacheInvalidator APIKeyAuthCacheInvalidator
} }
type userGroupRateBatchReader interface {
GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error)
}
type groupExistenceBatchReader interface {
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
}
// NewAdminService creates a new AdminService // NewAdminService creates a new AdminService
func NewAdminService( func NewAdminService(
userRepo UserRepository, userRepo UserRepository,
...@@ -442,18 +456,43 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi ...@@ -442,18 +456,43 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi
} }
// 批量加载用户专属分组倍率 // 批量加载用户专属分组倍率
if s.userGroupRateRepo != nil && len(users) > 0 { if s.userGroupRateRepo != nil && len(users) > 0 {
for i := range users { if batchRepo, ok := s.userGroupRateRepo.(userGroupRateBatchReader); ok {
rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID) userIDs := make([]int64, 0, len(users))
for i := range users {
userIDs = append(userIDs, users[i].ID)
}
ratesByUser, err := batchRepo.GetByUserIDs(ctx, userIDs)
if err != nil { if err != nil {
logger.LegacyPrintf("service.admin", "failed to load user group rates: user_id=%d err=%v", users[i].ID, err) logger.LegacyPrintf("service.admin", "failed to load user group rates in batch: err=%v", err)
continue s.loadUserGroupRatesOneByOne(ctx, users)
} else {
for i := range users {
if rates, ok := ratesByUser[users[i].ID]; ok {
users[i].GroupRates = rates
}
}
} }
users[i].GroupRates = rates } else {
s.loadUserGroupRatesOneByOne(ctx, users)
} }
} }
return users, result.Total, nil return users, result.Total, nil
} }
func (s *adminServiceImpl) loadUserGroupRatesOneByOne(ctx context.Context, users []User) {
if s.userGroupRateRepo == nil {
return
}
for i := range users {
rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID)
if err != nil {
logger.LegacyPrintf("service.admin", "failed to load user group rates: user_id=%d err=%v", users[i].ID, err)
continue
}
users[i].GroupRates = rates
}
}
func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) { func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) {
user, err := s.userRepo.GetByID(ctx, id) user, err := s.userRepo.GetByID(ctx, id)
if err != nil { if err != nil {
...@@ -473,14 +512,15 @@ func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) ...@@ -473,14 +512,15 @@ func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error)
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) { func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) {
user := &User{ user := &User{
Email: input.Email, Email: input.Email,
Username: input.Username, Username: input.Username,
Notes: input.Notes, Notes: input.Notes,
Role: RoleUser, // Always create as regular user, never admin Role: RoleUser, // Always create as regular user, never admin
Balance: input.Balance, Balance: input.Balance,
Concurrency: input.Concurrency, Concurrency: input.Concurrency,
Status: StatusActive, Status: StatusActive,
AllowedGroups: input.AllowedGroups, AllowedGroups: input.AllowedGroups,
SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
} }
if err := user.SetPassword(input.Password); err != nil { if err := user.SetPassword(input.Password); err != nil {
return nil, err return nil, err
...@@ -534,6 +574,10 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda ...@@ -534,6 +574,10 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
user.AllowedGroups = *input.AllowedGroups user.AllowedGroups = *input.AllowedGroups
} }
if input.SoraStorageQuotaBytes != nil {
user.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes
}
if err := s.userRepo.Update(ctx, user); err != nil { if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err return nil, err
} }
...@@ -820,6 +864,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn ...@@ -820,6 +864,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
ModelRouting: input.ModelRouting, ModelRouting: input.ModelRouting,
MCPXMLInject: mcpXMLInject, MCPXMLInject: mcpXMLInject,
SupportedModelScopes: input.SupportedModelScopes, SupportedModelScopes: input.SupportedModelScopes,
SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
} }
if err := s.groupRepo.Create(ctx, group); err != nil { if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err return nil, err
...@@ -982,6 +1027,9 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd ...@@ -982,6 +1027,9 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.SoraVideoPricePerRequestHD != nil { if input.SoraVideoPricePerRequestHD != nil {
group.SoraVideoPricePerRequestHD = normalizePrice(input.SoraVideoPricePerRequestHD) group.SoraVideoPricePerRequestHD = normalizePrice(input.SoraVideoPricePerRequestHD)
} }
if input.SoraStorageQuotaBytes != nil {
group.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes
}
// Claude Code 客户端限制 // Claude Code 客户端限制
if input.ClaudeCodeOnly != nil { if input.ClaudeCodeOnly != nil {
...@@ -1188,6 +1236,18 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou ...@@ -1188,6 +1236,18 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
} }
} }
// Sora apikey 账号的 base_url 必填校验
if input.Platform == PlatformSora && input.Type == AccountTypeAPIKey {
baseURL, _ := input.Credentials["base_url"].(string)
baseURL = strings.TrimSpace(baseURL)
if baseURL == "" {
return nil, errors.New("sora apikey 账号必须设置 base_url")
}
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
return nil, errors.New("base_url 必须以 http:// 或 https:// 开头")
}
}
account := &Account{ account := &Account{
Name: input.Name, Name: input.Name,
Notes: normalizeAccountNotes(input.Notes), Notes: normalizeAccountNotes(input.Notes),
...@@ -1301,12 +1361,22 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U ...@@ -1301,12 +1361,22 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
account.AutoPauseOnExpired = *input.AutoPauseOnExpired account.AutoPauseOnExpired = *input.AutoPauseOnExpired
} }
// Sora apikey 账号的 base_url 必填校验
if account.Platform == PlatformSora && account.Type == AccountTypeAPIKey {
baseURL, _ := account.Credentials["base_url"].(string)
baseURL = strings.TrimSpace(baseURL)
if baseURL == "" {
return nil, errors.New("sora apikey 账号必须设置 base_url")
}
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
return nil, errors.New("base_url 必须以 http:// 或 https:// 开头")
}
}
// 先验证分组是否存在(在任何写操作之前) // 先验证分组是否存在(在任何写操作之前)
if input.GroupIDs != nil { if input.GroupIDs != nil {
for _, groupID := range *input.GroupIDs { if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil {
if _, err := s.groupRepo.GetByID(ctx, groupID); err != nil { return nil, err
return nil, fmt.Errorf("get group: %w", err)
}
} }
// 检查混合渠道风险(除非用户已确认) // 检查混合渠道风险(除非用户已确认)
...@@ -1348,11 +1418,18 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp ...@@ -1348,11 +1418,18 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
if len(input.AccountIDs) == 0 { if len(input.AccountIDs) == 0 {
return result, nil return result, nil
} }
if input.GroupIDs != nil {
if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil {
return nil, err
}
}
needMixedChannelCheck := input.GroupIDs != nil && !input.SkipMixedChannelCheck needMixedChannelCheck := input.GroupIDs != nil && !input.SkipMixedChannelCheck
// 预加载账号平台信息(混合渠道检查或 Sora 同步需要)。 // 预加载账号平台信息(混合渠道检查或 Sora 同步需要)。
platformByID := map[int64]string{} platformByID := map[int64]string{}
groupAccountsByID := map[int64][]Account{}
groupNameByID := map[int64]string{}
if needMixedChannelCheck { if needMixedChannelCheck {
accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs) accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs)
if err != nil { if err != nil {
...@@ -1366,6 +1443,13 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp ...@@ -1366,6 +1443,13 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
} }
} }
} }
loadedAccounts, loadedNames, err := s.preloadMixedChannelRiskData(ctx, *input.GroupIDs)
if err != nil {
return nil, err
}
groupAccountsByID = loadedAccounts
groupNameByID = loadedNames
} }
if input.RateMultiplier != nil { if input.RateMultiplier != nil {
...@@ -1409,11 +1493,12 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp ...@@ -1409,11 +1493,12 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
// Handle group bindings per account (requires individual operations). // Handle group bindings per account (requires individual operations).
for _, accountID := range input.AccountIDs { for _, accountID := range input.AccountIDs {
entry := BulkUpdateAccountResult{AccountID: accountID} entry := BulkUpdateAccountResult{AccountID: accountID}
platform := ""
if input.GroupIDs != nil { if input.GroupIDs != nil {
// 检查混合渠道风险(除非用户已确认) // 检查混合渠道风险(除非用户已确认)
if !input.SkipMixedChannelCheck { if !input.SkipMixedChannelCheck {
platform := platformByID[accountID] platform = platformByID[accountID]
if platform == "" { if platform == "" {
account, err := s.accountRepo.GetByID(ctx, accountID) account, err := s.accountRepo.GetByID(ctx, accountID)
if err != nil { if err != nil {
...@@ -1426,7 +1511,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp ...@@ -1426,7 +1511,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
} }
platform = account.Platform platform = account.Platform
} }
if err := s.checkMixedChannelRisk(ctx, accountID, platform, *input.GroupIDs); err != nil { if err := s.checkMixedChannelRiskWithPreloaded(accountID, platform, *input.GroupIDs, groupAccountsByID, groupNameByID); err != nil {
entry.Success = false entry.Success = false
entry.Error = err.Error() entry.Error = err.Error()
result.Failed++ result.Failed++
...@@ -1444,6 +1529,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp ...@@ -1444,6 +1529,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
result.Results = append(result.Results, entry) result.Results = append(result.Results, entry)
continue continue
} }
if !input.SkipMixedChannelCheck && platform != "" {
updateMixedChannelPreloadedAccounts(groupAccountsByID, *input.GroupIDs, accountID, platform)
}
} }
entry.Success = true entry.Success = true
...@@ -2115,6 +2203,135 @@ func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAcc ...@@ -2115,6 +2203,135 @@ func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAcc
return nil return nil
} }
func (s *adminServiceImpl) preloadMixedChannelRiskData(ctx context.Context, groupIDs []int64) (map[int64][]Account, map[int64]string, error) {
accountsByGroup := make(map[int64][]Account)
groupNameByID := make(map[int64]string)
if len(groupIDs) == 0 {
return accountsByGroup, groupNameByID, nil
}
seen := make(map[int64]struct{}, len(groupIDs))
for _, groupID := range groupIDs {
if groupID <= 0 {
continue
}
if _, ok := seen[groupID]; ok {
continue
}
seen[groupID] = struct{}{}
accounts, err := s.accountRepo.ListByGroup(ctx, groupID)
if err != nil {
return nil, nil, fmt.Errorf("get accounts in group %d: %w", groupID, err)
}
accountsByGroup[groupID] = accounts
group, err := s.groupRepo.GetByID(ctx, groupID)
if err != nil {
continue
}
if group != nil {
groupNameByID[groupID] = group.Name
}
}
return accountsByGroup, groupNameByID, nil
}
func (s *adminServiceImpl) validateGroupIDsExist(ctx context.Context, groupIDs []int64) error {
if len(groupIDs) == 0 {
return nil
}
if s.groupRepo == nil {
return errors.New("group repository not configured")
}
if batchReader, ok := s.groupRepo.(groupExistenceBatchReader); ok {
existsByID, err := batchReader.ExistsByIDs(ctx, groupIDs)
if err != nil {
return fmt.Errorf("check groups exists: %w", err)
}
for _, groupID := range groupIDs {
if groupID <= 0 || !existsByID[groupID] {
return fmt.Errorf("get group: %w", ErrGroupNotFound)
}
}
return nil
}
for _, groupID := range groupIDs {
if _, err := s.groupRepo.GetByID(ctx, groupID); err != nil {
return fmt.Errorf("get group: %w", err)
}
}
return nil
}
func (s *adminServiceImpl) checkMixedChannelRiskWithPreloaded(currentAccountID int64, currentAccountPlatform string, groupIDs []int64, accountsByGroup map[int64][]Account, groupNameByID map[int64]string) error {
currentPlatform := getAccountPlatform(currentAccountPlatform)
if currentPlatform == "" {
return nil
}
for _, groupID := range groupIDs {
accounts := accountsByGroup[groupID]
for _, account := range accounts {
if currentAccountID > 0 && account.ID == currentAccountID {
continue
}
otherPlatform := getAccountPlatform(account.Platform)
if otherPlatform == "" {
continue
}
if currentPlatform != otherPlatform {
groupName := fmt.Sprintf("Group %d", groupID)
if name := strings.TrimSpace(groupNameByID[groupID]); name != "" {
groupName = name
}
return &MixedChannelError{
GroupID: groupID,
GroupName: groupName,
CurrentPlatform: currentPlatform,
OtherPlatform: otherPlatform,
}
}
}
}
return nil
}
func updateMixedChannelPreloadedAccounts(accountsByGroup map[int64][]Account, groupIDs []int64, accountID int64, platform string) {
if len(groupIDs) == 0 || accountID <= 0 || platform == "" {
return
}
for _, groupID := range groupIDs {
if groupID <= 0 {
continue
}
accounts := accountsByGroup[groupID]
found := false
for i := range accounts {
if accounts[i].ID != accountID {
continue
}
accounts[i].Platform = platform
found = true
break
}
if !found {
accounts = append(accounts, Account{
ID: accountID,
Platform: platform,
})
}
accountsByGroup[groupID] = accounts
}
}
// CheckMixedChannelRisk checks whether target groups contain mixed channels for the current account platform. // CheckMixedChannelRisk checks whether target groups contain mixed channels for the current account platform.
func (s *adminServiceImpl) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error { func (s *adminServiceImpl) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
return s.checkMixedChannelRisk(ctx, currentAccountID, currentAccountPlatform, groupIDs) return s.checkMixedChannelRisk(ctx, currentAccountID, currentAccountPlatform, groupIDs)
......
...@@ -15,6 +15,7 @@ type accountRepoStubForBulkUpdate struct { ...@@ -15,6 +15,7 @@ type accountRepoStubForBulkUpdate struct {
bulkUpdateErr error bulkUpdateErr error
bulkUpdateIDs []int64 bulkUpdateIDs []int64
bindGroupErrByID map[int64]error bindGroupErrByID map[int64]error
bindGroupsCalls []int64
getByIDsAccounts []*Account getByIDsAccounts []*Account
getByIDsErr error getByIDsErr error
getByIDsCalled bool getByIDsCalled bool
...@@ -22,6 +23,8 @@ type accountRepoStubForBulkUpdate struct { ...@@ -22,6 +23,8 @@ type accountRepoStubForBulkUpdate struct {
getByIDAccounts map[int64]*Account getByIDAccounts map[int64]*Account
getByIDErrByID map[int64]error getByIDErrByID map[int64]error
getByIDCalled []int64 getByIDCalled []int64
listByGroupData map[int64][]Account
listByGroupErr map[int64]error
} }
func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) { func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) {
...@@ -33,6 +36,7 @@ func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64 ...@@ -33,6 +36,7 @@ func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64
} }
func (s *accountRepoStubForBulkUpdate) BindGroups(_ context.Context, accountID int64, _ []int64) error { func (s *accountRepoStubForBulkUpdate) BindGroups(_ context.Context, accountID int64, _ []int64) error {
s.bindGroupsCalls = append(s.bindGroupsCalls, accountID)
if err, ok := s.bindGroupErrByID[accountID]; ok { if err, ok := s.bindGroupErrByID[accountID]; ok {
return err return err
} }
...@@ -59,6 +63,16 @@ func (s *accountRepoStubForBulkUpdate) GetByID(_ context.Context, id int64) (*Ac ...@@ -59,6 +63,16 @@ func (s *accountRepoStubForBulkUpdate) GetByID(_ context.Context, id int64) (*Ac
return nil, errors.New("account not found") return nil, errors.New("account not found")
} }
func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID int64) ([]Account, error) {
if err, ok := s.listByGroupErr[groupID]; ok {
return nil, err
}
if rows, ok := s.listByGroupData[groupID]; ok {
return rows, nil
}
return nil, nil
}
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。 // TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) { func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) {
repo := &accountRepoStubForBulkUpdate{} repo := &accountRepoStubForBulkUpdate{}
...@@ -86,7 +100,10 @@ func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) { ...@@ -86,7 +100,10 @@ func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) {
2: errors.New("bind failed"), 2: errors.New("bind failed"),
}, },
} }
svc := &adminServiceImpl{accountRepo: repo} svc := &adminServiceImpl{
accountRepo: repo,
groupRepo: &groupRepoStubForAdmin{getByID: &Group{ID: 10, Name: "g10"}},
}
groupIDs := []int64{10} groupIDs := []int64{10}
schedulable := false schedulable := false
...@@ -105,3 +122,51 @@ func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) { ...@@ -105,3 +122,51 @@ func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) {
require.ElementsMatch(t, []int64{2}, result.FailedIDs) require.ElementsMatch(t, []int64{2}, result.FailedIDs)
require.Len(t, result.Results, 3) require.Len(t, result.Results, 3)
} }
func TestAdminService_BulkUpdateAccounts_NilGroupRepoReturnsError(t *testing.T) {
repo := &accountRepoStubForBulkUpdate{}
svc := &adminServiceImpl{accountRepo: repo}
groupIDs := []int64{10}
input := &BulkUpdateAccountsInput{
AccountIDs: []int64{1},
GroupIDs: &groupIDs,
}
result, err := svc.BulkUpdateAccounts(context.Background(), input)
require.Nil(t, result)
require.Error(t, err)
require.Contains(t, err.Error(), "group repository not configured")
}
func TestAdminService_BulkUpdateAccounts_MixedChannelCheckUsesUpdatedSnapshot(t *testing.T) {
repo := &accountRepoStubForBulkUpdate{
getByIDsAccounts: []*Account{
{ID: 1, Platform: PlatformAnthropic},
{ID: 2, Platform: PlatformAntigravity},
},
listByGroupData: map[int64][]Account{
10: {},
},
}
svc := &adminServiceImpl{
accountRepo: repo,
groupRepo: &groupRepoStubForAdmin{getByID: &Group{ID: 10, Name: "目标分组"}},
}
groupIDs := []int64{10}
input := &BulkUpdateAccountsInput{
AccountIDs: []int64{1, 2},
GroupIDs: &groupIDs,
}
result, err := svc.BulkUpdateAccounts(context.Background(), input)
require.NoError(t, err)
require.Equal(t, 1, result.Success)
require.Equal(t, 1, result.Failed)
require.ElementsMatch(t, []int64{1}, result.SuccessIDs)
require.ElementsMatch(t, []int64{2}, result.FailedIDs)
require.Len(t, result.Results, 2)
require.Contains(t, result.Results[1].Error, "mixed channel")
require.Equal(t, []int64{1}, repo.bindGroupsCalls)
}
//go:build unit
package service
import (
"context"
"errors"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type userRepoStubForListUsers struct {
userRepoStub
users []User
err error
}
func (s *userRepoStubForListUsers) ListWithFilters(_ context.Context, params pagination.PaginationParams, _ UserListFilters) ([]User, *pagination.PaginationResult, error) {
if s.err != nil {
return nil, nil, s.err
}
out := make([]User, len(s.users))
copy(out, s.users)
return out, &pagination.PaginationResult{
Total: int64(len(out)),
Page: params.Page,
PageSize: params.PageSize,
}, nil
}
type userGroupRateRepoStubForListUsers struct {
batchCalls int
singleCall []int64
batchErr error
batchData map[int64]map[int64]float64
singleErr map[int64]error
singleData map[int64]map[int64]float64
}
func (s *userGroupRateRepoStubForListUsers) GetByUserIDs(_ context.Context, _ []int64) (map[int64]map[int64]float64, error) {
s.batchCalls++
if s.batchErr != nil {
return nil, s.batchErr
}
return s.batchData, nil
}
func (s *userGroupRateRepoStubForListUsers) GetByUserID(_ context.Context, userID int64) (map[int64]float64, error) {
s.singleCall = append(s.singleCall, userID)
if err, ok := s.singleErr[userID]; ok {
return nil, err
}
if rates, ok := s.singleData[userID]; ok {
return rates, nil
}
return map[int64]float64{}, nil
}
func (s *userGroupRateRepoStubForListUsers) GetByUserAndGroup(_ context.Context, userID, groupID int64) (*float64, error) {
panic("unexpected GetByUserAndGroup call")
}
func (s *userGroupRateRepoStubForListUsers) SyncUserGroupRates(_ context.Context, userID int64, rates map[int64]*float64) error {
panic("unexpected SyncUserGroupRates call")
}
func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, groupID int64) error {
panic("unexpected DeleteByGroupID call")
}
func (s *userGroupRateRepoStubForListUsers) DeleteByUserID(_ context.Context, userID int64) error {
panic("unexpected DeleteByUserID call")
}
func TestAdminService_ListUsers_BatchRateFallbackToSingle(t *testing.T) {
userRepo := &userRepoStubForListUsers{
users: []User{
{ID: 101, Username: "u1"},
{ID: 202, Username: "u2"},
},
}
rateRepo := &userGroupRateRepoStubForListUsers{
batchErr: errors.New("batch unavailable"),
singleData: map[int64]map[int64]float64{
101: {11: 1.1},
202: {22: 2.2},
},
}
svc := &adminServiceImpl{
userRepo: userRepo,
userGroupRateRepo: rateRepo,
}
users, total, err := svc.ListUsers(context.Background(), 1, 20, UserListFilters{})
require.NoError(t, err)
require.Equal(t, int64(2), total)
require.Len(t, users, 2)
require.Equal(t, 1, rateRepo.batchCalls)
require.ElementsMatch(t, []int64{101, 202}, rateRepo.singleCall)
require.Equal(t, 1.1, users[0].GroupRates[11])
require.Equal(t, 2.2, users[1].GroupRates[22])
}
...@@ -21,7 +21,6 @@ import ( ...@@ -21,7 +21,6 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
...@@ -2291,7 +2290,7 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool { ...@@ -2291,7 +2290,7 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
// isSingleAccountRetry 检查 context 中是否设置了单账号退避重试标记 // isSingleAccountRetry 检查 context 中是否设置了单账号退避重试标记
func isSingleAccountRetry(ctx context.Context) bool { func isSingleAccountRetry(ctx context.Context) bool {
v, _ := ctx.Value(ctxkey.SingleAccountRetry).(bool) v, _ := SingleAccountRetryFromContext(ctx)
return v return v
} }
......
package service package service
import "time" import (
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
)
// API Key status constants // API Key status constants
const ( const (
...@@ -19,11 +23,14 @@ type APIKey struct { ...@@ -19,11 +23,14 @@ type APIKey struct {
Status string Status string
IPWhitelist []string IPWhitelist []string
IPBlacklist []string IPBlacklist []string
LastUsedAt *time.Time // 预编译的 IP 规则,用于认证热路径避免重复 ParseIP/ParseCIDR。
CreatedAt time.Time CompiledIPWhitelist *ip.CompiledIPRules `json:"-"`
UpdatedAt time.Time CompiledIPBlacklist *ip.CompiledIPRules `json:"-"`
User *User LastUsedAt *time.Time
Group *Group CreatedAt time.Time
UpdatedAt time.Time
User *User
Group *Group
// Quota fields // Quota fields
Quota float64 // Quota limit in USD (0 = unlimited) Quota float64 // Quota limit in USD (0 = unlimited)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment