Unverified Commit 9d795061 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #682 from mt21625457/pr/all-code-sync-20260228

feat(openai-ws): support websocket mode v2, optimize relay performance, enhance sora
parents bfc7b339 1d1fc019
......@@ -97,7 +97,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
// 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制
if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 {
clientIP := ip.GetTrustedClientIP(c)
allowed, _ := ip.CheckIPRestriction(clientIP, apiKey.IPWhitelist, apiKey.IPBlacklist)
allowed, _ := ip.CheckIPRestrictionWithCompiledRules(clientIP, apiKey.CompiledIPWhitelist, apiKey.CompiledIPBlacklist)
if !allowed {
AbortWithError(c, 403, "ACCESS_DENIED", "Access denied")
return
......
......@@ -80,17 +80,25 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
abortWithGoogleError(c, 403, "No active subscription found for this group")
return
}
if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil {
abortWithGoogleError(c, 403, err.Error())
return
needsMaintenance, err := subscriptionService.ValidateAndCheckLimits(subscription, apiKey.Group)
if err != nil {
status := 403
if errors.Is(err, service.ErrDailyLimitExceeded) ||
errors.Is(err, service.ErrWeeklyLimitExceeded) ||
errors.Is(err, service.ErrMonthlyLimitExceeded) {
status = 429
}
_ = subscriptionService.CheckAndActivateWindow(c.Request.Context(), subscription)
_ = subscriptionService.CheckAndResetWindows(c.Request.Context(), subscription)
if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil {
abortWithGoogleError(c, 429, err.Error())
abortWithGoogleError(c, status, err.Error())
return
}
c.Set(string(ContextKeySubscription), subscription)
if needsMaintenance {
maintenanceCopy := *subscription
subscriptionService.DoWindowMaintenance(&maintenanceCopy)
}
} else {
if apiKey.User.Balance <= 0 {
abortWithGoogleError(c, 403, "Insufficient account balance")
......
......@@ -23,6 +23,15 @@ type fakeAPIKeyRepo struct {
updateLastUsed func(ctx context.Context, id int64, usedAt time.Time) error
}
type fakeGoogleSubscriptionRepo struct {
getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error)
updateStatus func(ctx context.Context, subscriptionID int64, status string) error
activateWindow func(ctx context.Context, id int64, start time.Time) error
resetDaily func(ctx context.Context, id int64, start time.Time) error
resetWeekly func(ctx context.Context, id int64, start time.Time) error
resetMonthly func(ctx context.Context, id int64, start time.Time) error
}
func (f fakeAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
return errors.New("not implemented")
}
......@@ -87,6 +96,85 @@ func (f fakeAPIKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt tim
return nil
}
func (f fakeGoogleSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error {
return errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) {
return nil, errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
return nil, errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
if f.getActive != nil {
return f.getActive(ctx, userID, groupID)
}
return nil, errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) Update(ctx context.Context, sub *service.UserSubscription) error {
return errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
return nil, errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
return nil, errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
return false, errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error {
return errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error {
if f.updateStatus != nil {
return f.updateStatus(ctx, subscriptionID, status)
}
return errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error {
return errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) ActivateWindows(ctx context.Context, id int64, start time.Time) error {
if f.activateWindow != nil {
return f.activateWindow(ctx, id, start)
}
return errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) ResetDailyUsage(ctx context.Context, id int64, start time.Time) error {
if f.resetDaily != nil {
return f.resetDaily(ctx, id, start)
}
return errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) ResetWeeklyUsage(ctx context.Context, id int64, start time.Time) error {
if f.resetWeekly != nil {
return f.resetWeekly(ctx, id, start)
}
return errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) ResetMonthlyUsage(ctx context.Context, id int64, start time.Time) error {
if f.resetMonthly != nil {
return f.resetMonthly(ctx, id, start)
}
return errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
return errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
return 0, errors.New("not implemented")
}
type googleErrorResponse struct {
Error struct {
Code int `json:"code"`
......@@ -505,3 +593,85 @@ func TestApiKeyAuthWithSubscriptionGoogle_TouchesLastUsedInStandardMode(t *testi
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, 1, touchCalls)
}
func TestApiKeyAuthWithSubscriptionGoogle_SubscriptionLimitExceededReturns429(t *testing.T) {
gin.SetMode(gin.TestMode)
limit := 1.0
group := &service.Group{
ID: 77,
Name: "gemini-sub",
Status: service.StatusActive,
Platform: service.PlatformGemini,
Hydrated: true,
SubscriptionType: service.SubscriptionTypeSubscription,
DailyLimitUSD: &limit,
}
user := &service.User{
ID: 999,
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 10,
Concurrency: 3,
}
apiKey := &service.APIKey{
ID: 501,
UserID: user.ID,
Key: "google-sub-limit",
Status: service.StatusActive,
User: user,
Group: group,
}
apiKey.GroupID = &group.ID
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
if key != apiKey.Key {
return nil, service.ErrAPIKeyNotFound
}
clone := *apiKey
return &clone, nil
},
})
now := time.Now()
sub := &service.UserSubscription{
ID: 601,
UserID: user.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: now.Add(24 * time.Hour),
DailyWindowStart: &now,
DailyUsageUSD: 10,
}
subscriptionService := service.NewSubscriptionService(nil, fakeGoogleSubscriptionRepo{
getActive: func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
if userID != user.ID || groupID != group.ID {
return nil, service.ErrSubscriptionNotFound
}
clone := *sub
return &clone, nil
},
updateStatus: func(ctx context.Context, subscriptionID int64, status string) error { return nil },
activateWindow: func(ctx context.Context, id int64, start time.Time) error { return nil },
resetDaily: func(ctx context.Context, id int64, start time.Time) error { return nil },
resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil },
resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil },
}, nil, nil, &config.Config{RunMode: config.RunModeStandard})
r := gin.New()
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, &config.Config{RunMode: config.RunModeStandard}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
req.Header.Set("x-goog-api-key", apiKey.Key)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusTooManyRequests, rec.Code)
var resp googleErrorResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, http.StatusTooManyRequests, resp.Error.Code)
require.Equal(t, "RESOURCE_EXHAUSTED", resp.Error.Status)
require.Contains(t, resp.Error.Message, "daily usage limit exceeded")
}
......@@ -54,6 +54,10 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
c.Header("X-Content-Type-Options", "nosniff")
c.Header("X-Frame-Options", "DENY")
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
if isAPIRoutePath(c) {
c.Next()
return
}
if cfg.Enabled {
// Generate nonce for this request
......@@ -73,6 +77,18 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
}
}
func isAPIRoutePath(c *gin.Context) bool {
if c == nil || c.Request == nil || c.Request.URL == nil {
return false
}
path := c.Request.URL.Path
return strings.HasPrefix(path, "/v1/") ||
strings.HasPrefix(path, "/v1beta/") ||
strings.HasPrefix(path, "/antigravity/") ||
strings.HasPrefix(path, "/sora/") ||
strings.HasPrefix(path, "/responses")
}
// enhanceCSPPolicy ensures the CSP policy includes nonce support and Cloudflare Insights domain.
// This allows the application to work correctly even if the config file has an older CSP policy.
func enhanceCSPPolicy(policy string) string {
......
......@@ -131,6 +131,26 @@ func TestSecurityHeaders(t *testing.T) {
assert.Contains(t, csp, CloudflareInsightsDomain)
})
t.Run("api_route_skips_csp_nonce_generation", func(t *testing.T) {
cfg := config.CSPConfig{
Enabled: true,
Policy: "default-src 'self'; script-src 'self' __CSP_NONCE__",
}
middleware := SecurityHeaders(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
middleware(c)
assert.Equal(t, "nosniff", w.Header().Get("X-Content-Type-Options"))
assert.Equal(t, "DENY", w.Header().Get("X-Frame-Options"))
assert.Equal(t, "strict-origin-when-cross-origin", w.Header().Get("Referrer-Policy"))
assert.Empty(t, w.Header().Get("Content-Security-Policy"))
assert.Empty(t, GetNonceFromContext(c))
})
t.Run("csp_enabled_with_nonce_placeholder", func(t *testing.T) {
cfg := config.CSPConfig{
Enabled: true,
......
......@@ -75,6 +75,7 @@ func registerRoutes(
// 注册各模块路由
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient)
routes.RegisterUserRoutes(v1, h, jwtAuth)
routes.RegisterSoraClientRoutes(v1, h, jwtAuth)
routes.RegisterAdminRoutes(v1, h, adminAuth)
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg)
}
......@@ -55,6 +55,9 @@ func RegisterAdminRoutes(
// 系统设置
registerSettingsRoutes(admin, h)
// 数据管理
registerDataManagementRoutes(admin, h)
// 运维监控(Ops)
registerOpsRoutes(admin, h)
......@@ -231,6 +234,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts.POST("/:id/clear-error", h.Admin.Account.ClearError)
accounts.GET("/:id/usage", h.Admin.Account.GetUsage)
accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats)
accounts.POST("/today-stats/batch", h.Admin.Account.GetBatchTodayStats)
accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit)
accounts.GET("/:id/temp-unschedulable", h.Admin.Account.GetTempUnschedulable)
accounts.DELETE("/:id/temp-unschedulable", h.Admin.Account.ClearTempUnschedulable)
......@@ -370,6 +374,38 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
// 流超时处理配置
adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings)
adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings)
// Sora S3 存储配置
adminSettings.GET("/sora-s3", h.Admin.Setting.GetSoraS3Settings)
adminSettings.PUT("/sora-s3", h.Admin.Setting.UpdateSoraS3Settings)
adminSettings.POST("/sora-s3/test", h.Admin.Setting.TestSoraS3Connection)
adminSettings.GET("/sora-s3/profiles", h.Admin.Setting.ListSoraS3Profiles)
adminSettings.POST("/sora-s3/profiles", h.Admin.Setting.CreateSoraS3Profile)
adminSettings.PUT("/sora-s3/profiles/:profile_id", h.Admin.Setting.UpdateSoraS3Profile)
adminSettings.DELETE("/sora-s3/profiles/:profile_id", h.Admin.Setting.DeleteSoraS3Profile)
adminSettings.POST("/sora-s3/profiles/:profile_id/activate", h.Admin.Setting.SetActiveSoraS3Profile)
}
}
func registerDataManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
dataManagement := admin.Group("/data-management")
{
dataManagement.GET("/agent/health", h.Admin.DataManagement.GetAgentHealth)
dataManagement.GET("/config", h.Admin.DataManagement.GetConfig)
dataManagement.PUT("/config", h.Admin.DataManagement.UpdateConfig)
dataManagement.GET("/sources/:source_type/profiles", h.Admin.DataManagement.ListSourceProfiles)
dataManagement.POST("/sources/:source_type/profiles", h.Admin.DataManagement.CreateSourceProfile)
dataManagement.PUT("/sources/:source_type/profiles/:profile_id", h.Admin.DataManagement.UpdateSourceProfile)
dataManagement.DELETE("/sources/:source_type/profiles/:profile_id", h.Admin.DataManagement.DeleteSourceProfile)
dataManagement.POST("/sources/:source_type/profiles/:profile_id/activate", h.Admin.DataManagement.SetActiveSourceProfile)
dataManagement.POST("/s3/test", h.Admin.DataManagement.TestS3)
dataManagement.GET("/s3/profiles", h.Admin.DataManagement.ListS3Profiles)
dataManagement.POST("/s3/profiles", h.Admin.DataManagement.CreateS3Profile)
dataManagement.PUT("/s3/profiles/:profile_id", h.Admin.DataManagement.UpdateS3Profile)
dataManagement.DELETE("/s3/profiles/:profile_id", h.Admin.DataManagement.DeleteS3Profile)
dataManagement.POST("/s3/profiles/:profile_id/activate", h.Admin.DataManagement.SetActiveS3Profile)
dataManagement.POST("/backups", h.Admin.DataManagement.CreateBackupJob)
dataManagement.GET("/backups", h.Admin.DataManagement.ListBackupJobs)
dataManagement.GET("/backups/:job_id", h.Admin.DataManagement.GetBackupJob)
}
}
......
......@@ -43,6 +43,7 @@ func RegisterGatewayRoutes(
gateway.GET("/usage", h.Gateway.Usage)
// OpenAI Responses API
gateway.POST("/responses", h.OpenAIGateway.Responses)
gateway.GET("/responses", h.OpenAIGateway.ResponsesWebSocket)
// 明确阻止旧协议入口:OpenAI 仅支持 Responses API,避免客户端误解为会自动路由到其它平台。
gateway.POST("/chat/completions", func(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{
......@@ -69,6 +70,7 @@ func RegisterGatewayRoutes(
// OpenAI Responses API(不带v1前缀的别名)
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses)
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.ResponsesWebSocket)
// Antigravity 模型列表
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), h.Gateway.AntigravityModels)
......
package routes
import (
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/gin-gonic/gin"
)
// RegisterSoraClientRoutes 注册 Sora 客户端 API 路由(需要用户认证)。
func RegisterSoraClientRoutes(
v1 *gin.RouterGroup,
h *handler.Handlers,
jwtAuth middleware.JWTAuthMiddleware,
) {
if h.SoraClient == nil {
return
}
authenticated := v1.Group("/sora")
authenticated.Use(gin.HandlerFunc(jwtAuth))
{
authenticated.POST("/generate", h.SoraClient.Generate)
authenticated.GET("/generations", h.SoraClient.ListGenerations)
authenticated.GET("/generations/:id", h.SoraClient.GetGeneration)
authenticated.DELETE("/generations/:id", h.SoraClient.DeleteGeneration)
authenticated.POST("/generations/:id/cancel", h.SoraClient.CancelGeneration)
authenticated.POST("/generations/:id/save", h.SoraClient.SaveToStorage)
authenticated.GET("/quota", h.SoraClient.GetQuota)
authenticated.GET("/models", h.SoraClient.GetModels)
authenticated.GET("/storage-status", h.SoraClient.GetStorageStatus)
}
}
......@@ -3,6 +3,8 @@ package service
import (
"encoding/json"
"hash/fnv"
"reflect"
"sort"
"strconv"
"strings"
......@@ -50,6 +52,14 @@ type Account struct {
AccountGroups []AccountGroup
GroupIDs []int64
Groups []*Group
// model_mapping 热路径缓存(非持久化字段)
modelMappingCache map[string]string
modelMappingCacheReady bool
modelMappingCacheCredentialsPtr uintptr
modelMappingCacheRawPtr uintptr
modelMappingCacheRawLen int
modelMappingCacheRawSig uint64
}
type TempUnschedulableRule struct {
......@@ -349,6 +359,39 @@ func parseTempUnschedInt(value any) int {
}
func (a *Account) GetModelMapping() map[string]string {
credentialsPtr := mapPtr(a.Credentials)
rawMapping, _ := a.Credentials["model_mapping"].(map[string]any)
rawPtr := mapPtr(rawMapping)
rawLen := len(rawMapping)
rawSig := uint64(0)
rawSigReady := false
if a.modelMappingCacheReady &&
a.modelMappingCacheCredentialsPtr == credentialsPtr &&
a.modelMappingCacheRawPtr == rawPtr &&
a.modelMappingCacheRawLen == rawLen {
rawSig = modelMappingSignature(rawMapping)
rawSigReady = true
if a.modelMappingCacheRawSig == rawSig {
return a.modelMappingCache
}
}
mapping := a.resolveModelMapping(rawMapping)
if !rawSigReady {
rawSig = modelMappingSignature(rawMapping)
}
a.modelMappingCache = mapping
a.modelMappingCacheReady = true
a.modelMappingCacheCredentialsPtr = credentialsPtr
a.modelMappingCacheRawPtr = rawPtr
a.modelMappingCacheRawLen = rawLen
a.modelMappingCacheRawSig = rawSig
return mapping
}
func (a *Account) resolveModelMapping(rawMapping map[string]any) map[string]string {
if a.Credentials == nil {
// Antigravity 平台使用默认映射
if a.Platform == domain.PlatformAntigravity {
......@@ -356,17 +399,16 @@ func (a *Account) GetModelMapping() map[string]string {
}
return nil
}
raw, ok := a.Credentials["model_mapping"]
if !ok || raw == nil {
if len(rawMapping) == 0 {
// Antigravity 平台使用默认映射
if a.Platform == domain.PlatformAntigravity {
return domain.DefaultAntigravityModelMapping
}
return nil
}
if m, ok := raw.(map[string]any); ok {
result := make(map[string]string)
for k, v := range m {
for k, v := range rawMapping {
if s, ok := v.(string); ok {
result[k] = s
}
......@@ -381,7 +423,7 @@ func (a *Account) GetModelMapping() map[string]string {
}
return result
}
}
// Antigravity 平台使用默认映射
if a.Platform == domain.PlatformAntigravity {
return domain.DefaultAntigravityModelMapping
......@@ -389,6 +431,37 @@ func (a *Account) GetModelMapping() map[string]string {
return nil
}
func mapPtr(m map[string]any) uintptr {
if m == nil {
return 0
}
return reflect.ValueOf(m).Pointer()
}
func modelMappingSignature(rawMapping map[string]any) uint64 {
if len(rawMapping) == 0 {
return 0
}
keys := make([]string, 0, len(rawMapping))
for k := range rawMapping {
keys = append(keys, k)
}
sort.Strings(keys)
h := fnv.New64a()
for _, k := range keys {
_, _ = h.Write([]byte(k))
_, _ = h.Write([]byte{0})
if v, ok := rawMapping[k].(string); ok {
_, _ = h.Write([]byte(v))
} else {
_, _ = h.Write([]byte{1})
}
_, _ = h.Write([]byte{0xff})
}
return h.Sum64()
}
func ensureAntigravityDefaultPassthrough(mapping map[string]string, model string) {
if mapping == nil || model == "" {
return
......@@ -742,6 +815,159 @@ func (a *Account) IsOpenAIPassthroughEnabled() bool {
return false
}
// IsOpenAIResponsesWebSocketV2Enabled 返回 OpenAI 账号是否开启 Responses WebSocket v2。
//
// 分类型新字段:
// - OAuth 账号:accounts.extra.openai_oauth_responses_websockets_v2_enabled
// - API Key 账号:accounts.extra.openai_apikey_responses_websockets_v2_enabled
//
// 兼容字段:
// - accounts.extra.responses_websockets_v2_enabled
// - accounts.extra.openai_ws_enabled(历史开关)
//
// 优先级:
// 1. 按账号类型读取分类型字段
// 2. 分类型字段缺失时,回退兼容字段
func (a *Account) IsOpenAIResponsesWebSocketV2Enabled() bool {
if a == nil || !a.IsOpenAI() || a.Extra == nil {
return false
}
if a.IsOpenAIOAuth() {
if enabled, ok := a.Extra["openai_oauth_responses_websockets_v2_enabled"].(bool); ok {
return enabled
}
}
if a.IsOpenAIApiKey() {
if enabled, ok := a.Extra["openai_apikey_responses_websockets_v2_enabled"].(bool); ok {
return enabled
}
}
if enabled, ok := a.Extra["responses_websockets_v2_enabled"].(bool); ok {
return enabled
}
if enabled, ok := a.Extra["openai_ws_enabled"].(bool); ok {
return enabled
}
return false
}
const (
OpenAIWSIngressModeOff = "off"
OpenAIWSIngressModeShared = "shared"
OpenAIWSIngressModeDedicated = "dedicated"
)
func normalizeOpenAIWSIngressMode(mode string) string {
switch strings.ToLower(strings.TrimSpace(mode)) {
case OpenAIWSIngressModeOff:
return OpenAIWSIngressModeOff
case OpenAIWSIngressModeShared:
return OpenAIWSIngressModeShared
case OpenAIWSIngressModeDedicated:
return OpenAIWSIngressModeDedicated
default:
return ""
}
}
func normalizeOpenAIWSIngressDefaultMode(mode string) string {
if normalized := normalizeOpenAIWSIngressMode(mode); normalized != "" {
return normalized
}
return OpenAIWSIngressModeShared
}
// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/shared/dedicated)。
//
// 优先级:
// 1. 分类型 mode 新字段(string)
// 2. 分类型 enabled 旧字段(bool)
// 3. 兼容 enabled 旧字段(bool)
// 4. defaultMode(非法时回退 shared)
func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) string {
resolvedDefault := normalizeOpenAIWSIngressDefaultMode(defaultMode)
if a == nil || !a.IsOpenAI() {
return OpenAIWSIngressModeOff
}
if a.Extra == nil {
return resolvedDefault
}
resolveModeString := func(key string) (string, bool) {
raw, ok := a.Extra[key]
if !ok {
return "", false
}
mode, ok := raw.(string)
if !ok {
return "", false
}
normalized := normalizeOpenAIWSIngressMode(mode)
if normalized == "" {
return "", false
}
return normalized, true
}
resolveBoolMode := func(key string) (string, bool) {
raw, ok := a.Extra[key]
if !ok {
return "", false
}
enabled, ok := raw.(bool)
if !ok {
return "", false
}
if enabled {
return OpenAIWSIngressModeShared, true
}
return OpenAIWSIngressModeOff, true
}
if a.IsOpenAIOAuth() {
if mode, ok := resolveModeString("openai_oauth_responses_websockets_v2_mode"); ok {
return mode
}
if mode, ok := resolveBoolMode("openai_oauth_responses_websockets_v2_enabled"); ok {
return mode
}
}
if a.IsOpenAIApiKey() {
if mode, ok := resolveModeString("openai_apikey_responses_websockets_v2_mode"); ok {
return mode
}
if mode, ok := resolveBoolMode("openai_apikey_responses_websockets_v2_enabled"); ok {
return mode
}
}
if mode, ok := resolveBoolMode("responses_websockets_v2_enabled"); ok {
return mode
}
if mode, ok := resolveBoolMode("openai_ws_enabled"); ok {
return mode
}
return resolvedDefault
}
// IsOpenAIWSForceHTTPEnabled 返回账号级“强制 HTTP”开关。
// 字段:accounts.extra.openai_ws_force_http。
func (a *Account) IsOpenAIWSForceHTTPEnabled() bool {
if a == nil || !a.IsOpenAI() || a.Extra == nil {
return false
}
enabled, ok := a.Extra["openai_ws_force_http"].(bool)
return ok && enabled
}
// IsOpenAIWSAllowStoreRecoveryEnabled 返回账号级 store 恢复开关。
// 字段:accounts.extra.openai_ws_allow_store_recovery。
func (a *Account) IsOpenAIWSAllowStoreRecoveryEnabled() bool {
if a == nil || !a.IsOpenAI() || a.Extra == nil {
return false
}
enabled, ok := a.Extra["openai_ws_allow_store_recovery"].(bool)
return ok && enabled
}
// IsOpenAIOAuthPassthroughEnabled 兼容旧接口,等价于 OAuth 账号的 IsOpenAIPassthroughEnabled。
func (a *Account) IsOpenAIOAuthPassthroughEnabled() bool {
return a != nil && a.IsOpenAIOAuth() && a.IsOpenAIPassthroughEnabled()
......
......@@ -134,3 +134,161 @@ func TestAccount_IsCodexCLIOnlyEnabled(t *testing.T) {
require.False(t, otherPlatform.IsCodexCLIOnlyEnabled())
})
}
func TestAccount_IsOpenAIResponsesWebSocketV2Enabled(t *testing.T) {
t.Run("OAuth使用OAuth专用开关", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_enabled": true,
},
}
require.True(t, account.IsOpenAIResponsesWebSocketV2Enabled())
})
t.Run("API Key使用API Key专用开关", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
require.True(t, account.IsOpenAIResponsesWebSocketV2Enabled())
})
t.Run("OAuth账号不会读取API Key专用开关", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
require.False(t, account.IsOpenAIResponsesWebSocketV2Enabled())
})
t.Run("分类型新键优先于兼容键", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_enabled": false,
"responses_websockets_v2_enabled": true,
"openai_ws_enabled": true,
},
}
require.False(t, account.IsOpenAIResponsesWebSocketV2Enabled())
})
t.Run("分类型键缺失时回退兼容键", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
require.True(t, account.IsOpenAIResponsesWebSocketV2Enabled())
})
t.Run("非OpenAI账号默认关闭", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
require.False(t, account.IsOpenAIResponsesWebSocketV2Enabled())
})
}
func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
t.Run("default fallback to shared", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{},
}
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(""))
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid"))
})
t.Run("oauth mode field has highest priority", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
"openai_oauth_responses_websockets_v2_enabled": false,
"responses_websockets_v2_enabled": false,
},
}
require.Equal(t, OpenAIWSIngressModeDedicated, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared))
})
t.Run("legacy enabled maps to shared", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
})
t.Run("legacy disabled maps to off", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": false,
"responses_websockets_v2_enabled": true,
},
}
require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared))
})
t.Run("non openai always off", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
},
}
require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeDedicated))
})
}
func TestAccount_OpenAIWSExtraFlags(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_ws_force_http": true,
"openai_ws_allow_store_recovery": true,
},
}
require.True(t, account.IsOpenAIWSForceHTTPEnabled())
require.True(t, account.IsOpenAIWSAllowStoreRecoveryEnabled())
off := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{}}
require.False(t, off.IsOpenAIWSForceHTTPEnabled())
require.False(t, off.IsOpenAIWSAllowStoreRecoveryEnabled())
var nilAccount *Account
require.False(t, nilAccount.IsOpenAIWSAllowStoreRecoveryEnabled())
nonOpenAI := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_ws_allow_store_recovery": true,
},
}
require.False(t, nonOpenAI.IsOpenAIWSAllowStoreRecoveryEnabled())
}
......@@ -119,6 +119,10 @@ type AccountService struct {
groupRepo GroupRepository
}
type groupExistenceBatchChecker interface {
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
}
// NewAccountService 创建账号服务实例
func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository) *AccountService {
return &AccountService{
......@@ -131,11 +135,8 @@ func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository)
func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (*Account, error) {
// 验证分组是否存在(如果指定了分组)
if len(req.GroupIDs) > 0 {
for _, groupID := range req.GroupIDs {
_, err := s.groupRepo.GetByID(ctx, groupID)
if err != nil {
return nil, fmt.Errorf("get group: %w", err)
}
if err := s.validateGroupIDsExist(ctx, req.GroupIDs); err != nil {
return nil, err
}
}
......@@ -256,11 +257,8 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount
// 先验证分组是否存在(在任何写操作之前)
if req.GroupIDs != nil {
for _, groupID := range *req.GroupIDs {
_, err := s.groupRepo.GetByID(ctx, groupID)
if err != nil {
return nil, fmt.Errorf("get group: %w", err)
}
if err := s.validateGroupIDsExist(ctx, *req.GroupIDs); err != nil {
return nil, err
}
}
......@@ -300,6 +298,39 @@ func (s *AccountService) Delete(ctx context.Context, id int64) error {
return nil
}
func (s *AccountService) validateGroupIDsExist(ctx context.Context, groupIDs []int64) error {
if len(groupIDs) == 0 {
return nil
}
if s.groupRepo == nil {
return fmt.Errorf("group repository not configured")
}
if batchChecker, ok := s.groupRepo.(groupExistenceBatchChecker); ok {
existsByID, err := batchChecker.ExistsByIDs(ctx, groupIDs)
if err != nil {
return fmt.Errorf("check groups exists: %w", err)
}
for _, groupID := range groupIDs {
if groupID <= 0 {
return fmt.Errorf("get group: %w", ErrGroupNotFound)
}
if !existsByID[groupID] {
return fmt.Errorf("get group: %w", ErrGroupNotFound)
}
}
return nil
}
for _, groupID := range groupIDs {
_, err := s.groupRepo.GetByID(ctx, groupID)
if err != nil {
return fmt.Errorf("get group: %w", err)
}
}
return nil
}
// UpdateStatus 更新账号状态
func (s *AccountService) UpdateStatus(ctx context.Context, id int64, status string, errorMessage string) error {
account, err := s.accountRepo.GetByID(ctx, id)
......
......@@ -598,9 +598,102 @@ func ceilSeconds(d time.Duration) int {
return sec
}
// testSoraAPIKeyAccountConnection 测试 Sora apikey 类型账号的连通性。
// 向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性和 API Key 有效性。
func (s *AccountTestService) testSoraAPIKeyAccountConnection(c *gin.Context, account *Account) error {
ctx := c.Request.Context()
apiKey := account.GetCredential("api_key")
if apiKey == "" {
return s.sendErrorAndEnd(c, "Sora apikey 账号缺少 api_key 凭证")
}
baseURL := account.GetBaseURL()
if baseURL == "" {
return s.sendErrorAndEnd(c, "Sora apikey 账号缺少 base_url")
}
// 验证 base_url 格式
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("base_url 无效: %s", err.Error()))
}
upstreamURL := strings.TrimSuffix(normalizedBaseURL, "/") + "/sora/v1/chat/completions"
// 设置 SSE 头
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Flush()
if wait, ok := s.acquireSoraTestPermit(account.ID); !ok {
msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait))
return s.sendErrorAndEnd(c, msg)
}
s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora-upstream"})
// 构建轻量级 prompt-enhance 请求作为连通性测试
testPayload := map[string]any{
"model": "prompt-enhance-short-10s",
"messages": []map[string]string{{"role": "user", "content": "test"}},
"stream": false,
}
payloadBytes, _ := json.Marshal(testPayload)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(payloadBytes))
if err != nil {
return s.sendErrorAndEnd(c, "构建测试请求失败")
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
// 获取代理 URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("上游连接失败: %s", err.Error()))
}
defer func() { _ = resp.Body.Close() }()
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 64*1024))
if resp.StatusCode == http.StatusOK {
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("上游连接成功 (%s)", upstreamURL)})
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("API Key 有效 (HTTP %d)", resp.StatusCode)})
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
return s.sendErrorAndEnd(c, fmt.Sprintf("上游认证失败 (HTTP %d),请检查 API Key 是否正确", resp.StatusCode))
}
// 其他错误但能连通(如 400 参数错误)也算连通性测试通过
if resp.StatusCode == http.StatusBadRequest {
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("上游连接成功 (%s)", upstreamURL)})
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("API Key 有效(上游返回 %d,参数校验错误属正常)", resp.StatusCode)})
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
return s.sendErrorAndEnd(c, fmt.Sprintf("上游返回异常 HTTP %d: %s", resp.StatusCode, truncateSoraErrorBody(respBody, 256)))
}
// testSoraAccountConnection 测试 Sora 账号的连接
// 调用 /backend/me 接口验证 access_token 有效性(不需要 Sentinel Token)
// OAuth 类型:调用 /backend/me 接口验证 access_token 有效性
// APIKey 类型:向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性
func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *Account) error {
// apikey 类型走独立测试流程
if account.Type == AccountTypeAPIKey {
return s.testSoraAPIKeyAccountConnection(c, account)
}
ctx := c.Request.Context()
recorder := &soraProbeRecorder{}
......
......@@ -9,7 +9,9 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"golang.org/x/sync/errgroup"
)
type UsageLogRepository interface {
......@@ -33,8 +35,8 @@ type UsageLogRepository interface {
// Admin dashboard stats
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error)
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error)
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error)
......@@ -62,6 +64,10 @@ type UsageLogRepository interface {
GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error)
}
type accountWindowStatsBatchReader interface {
GetAccountWindowStatsBatch(ctx context.Context, accountIDs []int64, startTime time.Time) (map[int64]*usagestats.AccountStats, error)
}
// apiUsageCache 缓存从 Anthropic API 获取的使用率数据(utilization, resets_at)
type apiUsageCache struct {
response *ClaudeUsageResponse
......@@ -297,7 +303,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
}
dayStart := geminiDailyWindowStart(now)
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil, nil)
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil, nil, nil)
if err != nil {
return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
}
......@@ -319,7 +325,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
// Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m)
minuteStart := now.Truncate(time.Minute)
minuteResetAt := minuteStart.Add(time.Minute)
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil, nil)
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil, nil, nil)
if err != nil {
return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err)
}
......@@ -440,6 +446,78 @@ func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64
}, nil
}
// GetTodayStatsBatch 批量获取账号今日统计,优先走批量 SQL,失败时回退单账号查询。
func (s *AccountUsageService) GetTodayStatsBatch(ctx context.Context, accountIDs []int64) (map[int64]*WindowStats, error) {
uniqueIDs := make([]int64, 0, len(accountIDs))
seen := make(map[int64]struct{}, len(accountIDs))
for _, accountID := range accountIDs {
if accountID <= 0 {
continue
}
if _, exists := seen[accountID]; exists {
continue
}
seen[accountID] = struct{}{}
uniqueIDs = append(uniqueIDs, accountID)
}
result := make(map[int64]*WindowStats, len(uniqueIDs))
if len(uniqueIDs) == 0 {
return result, nil
}
startTime := timezone.Today()
if batchReader, ok := s.usageLogRepo.(accountWindowStatsBatchReader); ok {
statsByAccount, err := batchReader.GetAccountWindowStatsBatch(ctx, uniqueIDs, startTime)
if err == nil {
for _, accountID := range uniqueIDs {
result[accountID] = windowStatsFromAccountStats(statsByAccount[accountID])
}
return result, nil
}
}
var mu sync.Mutex
g, gctx := errgroup.WithContext(ctx)
g.SetLimit(8)
for _, accountID := range uniqueIDs {
id := accountID
g.Go(func() error {
stats, err := s.usageLogRepo.GetAccountWindowStats(gctx, id, startTime)
if err != nil {
return nil
}
mu.Lock()
result[id] = windowStatsFromAccountStats(stats)
mu.Unlock()
return nil
})
}
_ = g.Wait()
for _, accountID := range uniqueIDs {
if _, ok := result[accountID]; !ok {
result[accountID] = &WindowStats{}
}
}
return result, nil
}
func windowStatsFromAccountStats(stats *usagestats.AccountStats) *WindowStats {
if stats == nil {
return &WindowStats{}
}
return &WindowStats{
Requests: stats.Requests,
Tokens: stats.Tokens,
Cost: stats.Cost,
StandardCost: stats.StandardCost,
UserCost: stats.UserCost,
}
}
func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) {
stats, err := s.usageLogRepo.GetAccountUsageStats(ctx, accountID, startTime, endTime)
if err != nil {
......
......@@ -314,3 +314,72 @@ func TestAccountGetModelMapping_AntigravityRespectsWildcardOverride(t *testing.T
t.Fatalf("expected wildcard mapping to stay effective, got: %q", mapped)
}
}
func TestAccountGetModelMapping_CacheInvalidatesOnCredentialsReplace(t *testing.T) {
account := &Account{
Credentials: map[string]any{
"model_mapping": map[string]any{
"claude-3-5-sonnet": "upstream-a",
},
},
}
first := account.GetModelMapping()
if first["claude-3-5-sonnet"] != "upstream-a" {
t.Fatalf("unexpected first mapping: %v", first)
}
account.Credentials = map[string]any{
"model_mapping": map[string]any{
"claude-3-5-sonnet": "upstream-b",
},
}
second := account.GetModelMapping()
if second["claude-3-5-sonnet"] != "upstream-b" {
t.Fatalf("expected cache invalidated after credentials replace, got: %v", second)
}
}
func TestAccountGetModelMapping_CacheInvalidatesOnMappingLenChange(t *testing.T) {
rawMapping := map[string]any{
"claude-sonnet": "sonnet-a",
}
account := &Account{
Credentials: map[string]any{
"model_mapping": rawMapping,
},
}
first := account.GetModelMapping()
if len(first) != 1 {
t.Fatalf("unexpected first mapping length: %d", len(first))
}
rawMapping["claude-opus"] = "opus-b"
second := account.GetModelMapping()
if second["claude-opus"] != "opus-b" {
t.Fatalf("expected cache invalidated after mapping len change, got: %v", second)
}
}
func TestAccountGetModelMapping_CacheInvalidatesOnInPlaceValueChange(t *testing.T) {
rawMapping := map[string]any{
"claude-sonnet": "sonnet-a",
}
account := &Account{
Credentials: map[string]any{
"model_mapping": rawMapping,
},
}
first := account.GetModelMapping()
if first["claude-sonnet"] != "sonnet-a" {
t.Fatalf("unexpected first mapping: %v", first)
}
rawMapping["claude-sonnet"] = "sonnet-b"
second := account.GetModelMapping()
if second["claude-sonnet"] != "sonnet-b" {
t.Fatalf("expected cache invalidated after in-place value change, got: %v", second)
}
}
......@@ -90,6 +90,7 @@ type CreateUserInput struct {
Balance float64
Concurrency int
AllowedGroups []int64
SoraStorageQuotaBytes int64
}
type UpdateUserInput struct {
......@@ -104,6 +105,7 @@ type UpdateUserInput struct {
// GroupRates 用户专属分组倍率配置
// map[groupID]*rate,nil 表示删除该分组的专属倍率
GroupRates map[int64]*float64
SoraStorageQuotaBytes *int64
}
type CreateGroupInput struct {
......@@ -135,6 +137,8 @@ type CreateGroupInput struct {
MCPXMLInject *bool
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string
// Sora 存储配额
SoraStorageQuotaBytes int64
// 从指定分组复制账号(创建分组后在同一事务内绑定)
CopyAccountsFromGroupIDs []int64
}
......@@ -169,6 +173,8 @@ type UpdateGroupInput struct {
MCPXMLInject *bool
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string
// Sora 存储配额
SoraStorageQuotaBytes *int64
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64
}
......@@ -402,6 +408,14 @@ type adminServiceImpl struct {
authCacheInvalidator APIKeyAuthCacheInvalidator
}
type userGroupRateBatchReader interface {
GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error)
}
type groupExistenceBatchReader interface {
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
}
// NewAdminService creates a new AdminService
func NewAdminService(
userRepo UserRepository,
......@@ -442,6 +456,33 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi
}
// 批量加载用户专属分组倍率
if s.userGroupRateRepo != nil && len(users) > 0 {
if batchRepo, ok := s.userGroupRateRepo.(userGroupRateBatchReader); ok {
userIDs := make([]int64, 0, len(users))
for i := range users {
userIDs = append(userIDs, users[i].ID)
}
ratesByUser, err := batchRepo.GetByUserIDs(ctx, userIDs)
if err != nil {
logger.LegacyPrintf("service.admin", "failed to load user group rates in batch: err=%v", err)
s.loadUserGroupRatesOneByOne(ctx, users)
} else {
for i := range users {
if rates, ok := ratesByUser[users[i].ID]; ok {
users[i].GroupRates = rates
}
}
}
} else {
s.loadUserGroupRatesOneByOne(ctx, users)
}
}
return users, result.Total, nil
}
func (s *adminServiceImpl) loadUserGroupRatesOneByOne(ctx context.Context, users []User) {
if s.userGroupRateRepo == nil {
return
}
for i := range users {
rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID)
if err != nil {
......@@ -450,8 +491,6 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi
}
users[i].GroupRates = rates
}
}
return users, result.Total, nil
}
func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) {
......@@ -481,6 +520,7 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu
Concurrency: input.Concurrency,
Status: StatusActive,
AllowedGroups: input.AllowedGroups,
SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
}
if err := user.SetPassword(input.Password); err != nil {
return nil, err
......@@ -534,6 +574,10 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
user.AllowedGroups = *input.AllowedGroups
}
if input.SoraStorageQuotaBytes != nil {
user.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes
}
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err
}
......@@ -820,6 +864,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
ModelRouting: input.ModelRouting,
MCPXMLInject: mcpXMLInject,
SupportedModelScopes: input.SupportedModelScopes,
SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
}
if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err
......@@ -982,6 +1027,9 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.SoraVideoPricePerRequestHD != nil {
group.SoraVideoPricePerRequestHD = normalizePrice(input.SoraVideoPricePerRequestHD)
}
if input.SoraStorageQuotaBytes != nil {
group.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes
}
// Claude Code 客户端限制
if input.ClaudeCodeOnly != nil {
......@@ -1188,6 +1236,18 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
}
}
// Sora apikey 账号的 base_url 必填校验
if input.Platform == PlatformSora && input.Type == AccountTypeAPIKey {
baseURL, _ := input.Credentials["base_url"].(string)
baseURL = strings.TrimSpace(baseURL)
if baseURL == "" {
return nil, errors.New("sora apikey 账号必须设置 base_url")
}
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
return nil, errors.New("base_url 必须以 http:// 或 https:// 开头")
}
}
account := &Account{
Name: input.Name,
Notes: normalizeAccountNotes(input.Notes),
......@@ -1301,12 +1361,22 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
account.AutoPauseOnExpired = *input.AutoPauseOnExpired
}
// Sora apikey 账号的 base_url 必填校验
if account.Platform == PlatformSora && account.Type == AccountTypeAPIKey {
baseURL, _ := account.Credentials["base_url"].(string)
baseURL = strings.TrimSpace(baseURL)
if baseURL == "" {
return nil, errors.New("sora apikey 账号必须设置 base_url")
}
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
return nil, errors.New("base_url 必须以 http:// 或 https:// 开头")
}
}
// 先验证分组是否存在(在任何写操作之前)
if input.GroupIDs != nil {
for _, groupID := range *input.GroupIDs {
if _, err := s.groupRepo.GetByID(ctx, groupID); err != nil {
return nil, fmt.Errorf("get group: %w", err)
}
if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil {
return nil, err
}
// 检查混合渠道风险(除非用户已确认)
......@@ -1348,11 +1418,18 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
if len(input.AccountIDs) == 0 {
return result, nil
}
if input.GroupIDs != nil {
if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil {
return nil, err
}
}
needMixedChannelCheck := input.GroupIDs != nil && !input.SkipMixedChannelCheck
// 预加载账号平台信息(混合渠道检查或 Sora 同步需要)。
platformByID := map[int64]string{}
groupAccountsByID := map[int64][]Account{}
groupNameByID := map[int64]string{}
if needMixedChannelCheck {
accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs)
if err != nil {
......@@ -1366,6 +1443,13 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
}
}
}
loadedAccounts, loadedNames, err := s.preloadMixedChannelRiskData(ctx, *input.GroupIDs)
if err != nil {
return nil, err
}
groupAccountsByID = loadedAccounts
groupNameByID = loadedNames
}
if input.RateMultiplier != nil {
......@@ -1409,11 +1493,12 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
// Handle group bindings per account (requires individual operations).
for _, accountID := range input.AccountIDs {
entry := BulkUpdateAccountResult{AccountID: accountID}
platform := ""
if input.GroupIDs != nil {
// 检查混合渠道风险(除非用户已确认)
if !input.SkipMixedChannelCheck {
platform := platformByID[accountID]
platform = platformByID[accountID]
if platform == "" {
account, err := s.accountRepo.GetByID(ctx, accountID)
if err != nil {
......@@ -1426,7 +1511,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
}
platform = account.Platform
}
if err := s.checkMixedChannelRisk(ctx, accountID, platform, *input.GroupIDs); err != nil {
if err := s.checkMixedChannelRiskWithPreloaded(accountID, platform, *input.GroupIDs, groupAccountsByID, groupNameByID); err != nil {
entry.Success = false
entry.Error = err.Error()
result.Failed++
......@@ -1444,6 +1529,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
result.Results = append(result.Results, entry)
continue
}
if !input.SkipMixedChannelCheck && platform != "" {
updateMixedChannelPreloadedAccounts(groupAccountsByID, *input.GroupIDs, accountID, platform)
}
}
entry.Success = true
......@@ -2115,6 +2203,135 @@ func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAcc
return nil
}
func (s *adminServiceImpl) preloadMixedChannelRiskData(ctx context.Context, groupIDs []int64) (map[int64][]Account, map[int64]string, error) {
accountsByGroup := make(map[int64][]Account)
groupNameByID := make(map[int64]string)
if len(groupIDs) == 0 {
return accountsByGroup, groupNameByID, nil
}
seen := make(map[int64]struct{}, len(groupIDs))
for _, groupID := range groupIDs {
if groupID <= 0 {
continue
}
if _, ok := seen[groupID]; ok {
continue
}
seen[groupID] = struct{}{}
accounts, err := s.accountRepo.ListByGroup(ctx, groupID)
if err != nil {
return nil, nil, fmt.Errorf("get accounts in group %d: %w", groupID, err)
}
accountsByGroup[groupID] = accounts
group, err := s.groupRepo.GetByID(ctx, groupID)
if err != nil {
continue
}
if group != nil {
groupNameByID[groupID] = group.Name
}
}
return accountsByGroup, groupNameByID, nil
}
func (s *adminServiceImpl) validateGroupIDsExist(ctx context.Context, groupIDs []int64) error {
if len(groupIDs) == 0 {
return nil
}
if s.groupRepo == nil {
return errors.New("group repository not configured")
}
if batchReader, ok := s.groupRepo.(groupExistenceBatchReader); ok {
existsByID, err := batchReader.ExistsByIDs(ctx, groupIDs)
if err != nil {
return fmt.Errorf("check groups exists: %w", err)
}
for _, groupID := range groupIDs {
if groupID <= 0 || !existsByID[groupID] {
return fmt.Errorf("get group: %w", ErrGroupNotFound)
}
}
return nil
}
for _, groupID := range groupIDs {
if _, err := s.groupRepo.GetByID(ctx, groupID); err != nil {
return fmt.Errorf("get group: %w", err)
}
}
return nil
}
func (s *adminServiceImpl) checkMixedChannelRiskWithPreloaded(currentAccountID int64, currentAccountPlatform string, groupIDs []int64, accountsByGroup map[int64][]Account, groupNameByID map[int64]string) error {
currentPlatform := getAccountPlatform(currentAccountPlatform)
if currentPlatform == "" {
return nil
}
for _, groupID := range groupIDs {
accounts := accountsByGroup[groupID]
for _, account := range accounts {
if currentAccountID > 0 && account.ID == currentAccountID {
continue
}
otherPlatform := getAccountPlatform(account.Platform)
if otherPlatform == "" {
continue
}
if currentPlatform != otherPlatform {
groupName := fmt.Sprintf("Group %d", groupID)
if name := strings.TrimSpace(groupNameByID[groupID]); name != "" {
groupName = name
}
return &MixedChannelError{
GroupID: groupID,
GroupName: groupName,
CurrentPlatform: currentPlatform,
OtherPlatform: otherPlatform,
}
}
}
}
return nil
}
func updateMixedChannelPreloadedAccounts(accountsByGroup map[int64][]Account, groupIDs []int64, accountID int64, platform string) {
if len(groupIDs) == 0 || accountID <= 0 || platform == "" {
return
}
for _, groupID := range groupIDs {
if groupID <= 0 {
continue
}
accounts := accountsByGroup[groupID]
found := false
for i := range accounts {
if accounts[i].ID != accountID {
continue
}
accounts[i].Platform = platform
found = true
break
}
if !found {
accounts = append(accounts, Account{
ID: accountID,
Platform: platform,
})
}
accountsByGroup[groupID] = accounts
}
}
// CheckMixedChannelRisk checks whether target groups contain mixed channels for the current account platform.
func (s *adminServiceImpl) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
return s.checkMixedChannelRisk(ctx, currentAccountID, currentAccountPlatform, groupIDs)
......
......@@ -15,6 +15,7 @@ type accountRepoStubForBulkUpdate struct {
bulkUpdateErr error
bulkUpdateIDs []int64
bindGroupErrByID map[int64]error
bindGroupsCalls []int64
getByIDsAccounts []*Account
getByIDsErr error
getByIDsCalled bool
......@@ -22,6 +23,8 @@ type accountRepoStubForBulkUpdate struct {
getByIDAccounts map[int64]*Account
getByIDErrByID map[int64]error
getByIDCalled []int64
listByGroupData map[int64][]Account
listByGroupErr map[int64]error
}
func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) {
......@@ -33,6 +36,7 @@ func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64
}
func (s *accountRepoStubForBulkUpdate) BindGroups(_ context.Context, accountID int64, _ []int64) error {
s.bindGroupsCalls = append(s.bindGroupsCalls, accountID)
if err, ok := s.bindGroupErrByID[accountID]; ok {
return err
}
......@@ -59,6 +63,16 @@ func (s *accountRepoStubForBulkUpdate) GetByID(_ context.Context, id int64) (*Ac
return nil, errors.New("account not found")
}
func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID int64) ([]Account, error) {
if err, ok := s.listByGroupErr[groupID]; ok {
return nil, err
}
if rows, ok := s.listByGroupData[groupID]; ok {
return rows, nil
}
return nil, nil
}
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) {
repo := &accountRepoStubForBulkUpdate{}
......@@ -86,7 +100,10 @@ func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) {
2: errors.New("bind failed"),
},
}
svc := &adminServiceImpl{accountRepo: repo}
svc := &adminServiceImpl{
accountRepo: repo,
groupRepo: &groupRepoStubForAdmin{getByID: &Group{ID: 10, Name: "g10"}},
}
groupIDs := []int64{10}
schedulable := false
......@@ -105,3 +122,51 @@ func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) {
require.ElementsMatch(t, []int64{2}, result.FailedIDs)
require.Len(t, result.Results, 3)
}
func TestAdminService_BulkUpdateAccounts_NilGroupRepoReturnsError(t *testing.T) {
repo := &accountRepoStubForBulkUpdate{}
svc := &adminServiceImpl{accountRepo: repo}
groupIDs := []int64{10}
input := &BulkUpdateAccountsInput{
AccountIDs: []int64{1},
GroupIDs: &groupIDs,
}
result, err := svc.BulkUpdateAccounts(context.Background(), input)
require.Nil(t, result)
require.Error(t, err)
require.Contains(t, err.Error(), "group repository not configured")
}
func TestAdminService_BulkUpdateAccounts_MixedChannelCheckUsesUpdatedSnapshot(t *testing.T) {
repo := &accountRepoStubForBulkUpdate{
getByIDsAccounts: []*Account{
{ID: 1, Platform: PlatformAnthropic},
{ID: 2, Platform: PlatformAntigravity},
},
listByGroupData: map[int64][]Account{
10: {},
},
}
svc := &adminServiceImpl{
accountRepo: repo,
groupRepo: &groupRepoStubForAdmin{getByID: &Group{ID: 10, Name: "目标分组"}},
}
groupIDs := []int64{10}
input := &BulkUpdateAccountsInput{
AccountIDs: []int64{1, 2},
GroupIDs: &groupIDs,
}
result, err := svc.BulkUpdateAccounts(context.Background(), input)
require.NoError(t, err)
require.Equal(t, 1, result.Success)
require.Equal(t, 1, result.Failed)
require.ElementsMatch(t, []int64{1}, result.SuccessIDs)
require.ElementsMatch(t, []int64{2}, result.FailedIDs)
require.Len(t, result.Results, 2)
require.Contains(t, result.Results[1].Error, "mixed channel")
require.Equal(t, []int64{1}, repo.bindGroupsCalls)
}
//go:build unit
package service
import (
"context"
"errors"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type userRepoStubForListUsers struct {
userRepoStub
users []User
err error
}
func (s *userRepoStubForListUsers) ListWithFilters(_ context.Context, params pagination.PaginationParams, _ UserListFilters) ([]User, *pagination.PaginationResult, error) {
if s.err != nil {
return nil, nil, s.err
}
out := make([]User, len(s.users))
copy(out, s.users)
return out, &pagination.PaginationResult{
Total: int64(len(out)),
Page: params.Page,
PageSize: params.PageSize,
}, nil
}
type userGroupRateRepoStubForListUsers struct {
batchCalls int
singleCall []int64
batchErr error
batchData map[int64]map[int64]float64
singleErr map[int64]error
singleData map[int64]map[int64]float64
}
func (s *userGroupRateRepoStubForListUsers) GetByUserIDs(_ context.Context, _ []int64) (map[int64]map[int64]float64, error) {
s.batchCalls++
if s.batchErr != nil {
return nil, s.batchErr
}
return s.batchData, nil
}
func (s *userGroupRateRepoStubForListUsers) GetByUserID(_ context.Context, userID int64) (map[int64]float64, error) {
s.singleCall = append(s.singleCall, userID)
if err, ok := s.singleErr[userID]; ok {
return nil, err
}
if rates, ok := s.singleData[userID]; ok {
return rates, nil
}
return map[int64]float64{}, nil
}
func (s *userGroupRateRepoStubForListUsers) GetByUserAndGroup(_ context.Context, userID, groupID int64) (*float64, error) {
panic("unexpected GetByUserAndGroup call")
}
func (s *userGroupRateRepoStubForListUsers) SyncUserGroupRates(_ context.Context, userID int64, rates map[int64]*float64) error {
panic("unexpected SyncUserGroupRates call")
}
func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, groupID int64) error {
panic("unexpected DeleteByGroupID call")
}
func (s *userGroupRateRepoStubForListUsers) DeleteByUserID(_ context.Context, userID int64) error {
panic("unexpected DeleteByUserID call")
}
func TestAdminService_ListUsers_BatchRateFallbackToSingle(t *testing.T) {
userRepo := &userRepoStubForListUsers{
users: []User{
{ID: 101, Username: "u1"},
{ID: 202, Username: "u2"},
},
}
rateRepo := &userGroupRateRepoStubForListUsers{
batchErr: errors.New("batch unavailable"),
singleData: map[int64]map[int64]float64{
101: {11: 1.1},
202: {22: 2.2},
},
}
svc := &adminServiceImpl{
userRepo: userRepo,
userGroupRateRepo: rateRepo,
}
users, total, err := svc.ListUsers(context.Background(), 1, 20, UserListFilters{})
require.NoError(t, err)
require.Equal(t, int64(2), total)
require.Len(t, users, 2)
require.Equal(t, 1, rateRepo.batchCalls)
require.ElementsMatch(t, []int64{101, 202}, rateRepo.singleCall)
require.Equal(t, 1.1, users[0].GroupRates[11])
require.Equal(t, 2.2, users[1].GroupRates[22])
}
......@@ -21,7 +21,6 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
......@@ -2291,7 +2290,7 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
// isSingleAccountRetry 检查 context 中是否设置了单账号退避重试标记
func isSingleAccountRetry(ctx context.Context) bool {
v, _ := ctx.Value(ctxkey.SingleAccountRetry).(bool)
v, _ := SingleAccountRetryFromContext(ctx)
return v
}
......
package service
import "time"
import (
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
)
// API Key status constants
const (
......@@ -19,6 +23,9 @@ type APIKey struct {
Status string
IPWhitelist []string
IPBlacklist []string
// 预编译的 IP 规则,用于认证热路径避免重复 ParseIP/ParseCIDR。
CompiledIPWhitelist *ip.CompiledIPRules `json:"-"`
CompiledIPBlacklist *ip.CompiledIPRules `json:"-"`
LastUsedAt *time.Time
CreatedAt time.Time
UpdatedAt time.Time
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment