"backend/internal/vscode:/vscode.git/clone" did not exist on "4905e7193af9cd74d305998bf889b304ff8061f6"
Commit 1a641392 authored by cyhhao's avatar cyhhao
Browse files

Merge up/main

parents 36b817d0 24d19a5f
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/web" "github.com/Wei-Shaw/sub2api/internal/web"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
) )
// SetupRouter 配置路由器中间件和路由 // SetupRouter 配置路由器中间件和路由
...@@ -21,6 +22,7 @@ func SetupRouter( ...@@ -21,6 +22,7 @@ func SetupRouter(
apiKeyService *service.APIKeyService, apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService, subscriptionService *service.SubscriptionService,
cfg *config.Config, cfg *config.Config,
redisClient *redis.Client,
) *gin.Engine { ) *gin.Engine {
// 应用中间件 // 应用中间件
r.Use(middleware2.Logger()) r.Use(middleware2.Logger())
...@@ -33,7 +35,7 @@ func SetupRouter( ...@@ -33,7 +35,7 @@ func SetupRouter(
} }
// 注册路由 // 注册路由
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg) registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg, redisClient)
return r return r
} }
...@@ -48,6 +50,7 @@ func registerRoutes( ...@@ -48,6 +50,7 @@ func registerRoutes(
apiKeyService *service.APIKeyService, apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService, subscriptionService *service.SubscriptionService,
cfg *config.Config, cfg *config.Config,
redisClient *redis.Client,
) { ) {
// 通用路由(健康检查、状态等) // 通用路由(健康检查、状态等)
routes.RegisterCommonRoutes(r) routes.RegisterCommonRoutes(r)
...@@ -56,7 +59,7 @@ func registerRoutes( ...@@ -56,7 +59,7 @@ func registerRoutes(
v1 := r.Group("/api/v1") v1 := r.Group("/api/v1")
// 注册各模块路由 // 注册各模块路由
routes.RegisterAuthRoutes(v1, h, jwtAuth) routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient)
routes.RegisterUserRoutes(v1, h, jwtAuth) routes.RegisterUserRoutes(v1, h, jwtAuth)
routes.RegisterAdminRoutes(v1, h, adminAuth) routes.RegisterAdminRoutes(v1, h, adminAuth)
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, cfg) routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, cfg)
......
...@@ -44,6 +44,9 @@ func RegisterAdminRoutes( ...@@ -44,6 +44,9 @@ func RegisterAdminRoutes(
// 卡密管理 // 卡密管理
registerRedeemCodeRoutes(admin, h) registerRedeemCodeRoutes(admin, h)
// 优惠码管理
registerPromoCodeRoutes(admin, h)
// 系统设置 // 系统设置
registerSettingsRoutes(admin, h) registerSettingsRoutes(admin, h)
...@@ -201,6 +204,18 @@ func registerRedeemCodeRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -201,6 +204,18 @@ func registerRedeemCodeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
} }
} }
func registerPromoCodeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
promoCodes := admin.Group("/promo-codes")
{
promoCodes.GET("", h.Admin.Promo.List)
promoCodes.GET("/:id", h.Admin.Promo.GetByID)
promoCodes.POST("", h.Admin.Promo.Create)
promoCodes.PUT("/:id", h.Admin.Promo.Update)
promoCodes.DELETE("/:id", h.Admin.Promo.Delete)
promoCodes.GET("/:id/usages", h.Admin.Promo.GetUsages)
}
}
func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
adminSettings := admin.Group("/settings") adminSettings := admin.Group("/settings")
{ {
......
package routes package routes
import ( import (
"time"
"github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/middleware"
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
) )
// RegisterAuthRoutes 注册认证相关路由 // RegisterAuthRoutes 注册认证相关路由
func RegisterAuthRoutes( func RegisterAuthRoutes(
v1 *gin.RouterGroup, v1 *gin.RouterGroup,
h *handler.Handlers, h *handler.Handlers,
jwtAuth middleware.JWTAuthMiddleware, jwtAuth servermiddleware.JWTAuthMiddleware,
redisClient *redis.Client,
) { ) {
// 创建速率限制器
rateLimiter := middleware.NewRateLimiter(redisClient)
// 公开接口 // 公开接口
auth := v1.Group("/auth") auth := v1.Group("/auth")
{ {
auth.POST("/register", h.Auth.Register) auth.POST("/register", h.Auth.Register)
auth.POST("/login", h.Auth.Login) auth.POST("/login", h.Auth.Login)
auth.POST("/send-verify-code", h.Auth.SendVerifyCode) auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
// 优惠码验证接口添加速率限制:每分钟最多 10 次
auth.POST("/validate-promo-code", rateLimiter.Limit("validate-promo", 10, time.Minute), h.Auth.ValidatePromoCode)
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
} }
// 公开设置(无需认证) // 公开设置(无需认证)
......
...@@ -49,10 +49,12 @@ type AccountRepository interface { ...@@ -49,10 +49,12 @@ type AccountRepository interface {
ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error
SetOverloaded(ctx context.Context, id int64, until time.Time) error SetOverloaded(ctx context.Context, id int64, until time.Time) error
SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error
ClearTempUnschedulable(ctx context.Context, id int64) error ClearTempUnschedulable(ctx context.Context, id int64) error
ClearRateLimit(ctx context.Context, id int64) error ClearRateLimit(ctx context.Context, id int64) error
ClearAntigravityQuotaScopes(ctx context.Context, id int64) error
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error) BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error)
...@@ -66,6 +68,7 @@ type AccountBulkUpdate struct { ...@@ -66,6 +68,7 @@ type AccountBulkUpdate struct {
Concurrency *int Concurrency *int
Priority *int Priority *int
Status *string Status *string
Schedulable *bool
Credentials map[string]any Credentials map[string]any
Extra map[string]any Extra map[string]any
} }
......
...@@ -139,6 +139,10 @@ func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt ...@@ -139,6 +139,10 @@ func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt
panic("unexpected SetRateLimited call") panic("unexpected SetRateLimited call")
} }
func (s *accountRepoStub) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
panic("unexpected SetAntigravityQuotaScopeLimit call")
}
func (s *accountRepoStub) SetOverloaded(ctx context.Context, id int64, until time.Time) error { func (s *accountRepoStub) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
panic("unexpected SetOverloaded call") panic("unexpected SetOverloaded call")
} }
...@@ -155,6 +159,10 @@ func (s *accountRepoStub) ClearRateLimit(ctx context.Context, id int64) error { ...@@ -155,6 +159,10 @@ func (s *accountRepoStub) ClearRateLimit(ctx context.Context, id int64) error {
panic("unexpected ClearRateLimit call") panic("unexpected ClearRateLimit call")
} }
func (s *accountRepoStub) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
panic("unexpected ClearAntigravityQuotaScopes call")
}
func (s *accountRepoStub) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error { func (s *accountRepoStub) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
panic("unexpected UpdateSessionWindow call") panic("unexpected UpdateSessionWindow call")
} }
......
...@@ -24,7 +24,7 @@ type AdminService interface { ...@@ -24,7 +24,7 @@ type AdminService interface {
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
// Group management // Group management
ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error)
GetAllGroups(ctx context.Context) ([]Group, error) GetAllGroups(ctx context.Context) ([]Group, error)
GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error)
GetGroup(ctx context.Context, id int64) (*Group, error) GetGroup(ctx context.Context, id int64) (*Group, error)
...@@ -168,6 +168,7 @@ type BulkUpdateAccountsInput struct { ...@@ -168,6 +168,7 @@ type BulkUpdateAccountsInput struct {
Concurrency *int Concurrency *int
Priority *int Priority *int
Status string Status string
Schedulable *bool
GroupIDs *[]int64 GroupIDs *[]int64
Credentials map[string]any Credentials map[string]any
Extra map[string]any Extra map[string]any
...@@ -478,9 +479,9 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, ...@@ -478,9 +479,9 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64,
} }
// Group management implementations // Group management implementations
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error) { func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, isExclusive) groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, search, isExclusive)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
...@@ -575,18 +576,33 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro ...@@ -575,18 +576,33 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro
return fmt.Errorf("cannot set self as fallback group") return fmt.Errorf("cannot set self as fallback group")
} }
// 检查降级分组是否存在 visited := map[int64]struct{}{}
fallbackGroup, err := s.groupRepo.GetByID(ctx, fallbackGroupID) nextID := fallbackGroupID
if err != nil { for {
return fmt.Errorf("fallback group not found: %w", err) if _, seen := visited[nextID]; seen {
} return fmt.Errorf("fallback group cycle detected")
}
visited[nextID] = struct{}{}
if currentGroupID > 0 && nextID == currentGroupID {
return fmt.Errorf("fallback group cycle detected")
}
// 降级分组不能启用 claude_code_only,否则会造成死循环 // 检查降级分组是否存在
if fallbackGroup.ClaudeCodeOnly { fallbackGroup, err := s.groupRepo.GetByIDLite(ctx, nextID)
return fmt.Errorf("fallback group cannot have claude_code_only enabled") if err != nil {
} return fmt.Errorf("fallback group not found: %w", err)
}
return nil // 降级分组不能启用 claude_code_only,否则会造成死循环
if nextID == fallbackGroupID && fallbackGroup.ClaudeCodeOnly {
return fmt.Errorf("fallback group cannot have claude_code_only enabled")
}
if fallbackGroup.FallbackGroupID == nil {
return nil
}
nextID = *fallbackGroup.FallbackGroupID
}
} }
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) { func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
...@@ -910,6 +926,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp ...@@ -910,6 +926,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
if input.Status != "" { if input.Status != "" {
repoUpdates.Status = &input.Status repoUpdates.Status = &input.Status
} }
if input.Schedulable != nil {
repoUpdates.Schedulable = input.Schedulable
}
// Run bulk update for column/jsonb fields first. // Run bulk update for column/jsonb fields first.
if _, err := s.accountRepo.BulkUpdate(ctx, input.AccountIDs, repoUpdates); err != nil { if _, err := s.accountRepo.BulkUpdate(ctx, input.AccountIDs, repoUpdates); err != nil {
......
...@@ -107,6 +107,10 @@ func (s *groupRepoStub) GetByID(ctx context.Context, id int64) (*Group, error) { ...@@ -107,6 +107,10 @@ func (s *groupRepoStub) GetByID(ctx context.Context, id int64) (*Group, error) {
panic("unexpected GetByID call") panic("unexpected GetByID call")
} }
func (s *groupRepoStub) GetByIDLite(ctx context.Context, id int64) (*Group, error) {
panic("unexpected GetByIDLite call")
}
func (s *groupRepoStub) Update(ctx context.Context, group *Group) error { func (s *groupRepoStub) Update(ctx context.Context, group *Group) error {
panic("unexpected Update call") panic("unexpected Update call")
} }
...@@ -124,7 +128,7 @@ func (s *groupRepoStub) List(ctx context.Context, params pagination.PaginationPa ...@@ -124,7 +128,7 @@ func (s *groupRepoStub) List(ctx context.Context, params pagination.PaginationPa
panic("unexpected List call") panic("unexpected List call")
} }
func (s *groupRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) { func (s *groupRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call") panic("unexpected ListWithFilters call")
} }
......
...@@ -16,6 +16,16 @@ type groupRepoStubForAdmin struct { ...@@ -16,6 +16,16 @@ type groupRepoStubForAdmin struct {
updated *Group // 记录 Update 调用的参数 updated *Group // 记录 Update 调用的参数
getByID *Group // GetByID 返回值 getByID *Group // GetByID 返回值
getErr error // GetByID 返回的错误 getErr error // GetByID 返回的错误
listWithFiltersCalls int
listWithFiltersParams pagination.PaginationParams
listWithFiltersPlatform string
listWithFiltersStatus string
listWithFiltersSearch string
listWithFiltersIsExclusive *bool
listWithFiltersGroups []Group
listWithFiltersResult *pagination.PaginationResult
listWithFiltersErr error
} }
func (s *groupRepoStubForAdmin) Create(_ context.Context, g *Group) error { func (s *groupRepoStubForAdmin) Create(_ context.Context, g *Group) error {
...@@ -35,6 +45,13 @@ func (s *groupRepoStubForAdmin) GetByID(_ context.Context, _ int64) (*Group, err ...@@ -35,6 +45,13 @@ func (s *groupRepoStubForAdmin) GetByID(_ context.Context, _ int64) (*Group, err
return s.getByID, nil return s.getByID, nil
} }
func (s *groupRepoStubForAdmin) GetByIDLite(_ context.Context, _ int64) (*Group, error) {
if s.getErr != nil {
return nil, s.getErr
}
return s.getByID, nil
}
func (s *groupRepoStubForAdmin) Delete(_ context.Context, _ int64) error { func (s *groupRepoStubForAdmin) Delete(_ context.Context, _ int64) error {
panic("unexpected Delete call") panic("unexpected Delete call")
} }
...@@ -47,8 +64,28 @@ func (s *groupRepoStubForAdmin) List(_ context.Context, _ pagination.PaginationP ...@@ -47,8 +64,28 @@ func (s *groupRepoStubForAdmin) List(_ context.Context, _ pagination.PaginationP
panic("unexpected List call") panic("unexpected List call")
} }
func (s *groupRepoStubForAdmin) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) { func (s *groupRepoStubForAdmin) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call") s.listWithFiltersCalls++
s.listWithFiltersParams = params
s.listWithFiltersPlatform = platform
s.listWithFiltersStatus = status
s.listWithFiltersSearch = search
s.listWithFiltersIsExclusive = isExclusive
if s.listWithFiltersErr != nil {
return nil, nil, s.listWithFiltersErr
}
result := s.listWithFiltersResult
if result == nil {
result = &pagination.PaginationResult{
Total: int64(len(s.listWithFiltersGroups)),
Page: params.Page,
PageSize: params.PageSize,
}
}
return s.listWithFiltersGroups, result, nil
} }
func (s *groupRepoStubForAdmin) ListActive(_ context.Context) ([]Group, error) { func (s *groupRepoStubForAdmin) ListActive(_ context.Context) ([]Group, error) {
...@@ -195,3 +232,149 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) { ...@@ -195,3 +232,149 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
require.InDelta(t, 0.15, *repo.updated.ImagePrice2K, 0.0001) // 原值保持 require.InDelta(t, 0.15, *repo.updated.ImagePrice2K, 0.0001) // 原值保持
require.Nil(t, repo.updated.ImagePrice4K) require.Nil(t, repo.updated.ImagePrice4K)
} }
func TestAdminService_ListGroups_WithSearch(t *testing.T) {
// 测试:
// 1. search 参数正常传递到 repository 层
// 2. search 为空字符串时的行为
// 3. search 与其他过滤条件组合使用
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
repo := &groupRepoStubForAdmin{
listWithFiltersGroups: []Group{{ID: 1, Name: "alpha"}},
listWithFiltersResult: &pagination.PaginationResult{Total: 1},
}
svc := &adminServiceImpl{groupRepo: repo}
groups, total, err := svc.ListGroups(context.Background(), 1, 20, "", "", "alpha", nil)
require.NoError(t, err)
require.Equal(t, int64(1), total)
require.Equal(t, []Group{{ID: 1, Name: "alpha"}}, groups)
require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams)
require.Equal(t, "alpha", repo.listWithFiltersSearch)
require.Nil(t, repo.listWithFiltersIsExclusive)
})
t.Run("search 为空字符串时传递空字符串", func(t *testing.T) {
repo := &groupRepoStubForAdmin{
listWithFiltersGroups: []Group{},
listWithFiltersResult: &pagination.PaginationResult{Total: 0},
}
svc := &adminServiceImpl{groupRepo: repo}
groups, total, err := svc.ListGroups(context.Background(), 2, 10, "", "", "", nil)
require.NoError(t, err)
require.Empty(t, groups)
require.Equal(t, int64(0), total)
require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10}, repo.listWithFiltersParams)
require.Equal(t, "", repo.listWithFiltersSearch)
require.Nil(t, repo.listWithFiltersIsExclusive)
})
t.Run("search 与其他过滤条件组合使用", func(t *testing.T) {
isExclusive := true
repo := &groupRepoStubForAdmin{
listWithFiltersGroups: []Group{{ID: 2, Name: "beta"}},
listWithFiltersResult: &pagination.PaginationResult{Total: 42},
}
svc := &adminServiceImpl{groupRepo: repo}
groups, total, err := svc.ListGroups(context.Background(), 3, 50, PlatformAntigravity, StatusActive, "beta", &isExclusive)
require.NoError(t, err)
require.Equal(t, int64(42), total)
require.Equal(t, []Group{{ID: 2, Name: "beta"}}, groups)
require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50}, repo.listWithFiltersParams)
require.Equal(t, PlatformAntigravity, repo.listWithFiltersPlatform)
require.Equal(t, StatusActive, repo.listWithFiltersStatus)
require.Equal(t, "beta", repo.listWithFiltersSearch)
require.NotNil(t, repo.listWithFiltersIsExclusive)
require.True(t, *repo.listWithFiltersIsExclusive)
})
}
func TestAdminService_ValidateFallbackGroup_DetectsCycle(t *testing.T) {
groupID := int64(1)
fallbackID := int64(2)
repo := &groupRepoStubForFallbackCycle{
groups: map[int64]*Group{
groupID: {
ID: groupID,
FallbackGroupID: &fallbackID,
},
fallbackID: {
ID: fallbackID,
FallbackGroupID: &groupID,
},
},
}
svc := &adminServiceImpl{groupRepo: repo}
err := svc.validateFallbackGroup(context.Background(), groupID, fallbackID)
require.Error(t, err)
require.Contains(t, err.Error(), "fallback group cycle")
}
type groupRepoStubForFallbackCycle struct {
groups map[int64]*Group
}
func (s *groupRepoStubForFallbackCycle) Create(_ context.Context, _ *Group) error {
panic("unexpected Create call")
}
func (s *groupRepoStubForFallbackCycle) Update(_ context.Context, _ *Group) error {
panic("unexpected Update call")
}
func (s *groupRepoStubForFallbackCycle) GetByID(ctx context.Context, id int64) (*Group, error) {
return s.GetByIDLite(ctx, id)
}
func (s *groupRepoStubForFallbackCycle) GetByIDLite(_ context.Context, id int64) (*Group, error) {
if g, ok := s.groups[id]; ok {
return g, nil
}
return nil, ErrGroupNotFound
}
func (s *groupRepoStubForFallbackCycle) Delete(_ context.Context, _ int64) error {
panic("unexpected Delete call")
}
func (s *groupRepoStubForFallbackCycle) DeleteCascade(_ context.Context, _ int64) ([]int64, error) {
panic("unexpected DeleteCascade call")
}
func (s *groupRepoStubForFallbackCycle) List(_ context.Context, _ pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
func (s *groupRepoStubForFallbackCycle) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}
func (s *groupRepoStubForFallbackCycle) ListActive(_ context.Context) ([]Group, error) {
panic("unexpected ListActive call")
}
func (s *groupRepoStubForFallbackCycle) ListActiveByPlatform(_ context.Context, _ string) ([]Group, error) {
panic("unexpected ListActiveByPlatform call")
}
func (s *groupRepoStubForFallbackCycle) ExistsByName(_ context.Context, _ string) (bool, error) {
panic("unexpected ExistsByName call")
}
func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, error) {
panic("unexpected GetAccountCount call")
}
func (s *groupRepoStubForFallbackCycle) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
panic("unexpected DeleteAccountGroupsByGroupID call")
}
//go:build unit
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type accountRepoStubForAdminList struct {
accountRepoStub
listWithFiltersCalls int
listWithFiltersParams pagination.PaginationParams
listWithFiltersPlatform string
listWithFiltersType string
listWithFiltersStatus string
listWithFiltersSearch string
listWithFiltersAccounts []Account
listWithFiltersResult *pagination.PaginationResult
listWithFiltersErr error
}
func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
s.listWithFiltersCalls++
s.listWithFiltersParams = params
s.listWithFiltersPlatform = platform
s.listWithFiltersType = accountType
s.listWithFiltersStatus = status
s.listWithFiltersSearch = search
if s.listWithFiltersErr != nil {
return nil, nil, s.listWithFiltersErr
}
result := s.listWithFiltersResult
if result == nil {
result = &pagination.PaginationResult{
Total: int64(len(s.listWithFiltersAccounts)),
Page: params.Page,
PageSize: params.PageSize,
}
}
return s.listWithFiltersAccounts, result, nil
}
type proxyRepoStubForAdminList struct {
proxyRepoStub
listWithFiltersCalls int
listWithFiltersParams pagination.PaginationParams
listWithFiltersProtocol string
listWithFiltersStatus string
listWithFiltersSearch string
listWithFiltersProxies []Proxy
listWithFiltersResult *pagination.PaginationResult
listWithFiltersErr error
listWithFiltersAndAccountCountCalls int
listWithFiltersAndAccountCountParams pagination.PaginationParams
listWithFiltersAndAccountCountProtocol string
listWithFiltersAndAccountCountStatus string
listWithFiltersAndAccountCountSearch string
listWithFiltersAndAccountCountProxies []ProxyWithAccountCount
listWithFiltersAndAccountCountResult *pagination.PaginationResult
listWithFiltersAndAccountCountErr error
}
func (s *proxyRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error) {
s.listWithFiltersCalls++
s.listWithFiltersParams = params
s.listWithFiltersProtocol = protocol
s.listWithFiltersStatus = status
s.listWithFiltersSearch = search
if s.listWithFiltersErr != nil {
return nil, nil, s.listWithFiltersErr
}
result := s.listWithFiltersResult
if result == nil {
result = &pagination.PaginationResult{
Total: int64(len(s.listWithFiltersProxies)),
Page: params.Page,
PageSize: params.PageSize,
}
}
return s.listWithFiltersProxies, result, nil
}
func (s *proxyRepoStubForAdminList) ListWithFiltersAndAccountCount(_ context.Context, params pagination.PaginationParams, protocol, status, search string) ([]ProxyWithAccountCount, *pagination.PaginationResult, error) {
s.listWithFiltersAndAccountCountCalls++
s.listWithFiltersAndAccountCountParams = params
s.listWithFiltersAndAccountCountProtocol = protocol
s.listWithFiltersAndAccountCountStatus = status
s.listWithFiltersAndAccountCountSearch = search
if s.listWithFiltersAndAccountCountErr != nil {
return nil, nil, s.listWithFiltersAndAccountCountErr
}
result := s.listWithFiltersAndAccountCountResult
if result == nil {
result = &pagination.PaginationResult{
Total: int64(len(s.listWithFiltersAndAccountCountProxies)),
Page: params.Page,
PageSize: params.PageSize,
}
}
return s.listWithFiltersAndAccountCountProxies, result, nil
}
type redeemRepoStubForAdminList struct {
redeemRepoStub
listWithFiltersCalls int
listWithFiltersParams pagination.PaginationParams
listWithFiltersType string
listWithFiltersStatus string
listWithFiltersSearch string
listWithFiltersCodes []RedeemCode
listWithFiltersResult *pagination.PaginationResult
listWithFiltersErr error
}
func (s *redeemRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, codeType, status, search string) ([]RedeemCode, *pagination.PaginationResult, error) {
s.listWithFiltersCalls++
s.listWithFiltersParams = params
s.listWithFiltersType = codeType
s.listWithFiltersStatus = status
s.listWithFiltersSearch = search
if s.listWithFiltersErr != nil {
return nil, nil, s.listWithFiltersErr
}
result := s.listWithFiltersResult
if result == nil {
result = &pagination.PaginationResult{
Total: int64(len(s.listWithFiltersCodes)),
Page: params.Page,
PageSize: params.PageSize,
}
}
return s.listWithFiltersCodes, result, nil
}
func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
repo := &accountRepoStubForAdminList{
listWithFiltersAccounts: []Account{{ID: 1, Name: "acc"}},
listWithFiltersResult: &pagination.PaginationResult{Total: 10},
}
svc := &adminServiceImpl{accountRepo: repo}
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc")
require.NoError(t, err)
require.Equal(t, int64(10), total)
require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts)
require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams)
require.Equal(t, PlatformGemini, repo.listWithFiltersPlatform)
require.Equal(t, AccountTypeOAuth, repo.listWithFiltersType)
require.Equal(t, StatusActive, repo.listWithFiltersStatus)
require.Equal(t, "acc", repo.listWithFiltersSearch)
})
}
func TestAdminService_ListProxies_WithSearch(t *testing.T) {
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
repo := &proxyRepoStubForAdminList{
listWithFiltersProxies: []Proxy{{ID: 2, Name: "p1"}},
listWithFiltersResult: &pagination.PaginationResult{Total: 7},
}
svc := &adminServiceImpl{proxyRepo: repo}
proxies, total, err := svc.ListProxies(context.Background(), 3, 50, "http", StatusActive, "p1")
require.NoError(t, err)
require.Equal(t, int64(7), total)
require.Equal(t, []Proxy{{ID: 2, Name: "p1"}}, proxies)
require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50}, repo.listWithFiltersParams)
require.Equal(t, "http", repo.listWithFiltersProtocol)
require.Equal(t, StatusActive, repo.listWithFiltersStatus)
require.Equal(t, "p1", repo.listWithFiltersSearch)
})
}
func TestAdminService_ListProxiesWithAccountCount_WithSearch(t *testing.T) {
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
repo := &proxyRepoStubForAdminList{
listWithFiltersAndAccountCountProxies: []ProxyWithAccountCount{{Proxy: Proxy{ID: 3, Name: "p2"}, AccountCount: 5}},
listWithFiltersAndAccountCountResult: &pagination.PaginationResult{Total: 9},
}
svc := &adminServiceImpl{proxyRepo: repo}
proxies, total, err := svc.ListProxiesWithAccountCount(context.Background(), 2, 10, "socks5", StatusDisabled, "p2")
require.NoError(t, err)
require.Equal(t, int64(9), total)
require.Equal(t, []ProxyWithAccountCount{{Proxy: Proxy{ID: 3, Name: "p2"}, AccountCount: 5}}, proxies)
require.Equal(t, 1, repo.listWithFiltersAndAccountCountCalls)
require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10}, repo.listWithFiltersAndAccountCountParams)
require.Equal(t, "socks5", repo.listWithFiltersAndAccountCountProtocol)
require.Equal(t, StatusDisabled, repo.listWithFiltersAndAccountCountStatus)
require.Equal(t, "p2", repo.listWithFiltersAndAccountCountSearch)
})
}
func TestAdminService_ListRedeemCodes_WithSearch(t *testing.T) {
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
repo := &redeemRepoStubForAdminList{
listWithFiltersCodes: []RedeemCode{{ID: 4, Code: "ABC"}},
listWithFiltersResult: &pagination.PaginationResult{Total: 3},
}
svc := &adminServiceImpl{redeemCodeRepo: repo}
codes, total, err := svc.ListRedeemCodes(context.Background(), 1, 20, RedeemTypeBalance, StatusUnused, "ABC")
require.NoError(t, err)
require.Equal(t, int64(3), total)
require.Equal(t, []RedeemCode{{ID: 4, Code: "ABC"}}, codes)
require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams)
require.Equal(t, RedeemTypeBalance, repo.listWithFiltersType)
require.Equal(t, StatusUnused, repo.listWithFiltersStatus)
require.Equal(t, "ABC", repo.listWithFiltersSearch)
})
}
...@@ -93,6 +93,7 @@ var antigravityPrefixMapping = []struct { ...@@ -93,6 +93,7 @@ var antigravityPrefixMapping = []struct {
// 长前缀优先 // 长前缀优先
{"gemini-2.5-flash-image", "gemini-3-pro-image"}, // gemini-2.5-flash-image → 3-pro-image {"gemini-2.5-flash-image", "gemini-3-pro-image"}, // gemini-2.5-flash-image → 3-pro-image
{"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等 {"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等
{"gemini-3-flash", "gemini-3-flash"}, // gemini-3-flash-preview 等 → gemini-3-flash
{"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx {"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx
{"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx {"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx
{"claude-haiku-4-5", "claude-sonnet-4-5"}, // claude-haiku-4-5-xxx → sonnet {"claude-haiku-4-5", "claude-sonnet-4-5"}, // claude-haiku-4-5-xxx → sonnet
...@@ -502,6 +503,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -502,6 +503,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
originalModel := claudeReq.Model originalModel := claudeReq.Model
mappedModel := s.getMappedModel(account, claudeReq.Model) mappedModel := s.getMappedModel(account, claudeReq.Model)
quotaScope, _ := resolveAntigravityQuotaScope(originalModel)
// 获取 access_token // 获取 access_token
if s.tokenProvider == nil { if s.tokenProvider == nil {
...@@ -603,7 +605,7 @@ urlFallbackLoop: ...@@ -603,7 +605,7 @@ urlFallbackLoop:
} }
// 所有重试都失败,标记限流状态 // 所有重试都失败,标记限流状态
if resp.StatusCode == 429 { if resp.StatusCode == 429 {
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody) s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
} }
// 最后一次尝试也失败 // 最后一次尝试也失败
resp = &http.Response{ resp = &http.Response{
...@@ -696,7 +698,7 @@ urlFallbackLoop: ...@@ -696,7 +698,7 @@ urlFallbackLoop:
// 处理错误响应(重试后仍失败或不触发重试) // 处理错误响应(重试后仍失败或不触发重试)
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody) s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
if s.shouldFailoverUpstreamError(resp.StatusCode) { if s.shouldFailoverUpstreamError(resp.StatusCode) {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
...@@ -1021,6 +1023,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co ...@@ -1021,6 +1023,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
if len(body) == 0 { if len(body) == 0 {
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty") return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty")
} }
quotaScope, _ := resolveAntigravityQuotaScope(originalModel)
// 解析请求以获取 image_size(用于图片计费) // 解析请求以获取 image_size(用于图片计费)
imageSize := s.extractImageSize(body) imageSize := s.extractImageSize(body)
...@@ -1146,7 +1149,7 @@ urlFallbackLoop: ...@@ -1146,7 +1149,7 @@ urlFallbackLoop:
} }
// 所有重试都失败,标记限流状态 // 所有重试都失败,标记限流状态
if resp.StatusCode == 429 { if resp.StatusCode == 429 {
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody) s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
} }
resp = &http.Response{ resp = &http.Response{
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
...@@ -1200,7 +1203,7 @@ urlFallbackLoop: ...@@ -1200,7 +1203,7 @@ urlFallbackLoop:
goto handleSuccess goto handleSuccess
} }
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody) s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
if s.shouldFailoverUpstreamError(resp.StatusCode) { if s.shouldFailoverUpstreamError(resp.StatusCode) {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
...@@ -1314,7 +1317,7 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool { ...@@ -1314,7 +1317,7 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
} }
} }
func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte) { func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) {
// 429 使用 Gemini 格式解析(从 body 解析重置时间) // 429 使用 Gemini 格式解析(从 body 解析重置时间)
if statusCode == 429 { if statusCode == 429 {
resetAt := ParseGeminiRateLimitResetTime(body) resetAt := ParseGeminiRateLimitResetTime(body)
...@@ -1325,13 +1328,23 @@ func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, pre ...@@ -1325,13 +1328,23 @@ func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, pre
defaultDur = 5 * time.Minute defaultDur = 5 * time.Minute
} }
ra := time.Now().Add(defaultDur) ra := time.Now().Add(defaultDur)
log.Printf("%s status=429 rate_limited reset_in=%v (fallback)", prefix, defaultDur) log.Printf("%s status=429 rate_limited scope=%s reset_in=%v (fallback)", prefix, quotaScope, defaultDur)
_ = s.accountRepo.SetRateLimited(ctx, account.ID, ra) if quotaScope == "" {
return
}
if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, ra); err != nil {
log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err)
}
return return
} }
resetTime := time.Unix(*resetAt, 0) resetTime := time.Unix(*resetAt, 0)
log.Printf("%s status=429 rate_limited reset_at=%v reset_in=%v", prefix, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second)) log.Printf("%s status=429 rate_limited scope=%s reset_at=%v reset_in=%v", prefix, quotaScope, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second))
_ = s.accountRepo.SetRateLimited(ctx, account.ID, resetTime) if quotaScope == "" {
return
}
if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, resetTime); err != nil {
log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err)
}
return return
} }
// 其他错误码继续使用 rateLimitService // 其他错误码继续使用 rateLimitService
......
package service
import (
"strings"
"time"
)
const antigravityQuotaScopesKey = "antigravity_quota_scopes"
// AntigravityQuotaScope 表示 Antigravity 的配额域
type AntigravityQuotaScope string
const (
AntigravityQuotaScopeClaude AntigravityQuotaScope = "claude"
AntigravityQuotaScopeGeminiText AntigravityQuotaScope = "gemini_text"
AntigravityQuotaScopeGeminiImage AntigravityQuotaScope = "gemini_image"
)
// resolveAntigravityQuotaScope 根据模型名称解析配额域
func resolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) {
model := normalizeAntigravityModelName(requestedModel)
if model == "" {
return "", false
}
switch {
case strings.HasPrefix(model, "claude-"):
return AntigravityQuotaScopeClaude, true
case strings.HasPrefix(model, "gemini-"):
if isImageGenerationModel(model) {
return AntigravityQuotaScopeGeminiImage, true
}
return AntigravityQuotaScopeGeminiText, true
default:
return "", false
}
}
func normalizeAntigravityModelName(model string) string {
normalized := strings.ToLower(strings.TrimSpace(model))
normalized = strings.TrimPrefix(normalized, "models/")
return normalized
}
// IsSchedulableForModel 结合 Antigravity 配额域限流判断是否可调度
func (a *Account) IsSchedulableForModel(requestedModel string) bool {
if a == nil {
return false
}
if !a.IsSchedulable() {
return false
}
if a.Platform != PlatformAntigravity {
return true
}
scope, ok := resolveAntigravityQuotaScope(requestedModel)
if !ok {
return true
}
resetAt := a.antigravityQuotaScopeResetAt(scope)
if resetAt == nil {
return true
}
now := time.Now()
return !now.Before(*resetAt)
}
func (a *Account) antigravityQuotaScopeResetAt(scope AntigravityQuotaScope) *time.Time {
if a == nil || a.Extra == nil || scope == "" {
return nil
}
rawScopes, ok := a.Extra[antigravityQuotaScopesKey].(map[string]any)
if !ok {
return nil
}
rawScope, ok := rawScopes[string(scope)].(map[string]any)
if !ok {
return nil
}
resetAtRaw, ok := rawScope["rate_limit_reset_at"].(string)
if !ok || strings.TrimSpace(resetAtRaw) == "" {
return nil
}
resetAt, err := time.Parse(time.RFC3339, resetAtRaw)
if err != nil {
return nil
}
return &resetAt
}
...@@ -3,16 +3,18 @@ package service ...@@ -3,16 +3,18 @@ package service
import "time" import "time"
type APIKey struct { type APIKey struct {
ID int64 ID int64
UserID int64 UserID int64
Key string Key string
Name string Name string
GroupID *int64 GroupID *int64
Status string Status string
CreatedAt time.Time IPWhitelist []string
UpdatedAt time.Time IPBlacklist []string
User *User CreatedAt time.Time
Group *Group UpdatedAt time.Time
User *User
Group *Group
} }
func (k *APIKey) IsActive() bool { func (k *APIKey) IsActive() bool {
......
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
) )
...@@ -20,6 +21,7 @@ var ( ...@@ -20,6 +21,7 @@ var (
ErrAPIKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters") ErrAPIKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens") ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later") ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
ErrInvalidIPPattern = infraerrors.BadRequest("INVALID_IP_PATTERN", "invalid IP or CIDR pattern")
) )
const ( const (
...@@ -57,16 +59,20 @@ type APIKeyCache interface { ...@@ -57,16 +59,20 @@ type APIKeyCache interface {
// CreateAPIKeyRequest 创建API Key请求 // CreateAPIKeyRequest 创建API Key请求
type CreateAPIKeyRequest struct { type CreateAPIKeyRequest struct {
Name string `json:"name"` Name string `json:"name"`
GroupID *int64 `json:"group_id"` GroupID *int64 `json:"group_id"`
CustomKey *string `json:"custom_key"` // 可选的自定义key CustomKey *string `json:"custom_key"` // 可选的自定义key
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
} }
// UpdateAPIKeyRequest 更新API Key请求 // UpdateAPIKeyRequest 更新API Key请求
type UpdateAPIKeyRequest struct { type UpdateAPIKeyRequest struct {
Name *string `json:"name"` Name *string `json:"name"`
GroupID *int64 `json:"group_id"` GroupID *int64 `json:"group_id"`
Status *string `json:"status"` Status *string `json:"status"`
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单(空数组清空)
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单(空数组清空)
} }
// APIKeyService API Key服务 // APIKeyService API Key服务
...@@ -186,6 +192,20 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK ...@@ -186,6 +192,20 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
return nil, fmt.Errorf("get user: %w", err) return nil, fmt.Errorf("get user: %w", err)
} }
// 验证 IP 白名单格式
if len(req.IPWhitelist) > 0 {
if invalid := ip.ValidateIPPatterns(req.IPWhitelist); len(invalid) > 0 {
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
}
}
// 验证 IP 黑名单格式
if len(req.IPBlacklist) > 0 {
if invalid := ip.ValidateIPPatterns(req.IPBlacklist); len(invalid) > 0 {
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
}
}
// 验证分组权限(如果指定了分组) // 验证分组权限(如果指定了分组)
if req.GroupID != nil { if req.GroupID != nil {
group, err := s.groupRepo.GetByID(ctx, *req.GroupID) group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
...@@ -236,11 +256,13 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK ...@@ -236,11 +256,13 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
// 创建API Key记录 // 创建API Key记录
apiKey := &APIKey{ apiKey := &APIKey{
UserID: userID, UserID: userID,
Key: key, Key: key,
Name: req.Name, Name: req.Name,
GroupID: req.GroupID, GroupID: req.GroupID,
Status: StatusActive, Status: StatusActive,
IPWhitelist: req.IPWhitelist,
IPBlacklist: req.IPBlacklist,
} }
if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil { if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil {
...@@ -312,6 +334,20 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req ...@@ -312,6 +334,20 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
return nil, ErrInsufficientPerms return nil, ErrInsufficientPerms
} }
// 验证 IP 白名单格式
if len(req.IPWhitelist) > 0 {
if invalid := ip.ValidateIPPatterns(req.IPWhitelist); len(invalid) > 0 {
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
}
}
// 验证 IP 黑名单格式
if len(req.IPBlacklist) > 0 {
if invalid := ip.ValidateIPPatterns(req.IPBlacklist); len(invalid) > 0 {
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
}
}
// 更新字段 // 更新字段
if req.Name != nil { if req.Name != nil {
apiKey.Name = *req.Name apiKey.Name = *req.Name
...@@ -344,6 +380,10 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req ...@@ -344,6 +380,10 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
} }
} }
// 更新 IP 限制(空数组会清空设置)
apiKey.IPWhitelist = req.IPWhitelist
apiKey.IPBlacklist = req.IPBlacklist
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil { if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
return nil, fmt.Errorf("update api key: %w", err) return nil, fmt.Errorf("update api key: %w", err)
} }
......
...@@ -2,9 +2,13 @@ package service ...@@ -2,9 +2,13 @@ package service
import ( import (
"context" "context"
"crypto/rand"
"encoding/hex"
"errors" "errors"
"fmt" "fmt"
"log" "log"
"net/mail"
"strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
...@@ -18,6 +22,7 @@ var ( ...@@ -18,6 +22,7 @@ var (
ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password") ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password")
ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active") ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active")
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists") ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved")
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token") ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired") ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large") ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
...@@ -47,6 +52,7 @@ type AuthService struct { ...@@ -47,6 +52,7 @@ type AuthService struct {
emailService *EmailService emailService *EmailService
turnstileService *TurnstileService turnstileService *TurnstileService
emailQueueService *EmailQueueService emailQueueService *EmailQueueService
promoService *PromoService
} }
// NewAuthService 创建认证服务实例 // NewAuthService 创建认证服务实例
...@@ -57,6 +63,7 @@ func NewAuthService( ...@@ -57,6 +63,7 @@ func NewAuthService(
emailService *EmailService, emailService *EmailService,
turnstileService *TurnstileService, turnstileService *TurnstileService,
emailQueueService *EmailQueueService, emailQueueService *EmailQueueService,
promoService *PromoService,
) *AuthService { ) *AuthService {
return &AuthService{ return &AuthService{
userRepo: userRepo, userRepo: userRepo,
...@@ -65,21 +72,27 @@ func NewAuthService( ...@@ -65,21 +72,27 @@ func NewAuthService(
emailService: emailService, emailService: emailService,
turnstileService: turnstileService, turnstileService: turnstileService,
emailQueueService: emailQueueService, emailQueueService: emailQueueService,
promoService: promoService,
} }
} }
// Register 用户注册,返回token和用户 // Register 用户注册,返回token和用户
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) { func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
return s.RegisterWithVerification(ctx, email, password, "") return s.RegisterWithVerification(ctx, email, password, "", "")
} }
// RegisterWithVerification 用户注册(支持邮件验证),返回token和用户 // RegisterWithVerification 用户注册(支持邮件验证和优惠码),返回token和用户
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode string) (string, *User, error) { func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode string) (string, *User, error) {
// 检查是否开放注册 // 检查是否开放注册(默认关闭:settingService 未配置时不允许注册)
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) { if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
return "", nil, ErrRegDisabled return "", nil, ErrRegDisabled
} }
// 防止用户注册 LinuxDo OAuth 合成邮箱,避免第三方登录与本地账号发生碰撞。
if isReservedEmail(email) {
return "", nil, ErrEmailReserved
}
// 检查是否需要邮件验证 // 检查是否需要邮件验证
if s.settingService != nil && s.settingService.IsEmailVerifyEnabled(ctx) { if s.settingService != nil && s.settingService.IsEmailVerifyEnabled(ctx) {
// 如果邮件验证已开启但邮件服务未配置,拒绝注册 // 如果邮件验证已开启但邮件服务未配置,拒绝注册
...@@ -132,10 +145,27 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw ...@@ -132,10 +145,27 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
} }
if err := s.userRepo.Create(ctx, user); err != nil { if err := s.userRepo.Create(ctx, user); err != nil {
// 优先检查邮箱冲突错误(竞态条件下可能发生)
if errors.Is(err, ErrEmailExists) {
return "", nil, ErrEmailExists
}
log.Printf("[Auth] Database error creating user: %v", err) log.Printf("[Auth] Database error creating user: %v", err)
return "", nil, ErrServiceUnavailable return "", nil, ErrServiceUnavailable
} }
// 应用优惠码(如果提供)
if promoCode != "" && s.promoService != nil {
if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil {
// 优惠码应用失败不影响注册,只记录日志
log.Printf("[Auth] Failed to apply promo code for user %d: %v", user.ID, err)
} else {
// 重新获取用户信息以获取更新后的余额
if updatedUser, err := s.userRepo.GetByID(ctx, user.ID); err == nil {
user = updatedUser
}
}
}
// 生成token // 生成token
token, err := s.GenerateToken(user) token, err := s.GenerateToken(user)
if err != nil { if err != nil {
...@@ -152,11 +182,15 @@ type SendVerifyCodeResult struct { ...@@ -152,11 +182,15 @@ type SendVerifyCodeResult struct {
// SendVerifyCode 发送邮箱验证码(同步方式) // SendVerifyCode 发送邮箱验证码(同步方式)
func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error { func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
// 检查是否开放注册 // 检查是否开放注册(默认关闭)
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) { if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
return ErrRegDisabled return ErrRegDisabled
} }
if isReservedEmail(email) {
return ErrEmailReserved
}
// 检查邮箱是否已存在 // 检查邮箱是否已存在
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
if err != nil { if err != nil {
...@@ -185,12 +219,16 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error { ...@@ -185,12 +219,16 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*SendVerifyCodeResult, error) { func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*SendVerifyCodeResult, error) {
log.Printf("[Auth] SendVerifyCodeAsync called for email: %s", email) log.Printf("[Auth] SendVerifyCodeAsync called for email: %s", email)
// 检查是否开放注册 // 检查是否开放注册(默认关闭)
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) { if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
log.Println("[Auth] Registration is disabled") log.Println("[Auth] Registration is disabled")
return nil, ErrRegDisabled return nil, ErrRegDisabled
} }
if isReservedEmail(email) {
return nil, ErrEmailReserved
}
// 检查邮箱是否已存在 // 检查邮箱是否已存在
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
if err != nil { if err != nil {
...@@ -270,7 +308,7 @@ func (s *AuthService) IsTurnstileEnabled(ctx context.Context) bool { ...@@ -270,7 +308,7 @@ func (s *AuthService) IsTurnstileEnabled(ctx context.Context) bool {
// IsRegistrationEnabled 检查是否开放注册 // IsRegistrationEnabled 检查是否开放注册
func (s *AuthService) IsRegistrationEnabled(ctx context.Context) bool { func (s *AuthService) IsRegistrationEnabled(ctx context.Context) bool {
if s.settingService == nil { if s.settingService == nil {
return true return false // 安全默认:settingService 未配置时关闭注册
} }
return s.settingService.IsRegistrationEnabled(ctx) return s.settingService.IsRegistrationEnabled(ctx)
} }
...@@ -315,6 +353,102 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string ...@@ -315,6 +353,102 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
return token, user, nil return token, user, nil
} }
// LoginOrRegisterOAuth 用于第三方 OAuth/SSO 登录:
// - 如果邮箱已存在:直接登录(不需要本地密码)
// - 如果邮箱不存在:创建新用户并登录
//
// 注意:该函数用于“终端用户登录 Sub2API 本身”的场景(不同于上游账号的 OAuth,例如 OpenAI/Gemini)。
// 为了满足现有数据库约束(需要密码哈希),新用户会生成随机密码并进行哈希保存。
func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username string) (string, *User, error) {
email = strings.TrimSpace(email)
if email == "" || len(email) > 255 {
return "", nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
}
if _, err := mail.ParseAddress(email); err != nil {
return "", nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
}
username = strings.TrimSpace(username)
if len([]rune(username)) > 100 {
username = string([]rune(username)[:100])
}
user, err := s.userRepo.GetByEmail(ctx, email)
if err != nil {
if errors.Is(err, ErrUserNotFound) {
// OAuth 首次登录视为注册。
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) {
return "", nil, ErrRegDisabled
}
randomPassword, err := randomHexString(32)
if err != nil {
log.Printf("[Auth] Failed to generate random password for oauth signup: %v", err)
return "", nil, ErrServiceUnavailable
}
hashedPassword, err := s.HashPassword(randomPassword)
if err != nil {
return "", nil, fmt.Errorf("hash password: %w", err)
}
// 新用户默认值。
defaultBalance := s.cfg.Default.UserBalance
defaultConcurrency := s.cfg.Default.UserConcurrency
if s.settingService != nil {
defaultBalance = s.settingService.GetDefaultBalance(ctx)
defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
}
newUser := &User{
Email: email,
Username: username,
PasswordHash: hashedPassword,
Role: RoleUser,
Balance: defaultBalance,
Concurrency: defaultConcurrency,
Status: StatusActive,
}
if err := s.userRepo.Create(ctx, newUser); err != nil {
if errors.Is(err, ErrEmailExists) {
// 并发场景:GetByEmail 与 Create 之间用户被创建。
user, err = s.userRepo.GetByEmail(ctx, email)
if err != nil {
log.Printf("[Auth] Database error getting user after conflict: %v", err)
return "", nil, ErrServiceUnavailable
}
} else {
log.Printf("[Auth] Database error creating oauth user: %v", err)
return "", nil, ErrServiceUnavailable
}
} else {
user = newUser
}
} else {
log.Printf("[Auth] Database error during oauth login: %v", err)
return "", nil, ErrServiceUnavailable
}
}
if !user.IsActive() {
return "", nil, ErrUserNotActive
}
// 尽力补全:当用户名为空时,使用第三方返回的用户名回填。
if user.Username == "" && username != "" {
user.Username = username
if err := s.userRepo.Update(ctx, user); err != nil {
log.Printf("[Auth] Failed to update username after oauth login: %v", err)
}
}
token, err := s.GenerateToken(user)
if err != nil {
return "", nil, fmt.Errorf("generate token: %w", err)
}
return token, user, nil
}
// ValidateToken 验证JWT token并返回用户声明 // ValidateToken 验证JWT token并返回用户声明
func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) { func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
// 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。 // 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。
...@@ -357,6 +491,22 @@ func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) { ...@@ -357,6 +491,22 @@ func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
return nil, ErrInvalidToken return nil, ErrInvalidToken
} }
func randomHexString(byteLength int) (string, error) {
if byteLength <= 0 {
byteLength = 16
}
buf := make([]byte, byteLength)
if _, err := rand.Read(buf); err != nil {
return "", err
}
return hex.EncodeToString(buf), nil
}
func isReservedEmail(email string) bool {
normalized := strings.ToLower(strings.TrimSpace(email))
return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain)
}
// GenerateToken 生成JWT token // GenerateToken 生成JWT token
func (s *AuthService) GenerateToken(user *User) (string, error) { func (s *AuthService) GenerateToken(user *User) (string, error) {
now := time.Now() now := time.Now()
......
...@@ -100,6 +100,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E ...@@ -100,6 +100,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
emailService, emailService,
nil, nil,
nil, nil,
nil, // promoService
) )
} }
...@@ -113,6 +114,15 @@ func TestAuthService_Register_Disabled(t *testing.T) { ...@@ -113,6 +114,15 @@ func TestAuthService_Register_Disabled(t *testing.T) {
require.ErrorIs(t, err, ErrRegDisabled) require.ErrorIs(t, err, ErrRegDisabled)
} }
func TestAuthService_Register_DisabledByDefault(t *testing.T) {
// 当 settings 为 nil(设置项不存在)时,注册应该默认关闭
repo := &userRepoStub{}
service := newAuthService(repo, nil, nil)
_, _, err := service.Register(context.Background(), "user@test.com", "password")
require.ErrorIs(t, err, ErrRegDisabled)
}
func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testing.T) { func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testing.T) {
repo := &userRepoStub{} repo := &userRepoStub{}
// 邮件验证开启但 emailCache 为 nil(emailService 未配置) // 邮件验证开启但 emailCache 为 nil(emailService 未配置)
...@@ -122,7 +132,7 @@ func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testi ...@@ -122,7 +132,7 @@ func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testi
}, nil) }, nil)
// 应返回服务不可用错误,而不是允许绕过验证 // 应返回服务不可用错误,而不是允许绕过验证
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code") _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "")
require.ErrorIs(t, err, ErrServiceUnavailable) require.ErrorIs(t, err, ErrServiceUnavailable)
} }
...@@ -134,7 +144,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) { ...@@ -134,7 +144,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
SettingKeyEmailVerifyEnabled: "true", SettingKeyEmailVerifyEnabled: "true",
}, cache) }, cache)
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "") _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "")
require.ErrorIs(t, err, ErrEmailVerifyRequired) require.ErrorIs(t, err, ErrEmailVerifyRequired)
} }
...@@ -148,14 +158,16 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) { ...@@ -148,14 +158,16 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
SettingKeyEmailVerifyEnabled: "true", SettingKeyEmailVerifyEnabled: "true",
}, cache) }, cache)
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong") _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "")
require.ErrorIs(t, err, ErrInvalidVerifyCode) require.ErrorIs(t, err, ErrInvalidVerifyCode)
require.ErrorContains(t, err, "verify code") require.ErrorContains(t, err, "verify code")
} }
func TestAuthService_Register_EmailExists(t *testing.T) { func TestAuthService_Register_EmailExists(t *testing.T) {
repo := &userRepoStub{exists: true} repo := &userRepoStub{exists: true}
service := newAuthService(repo, nil, nil) service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
}, nil)
_, _, err := service.Register(context.Background(), "user@test.com", "password") _, _, err := service.Register(context.Background(), "user@test.com", "password")
require.ErrorIs(t, err, ErrEmailExists) require.ErrorIs(t, err, ErrEmailExists)
...@@ -163,23 +175,50 @@ func TestAuthService_Register_EmailExists(t *testing.T) { ...@@ -163,23 +175,50 @@ func TestAuthService_Register_EmailExists(t *testing.T) {
func TestAuthService_Register_CheckEmailError(t *testing.T) { func TestAuthService_Register_CheckEmailError(t *testing.T) {
repo := &userRepoStub{existsErr: errors.New("db down")} repo := &userRepoStub{existsErr: errors.New("db down")}
service := newAuthService(repo, nil, nil) service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
}, nil)
_, _, err := service.Register(context.Background(), "user@test.com", "password") _, _, err := service.Register(context.Background(), "user@test.com", "password")
require.ErrorIs(t, err, ErrServiceUnavailable) require.ErrorIs(t, err, ErrServiceUnavailable)
} }
func TestAuthService_Register_ReservedEmail(t *testing.T) {
repo := &userRepoStub{}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
}, nil)
_, _, err := service.Register(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "password")
require.ErrorIs(t, err, ErrEmailReserved)
}
func TestAuthService_Register_CreateError(t *testing.T) { func TestAuthService_Register_CreateError(t *testing.T) {
repo := &userRepoStub{createErr: errors.New("create failed")} repo := &userRepoStub{createErr: errors.New("create failed")}
service := newAuthService(repo, nil, nil) service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
}, nil)
_, _, err := service.Register(context.Background(), "user@test.com", "password") _, _, err := service.Register(context.Background(), "user@test.com", "password")
require.ErrorIs(t, err, ErrServiceUnavailable) require.ErrorIs(t, err, ErrServiceUnavailable)
} }
func TestAuthService_Register_CreateEmailExistsRace(t *testing.T) {
// 模拟竞态条件:ExistsByEmail 返回 false,但 Create 时因唯一约束失败
repo := &userRepoStub{createErr: ErrEmailExists}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
}, nil)
_, _, err := service.Register(context.Background(), "user@test.com", "password")
require.ErrorIs(t, err, ErrEmailExists)
}
func TestAuthService_Register_Success(t *testing.T) { func TestAuthService_Register_Success(t *testing.T) {
repo := &userRepoStub{nextID: 5} repo := &userRepoStub{nextID: 5}
service := newAuthService(repo, nil, nil) service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
}, nil)
token, user, err := service.Register(context.Background(), "user@test.com", "password") token, user, err := service.Register(context.Background(), "user@test.com", "password")
require.NoError(t, err) require.NoError(t, err)
......
...@@ -38,6 +38,12 @@ const ( ...@@ -38,6 +38,12 @@ const (
RedeemTypeSubscription = "subscription" RedeemTypeSubscription = "subscription"
) )
// PromoCode status constants
const (
PromoCodeStatusActive = "active"
PromoCodeStatusDisabled = "disabled"
)
// Admin adjustment type constants // Admin adjustment type constants
const ( const (
AdjustmentTypeAdminBalance = "admin_balance" // 管理员调整余额 AdjustmentTypeAdminBalance = "admin_balance" // 管理员调整余额
...@@ -105,7 +111,17 @@ const ( ...@@ -105,7 +111,17 @@ const (
// Request identity patch (Claude -> Gemini systemInstruction injection) // Request identity patch (Claude -> Gemini systemInstruction injection)
SettingKeyEnableIdentityPatch = "enable_identity_patch" SettingKeyEnableIdentityPatch = "enable_identity_patch"
SettingKeyIdentityPatchPrompt = "identity_patch_prompt" SettingKeyIdentityPatchPrompt = "identity_patch_prompt"
// LinuxDo Connect OAuth 登录(终端用户 SSO)
SettingKeyLinuxDoConnectEnabled = "linuxdo_connect_enabled"
SettingKeyLinuxDoConnectClientID = "linuxdo_connect_client_id"
SettingKeyLinuxDoConnectClientSecret = "linuxdo_connect_client_secret"
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
) )
// LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀(RFC 保留域名)。
// 目的:避免第三方登录返回的用户标识与本地真实邮箱发生碰撞,进而造成账号被接管的风险。
const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys). // AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
const AdminAPIKeyPrefix = "admin-" const AdminAPIKeyPrefix = "admin-"
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
"crypto/rand" "crypto/rand"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"log"
"math/big" "math/big"
"net/smtp" "net/smtp"
"strconv" "strconv"
...@@ -256,7 +257,9 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error ...@@ -256,7 +257,9 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
// 验证码不匹配 // 验证码不匹配
if data.Code != code { if data.Code != code {
data.Attempts++ data.Attempts++
_ = s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL) if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
log.Printf("[Email] Failed to update verification attempt count: %v", err)
}
if data.Attempts >= maxVerifyCodeAttempts { if data.Attempts >= maxVerifyCodeAttempts {
return ErrVerifyCodeMaxAttempts return ErrVerifyCodeMaxAttempts
} }
...@@ -264,7 +267,9 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error ...@@ -264,7 +267,9 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
} }
// 验证成功,删除验证码 // 验证成功,删除验证码
_ = s.cache.DeleteVerificationCode(ctx, email) if err := s.cache.DeleteVerificationCode(ctx, email); err != nil {
log.Printf("[Email] Failed to delete verification code after success: %v", err)
}
return nil return nil
} }
......
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
...@@ -23,9 +24,11 @@ type mockAccountRepoForPlatform struct { ...@@ -23,9 +24,11 @@ type mockAccountRepoForPlatform struct {
accounts []Account accounts []Account
accountsByID map[int64]*Account accountsByID map[int64]*Account
listPlatformFunc func(ctx context.Context, platform string) ([]Account, error) listPlatformFunc func(ctx context.Context, platform string) ([]Account, error)
getByIDCalls int
} }
func (m *mockAccountRepoForPlatform) GetByID(ctx context.Context, id int64) (*Account, error) { func (m *mockAccountRepoForPlatform) GetByID(ctx context.Context, id int64) (*Account, error) {
m.getByIDCalls++
if acc, ok := m.accountsByID[id]; ok { if acc, ok := m.accountsByID[id]; ok {
return acc, nil return acc, nil
} }
...@@ -136,6 +139,9 @@ func (m *mockAccountRepoForPlatform) ListSchedulableByGroupIDAndPlatforms(ctx co ...@@ -136,6 +139,9 @@ func (m *mockAccountRepoForPlatform) ListSchedulableByGroupIDAndPlatforms(ctx co
func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return nil return nil
} }
func (m *mockAccountRepoForPlatform) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
return nil
}
func (m *mockAccountRepoForPlatform) SetOverloaded(ctx context.Context, id int64, until time.Time) error { func (m *mockAccountRepoForPlatform) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return nil return nil
} }
...@@ -148,6 +154,9 @@ func (m *mockAccountRepoForPlatform) ClearTempUnschedulable(ctx context.Context, ...@@ -148,6 +154,9 @@ func (m *mockAccountRepoForPlatform) ClearTempUnschedulable(ctx context.Context,
func (m *mockAccountRepoForPlatform) ClearRateLimit(ctx context.Context, id int64) error { func (m *mockAccountRepoForPlatform) ClearRateLimit(ctx context.Context, id int64) error {
return nil return nil
} }
func (m *mockAccountRepoForPlatform) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
return nil
}
func (m *mockAccountRepoForPlatform) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error { func (m *mockAccountRepoForPlatform) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
return nil return nil
} }
...@@ -185,6 +194,56 @@ func (m *mockGatewayCacheForPlatform) RefreshSessionTTL(ctx context.Context, gro ...@@ -185,6 +194,56 @@ func (m *mockGatewayCacheForPlatform) RefreshSessionTTL(ctx context.Context, gro
return nil return nil
} }
type mockGroupRepoForGateway struct {
groups map[int64]*Group
getByIDCalls int
getByIDLiteCalls int
}
func (m *mockGroupRepoForGateway) GetByID(ctx context.Context, id int64) (*Group, error) {
m.getByIDCalls++
if g, ok := m.groups[id]; ok {
return g, nil
}
return nil, ErrGroupNotFound
}
func (m *mockGroupRepoForGateway) GetByIDLite(ctx context.Context, id int64) (*Group, error) {
m.getByIDLiteCalls++
if g, ok := m.groups[id]; ok {
return g, nil
}
return nil, ErrGroupNotFound
}
func (m *mockGroupRepoForGateway) Create(ctx context.Context, group *Group) error { return nil }
func (m *mockGroupRepoForGateway) Update(ctx context.Context, group *Group) error { return nil }
func (m *mockGroupRepoForGateway) Delete(ctx context.Context, id int64) error { return nil }
func (m *mockGroupRepoForGateway) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
return nil, nil
}
func (m *mockGroupRepoForGateway) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockGroupRepoForGateway) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockGroupRepoForGateway) ListActive(ctx context.Context) ([]Group, error) {
return nil, nil
}
func (m *mockGroupRepoForGateway) ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error) {
return nil, nil
}
func (m *mockGroupRepoForGateway) ExistsByName(ctx context.Context, name string) (bool, error) {
return false, nil
}
func (m *mockGroupRepoForGateway) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
return 0, nil
}
func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, nil
}
func ptr[T any](v T) *T { func ptr[T any](v T) *T {
return &v return &v
} }
...@@ -891,6 +950,74 @@ func (m *mockConcurrencyService) GetAccountWaitingCount(ctx context.Context, acc ...@@ -891,6 +950,74 @@ func (m *mockConcurrencyService) GetAccountWaitingCount(ctx context.Context, acc
return m.accountWaitCounts[accountID], nil return m.accountWaitCounts[accountID], nil
} }
type mockConcurrencyCache struct {
acquireAccountCalls int
loadBatchCalls int
}
func (m *mockConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
m.acquireAccountCalls++
return true, nil
}
func (m *mockConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
return nil
}
func (m *mockConcurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
return 0, nil
}
func (m *mockConcurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
return true, nil
}
func (m *mockConcurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
return nil
}
func (m *mockConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
return 0, nil
}
func (m *mockConcurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
return true, nil
}
func (m *mockConcurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
return nil
}
func (m *mockConcurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
return 0, nil
}
func (m *mockConcurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
return true, nil
}
func (m *mockConcurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
return nil
}
func (m *mockConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
m.loadBatchCalls++
result := make(map[int64]*AccountLoadInfo, len(accounts))
for _, acc := range accounts {
result[acc.ID] = &AccountLoadInfo{
AccountID: acc.ID,
CurrentConcurrency: 0,
WaitingCount: 0,
LoadRate: 0,
}
}
return result, nil
}
func (m *mockConcurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
return nil
}
// TestGatewayService_SelectAccountWithLoadAwareness tests load-aware account selection // TestGatewayService_SelectAccountWithLoadAwareness tests load-aware account selection
func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
ctx := context.Background() ctx := context.Background()
...@@ -989,6 +1116,78 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { ...@@ -989,6 +1116,78 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
require.Equal(t, int64(2), result.Account.ID, "不应选择被排除的账号") require.Equal(t, int64(2), result.Account.ID, "不应选择被排除的账号")
}) })
t.Run("粘性命中-不调用GetByID", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{
sessionBindings: map[string]int64{"sticky": 1},
}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = true
concurrencyCache := &mockConcurrencyCache{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(1), result.Account.ID)
require.Equal(t, 0, repo.getByIDCalls, "粘性命中不应调用GetByID")
require.Equal(t, 0, concurrencyCache.loadBatchCalls, "粘性命中应在负载批量查询前返回")
})
t.Run("粘性账号不在候选集-回退负载感知选择", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{
sessionBindings: map[string]int64{"sticky": 1},
}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = true
concurrencyCache := &mockConcurrencyCache{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "粘性账号不在候选集时应回退到可用账号")
require.Equal(t, 0, repo.getByIDCalls, "粘性账号缺失不应回退到GetByID")
require.Equal(t, 1, concurrencyCache.loadBatchCalls, "应继续进行负载批量查询")
})
t.Run("无可用账号-返回错误", func(t *testing.T) { t.Run("无可用账号-返回错误", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{ repo := &mockAccountRepoForPlatform{
accounts: []Account{}, accounts: []Account{},
...@@ -1013,3 +1212,190 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { ...@@ -1013,3 +1212,190 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
require.Contains(t, err.Error(), "no available accounts") require.Contains(t, err.Error(), "no available accounts")
}) })
} }
func TestGatewayService_GroupResolution_ReusesContextGroup(t *testing.T) {
ctx := context.Background()
groupID := int64(42)
group := &Group{
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
}
ctx = context.WithValue(ctx, ctxkey.Group, group)
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{groupID: group},
}
svc := &GatewayService{
accountRepo: repo,
groupRepo: groupRepo,
cfg: testConfig(),
}
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, 0, groupRepo.getByIDCalls)
require.Equal(t, 0, groupRepo.getByIDLiteCalls)
}
func TestGatewayService_GroupResolution_IgnoresInvalidContextGroup(t *testing.T) {
ctx := context.Background()
groupID := int64(42)
ctxGroup := &Group{
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
}
ctx = context.WithValue(ctx, ctxkey.Group, ctxGroup)
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
group := &Group{
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{groupID: group},
}
svc := &GatewayService{
accountRepo: repo,
groupRepo: groupRepo,
cfg: testConfig(),
}
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, 0, groupRepo.getByIDCalls)
require.Equal(t, 1, groupRepo.getByIDLiteCalls)
}
func TestGatewayService_GroupContext_OverwritesInvalidContextGroup(t *testing.T) {
groupID := int64(42)
invalidGroup := &Group{
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
}
hydratedGroup := &Group{
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
}
ctx := context.WithValue(context.Background(), ctxkey.Group, invalidGroup)
svc := &GatewayService{}
ctx = svc.withGroupContext(ctx, hydratedGroup)
got, ok := ctx.Value(ctxkey.Group).(*Group)
require.True(t, ok)
require.Same(t, hydratedGroup, got)
}
func TestGatewayService_GroupResolution_FallbackUsesLiteOnce(t *testing.T) {
ctx := context.Background()
groupID := int64(10)
fallbackID := int64(11)
group := &Group{
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
ClaudeCodeOnly: true,
FallbackGroupID: &fallbackID,
Hydrated: true,
}
fallbackGroup := &Group{
ID: fallbackID,
Platform: PlatformAnthropic,
Status: StatusActive,
}
ctx = context.WithValue(ctx, ctxkey.Group, group)
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{fallbackID: fallbackGroup},
}
svc := &GatewayService{
accountRepo: repo,
groupRepo: groupRepo,
cfg: testConfig(),
}
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, 0, groupRepo.getByIDCalls)
require.Equal(t, 1, groupRepo.getByIDLiteCalls)
}
func TestGatewayService_ResolveGatewayGroup_DetectsFallbackCycle(t *testing.T) {
ctx := context.Background()
groupID := int64(10)
fallbackID := int64(11)
group := &Group{
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
ClaudeCodeOnly: true,
FallbackGroupID: &fallbackID,
}
fallbackGroup := &Group{
ID: fallbackID,
Platform: PlatformAnthropic,
Status: StatusActive,
ClaudeCodeOnly: true,
FallbackGroupID: &groupID,
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{
groupID: group,
fallbackID: fallbackGroup,
},
}
svc := &GatewayService{
groupRepo: groupRepo,
}
gotGroup, gotID, err := svc.resolveGatewayGroup(ctx, &groupID)
require.Error(t, err)
require.Nil(t, gotGroup)
require.Nil(t, gotID)
require.Contains(t, err.Error(), "fallback group cycle")
}
...@@ -33,7 +33,7 @@ const ( ...@@ -33,7 +33,7 @@ const (
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true" claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true" claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
stickySessionTTL = time.Hour // 粘性会话TTL stickySessionTTL = time.Hour // 粘性会话TTL
defaultMaxLineSize = 10 * 1024 * 1024 defaultMaxLineSize = 40 * 1024 * 1024
claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude." claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude."
maxCacheControlBlocks = 4 // Anthropic API 允许的最大 cache_control 块数量 maxCacheControlBlocks = 4 // Anthropic API 允许的最大 cache_control 块数量
) )
...@@ -361,27 +361,13 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context ...@@ -361,27 +361,13 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
if hasForcePlatform && forcePlatform != "" { if hasForcePlatform && forcePlatform != "" {
platform = forcePlatform platform = forcePlatform
} else if groupID != nil { } else if groupID != nil {
// 根据分组 platform 决定查询哪种账号 group, resolvedGroupID, err := s.resolveGatewayGroup(ctx, groupID)
group, err := s.groupRepo.GetByID(ctx, *groupID)
if err != nil { if err != nil {
return nil, fmt.Errorf("get group failed: %w", err) return nil, err
} }
groupID = resolvedGroupID
ctx = s.withGroupContext(ctx, group)
platform = group.Platform platform = group.Platform
// 检查 Claude Code 客户端限制
if group.ClaudeCodeOnly {
isClaudeCode := IsClaudeCodeClient(ctx)
if !isClaudeCode {
// 非 Claude Code 客户端,检查是否有降级分组
if group.FallbackGroupID != nil {
// 使用降级分组重新调度
fallbackGroupID := *group.FallbackGroupID
return s.SelectAccountForModelWithExclusions(ctx, &fallbackGroupID, sessionHash, requestedModel, excludedIDs)
}
// 无降级分组,拒绝访问
return nil, ErrClaudeCodeOnly
}
}
} else { } else {
// 无分组时只使用原生 anthropic 平台 // 无分组时只使用原生 anthropic 平台
platform = PlatformAnthropic platform = PlatformAnthropic
...@@ -409,10 +395,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -409,10 +395,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
} }
// 检查 Claude Code 客户端限制(可能会替换 groupID 为降级分组) // 检查 Claude Code 客户端限制(可能会替换 groupID 为降级分组)
groupID, err := s.checkClaudeCodeRestriction(ctx, groupID) group, groupID, err := s.checkClaudeCodeRestriction(ctx, groupID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ctx = s.withGroupContext(ctx, group)
if s.concurrencyService == nil || !cfg.LoadBatchEnabled { if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs) account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
...@@ -452,7 +439,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -452,7 +439,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}, nil }, nil
} }
platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID) platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID, group)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -478,10 +465,15 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -478,10 +465,15 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if sessionHash != "" && s.cache != nil { if sessionHash != "" && s.cache != nil {
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
if err == nil && accountID > 0 && !isExcluded(accountID) { if err == nil && accountID > 0 && !isExcluded(accountID) {
account, err := s.accountRepo.GetByID(ctx, accountID) // 粘性命中仅在当前可调度候选集中生效。
if err == nil && s.isAccountInGroup(account, groupID) && accountByID := make(map[int64]*Account, len(accounts))
for i := range accounts {
accountByID[accounts[i].ID] = &accounts[i]
}
account, ok := accountByID[accountID]
if ok && s.isAccountInGroup(account, groupID) &&
s.isAccountAllowedForPlatform(account, platform, useMixed) && s.isAccountAllowedForPlatform(account, platform, useMixed) &&
account.IsSchedulable() && account.IsSchedulableForModel(requestedModel) &&
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
...@@ -519,6 +511,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -519,6 +511,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if !s.isAccountAllowedForPlatform(acc, platform, useMixed) { if !s.isAccountAllowedForPlatform(acc, platform, useMixed) {
continue continue
} }
if !acc.IsSchedulableForModel(requestedModel) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) { if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue continue
} }
...@@ -652,51 +647,97 @@ func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig { ...@@ -652,51 +647,97 @@ func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig {
} }
} }
// checkClaudeCodeRestriction 检查分组的 Claude Code 客户端限制 func (s *GatewayService) withGroupContext(ctx context.Context, group *Group) context.Context {
// 如果分组启用了 claude_code_only 且请求不是来自 Claude Code 客户端: if !IsGroupContextValid(group) {
// - 有降级分组:返回降级分组的 ID return ctx
// - 无降级分组:返回 ErrClaudeCodeOnly 错误
func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID *int64) (*int64, error) {
if groupID == nil {
return groupID, nil
} }
if existing, ok := ctx.Value(ctxkey.Group).(*Group); ok && existing != nil && existing.ID == group.ID && IsGroupContextValid(existing) {
return ctx
}
return context.WithValue(ctx, ctxkey.Group, group)
}
// 强制平台模式不检查 Claude Code 限制 func (s *GatewayService) groupFromContext(ctx context.Context, groupID int64) *Group {
if _, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform { if group, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(group) && group.ID == groupID {
return groupID, nil return group
} }
return nil
}
group, err := s.groupRepo.GetByID(ctx, *groupID) func (s *GatewayService) resolveGroupByID(ctx context.Context, groupID int64) (*Group, error) {
if group := s.groupFromContext(ctx, groupID); group != nil {
return group, nil
}
group, err := s.groupRepo.GetByIDLite(ctx, groupID)
if err != nil { if err != nil {
return nil, fmt.Errorf("get group failed: %w", err) return nil, fmt.Errorf("get group failed: %w", err)
} }
return group, nil
}
func (s *GatewayService) resolveGatewayGroup(ctx context.Context, groupID *int64) (*Group, *int64, error) {
if groupID == nil {
return nil, nil, nil
}
currentID := *groupID
visited := map[int64]struct{}{}
for {
if _, seen := visited[currentID]; seen {
return nil, nil, fmt.Errorf("fallback group cycle detected")
}
visited[currentID] = struct{}{}
group, err := s.resolveGroupByID(ctx, currentID)
if err != nil {
return nil, nil, err
}
if !group.ClaudeCodeOnly || IsClaudeCodeClient(ctx) {
return group, &currentID, nil
}
if group.FallbackGroupID == nil {
return nil, nil, ErrClaudeCodeOnly
}
currentID = *group.FallbackGroupID
}
}
if !group.ClaudeCodeOnly { // checkClaudeCodeRestriction 检查分组的 Claude Code 客户端限制
return groupID, nil // 如果分组启用了 claude_code_only 且请求不是来自 Claude Code 客户端:
// - 有降级分组:返回降级分组的 ID
// - 无降级分组:返回 ErrClaudeCodeOnly 错误
func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID *int64) (*Group, *int64, error) {
if groupID == nil {
return nil, groupID, nil
} }
// 分组启用了 Claude Code 限制 // 强制平台模式不检查 Claude Code 限制
if IsClaudeCodeClient(ctx) { if _, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform {
return groupID, nil return nil, groupID, nil
} }
// 非 Claude Code 客户端,检查降级分组 group, resolvedID, err := s.resolveGatewayGroup(ctx, groupID)
if group.FallbackGroupID != nil { if err != nil {
return group.FallbackGroupID, nil return nil, nil, err
} }
return nil, ErrClaudeCodeOnly return group, resolvedID, nil
} }
func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64) (string, bool, error) { func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, group *Group) (string, bool, error) {
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
if hasForcePlatform && forcePlatform != "" { if hasForcePlatform && forcePlatform != "" {
return forcePlatform, true, nil return forcePlatform, true, nil
} }
if group != nil {
return group.Platform, false, nil
}
if groupID != nil { if groupID != nil {
group, err := s.groupRepo.GetByID(ctx, *groupID) group, err := s.resolveGroupByID(ctx, *groupID)
if err != nil { if err != nil {
return "", false, fmt.Errorf("get group failed: %w", err) return "", false, err
} }
return group.Platform, false, nil return group.Platform, false, nil
} }
...@@ -812,7 +853,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, ...@@ -812,7 +853,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if _, excluded := excludedIDs[accountID]; !excluded { if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.accountRepo.GetByID(ctx, accountID) account, err := s.accountRepo.GetByID(ctx, accountID)
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台) // 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
} }
...@@ -844,6 +885,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, ...@@ -844,6 +885,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if _, excluded := excludedIDs[acc.ID]; excluded { if _, excluded := excludedIDs[acc.ID]; excluded {
continue continue
} }
if !acc.IsSchedulableForModel(requestedModel) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) { if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue continue
} }
...@@ -901,7 +945,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g ...@@ -901,7 +945,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if _, excluded := excludedIDs[accountID]; !excluded { if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.accountRepo.GetByID(ctx, accountID) account, err := s.accountRepo.GetByID(ctx, accountID)
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度 // 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
...@@ -936,6 +980,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g ...@@ -936,6 +980,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue continue
} }
if !acc.IsSchedulableForModel(requestedModel) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) { if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue continue
} }
...@@ -2247,6 +2294,7 @@ type RecordUsageInput struct { ...@@ -2247,6 +2294,7 @@ type RecordUsageInput struct {
Account *Account Account *Account
Subscription *UserSubscription // 可选:订阅信息 Subscription *UserSubscription // 可选:订阅信息
UserAgent string // 请求的 User-Agent UserAgent string // 请求的 User-Agent
IPAddress string // 请求的客户端 IP 地址
} }
// RecordUsage 记录使用量并扣费(或更新订阅用量) // RecordUsage 记录使用量并扣费(或更新订阅用量)
...@@ -2337,6 +2385,11 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -2337,6 +2385,11 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
usageLog.UserAgent = &input.UserAgent usageLog.UserAgent = &input.UserAgent
} }
// 添加 IPAddress
if input.IPAddress != "" {
usageLog.IPAddress = &input.IPAddress
}
// 添加分组和订阅关联 // 添加分组和订阅关联
if apiKey.GroupID != nil { if apiKey.GroupID != nil {
usageLog.GroupID = apiKey.GroupID usageLog.GroupID = apiKey.GroupID
......
...@@ -86,9 +86,15 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co ...@@ -86,9 +86,15 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
platform = forcePlatform platform = forcePlatform
} else if groupID != nil { } else if groupID != nil {
// 根据分组 platform 决定查询哪种账号 // 根据分组 platform 决定查询哪种账号
group, err := s.groupRepo.GetByID(ctx, *groupID) var group *Group
if err != nil { if ctxGroup, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(ctxGroup) && ctxGroup.ID == *groupID {
return nil, fmt.Errorf("get group failed: %w", err) group = ctxGroup
} else {
var err error
group, err = s.groupRepo.GetByIDLite(ctx, *groupID)
if err != nil {
return nil, fmt.Errorf("get group failed: %w", err)
}
} }
platform = group.Platform platform = group.Platform
} else { } else {
...@@ -114,7 +120,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co ...@@ -114,7 +120,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
if _, excluded := excludedIDs[accountID]; !excluded { if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.accountRepo.GetByID(ctx, accountID) account, err := s.accountRepo.GetByID(ctx, accountID)
// 检查账号是否有效:原生平台直接匹配,antigravity 需要启用混合调度 // 检查账号是否有效:原生平台直接匹配,antigravity 需要启用混合调度
if err == nil && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { if err == nil && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
valid := false valid := false
if account.Platform == platform { if account.Platform == platform {
valid = true valid = true
...@@ -172,6 +178,9 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co ...@@ -172,6 +178,9 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
if useMixedScheduling && acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { if useMixedScheduling && acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue continue
} }
if !acc.IsSchedulableForModel(requestedModel) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) { if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue continue
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment