Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
1a641392
Commit
1a641392
authored
Jan 10, 2026
by
cyhhao
Browse files
Merge up/main
parents
36b817d0
24d19a5f
Changes
174
Show whitespace changes
Inline
Side-by-side
backend/internal/server/router.go
View file @
1a641392
...
...
@@ -9,6 +9,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/web"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
)
// SetupRouter 配置路由器中间件和路由
...
...
@@ -21,6 +22,7 @@ func SetupRouter(
apiKeyService
*
service
.
APIKeyService
,
subscriptionService
*
service
.
SubscriptionService
,
cfg
*
config
.
Config
,
redisClient
*
redis
.
Client
,
)
*
gin
.
Engine
{
// 应用中间件
r
.
Use
(
middleware2
.
Logger
())
...
...
@@ -33,7 +35,7 @@ func SetupRouter(
}
// 注册路由
registerRoutes
(
r
,
handlers
,
jwtAuth
,
adminAuth
,
apiKeyAuth
,
apiKeyService
,
subscriptionService
,
cfg
)
registerRoutes
(
r
,
handlers
,
jwtAuth
,
adminAuth
,
apiKeyAuth
,
apiKeyService
,
subscriptionService
,
cfg
,
redisClient
)
return
r
}
...
...
@@ -48,6 +50,7 @@ func registerRoutes(
apiKeyService
*
service
.
APIKeyService
,
subscriptionService
*
service
.
SubscriptionService
,
cfg
*
config
.
Config
,
redisClient
*
redis
.
Client
,
)
{
// 通用路由(健康检查、状态等)
routes
.
RegisterCommonRoutes
(
r
)
...
...
@@ -56,7 +59,7 @@ func registerRoutes(
v1
:=
r
.
Group
(
"/api/v1"
)
// 注册各模块路由
routes
.
RegisterAuthRoutes
(
v1
,
h
,
jwtAuth
)
routes
.
RegisterAuthRoutes
(
v1
,
h
,
jwtAuth
,
redisClient
)
routes
.
RegisterUserRoutes
(
v1
,
h
,
jwtAuth
)
routes
.
RegisterAdminRoutes
(
v1
,
h
,
adminAuth
)
routes
.
RegisterGatewayRoutes
(
r
,
h
,
apiKeyAuth
,
apiKeyService
,
subscriptionService
,
cfg
)
...
...
backend/internal/server/routes/admin.go
View file @
1a641392
...
...
@@ -44,6 +44,9 @@ func RegisterAdminRoutes(
// 卡密管理
registerRedeemCodeRoutes
(
admin
,
h
)
// 优惠码管理
registerPromoCodeRoutes
(
admin
,
h
)
// 系统设置
registerSettingsRoutes
(
admin
,
h
)
...
...
@@ -201,6 +204,18 @@ func registerRedeemCodeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
}
}
func
registerPromoCodeRoutes
(
admin
*
gin
.
RouterGroup
,
h
*
handler
.
Handlers
)
{
promoCodes
:=
admin
.
Group
(
"/promo-codes"
)
{
promoCodes
.
GET
(
""
,
h
.
Admin
.
Promo
.
List
)
promoCodes
.
GET
(
"/:id"
,
h
.
Admin
.
Promo
.
GetByID
)
promoCodes
.
POST
(
""
,
h
.
Admin
.
Promo
.
Create
)
promoCodes
.
PUT
(
"/:id"
,
h
.
Admin
.
Promo
.
Update
)
promoCodes
.
DELETE
(
"/:id"
,
h
.
Admin
.
Promo
.
Delete
)
promoCodes
.
GET
(
"/:id/usages"
,
h
.
Admin
.
Promo
.
GetUsages
)
}
}
func
registerSettingsRoutes
(
admin
*
gin
.
RouterGroup
,
h
*
handler
.
Handlers
)
{
adminSettings
:=
admin
.
Group
(
"/settings"
)
{
...
...
backend/internal/server/routes/auth.go
View file @
1a641392
package
routes
import
(
"time"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/middleware"
servermiddleware
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
)
// RegisterAuthRoutes 注册认证相关路由
func
RegisterAuthRoutes
(
v1
*
gin
.
RouterGroup
,
h
*
handler
.
Handlers
,
jwtAuth
middleware
.
JWTAuthMiddleware
,
jwtAuth
servermiddleware
.
JWTAuthMiddleware
,
redisClient
*
redis
.
Client
,
)
{
// 创建速率限制器
rateLimiter
:=
middleware
.
NewRateLimiter
(
redisClient
)
// 公开接口
auth
:=
v1
.
Group
(
"/auth"
)
{
auth
.
POST
(
"/register"
,
h
.
Auth
.
Register
)
auth
.
POST
(
"/login"
,
h
.
Auth
.
Login
)
auth
.
POST
(
"/send-verify-code"
,
h
.
Auth
.
SendVerifyCode
)
// 优惠码验证接口添加速率限制:每分钟最多 10 次
auth
.
POST
(
"/validate-promo-code"
,
rateLimiter
.
Limit
(
"validate-promo"
,
10
,
time
.
Minute
),
h
.
Auth
.
ValidatePromoCode
)
auth
.
GET
(
"/oauth/linuxdo/start"
,
h
.
Auth
.
LinuxDoOAuthStart
)
auth
.
GET
(
"/oauth/linuxdo/callback"
,
h
.
Auth
.
LinuxDoOAuthCallback
)
}
// 公开设置(无需认证)
...
...
backend/internal/service/account_service.go
View file @
1a641392
...
...
@@ -49,10 +49,12 @@ type AccountRepository interface {
ListSchedulableByGroupIDAndPlatforms
(
ctx
context
.
Context
,
groupID
int64
,
platforms
[]
string
)
([]
Account
,
error
)
SetRateLimited
(
ctx
context
.
Context
,
id
int64
,
resetAt
time
.
Time
)
error
SetAntigravityQuotaScopeLimit
(
ctx
context
.
Context
,
id
int64
,
scope
AntigravityQuotaScope
,
resetAt
time
.
Time
)
error
SetOverloaded
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
)
error
SetTempUnschedulable
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
,
reason
string
)
error
ClearTempUnschedulable
(
ctx
context
.
Context
,
id
int64
)
error
ClearRateLimit
(
ctx
context
.
Context
,
id
int64
)
error
ClearAntigravityQuotaScopes
(
ctx
context
.
Context
,
id
int64
)
error
UpdateSessionWindow
(
ctx
context
.
Context
,
id
int64
,
start
,
end
*
time
.
Time
,
status
string
)
error
UpdateExtra
(
ctx
context
.
Context
,
id
int64
,
updates
map
[
string
]
any
)
error
BulkUpdate
(
ctx
context
.
Context
,
ids
[]
int64
,
updates
AccountBulkUpdate
)
(
int64
,
error
)
...
...
@@ -66,6 +68,7 @@ type AccountBulkUpdate struct {
Concurrency
*
int
Priority
*
int
Status
*
string
Schedulable
*
bool
Credentials
map
[
string
]
any
Extra
map
[
string
]
any
}
...
...
backend/internal/service/account_service_delete_test.go
View file @
1a641392
...
...
@@ -139,6 +139,10 @@ func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt
panic
(
"unexpected SetRateLimited call"
)
}
func
(
s
*
accountRepoStub
)
SetAntigravityQuotaScopeLimit
(
ctx
context
.
Context
,
id
int64
,
scope
AntigravityQuotaScope
,
resetAt
time
.
Time
)
error
{
panic
(
"unexpected SetAntigravityQuotaScopeLimit call"
)
}
func
(
s
*
accountRepoStub
)
SetOverloaded
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
)
error
{
panic
(
"unexpected SetOverloaded call"
)
}
...
...
@@ -155,6 +159,10 @@ func (s *accountRepoStub) ClearRateLimit(ctx context.Context, id int64) error {
panic
(
"unexpected ClearRateLimit call"
)
}
func
(
s
*
accountRepoStub
)
ClearAntigravityQuotaScopes
(
ctx
context
.
Context
,
id
int64
)
error
{
panic
(
"unexpected ClearAntigravityQuotaScopes call"
)
}
func
(
s
*
accountRepoStub
)
UpdateSessionWindow
(
ctx
context
.
Context
,
id
int64
,
start
,
end
*
time
.
Time
,
status
string
)
error
{
panic
(
"unexpected UpdateSessionWindow call"
)
}
...
...
backend/internal/service/admin_service.go
View file @
1a641392
...
...
@@ -24,7 +24,7 @@ type AdminService interface {
GetUserUsageStats
(
ctx
context
.
Context
,
userID
int64
,
period
string
)
(
any
,
error
)
// Group management
ListGroups
(
ctx
context
.
Context
,
page
,
pageSize
int
,
platform
,
status
string
,
isExclusive
*
bool
)
([]
Group
,
int64
,
error
)
ListGroups
(
ctx
context
.
Context
,
page
,
pageSize
int
,
platform
,
status
,
search
string
,
isExclusive
*
bool
)
([]
Group
,
int64
,
error
)
GetAllGroups
(
ctx
context
.
Context
)
([]
Group
,
error
)
GetAllGroupsByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
Group
,
error
)
GetGroup
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
...
...
@@ -168,6 +168,7 @@ type BulkUpdateAccountsInput struct {
Concurrency
*
int
Priority
*
int
Status
string
Schedulable
*
bool
GroupIDs
*
[]
int64
Credentials
map
[
string
]
any
Extra
map
[
string
]
any
...
...
@@ -478,9 +479,9 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64,
}
// Group management implementations
func
(
s
*
adminServiceImpl
)
ListGroups
(
ctx
context
.
Context
,
page
,
pageSize
int
,
platform
,
status
string
,
isExclusive
*
bool
)
([]
Group
,
int64
,
error
)
{
func
(
s
*
adminServiceImpl
)
ListGroups
(
ctx
context
.
Context
,
page
,
pageSize
int
,
platform
,
status
,
search
string
,
isExclusive
*
bool
)
([]
Group
,
int64
,
error
)
{
params
:=
pagination
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
}
groups
,
result
,
err
:=
s
.
groupRepo
.
ListWithFilters
(
ctx
,
params
,
platform
,
status
,
isExclusive
)
groups
,
result
,
err
:=
s
.
groupRepo
.
ListWithFilters
(
ctx
,
params
,
platform
,
status
,
search
,
isExclusive
)
if
err
!=
nil
{
return
nil
,
0
,
err
}
...
...
@@ -575,18 +576,33 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro
return
fmt
.
Errorf
(
"cannot set self as fallback group"
)
}
visited
:=
map
[
int64
]
struct
{}{}
nextID
:=
fallbackGroupID
for
{
if
_
,
seen
:=
visited
[
nextID
];
seen
{
return
fmt
.
Errorf
(
"fallback group cycle detected"
)
}
visited
[
nextID
]
=
struct
{}{}
if
currentGroupID
>
0
&&
nextID
==
currentGroupID
{
return
fmt
.
Errorf
(
"fallback group cycle detected"
)
}
// 检查降级分组是否存在
fallbackGroup
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
fallbackGroup
ID
)
fallbackGroup
,
err
:=
s
.
groupRepo
.
GetByID
Lite
(
ctx
,
next
ID
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"fallback group not found: %w"
,
err
)
}
// 降级分组不能启用 claude_code_only,否则会造成死循环
if
fallbackGroup
.
ClaudeCodeOnly
{
if
nextID
==
fallbackGroupID
&&
fallbackGroup
.
ClaudeCodeOnly
{
return
fmt
.
Errorf
(
"fallback group cannot have claude_code_only enabled"
)
}
if
fallbackGroup
.
FallbackGroupID
==
nil
{
return
nil
}
nextID
=
*
fallbackGroup
.
FallbackGroupID
}
}
func
(
s
*
adminServiceImpl
)
UpdateGroup
(
ctx
context
.
Context
,
id
int64
,
input
*
UpdateGroupInput
)
(
*
Group
,
error
)
{
...
...
@@ -910,6 +926,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
if
input
.
Status
!=
""
{
repoUpdates
.
Status
=
&
input
.
Status
}
if
input
.
Schedulable
!=
nil
{
repoUpdates
.
Schedulable
=
input
.
Schedulable
}
// Run bulk update for column/jsonb fields first.
if
_
,
err
:=
s
.
accountRepo
.
BulkUpdate
(
ctx
,
input
.
AccountIDs
,
repoUpdates
);
err
!=
nil
{
...
...
backend/internal/service/admin_service_delete_test.go
View file @
1a641392
...
...
@@ -107,6 +107,10 @@ func (s *groupRepoStub) GetByID(ctx context.Context, id int64) (*Group, error) {
panic
(
"unexpected GetByID call"
)
}
func
(
s
*
groupRepoStub
)
GetByIDLite
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
panic
(
"unexpected GetByIDLite call"
)
}
func
(
s
*
groupRepoStub
)
Update
(
ctx
context
.
Context
,
group
*
Group
)
error
{
panic
(
"unexpected Update call"
)
}
...
...
@@ -124,7 +128,7 @@ func (s *groupRepoStub) List(ctx context.Context, params pagination.PaginationPa
panic
(
"unexpected List call"
)
}
func
(
s
*
groupRepoStub
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
status
string
,
isExclusive
*
bool
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
func
(
s
*
groupRepoStub
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
status
,
search
string
,
isExclusive
*
bool
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
panic
(
"unexpected ListWithFilters call"
)
}
...
...
backend/internal/service/admin_service_group_test.go
View file @
1a641392
...
...
@@ -16,6 +16,16 @@ type groupRepoStubForAdmin struct {
updated
*
Group
// 记录 Update 调用的参数
getByID
*
Group
// GetByID 返回值
getErr
error
// GetByID 返回的错误
listWithFiltersCalls
int
listWithFiltersParams
pagination
.
PaginationParams
listWithFiltersPlatform
string
listWithFiltersStatus
string
listWithFiltersSearch
string
listWithFiltersIsExclusive
*
bool
listWithFiltersGroups
[]
Group
listWithFiltersResult
*
pagination
.
PaginationResult
listWithFiltersErr
error
}
func
(
s
*
groupRepoStubForAdmin
)
Create
(
_
context
.
Context
,
g
*
Group
)
error
{
...
...
@@ -35,6 +45,13 @@ func (s *groupRepoStubForAdmin) GetByID(_ context.Context, _ int64) (*Group, err
return
s
.
getByID
,
nil
}
func
(
s
*
groupRepoStubForAdmin
)
GetByIDLite
(
_
context
.
Context
,
_
int64
)
(
*
Group
,
error
)
{
if
s
.
getErr
!=
nil
{
return
nil
,
s
.
getErr
}
return
s
.
getByID
,
nil
}
func
(
s
*
groupRepoStubForAdmin
)
Delete
(
_
context
.
Context
,
_
int64
)
error
{
panic
(
"unexpected Delete call"
)
}
...
...
@@ -47,8 +64,28 @@ func (s *groupRepoStubForAdmin) List(_ context.Context, _ pagination.PaginationP
panic
(
"unexpected List call"
)
}
func
(
s
*
groupRepoStubForAdmin
)
ListWithFilters
(
_
context
.
Context
,
_
pagination
.
PaginationParams
,
_
,
_
string
,
_
*
bool
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
panic
(
"unexpected ListWithFilters call"
)
func
(
s
*
groupRepoStubForAdmin
)
ListWithFilters
(
_
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
status
,
search
string
,
isExclusive
*
bool
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
s
.
listWithFiltersCalls
++
s
.
listWithFiltersParams
=
params
s
.
listWithFiltersPlatform
=
platform
s
.
listWithFiltersStatus
=
status
s
.
listWithFiltersSearch
=
search
s
.
listWithFiltersIsExclusive
=
isExclusive
if
s
.
listWithFiltersErr
!=
nil
{
return
nil
,
nil
,
s
.
listWithFiltersErr
}
result
:=
s
.
listWithFiltersResult
if
result
==
nil
{
result
=
&
pagination
.
PaginationResult
{
Total
:
int64
(
len
(
s
.
listWithFiltersGroups
)),
Page
:
params
.
Page
,
PageSize
:
params
.
PageSize
,
}
}
return
s
.
listWithFiltersGroups
,
result
,
nil
}
func
(
s
*
groupRepoStubForAdmin
)
ListActive
(
_
context
.
Context
)
([]
Group
,
error
)
{
...
...
@@ -195,3 +232,149 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
require
.
InDelta
(
t
,
0.15
,
*
repo
.
updated
.
ImagePrice2K
,
0.0001
)
// 原值保持
require
.
Nil
(
t
,
repo
.
updated
.
ImagePrice4K
)
}
func
TestAdminService_ListGroups_WithSearch
(
t
*
testing
.
T
)
{
// 测试:
// 1. search 参数正常传递到 repository 层
// 2. search 为空字符串时的行为
// 3. search 与其他过滤条件组合使用
t
.
Run
(
"search 参数正常传递到 repository 层"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
groupRepoStubForAdmin
{
listWithFiltersGroups
:
[]
Group
{{
ID
:
1
,
Name
:
"alpha"
}},
listWithFiltersResult
:
&
pagination
.
PaginationResult
{
Total
:
1
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
groups
,
total
,
err
:=
svc
.
ListGroups
(
context
.
Background
(),
1
,
20
,
""
,
""
,
"alpha"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
1
),
total
)
require
.
Equal
(
t
,
[]
Group
{{
ID
:
1
,
Name
:
"alpha"
}},
groups
)
require
.
Equal
(
t
,
1
,
repo
.
listWithFiltersCalls
)
require
.
Equal
(
t
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
20
},
repo
.
listWithFiltersParams
)
require
.
Equal
(
t
,
"alpha"
,
repo
.
listWithFiltersSearch
)
require
.
Nil
(
t
,
repo
.
listWithFiltersIsExclusive
)
})
t
.
Run
(
"search 为空字符串时传递空字符串"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
groupRepoStubForAdmin
{
listWithFiltersGroups
:
[]
Group
{},
listWithFiltersResult
:
&
pagination
.
PaginationResult
{
Total
:
0
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
groups
,
total
,
err
:=
svc
.
ListGroups
(
context
.
Background
(),
2
,
10
,
""
,
""
,
""
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
Empty
(
t
,
groups
)
require
.
Equal
(
t
,
int64
(
0
),
total
)
require
.
Equal
(
t
,
1
,
repo
.
listWithFiltersCalls
)
require
.
Equal
(
t
,
pagination
.
PaginationParams
{
Page
:
2
,
PageSize
:
10
},
repo
.
listWithFiltersParams
)
require
.
Equal
(
t
,
""
,
repo
.
listWithFiltersSearch
)
require
.
Nil
(
t
,
repo
.
listWithFiltersIsExclusive
)
})
t
.
Run
(
"search 与其他过滤条件组合使用"
,
func
(
t
*
testing
.
T
)
{
isExclusive
:=
true
repo
:=
&
groupRepoStubForAdmin
{
listWithFiltersGroups
:
[]
Group
{{
ID
:
2
,
Name
:
"beta"
}},
listWithFiltersResult
:
&
pagination
.
PaginationResult
{
Total
:
42
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
groups
,
total
,
err
:=
svc
.
ListGroups
(
context
.
Background
(),
3
,
50
,
PlatformAntigravity
,
StatusActive
,
"beta"
,
&
isExclusive
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
42
),
total
)
require
.
Equal
(
t
,
[]
Group
{{
ID
:
2
,
Name
:
"beta"
}},
groups
)
require
.
Equal
(
t
,
1
,
repo
.
listWithFiltersCalls
)
require
.
Equal
(
t
,
pagination
.
PaginationParams
{
Page
:
3
,
PageSize
:
50
},
repo
.
listWithFiltersParams
)
require
.
Equal
(
t
,
PlatformAntigravity
,
repo
.
listWithFiltersPlatform
)
require
.
Equal
(
t
,
StatusActive
,
repo
.
listWithFiltersStatus
)
require
.
Equal
(
t
,
"beta"
,
repo
.
listWithFiltersSearch
)
require
.
NotNil
(
t
,
repo
.
listWithFiltersIsExclusive
)
require
.
True
(
t
,
*
repo
.
listWithFiltersIsExclusive
)
})
}
func
TestAdminService_ValidateFallbackGroup_DetectsCycle
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
1
)
fallbackID
:=
int64
(
2
)
repo
:=
&
groupRepoStubForFallbackCycle
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
{
ID
:
groupID
,
FallbackGroupID
:
&
fallbackID
,
},
fallbackID
:
{
ID
:
fallbackID
,
FallbackGroupID
:
&
groupID
,
},
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
err
:=
svc
.
validateFallbackGroup
(
context
.
Background
(),
groupID
,
fallbackID
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"fallback group cycle"
)
}
type
groupRepoStubForFallbackCycle
struct
{
groups
map
[
int64
]
*
Group
}
func
(
s
*
groupRepoStubForFallbackCycle
)
Create
(
_
context
.
Context
,
_
*
Group
)
error
{
panic
(
"unexpected Create call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
Update
(
_
context
.
Context
,
_
*
Group
)
error
{
panic
(
"unexpected Update call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
return
s
.
GetByIDLite
(
ctx
,
id
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
GetByIDLite
(
_
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
if
g
,
ok
:=
s
.
groups
[
id
];
ok
{
return
g
,
nil
}
return
nil
,
ErrGroupNotFound
}
func
(
s
*
groupRepoStubForFallbackCycle
)
Delete
(
_
context
.
Context
,
_
int64
)
error
{
panic
(
"unexpected Delete call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
DeleteCascade
(
_
context
.
Context
,
_
int64
)
([]
int64
,
error
)
{
panic
(
"unexpected DeleteCascade call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
List
(
_
context
.
Context
,
_
pagination
.
PaginationParams
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
panic
(
"unexpected List call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
ListWithFilters
(
_
context
.
Context
,
_
pagination
.
PaginationParams
,
_
,
_
,
_
string
,
_
*
bool
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
panic
(
"unexpected ListWithFilters call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
ListActive
(
_
context
.
Context
)
([]
Group
,
error
)
{
panic
(
"unexpected ListActive call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
ListActiveByPlatform
(
_
context
.
Context
,
_
string
)
([]
Group
,
error
)
{
panic
(
"unexpected ListActiveByPlatform call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
ExistsByName
(
_
context
.
Context
,
_
string
)
(
bool
,
error
)
{
panic
(
"unexpected ExistsByName call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
GetAccountCount
(
_
context
.
Context
,
_
int64
)
(
int64
,
error
)
{
panic
(
"unexpected GetAccountCount call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
DeleteAccountGroupsByGroupID
(
_
context
.
Context
,
_
int64
)
(
int64
,
error
)
{
panic
(
"unexpected DeleteAccountGroupsByGroupID call"
)
}
backend/internal/service/admin_service_search_test.go
0 → 100644
View file @
1a641392
//go:build unit
package
service
import
(
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type
accountRepoStubForAdminList
struct
{
accountRepoStub
listWithFiltersCalls
int
listWithFiltersParams
pagination
.
PaginationParams
listWithFiltersPlatform
string
listWithFiltersType
string
listWithFiltersStatus
string
listWithFiltersSearch
string
listWithFiltersAccounts
[]
Account
listWithFiltersResult
*
pagination
.
PaginationResult
listWithFiltersErr
error
}
func
(
s
*
accountRepoStubForAdminList
)
ListWithFilters
(
_
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
accountType
,
status
,
search
string
)
([]
Account
,
*
pagination
.
PaginationResult
,
error
)
{
s
.
listWithFiltersCalls
++
s
.
listWithFiltersParams
=
params
s
.
listWithFiltersPlatform
=
platform
s
.
listWithFiltersType
=
accountType
s
.
listWithFiltersStatus
=
status
s
.
listWithFiltersSearch
=
search
if
s
.
listWithFiltersErr
!=
nil
{
return
nil
,
nil
,
s
.
listWithFiltersErr
}
result
:=
s
.
listWithFiltersResult
if
result
==
nil
{
result
=
&
pagination
.
PaginationResult
{
Total
:
int64
(
len
(
s
.
listWithFiltersAccounts
)),
Page
:
params
.
Page
,
PageSize
:
params
.
PageSize
,
}
}
return
s
.
listWithFiltersAccounts
,
result
,
nil
}
type
proxyRepoStubForAdminList
struct
{
proxyRepoStub
listWithFiltersCalls
int
listWithFiltersParams
pagination
.
PaginationParams
listWithFiltersProtocol
string
listWithFiltersStatus
string
listWithFiltersSearch
string
listWithFiltersProxies
[]
Proxy
listWithFiltersResult
*
pagination
.
PaginationResult
listWithFiltersErr
error
listWithFiltersAndAccountCountCalls
int
listWithFiltersAndAccountCountParams
pagination
.
PaginationParams
listWithFiltersAndAccountCountProtocol
string
listWithFiltersAndAccountCountStatus
string
listWithFiltersAndAccountCountSearch
string
listWithFiltersAndAccountCountProxies
[]
ProxyWithAccountCount
listWithFiltersAndAccountCountResult
*
pagination
.
PaginationResult
listWithFiltersAndAccountCountErr
error
}
func
(
s
*
proxyRepoStubForAdminList
)
ListWithFilters
(
_
context
.
Context
,
params
pagination
.
PaginationParams
,
protocol
,
status
,
search
string
)
([]
Proxy
,
*
pagination
.
PaginationResult
,
error
)
{
s
.
listWithFiltersCalls
++
s
.
listWithFiltersParams
=
params
s
.
listWithFiltersProtocol
=
protocol
s
.
listWithFiltersStatus
=
status
s
.
listWithFiltersSearch
=
search
if
s
.
listWithFiltersErr
!=
nil
{
return
nil
,
nil
,
s
.
listWithFiltersErr
}
result
:=
s
.
listWithFiltersResult
if
result
==
nil
{
result
=
&
pagination
.
PaginationResult
{
Total
:
int64
(
len
(
s
.
listWithFiltersProxies
)),
Page
:
params
.
Page
,
PageSize
:
params
.
PageSize
,
}
}
return
s
.
listWithFiltersProxies
,
result
,
nil
}
func
(
s
*
proxyRepoStubForAdminList
)
ListWithFiltersAndAccountCount
(
_
context
.
Context
,
params
pagination
.
PaginationParams
,
protocol
,
status
,
search
string
)
([]
ProxyWithAccountCount
,
*
pagination
.
PaginationResult
,
error
)
{
s
.
listWithFiltersAndAccountCountCalls
++
s
.
listWithFiltersAndAccountCountParams
=
params
s
.
listWithFiltersAndAccountCountProtocol
=
protocol
s
.
listWithFiltersAndAccountCountStatus
=
status
s
.
listWithFiltersAndAccountCountSearch
=
search
if
s
.
listWithFiltersAndAccountCountErr
!=
nil
{
return
nil
,
nil
,
s
.
listWithFiltersAndAccountCountErr
}
result
:=
s
.
listWithFiltersAndAccountCountResult
if
result
==
nil
{
result
=
&
pagination
.
PaginationResult
{
Total
:
int64
(
len
(
s
.
listWithFiltersAndAccountCountProxies
)),
Page
:
params
.
Page
,
PageSize
:
params
.
PageSize
,
}
}
return
s
.
listWithFiltersAndAccountCountProxies
,
result
,
nil
}
type
redeemRepoStubForAdminList
struct
{
redeemRepoStub
listWithFiltersCalls
int
listWithFiltersParams
pagination
.
PaginationParams
listWithFiltersType
string
listWithFiltersStatus
string
listWithFiltersSearch
string
listWithFiltersCodes
[]
RedeemCode
listWithFiltersResult
*
pagination
.
PaginationResult
listWithFiltersErr
error
}
func
(
s
*
redeemRepoStubForAdminList
)
ListWithFilters
(
_
context
.
Context
,
params
pagination
.
PaginationParams
,
codeType
,
status
,
search
string
)
([]
RedeemCode
,
*
pagination
.
PaginationResult
,
error
)
{
s
.
listWithFiltersCalls
++
s
.
listWithFiltersParams
=
params
s
.
listWithFiltersType
=
codeType
s
.
listWithFiltersStatus
=
status
s
.
listWithFiltersSearch
=
search
if
s
.
listWithFiltersErr
!=
nil
{
return
nil
,
nil
,
s
.
listWithFiltersErr
}
result
:=
s
.
listWithFiltersResult
if
result
==
nil
{
result
=
&
pagination
.
PaginationResult
{
Total
:
int64
(
len
(
s
.
listWithFiltersCodes
)),
Page
:
params
.
Page
,
PageSize
:
params
.
PageSize
,
}
}
return
s
.
listWithFiltersCodes
,
result
,
nil
}
func
TestAdminService_ListAccounts_WithSearch
(
t
*
testing
.
T
)
{
t
.
Run
(
"search 参数正常传递到 repository 层"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
accountRepoStubForAdminList
{
listWithFiltersAccounts
:
[]
Account
{{
ID
:
1
,
Name
:
"acc"
}},
listWithFiltersResult
:
&
pagination
.
PaginationResult
{
Total
:
10
},
}
svc
:=
&
adminServiceImpl
{
accountRepo
:
repo
}
accounts
,
total
,
err
:=
svc
.
ListAccounts
(
context
.
Background
(),
1
,
20
,
PlatformGemini
,
AccountTypeOAuth
,
StatusActive
,
"acc"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
10
),
total
)
require
.
Equal
(
t
,
[]
Account
{{
ID
:
1
,
Name
:
"acc"
}},
accounts
)
require
.
Equal
(
t
,
1
,
repo
.
listWithFiltersCalls
)
require
.
Equal
(
t
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
20
},
repo
.
listWithFiltersParams
)
require
.
Equal
(
t
,
PlatformGemini
,
repo
.
listWithFiltersPlatform
)
require
.
Equal
(
t
,
AccountTypeOAuth
,
repo
.
listWithFiltersType
)
require
.
Equal
(
t
,
StatusActive
,
repo
.
listWithFiltersStatus
)
require
.
Equal
(
t
,
"acc"
,
repo
.
listWithFiltersSearch
)
})
}
func
TestAdminService_ListProxies_WithSearch
(
t
*
testing
.
T
)
{
t
.
Run
(
"search 参数正常传递到 repository 层"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
proxyRepoStubForAdminList
{
listWithFiltersProxies
:
[]
Proxy
{{
ID
:
2
,
Name
:
"p1"
}},
listWithFiltersResult
:
&
pagination
.
PaginationResult
{
Total
:
7
},
}
svc
:=
&
adminServiceImpl
{
proxyRepo
:
repo
}
proxies
,
total
,
err
:=
svc
.
ListProxies
(
context
.
Background
(),
3
,
50
,
"http"
,
StatusActive
,
"p1"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
7
),
total
)
require
.
Equal
(
t
,
[]
Proxy
{{
ID
:
2
,
Name
:
"p1"
}},
proxies
)
require
.
Equal
(
t
,
1
,
repo
.
listWithFiltersCalls
)
require
.
Equal
(
t
,
pagination
.
PaginationParams
{
Page
:
3
,
PageSize
:
50
},
repo
.
listWithFiltersParams
)
require
.
Equal
(
t
,
"http"
,
repo
.
listWithFiltersProtocol
)
require
.
Equal
(
t
,
StatusActive
,
repo
.
listWithFiltersStatus
)
require
.
Equal
(
t
,
"p1"
,
repo
.
listWithFiltersSearch
)
})
}
func
TestAdminService_ListProxiesWithAccountCount_WithSearch
(
t
*
testing
.
T
)
{
t
.
Run
(
"search 参数正常传递到 repository 层"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
proxyRepoStubForAdminList
{
listWithFiltersAndAccountCountProxies
:
[]
ProxyWithAccountCount
{{
Proxy
:
Proxy
{
ID
:
3
,
Name
:
"p2"
},
AccountCount
:
5
}},
listWithFiltersAndAccountCountResult
:
&
pagination
.
PaginationResult
{
Total
:
9
},
}
svc
:=
&
adminServiceImpl
{
proxyRepo
:
repo
}
proxies
,
total
,
err
:=
svc
.
ListProxiesWithAccountCount
(
context
.
Background
(),
2
,
10
,
"socks5"
,
StatusDisabled
,
"p2"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
9
),
total
)
require
.
Equal
(
t
,
[]
ProxyWithAccountCount
{{
Proxy
:
Proxy
{
ID
:
3
,
Name
:
"p2"
},
AccountCount
:
5
}},
proxies
)
require
.
Equal
(
t
,
1
,
repo
.
listWithFiltersAndAccountCountCalls
)
require
.
Equal
(
t
,
pagination
.
PaginationParams
{
Page
:
2
,
PageSize
:
10
},
repo
.
listWithFiltersAndAccountCountParams
)
require
.
Equal
(
t
,
"socks5"
,
repo
.
listWithFiltersAndAccountCountProtocol
)
require
.
Equal
(
t
,
StatusDisabled
,
repo
.
listWithFiltersAndAccountCountStatus
)
require
.
Equal
(
t
,
"p2"
,
repo
.
listWithFiltersAndAccountCountSearch
)
})
}
func
TestAdminService_ListRedeemCodes_WithSearch
(
t
*
testing
.
T
)
{
t
.
Run
(
"search 参数正常传递到 repository 层"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
redeemRepoStubForAdminList
{
listWithFiltersCodes
:
[]
RedeemCode
{{
ID
:
4
,
Code
:
"ABC"
}},
listWithFiltersResult
:
&
pagination
.
PaginationResult
{
Total
:
3
},
}
svc
:=
&
adminServiceImpl
{
redeemCodeRepo
:
repo
}
codes
,
total
,
err
:=
svc
.
ListRedeemCodes
(
context
.
Background
(),
1
,
20
,
RedeemTypeBalance
,
StatusUnused
,
"ABC"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
3
),
total
)
require
.
Equal
(
t
,
[]
RedeemCode
{{
ID
:
4
,
Code
:
"ABC"
}},
codes
)
require
.
Equal
(
t
,
1
,
repo
.
listWithFiltersCalls
)
require
.
Equal
(
t
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
20
},
repo
.
listWithFiltersParams
)
require
.
Equal
(
t
,
RedeemTypeBalance
,
repo
.
listWithFiltersType
)
require
.
Equal
(
t
,
StatusUnused
,
repo
.
listWithFiltersStatus
)
require
.
Equal
(
t
,
"ABC"
,
repo
.
listWithFiltersSearch
)
})
}
backend/internal/service/antigravity_gateway_service.go
View file @
1a641392
...
...
@@ -93,6 +93,7 @@ var antigravityPrefixMapping = []struct {
// 长前缀优先
{
"gemini-2.5-flash-image"
,
"gemini-3-pro-image"
},
// gemini-2.5-flash-image → 3-pro-image
{
"gemini-3-pro-image"
,
"gemini-3-pro-image"
},
// gemini-3-pro-image-preview 等
{
"gemini-3-flash"
,
"gemini-3-flash"
},
// gemini-3-flash-preview 等 → gemini-3-flash
{
"claude-3-5-sonnet"
,
"claude-sonnet-4-5"
},
// 旧版 claude-3-5-sonnet-xxx
{
"claude-sonnet-4-5"
,
"claude-sonnet-4-5"
},
// claude-sonnet-4-5-xxx
{
"claude-haiku-4-5"
,
"claude-sonnet-4-5"
},
// claude-haiku-4-5-xxx → sonnet
...
...
@@ -502,6 +503,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
originalModel
:=
claudeReq
.
Model
mappedModel
:=
s
.
getMappedModel
(
account
,
claudeReq
.
Model
)
quotaScope
,
_
:=
resolveAntigravityQuotaScope
(
originalModel
)
// 获取 access_token
if
s
.
tokenProvider
==
nil
{
...
...
@@ -603,7 +605,7 @@ urlFallbackLoop:
}
// 所有重试都失败,标记限流状态
if
resp
.
StatusCode
==
429
{
s
.
handleUpstreamError
(
ctx
,
prefix
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
)
s
.
handleUpstreamError
(
ctx
,
prefix
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
,
quotaScope
)
}
// 最后一次尝试也失败
resp
=
&
http
.
Response
{
...
...
@@ -696,7 +698,7 @@ urlFallbackLoop:
// 处理错误响应(重试后仍失败或不触发重试)
if
resp
.
StatusCode
>=
400
{
s
.
handleUpstreamError
(
ctx
,
prefix
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
)
s
.
handleUpstreamError
(
ctx
,
prefix
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
,
quotaScope
)
if
s
.
shouldFailoverUpstreamError
(
resp
.
StatusCode
)
{
return
nil
,
&
UpstreamFailoverError
{
StatusCode
:
resp
.
StatusCode
}
...
...
@@ -1021,6 +1023,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
if
len
(
body
)
==
0
{
return
nil
,
s
.
writeGoogleError
(
c
,
http
.
StatusBadRequest
,
"Request body is empty"
)
}
quotaScope
,
_
:=
resolveAntigravityQuotaScope
(
originalModel
)
// 解析请求以获取 image_size(用于图片计费)
imageSize
:=
s
.
extractImageSize
(
body
)
...
...
@@ -1146,7 +1149,7 @@ urlFallbackLoop:
}
// 所有重试都失败,标记限流状态
if
resp
.
StatusCode
==
429
{
s
.
handleUpstreamError
(
ctx
,
prefix
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
)
s
.
handleUpstreamError
(
ctx
,
prefix
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
,
quotaScope
)
}
resp
=
&
http
.
Response
{
StatusCode
:
resp
.
StatusCode
,
...
...
@@ -1200,7 +1203,7 @@ urlFallbackLoop:
goto
handleSuccess
}
s
.
handleUpstreamError
(
ctx
,
prefix
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
)
s
.
handleUpstreamError
(
ctx
,
prefix
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
,
quotaScope
)
if
s
.
shouldFailoverUpstreamError
(
resp
.
StatusCode
)
{
return
nil
,
&
UpstreamFailoverError
{
StatusCode
:
resp
.
StatusCode
}
...
...
@@ -1314,7 +1317,7 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
}
}
func
(
s
*
AntigravityGatewayService
)
handleUpstreamError
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
)
{
func
(
s
*
AntigravityGatewayService
)
handleUpstreamError
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
quotaScope
AntigravityQuotaScope
)
{
// 429 使用 Gemini 格式解析(从 body 解析重置时间)
if
statusCode
==
429
{
resetAt
:=
ParseGeminiRateLimitResetTime
(
body
)
...
...
@@ -1325,13 +1328,23 @@ func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, pre
defaultDur
=
5
*
time
.
Minute
}
ra
:=
time
.
Now
()
.
Add
(
defaultDur
)
log
.
Printf
(
"%s status=429 rate_limited reset_in=%v (fallback)"
,
prefix
,
defaultDur
)
_
=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
ra
)
log
.
Printf
(
"%s status=429 rate_limited scope=%s reset_in=%v (fallback)"
,
prefix
,
quotaScope
,
defaultDur
)
if
quotaScope
==
""
{
return
}
if
err
:=
s
.
accountRepo
.
SetAntigravityQuotaScopeLimit
(
ctx
,
account
.
ID
,
quotaScope
,
ra
);
err
!=
nil
{
log
.
Printf
(
"%s status=429 rate_limit_set_failed scope=%s error=%v"
,
prefix
,
quotaScope
,
err
)
}
return
}
resetTime
:=
time
.
Unix
(
*
resetAt
,
0
)
log
.
Printf
(
"%s status=429 rate_limited reset_at=%v reset_in=%v"
,
prefix
,
resetTime
.
Format
(
"15:04:05"
),
time
.
Until
(
resetTime
)
.
Truncate
(
time
.
Second
))
_
=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
resetTime
)
log
.
Printf
(
"%s status=429 rate_limited scope=%s reset_at=%v reset_in=%v"
,
prefix
,
quotaScope
,
resetTime
.
Format
(
"15:04:05"
),
time
.
Until
(
resetTime
)
.
Truncate
(
time
.
Second
))
if
quotaScope
==
""
{
return
}
if
err
:=
s
.
accountRepo
.
SetAntigravityQuotaScopeLimit
(
ctx
,
account
.
ID
,
quotaScope
,
resetTime
);
err
!=
nil
{
log
.
Printf
(
"%s status=429 rate_limit_set_failed scope=%s error=%v"
,
prefix
,
quotaScope
,
err
)
}
return
}
// 其他错误码继续使用 rateLimitService
...
...
backend/internal/service/antigravity_quota_scope.go
0 → 100644
View file @
1a641392
package
service
import
(
"strings"
"time"
)
const
antigravityQuotaScopesKey
=
"antigravity_quota_scopes"
// AntigravityQuotaScope 表示 Antigravity 的配额域
type
AntigravityQuotaScope
string
const
(
AntigravityQuotaScopeClaude
AntigravityQuotaScope
=
"claude"
AntigravityQuotaScopeGeminiText
AntigravityQuotaScope
=
"gemini_text"
AntigravityQuotaScopeGeminiImage
AntigravityQuotaScope
=
"gemini_image"
)
// resolveAntigravityQuotaScope 根据模型名称解析配额域
func
resolveAntigravityQuotaScope
(
requestedModel
string
)
(
AntigravityQuotaScope
,
bool
)
{
model
:=
normalizeAntigravityModelName
(
requestedModel
)
if
model
==
""
{
return
""
,
false
}
switch
{
case
strings
.
HasPrefix
(
model
,
"claude-"
)
:
return
AntigravityQuotaScopeClaude
,
true
case
strings
.
HasPrefix
(
model
,
"gemini-"
)
:
if
isImageGenerationModel
(
model
)
{
return
AntigravityQuotaScopeGeminiImage
,
true
}
return
AntigravityQuotaScopeGeminiText
,
true
default
:
return
""
,
false
}
}
func
normalizeAntigravityModelName
(
model
string
)
string
{
normalized
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
model
))
normalized
=
strings
.
TrimPrefix
(
normalized
,
"models/"
)
return
normalized
}
// IsSchedulableForModel 结合 Antigravity 配额域限流判断是否可调度
func
(
a
*
Account
)
IsSchedulableForModel
(
requestedModel
string
)
bool
{
if
a
==
nil
{
return
false
}
if
!
a
.
IsSchedulable
()
{
return
false
}
if
a
.
Platform
!=
PlatformAntigravity
{
return
true
}
scope
,
ok
:=
resolveAntigravityQuotaScope
(
requestedModel
)
if
!
ok
{
return
true
}
resetAt
:=
a
.
antigravityQuotaScopeResetAt
(
scope
)
if
resetAt
==
nil
{
return
true
}
now
:=
time
.
Now
()
return
!
now
.
Before
(
*
resetAt
)
}
func
(
a
*
Account
)
antigravityQuotaScopeResetAt
(
scope
AntigravityQuotaScope
)
*
time
.
Time
{
if
a
==
nil
||
a
.
Extra
==
nil
||
scope
==
""
{
return
nil
}
rawScopes
,
ok
:=
a
.
Extra
[
antigravityQuotaScopesKey
]
.
(
map
[
string
]
any
)
if
!
ok
{
return
nil
}
rawScope
,
ok
:=
rawScopes
[
string
(
scope
)]
.
(
map
[
string
]
any
)
if
!
ok
{
return
nil
}
resetAtRaw
,
ok
:=
rawScope
[
"rate_limit_reset_at"
]
.
(
string
)
if
!
ok
||
strings
.
TrimSpace
(
resetAtRaw
)
==
""
{
return
nil
}
resetAt
,
err
:=
time
.
Parse
(
time
.
RFC3339
,
resetAtRaw
)
if
err
!=
nil
{
return
nil
}
return
&
resetAt
}
backend/internal/service/api_key.go
View file @
1a641392
...
...
@@ -9,6 +9,8 @@ type APIKey struct {
Name
string
GroupID
*
int64
Status
string
IPWhitelist
[]
string
IPBlacklist
[]
string
CreatedAt
time
.
Time
UpdatedAt
time
.
Time
User
*
User
...
...
backend/internal/service/api_key_service.go
View file @
1a641392
...
...
@@ -9,6 +9,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
)
...
...
@@ -20,6 +21,7 @@ var (
ErrAPIKeyTooShort
=
infraerrors
.
BadRequest
(
"API_KEY_TOO_SHORT"
,
"api key must be at least 16 characters"
)
ErrAPIKeyInvalidChars
=
infraerrors
.
BadRequest
(
"API_KEY_INVALID_CHARS"
,
"api key can only contain letters, numbers, underscores, and hyphens"
)
ErrAPIKeyRateLimited
=
infraerrors
.
TooManyRequests
(
"API_KEY_RATE_LIMITED"
,
"too many failed attempts, please try again later"
)
ErrInvalidIPPattern
=
infraerrors
.
BadRequest
(
"INVALID_IP_PATTERN"
,
"invalid IP or CIDR pattern"
)
)
const
(
...
...
@@ -60,6 +62,8 @@ type CreateAPIKeyRequest struct {
Name
string
`json:"name"`
GroupID
*
int64
`json:"group_id"`
CustomKey
*
string
`json:"custom_key"`
// 可选的自定义key
IPWhitelist
[]
string
`json:"ip_whitelist"`
// IP 白名单
IPBlacklist
[]
string
`json:"ip_blacklist"`
// IP 黑名单
}
// UpdateAPIKeyRequest 更新API Key请求
...
...
@@ -67,6 +71,8 @@ type UpdateAPIKeyRequest struct {
Name
*
string
`json:"name"`
GroupID
*
int64
`json:"group_id"`
Status
*
string
`json:"status"`
IPWhitelist
[]
string
`json:"ip_whitelist"`
// IP 白名单(空数组清空)
IPBlacklist
[]
string
`json:"ip_blacklist"`
// IP 黑名单(空数组清空)
}
// APIKeyService API Key服务
...
...
@@ -186,6 +192,20 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
return
nil
,
fmt
.
Errorf
(
"get user: %w"
,
err
)
}
// 验证 IP 白名单格式
if
len
(
req
.
IPWhitelist
)
>
0
{
if
invalid
:=
ip
.
ValidateIPPatterns
(
req
.
IPWhitelist
);
len
(
invalid
)
>
0
{
return
nil
,
fmt
.
Errorf
(
"%w: %v"
,
ErrInvalidIPPattern
,
invalid
)
}
}
// 验证 IP 黑名单格式
if
len
(
req
.
IPBlacklist
)
>
0
{
if
invalid
:=
ip
.
ValidateIPPatterns
(
req
.
IPBlacklist
);
len
(
invalid
)
>
0
{
return
nil
,
fmt
.
Errorf
(
"%w: %v"
,
ErrInvalidIPPattern
,
invalid
)
}
}
// 验证分组权限(如果指定了分组)
if
req
.
GroupID
!=
nil
{
group
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
*
req
.
GroupID
)
...
...
@@ -241,6 +261,8 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
Name
:
req
.
Name
,
GroupID
:
req
.
GroupID
,
Status
:
StatusActive
,
IPWhitelist
:
req
.
IPWhitelist
,
IPBlacklist
:
req
.
IPBlacklist
,
}
if
err
:=
s
.
apiKeyRepo
.
Create
(
ctx
,
apiKey
);
err
!=
nil
{
...
...
@@ -312,6 +334,20 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
return
nil
,
ErrInsufficientPerms
}
// 验证 IP 白名单格式
if
len
(
req
.
IPWhitelist
)
>
0
{
if
invalid
:=
ip
.
ValidateIPPatterns
(
req
.
IPWhitelist
);
len
(
invalid
)
>
0
{
return
nil
,
fmt
.
Errorf
(
"%w: %v"
,
ErrInvalidIPPattern
,
invalid
)
}
}
// 验证 IP 黑名单格式
if
len
(
req
.
IPBlacklist
)
>
0
{
if
invalid
:=
ip
.
ValidateIPPatterns
(
req
.
IPBlacklist
);
len
(
invalid
)
>
0
{
return
nil
,
fmt
.
Errorf
(
"%w: %v"
,
ErrInvalidIPPattern
,
invalid
)
}
}
// 更新字段
if
req
.
Name
!=
nil
{
apiKey
.
Name
=
*
req
.
Name
...
...
@@ -344,6 +380,10 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
}
}
// 更新 IP 限制(空数组会清空设置)
apiKey
.
IPWhitelist
=
req
.
IPWhitelist
apiKey
.
IPBlacklist
=
req
.
IPBlacklist
if
err
:=
s
.
apiKeyRepo
.
Update
(
ctx
,
apiKey
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"update api key: %w"
,
err
)
}
...
...
backend/internal/service/auth_service.go
View file @
1a641392
...
...
@@ -2,9 +2,13 @@ package service
import
(
"context"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"log"
"net/mail"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
...
...
@@ -18,6 +22,7 @@ var (
ErrInvalidCredentials
=
infraerrors
.
Unauthorized
(
"INVALID_CREDENTIALS"
,
"invalid email or password"
)
ErrUserNotActive
=
infraerrors
.
Forbidden
(
"USER_NOT_ACTIVE"
,
"user is not active"
)
ErrEmailExists
=
infraerrors
.
Conflict
(
"EMAIL_EXISTS"
,
"email already exists"
)
ErrEmailReserved
=
infraerrors
.
BadRequest
(
"EMAIL_RESERVED"
,
"email is reserved"
)
ErrInvalidToken
=
infraerrors
.
Unauthorized
(
"INVALID_TOKEN"
,
"invalid token"
)
ErrTokenExpired
=
infraerrors
.
Unauthorized
(
"TOKEN_EXPIRED"
,
"token has expired"
)
ErrTokenTooLarge
=
infraerrors
.
BadRequest
(
"TOKEN_TOO_LARGE"
,
"token too large"
)
...
...
@@ -47,6 +52,7 @@ type AuthService struct {
emailService
*
EmailService
turnstileService
*
TurnstileService
emailQueueService
*
EmailQueueService
promoService
*
PromoService
}
// NewAuthService 创建认证服务实例
...
...
@@ -57,6 +63,7 @@ func NewAuthService(
emailService
*
EmailService
,
turnstileService
*
TurnstileService
,
emailQueueService
*
EmailQueueService
,
promoService
*
PromoService
,
)
*
AuthService
{
return
&
AuthService
{
userRepo
:
userRepo
,
...
...
@@ -65,21 +72,27 @@ func NewAuthService(
emailService
:
emailService
,
turnstileService
:
turnstileService
,
emailQueueService
:
emailQueueService
,
promoService
:
promoService
,
}
}
// Register 用户注册,返回token和用户
func
(
s
*
AuthService
)
Register
(
ctx
context
.
Context
,
email
,
password
string
)
(
string
,
*
User
,
error
)
{
return
s
.
RegisterWithVerification
(
ctx
,
email
,
password
,
""
)
return
s
.
RegisterWithVerification
(
ctx
,
email
,
password
,
""
,
""
)
}
// RegisterWithVerification 用户注册(支持邮件验证),返回token和用户
func
(
s
*
AuthService
)
RegisterWithVerification
(
ctx
context
.
Context
,
email
,
password
,
verifyCode
string
)
(
string
,
*
User
,
error
)
{
// 检查是否开放注册
if
s
.
settingService
!
=
nil
&&
!
s
.
settingService
.
IsRegistrationEnabled
(
ctx
)
{
// RegisterWithVerification 用户注册(支持邮件验证
和优惠码
),返回token和用户
func
(
s
*
AuthService
)
RegisterWithVerification
(
ctx
context
.
Context
,
email
,
password
,
verifyCode
,
promoCode
string
)
(
string
,
*
User
,
error
)
{
// 检查是否开放注册
(默认关闭:settingService 未配置时不允许注册)
if
s
.
settingService
=
=
nil
||
!
s
.
settingService
.
IsRegistrationEnabled
(
ctx
)
{
return
""
,
nil
,
ErrRegDisabled
}
// 防止用户注册 LinuxDo OAuth 合成邮箱,避免第三方登录与本地账号发生碰撞。
if
isReservedEmail
(
email
)
{
return
""
,
nil
,
ErrEmailReserved
}
// 检查是否需要邮件验证
if
s
.
settingService
!=
nil
&&
s
.
settingService
.
IsEmailVerifyEnabled
(
ctx
)
{
// 如果邮件验证已开启但邮件服务未配置,拒绝注册
...
...
@@ -132,10 +145,27 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
}
if
err
:=
s
.
userRepo
.
Create
(
ctx
,
user
);
err
!=
nil
{
// 优先检查邮箱冲突错误(竞态条件下可能发生)
if
errors
.
Is
(
err
,
ErrEmailExists
)
{
return
""
,
nil
,
ErrEmailExists
}
log
.
Printf
(
"[Auth] Database error creating user: %v"
,
err
)
return
""
,
nil
,
ErrServiceUnavailable
}
// 应用优惠码(如果提供)
if
promoCode
!=
""
&&
s
.
promoService
!=
nil
{
if
err
:=
s
.
promoService
.
ApplyPromoCode
(
ctx
,
user
.
ID
,
promoCode
);
err
!=
nil
{
// 优惠码应用失败不影响注册,只记录日志
log
.
Printf
(
"[Auth] Failed to apply promo code for user %d: %v"
,
user
.
ID
,
err
)
}
else
{
// 重新获取用户信息以获取更新后的余额
if
updatedUser
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
user
.
ID
);
err
==
nil
{
user
=
updatedUser
}
}
}
// 生成token
token
,
err
:=
s
.
GenerateToken
(
user
)
if
err
!=
nil
{
...
...
@@ -152,11 +182,15 @@ type SendVerifyCodeResult struct {
// SendVerifyCode 发送邮箱验证码(同步方式)
func
(
s
*
AuthService
)
SendVerifyCode
(
ctx
context
.
Context
,
email
string
)
error
{
// 检查是否开放注册
if
s
.
settingService
!
=
nil
&&
!
s
.
settingService
.
IsRegistrationEnabled
(
ctx
)
{
// 检查是否开放注册
(默认关闭)
if
s
.
settingService
=
=
nil
||
!
s
.
settingService
.
IsRegistrationEnabled
(
ctx
)
{
return
ErrRegDisabled
}
if
isReservedEmail
(
email
)
{
return
ErrEmailReserved
}
// 检查邮箱是否已存在
existsEmail
,
err
:=
s
.
userRepo
.
ExistsByEmail
(
ctx
,
email
)
if
err
!=
nil
{
...
...
@@ -185,12 +219,16 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
func
(
s
*
AuthService
)
SendVerifyCodeAsync
(
ctx
context
.
Context
,
email
string
)
(
*
SendVerifyCodeResult
,
error
)
{
log
.
Printf
(
"[Auth] SendVerifyCodeAsync called for email: %s"
,
email
)
// 检查是否开放注册
if
s
.
settingService
!
=
nil
&&
!
s
.
settingService
.
IsRegistrationEnabled
(
ctx
)
{
// 检查是否开放注册
(默认关闭)
if
s
.
settingService
=
=
nil
||
!
s
.
settingService
.
IsRegistrationEnabled
(
ctx
)
{
log
.
Println
(
"[Auth] Registration is disabled"
)
return
nil
,
ErrRegDisabled
}
if
isReservedEmail
(
email
)
{
return
nil
,
ErrEmailReserved
}
// 检查邮箱是否已存在
existsEmail
,
err
:=
s
.
userRepo
.
ExistsByEmail
(
ctx
,
email
)
if
err
!=
nil
{
...
...
@@ -270,7 +308,7 @@ func (s *AuthService) IsTurnstileEnabled(ctx context.Context) bool {
// IsRegistrationEnabled 检查是否开放注册
func
(
s
*
AuthService
)
IsRegistrationEnabled
(
ctx
context
.
Context
)
bool
{
if
s
.
settingService
==
nil
{
return
true
return
false
// 安全默认:settingService 未配置时关闭注册
}
return
s
.
settingService
.
IsRegistrationEnabled
(
ctx
)
}
...
...
@@ -315,6 +353,102 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
return
token
,
user
,
nil
}
// LoginOrRegisterOAuth 用于第三方 OAuth/SSO 登录:
// - 如果邮箱已存在:直接登录(不需要本地密码)
// - 如果邮箱不存在:创建新用户并登录
//
// 注意:该函数用于“终端用户登录 Sub2API 本身”的场景(不同于上游账号的 OAuth,例如 OpenAI/Gemini)。
// 为了满足现有数据库约束(需要密码哈希),新用户会生成随机密码并进行哈希保存。
func
(
s
*
AuthService
)
LoginOrRegisterOAuth
(
ctx
context
.
Context
,
email
,
username
string
)
(
string
,
*
User
,
error
)
{
email
=
strings
.
TrimSpace
(
email
)
if
email
==
""
||
len
(
email
)
>
255
{
return
""
,
nil
,
infraerrors
.
BadRequest
(
"INVALID_EMAIL"
,
"invalid email"
)
}
if
_
,
err
:=
mail
.
ParseAddress
(
email
);
err
!=
nil
{
return
""
,
nil
,
infraerrors
.
BadRequest
(
"INVALID_EMAIL"
,
"invalid email"
)
}
username
=
strings
.
TrimSpace
(
username
)
if
len
([]
rune
(
username
))
>
100
{
username
=
string
([]
rune
(
username
)[
:
100
])
}
user
,
err
:=
s
.
userRepo
.
GetByEmail
(
ctx
,
email
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
ErrUserNotFound
)
{
// OAuth 首次登录视为注册。
if
s
.
settingService
!=
nil
&&
!
s
.
settingService
.
IsRegistrationEnabled
(
ctx
)
{
return
""
,
nil
,
ErrRegDisabled
}
randomPassword
,
err
:=
randomHexString
(
32
)
if
err
!=
nil
{
log
.
Printf
(
"[Auth] Failed to generate random password for oauth signup: %v"
,
err
)
return
""
,
nil
,
ErrServiceUnavailable
}
hashedPassword
,
err
:=
s
.
HashPassword
(
randomPassword
)
if
err
!=
nil
{
return
""
,
nil
,
fmt
.
Errorf
(
"hash password: %w"
,
err
)
}
// 新用户默认值。
defaultBalance
:=
s
.
cfg
.
Default
.
UserBalance
defaultConcurrency
:=
s
.
cfg
.
Default
.
UserConcurrency
if
s
.
settingService
!=
nil
{
defaultBalance
=
s
.
settingService
.
GetDefaultBalance
(
ctx
)
defaultConcurrency
=
s
.
settingService
.
GetDefaultConcurrency
(
ctx
)
}
newUser
:=
&
User
{
Email
:
email
,
Username
:
username
,
PasswordHash
:
hashedPassword
,
Role
:
RoleUser
,
Balance
:
defaultBalance
,
Concurrency
:
defaultConcurrency
,
Status
:
StatusActive
,
}
if
err
:=
s
.
userRepo
.
Create
(
ctx
,
newUser
);
err
!=
nil
{
if
errors
.
Is
(
err
,
ErrEmailExists
)
{
// 并发场景:GetByEmail 与 Create 之间用户被创建。
user
,
err
=
s
.
userRepo
.
GetByEmail
(
ctx
,
email
)
if
err
!=
nil
{
log
.
Printf
(
"[Auth] Database error getting user after conflict: %v"
,
err
)
return
""
,
nil
,
ErrServiceUnavailable
}
}
else
{
log
.
Printf
(
"[Auth] Database error creating oauth user: %v"
,
err
)
return
""
,
nil
,
ErrServiceUnavailable
}
}
else
{
user
=
newUser
}
}
else
{
log
.
Printf
(
"[Auth] Database error during oauth login: %v"
,
err
)
return
""
,
nil
,
ErrServiceUnavailable
}
}
if
!
user
.
IsActive
()
{
return
""
,
nil
,
ErrUserNotActive
}
// 尽力补全:当用户名为空时,使用第三方返回的用户名回填。
if
user
.
Username
==
""
&&
username
!=
""
{
user
.
Username
=
username
if
err
:=
s
.
userRepo
.
Update
(
ctx
,
user
);
err
!=
nil
{
log
.
Printf
(
"[Auth] Failed to update username after oauth login: %v"
,
err
)
}
}
token
,
err
:=
s
.
GenerateToken
(
user
)
if
err
!=
nil
{
return
""
,
nil
,
fmt
.
Errorf
(
"generate token: %w"
,
err
)
}
return
token
,
user
,
nil
}
// ValidateToken 验证JWT token并返回用户声明
func
(
s
*
AuthService
)
ValidateToken
(
tokenString
string
)
(
*
JWTClaims
,
error
)
{
// 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。
...
...
@@ -357,6 +491,22 @@ func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
return
nil
,
ErrInvalidToken
}
func
randomHexString
(
byteLength
int
)
(
string
,
error
)
{
if
byteLength
<=
0
{
byteLength
=
16
}
buf
:=
make
([]
byte
,
byteLength
)
if
_
,
err
:=
rand
.
Read
(
buf
);
err
!=
nil
{
return
""
,
err
}
return
hex
.
EncodeToString
(
buf
),
nil
}
func
isReservedEmail
(
email
string
)
bool
{
normalized
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
email
))
return
strings
.
HasSuffix
(
normalized
,
LinuxDoConnectSyntheticEmailDomain
)
}
// GenerateToken 生成JWT token
func
(
s
*
AuthService
)
GenerateToken
(
user
*
User
)
(
string
,
error
)
{
now
:=
time
.
Now
()
...
...
backend/internal/service/auth_service_register_test.go
View file @
1a641392
...
...
@@ -100,6 +100,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
emailService
,
nil
,
nil
,
nil
,
// promoService
)
}
...
...
@@ -113,6 +114,15 @@ func TestAuthService_Register_Disabled(t *testing.T) {
require
.
ErrorIs
(
t
,
err
,
ErrRegDisabled
)
}
func
TestAuthService_Register_DisabledByDefault
(
t
*
testing
.
T
)
{
// 当 settings 为 nil(设置项不存在)时,注册应该默认关闭
repo
:=
&
userRepoStub
{}
service
:=
newAuthService
(
repo
,
nil
,
nil
)
_
,
_
,
err
:=
service
.
Register
(
context
.
Background
(),
"user@test.com"
,
"password"
)
require
.
ErrorIs
(
t
,
err
,
ErrRegDisabled
)
}
func
TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured
(
t
*
testing
.
T
)
{
repo
:=
&
userRepoStub
{}
// 邮件验证开启但 emailCache 为 nil(emailService 未配置)
...
...
@@ -122,7 +132,7 @@ func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testi
},
nil
)
// 应返回服务不可用错误,而不是允许绕过验证
_
,
_
,
err
:=
service
.
RegisterWithVerification
(
context
.
Background
(),
"user@test.com"
,
"password"
,
"any-code"
)
_
,
_
,
err
:=
service
.
RegisterWithVerification
(
context
.
Background
(),
"user@test.com"
,
"password"
,
"any-code"
,
""
)
require
.
ErrorIs
(
t
,
err
,
ErrServiceUnavailable
)
}
...
...
@@ -134,7 +144,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
SettingKeyEmailVerifyEnabled
:
"true"
,
},
cache
)
_
,
_
,
err
:=
service
.
RegisterWithVerification
(
context
.
Background
(),
"user@test.com"
,
"password"
,
""
)
_
,
_
,
err
:=
service
.
RegisterWithVerification
(
context
.
Background
(),
"user@test.com"
,
"password"
,
""
,
""
)
require
.
ErrorIs
(
t
,
err
,
ErrEmailVerifyRequired
)
}
...
...
@@ -148,14 +158,16 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
SettingKeyEmailVerifyEnabled
:
"true"
,
},
cache
)
_
,
_
,
err
:=
service
.
RegisterWithVerification
(
context
.
Background
(),
"user@test.com"
,
"password"
,
"wrong"
)
_
,
_
,
err
:=
service
.
RegisterWithVerification
(
context
.
Background
(),
"user@test.com"
,
"password"
,
"wrong"
,
""
)
require
.
ErrorIs
(
t
,
err
,
ErrInvalidVerifyCode
)
require
.
ErrorContains
(
t
,
err
,
"verify code"
)
}
func
TestAuthService_Register_EmailExists
(
t
*
testing
.
T
)
{
repo
:=
&
userRepoStub
{
exists
:
true
}
service
:=
newAuthService
(
repo
,
nil
,
nil
)
service
:=
newAuthService
(
repo
,
map
[
string
]
string
{
SettingKeyRegistrationEnabled
:
"true"
,
},
nil
)
_
,
_
,
err
:=
service
.
Register
(
context
.
Background
(),
"user@test.com"
,
"password"
)
require
.
ErrorIs
(
t
,
err
,
ErrEmailExists
)
...
...
@@ -163,23 +175,50 @@ func TestAuthService_Register_EmailExists(t *testing.T) {
func
TestAuthService_Register_CheckEmailError
(
t
*
testing
.
T
)
{
repo
:=
&
userRepoStub
{
existsErr
:
errors
.
New
(
"db down"
)}
service
:=
newAuthService
(
repo
,
nil
,
nil
)
service
:=
newAuthService
(
repo
,
map
[
string
]
string
{
SettingKeyRegistrationEnabled
:
"true"
,
},
nil
)
_
,
_
,
err
:=
service
.
Register
(
context
.
Background
(),
"user@test.com"
,
"password"
)
require
.
ErrorIs
(
t
,
err
,
ErrServiceUnavailable
)
}
func
TestAuthService_Register_ReservedEmail
(
t
*
testing
.
T
)
{
repo
:=
&
userRepoStub
{}
service
:=
newAuthService
(
repo
,
map
[
string
]
string
{
SettingKeyRegistrationEnabled
:
"true"
,
},
nil
)
_
,
_
,
err
:=
service
.
Register
(
context
.
Background
(),
"linuxdo-123@linuxdo-connect.invalid"
,
"password"
)
require
.
ErrorIs
(
t
,
err
,
ErrEmailReserved
)
}
func
TestAuthService_Register_CreateError
(
t
*
testing
.
T
)
{
repo
:=
&
userRepoStub
{
createErr
:
errors
.
New
(
"create failed"
)}
service
:=
newAuthService
(
repo
,
nil
,
nil
)
service
:=
newAuthService
(
repo
,
map
[
string
]
string
{
SettingKeyRegistrationEnabled
:
"true"
,
},
nil
)
_
,
_
,
err
:=
service
.
Register
(
context
.
Background
(),
"user@test.com"
,
"password"
)
require
.
ErrorIs
(
t
,
err
,
ErrServiceUnavailable
)
}
func
TestAuthService_Register_CreateEmailExistsRace
(
t
*
testing
.
T
)
{
// 模拟竞态条件:ExistsByEmail 返回 false,但 Create 时因唯一约束失败
repo
:=
&
userRepoStub
{
createErr
:
ErrEmailExists
}
service
:=
newAuthService
(
repo
,
map
[
string
]
string
{
SettingKeyRegistrationEnabled
:
"true"
,
},
nil
)
_
,
_
,
err
:=
service
.
Register
(
context
.
Background
(),
"user@test.com"
,
"password"
)
require
.
ErrorIs
(
t
,
err
,
ErrEmailExists
)
}
func
TestAuthService_Register_Success
(
t
*
testing
.
T
)
{
repo
:=
&
userRepoStub
{
nextID
:
5
}
service
:=
newAuthService
(
repo
,
nil
,
nil
)
service
:=
newAuthService
(
repo
,
map
[
string
]
string
{
SettingKeyRegistrationEnabled
:
"true"
,
},
nil
)
token
,
user
,
err
:=
service
.
Register
(
context
.
Background
(),
"user@test.com"
,
"password"
)
require
.
NoError
(
t
,
err
)
...
...
backend/internal/service/domain_constants.go
View file @
1a641392
...
...
@@ -38,6 +38,12 @@ const (
RedeemTypeSubscription
=
"subscription"
)
// PromoCode status constants
const
(
PromoCodeStatusActive
=
"active"
PromoCodeStatusDisabled
=
"disabled"
)
// Admin adjustment type constants
const
(
AdjustmentTypeAdminBalance
=
"admin_balance"
// 管理员调整余额
...
...
@@ -105,7 +111,17 @@ const (
// Request identity patch (Claude -> Gemini systemInstruction injection)
SettingKeyEnableIdentityPatch
=
"enable_identity_patch"
SettingKeyIdentityPatchPrompt
=
"identity_patch_prompt"
// LinuxDo Connect OAuth 登录(终端用户 SSO)
SettingKeyLinuxDoConnectEnabled
=
"linuxdo_connect_enabled"
SettingKeyLinuxDoConnectClientID
=
"linuxdo_connect_client_id"
SettingKeyLinuxDoConnectClientSecret
=
"linuxdo_connect_client_secret"
SettingKeyLinuxDoConnectRedirectURL
=
"linuxdo_connect_redirect_url"
)
// LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀(RFC 保留域名)。
// 目的:避免第三方登录返回的用户标识与本地真实邮箱发生碰撞,进而造成账号被接管的风险。
const
LinuxDoConnectSyntheticEmailDomain
=
"@linuxdo-connect.invalid"
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
const
AdminAPIKeyPrefix
=
"admin-"
backend/internal/service/email_service.go
View file @
1a641392
...
...
@@ -5,6 +5,7 @@ import (
"crypto/rand"
"crypto/tls"
"fmt"
"log"
"math/big"
"net/smtp"
"strconv"
...
...
@@ -256,7 +257,9 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
// 验证码不匹配
if
data
.
Code
!=
code
{
data
.
Attempts
++
_
=
s
.
cache
.
SetVerificationCode
(
ctx
,
email
,
data
,
verifyCodeTTL
)
if
err
:=
s
.
cache
.
SetVerificationCode
(
ctx
,
email
,
data
,
verifyCodeTTL
);
err
!=
nil
{
log
.
Printf
(
"[Email] Failed to update verification attempt count: %v"
,
err
)
}
if
data
.
Attempts
>=
maxVerifyCodeAttempts
{
return
ErrVerifyCodeMaxAttempts
}
...
...
@@ -264,7 +267,9 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
}
// 验证成功,删除验证码
_
=
s
.
cache
.
DeleteVerificationCode
(
ctx
,
email
)
if
err
:=
s
.
cache
.
DeleteVerificationCode
(
ctx
,
email
);
err
!=
nil
{
log
.
Printf
(
"[Email] Failed to delete verification code after success: %v"
,
err
)
}
return
nil
}
...
...
backend/internal/service/gateway_multiplatform_test.go
View file @
1a641392
...
...
@@ -9,6 +9,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
...
...
@@ -23,9 +24,11 @@ type mockAccountRepoForPlatform struct {
accounts
[]
Account
accountsByID
map
[
int64
]
*
Account
listPlatformFunc
func
(
ctx
context
.
Context
,
platform
string
)
([]
Account
,
error
)
getByIDCalls
int
}
func
(
m
*
mockAccountRepoForPlatform
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Account
,
error
)
{
m
.
getByIDCalls
++
if
acc
,
ok
:=
m
.
accountsByID
[
id
];
ok
{
return
acc
,
nil
}
...
...
@@ -136,6 +139,9 @@ func (m *mockAccountRepoForPlatform) ListSchedulableByGroupIDAndPlatforms(ctx co
func
(
m
*
mockAccountRepoForPlatform
)
SetRateLimited
(
ctx
context
.
Context
,
id
int64
,
resetAt
time
.
Time
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
SetAntigravityQuotaScopeLimit
(
ctx
context
.
Context
,
id
int64
,
scope
AntigravityQuotaScope
,
resetAt
time
.
Time
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
SetOverloaded
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
)
error
{
return
nil
}
...
...
@@ -148,6 +154,9 @@ func (m *mockAccountRepoForPlatform) ClearTempUnschedulable(ctx context.Context,
func
(
m
*
mockAccountRepoForPlatform
)
ClearRateLimit
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
ClearAntigravityQuotaScopes
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
UpdateSessionWindow
(
ctx
context
.
Context
,
id
int64
,
start
,
end
*
time
.
Time
,
status
string
)
error
{
return
nil
}
...
...
@@ -185,6 +194,56 @@ func (m *mockGatewayCacheForPlatform) RefreshSessionTTL(ctx context.Context, gro
return
nil
}
type
mockGroupRepoForGateway
struct
{
groups
map
[
int64
]
*
Group
getByIDCalls
int
getByIDLiteCalls
int
}
func
(
m
*
mockGroupRepoForGateway
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
m
.
getByIDCalls
++
if
g
,
ok
:=
m
.
groups
[
id
];
ok
{
return
g
,
nil
}
return
nil
,
ErrGroupNotFound
}
func
(
m
*
mockGroupRepoForGateway
)
GetByIDLite
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
m
.
getByIDLiteCalls
++
if
g
,
ok
:=
m
.
groups
[
id
];
ok
{
return
g
,
nil
}
return
nil
,
ErrGroupNotFound
}
func
(
m
*
mockGroupRepoForGateway
)
Create
(
ctx
context
.
Context
,
group
*
Group
)
error
{
return
nil
}
func
(
m
*
mockGroupRepoForGateway
)
Update
(
ctx
context
.
Context
,
group
*
Group
)
error
{
return
nil
}
func
(
m
*
mockGroupRepoForGateway
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
m
*
mockGroupRepoForGateway
)
DeleteCascade
(
ctx
context
.
Context
,
id
int64
)
([]
int64
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockGroupRepoForGateway
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
m
*
mockGroupRepoForGateway
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
status
,
search
string
,
isExclusive
*
bool
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
m
*
mockGroupRepoForGateway
)
ListActive
(
ctx
context
.
Context
)
([]
Group
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockGroupRepoForGateway
)
ListActiveByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
Group
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockGroupRepoForGateway
)
ExistsByName
(
ctx
context
.
Context
,
name
string
)
(
bool
,
error
)
{
return
false
,
nil
}
func
(
m
*
mockGroupRepoForGateway
)
GetAccountCount
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
m
*
mockGroupRepoForGateway
)
DeleteAccountGroupsByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
return
0
,
nil
}
func
ptr
[
T
any
](
v
T
)
*
T
{
return
&
v
}
...
...
@@ -891,6 +950,74 @@ func (m *mockConcurrencyService) GetAccountWaitingCount(ctx context.Context, acc
return
m
.
accountWaitCounts
[
accountID
],
nil
}
type
mockConcurrencyCache
struct
{
acquireAccountCalls
int
loadBatchCalls
int
}
func
(
m
*
mockConcurrencyCache
)
AcquireAccountSlot
(
ctx
context
.
Context
,
accountID
int64
,
maxConcurrency
int
,
requestID
string
)
(
bool
,
error
)
{
m
.
acquireAccountCalls
++
return
true
,
nil
}
func
(
m
*
mockConcurrencyCache
)
ReleaseAccountSlot
(
ctx
context
.
Context
,
accountID
int64
,
requestID
string
)
error
{
return
nil
}
func
(
m
*
mockConcurrencyCache
)
GetAccountConcurrency
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
{
return
0
,
nil
}
func
(
m
*
mockConcurrencyCache
)
IncrementAccountWaitCount
(
ctx
context
.
Context
,
accountID
int64
,
maxWait
int
)
(
bool
,
error
)
{
return
true
,
nil
}
func
(
m
*
mockConcurrencyCache
)
DecrementAccountWaitCount
(
ctx
context
.
Context
,
accountID
int64
)
error
{
return
nil
}
func
(
m
*
mockConcurrencyCache
)
GetAccountWaitingCount
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
{
return
0
,
nil
}
func
(
m
*
mockConcurrencyCache
)
AcquireUserSlot
(
ctx
context
.
Context
,
userID
int64
,
maxConcurrency
int
,
requestID
string
)
(
bool
,
error
)
{
return
true
,
nil
}
func
(
m
*
mockConcurrencyCache
)
ReleaseUserSlot
(
ctx
context
.
Context
,
userID
int64
,
requestID
string
)
error
{
return
nil
}
func
(
m
*
mockConcurrencyCache
)
GetUserConcurrency
(
ctx
context
.
Context
,
userID
int64
)
(
int
,
error
)
{
return
0
,
nil
}
func
(
m
*
mockConcurrencyCache
)
IncrementWaitCount
(
ctx
context
.
Context
,
userID
int64
,
maxWait
int
)
(
bool
,
error
)
{
return
true
,
nil
}
func
(
m
*
mockConcurrencyCache
)
DecrementWaitCount
(
ctx
context
.
Context
,
userID
int64
)
error
{
return
nil
}
func
(
m
*
mockConcurrencyCache
)
GetAccountsLoadBatch
(
ctx
context
.
Context
,
accounts
[]
AccountWithConcurrency
)
(
map
[
int64
]
*
AccountLoadInfo
,
error
)
{
m
.
loadBatchCalls
++
result
:=
make
(
map
[
int64
]
*
AccountLoadInfo
,
len
(
accounts
))
for
_
,
acc
:=
range
accounts
{
result
[
acc
.
ID
]
=
&
AccountLoadInfo
{
AccountID
:
acc
.
ID
,
CurrentConcurrency
:
0
,
WaitingCount
:
0
,
LoadRate
:
0
,
}
}
return
result
,
nil
}
func
(
m
*
mockConcurrencyCache
)
CleanupExpiredAccountSlots
(
ctx
context
.
Context
,
accountID
int64
)
error
{
return
nil
}
// TestGatewayService_SelectAccountWithLoadAwareness tests load-aware account selection
func
TestGatewayService_SelectAccountWithLoadAwareness
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
...
...
@@ -989,6 +1116,78 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
require
.
Equal
(
t
,
int64
(
2
),
result
.
Account
.
ID
,
"不应选择被排除的账号"
)
})
t
.
Run
(
"粘性命中-不调用GetByID"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{
sessionBindings
:
map
[
string
]
int64
{
"sticky"
:
1
},
}
cfg
:=
testConfig
()
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
true
concurrencyCache
:=
&
mockConcurrencyCache
{}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
NewConcurrencyService
(
concurrencyCache
),
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
"sticky"
,
"claude-3-5-sonnet-20241022"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
int64
(
1
),
result
.
Account
.
ID
)
require
.
Equal
(
t
,
0
,
repo
.
getByIDCalls
,
"粘性命中不应调用GetByID"
)
require
.
Equal
(
t
,
0
,
concurrencyCache
.
loadBatchCalls
,
"粘性命中应在负载批量查询前返回"
)
})
t
.
Run
(
"粘性账号不在候选集-回退负载感知选择"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{
sessionBindings
:
map
[
string
]
int64
{
"sticky"
:
1
},
}
cfg
:=
testConfig
()
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
true
concurrencyCache
:=
&
mockConcurrencyCache
{}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
NewConcurrencyService
(
concurrencyCache
),
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
"sticky"
,
"claude-3-5-sonnet-20241022"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
int64
(
2
),
result
.
Account
.
ID
,
"粘性账号不在候选集时应回退到可用账号"
)
require
.
Equal
(
t
,
0
,
repo
.
getByIDCalls
,
"粘性账号缺失不应回退到GetByID"
)
require
.
Equal
(
t
,
1
,
concurrencyCache
.
loadBatchCalls
,
"应继续进行负载批量查询"
)
})
t
.
Run
(
"无可用账号-返回错误"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{},
...
...
@@ -1013,3 +1212,190 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
require
.
Contains
(
t
,
err
.
Error
(),
"no available accounts"
)
})
}
func
TestGatewayService_GroupResolution_ReusesContextGroup
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
42
)
group
:=
&
Group
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
Hydrated
:
true
,
}
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
Group
,
group
)
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
group
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cfg
:
testConfig
(),
}
account
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
&
groupID
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
account
)
require
.
Equal
(
t
,
0
,
groupRepo
.
getByIDCalls
)
require
.
Equal
(
t
,
0
,
groupRepo
.
getByIDLiteCalls
)
}
func
TestGatewayService_GroupResolution_IgnoresInvalidContextGroup
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
42
)
ctxGroup
:=
&
Group
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
}
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
Group
,
ctxGroup
)
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
group
:=
&
Group
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
}
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
group
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cfg
:
testConfig
(),
}
account
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
&
groupID
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
account
)
require
.
Equal
(
t
,
0
,
groupRepo
.
getByIDCalls
)
require
.
Equal
(
t
,
1
,
groupRepo
.
getByIDLiteCalls
)
}
func
TestGatewayService_GroupContext_OverwritesInvalidContextGroup
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
42
)
invalidGroup
:=
&
Group
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
}
hydratedGroup
:=
&
Group
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
Hydrated
:
true
,
}
ctx
:=
context
.
WithValue
(
context
.
Background
(),
ctxkey
.
Group
,
invalidGroup
)
svc
:=
&
GatewayService
{}
ctx
=
svc
.
withGroupContext
(
ctx
,
hydratedGroup
)
got
,
ok
:=
ctx
.
Value
(
ctxkey
.
Group
)
.
(
*
Group
)
require
.
True
(
t
,
ok
)
require
.
Same
(
t
,
hydratedGroup
,
got
)
}
func
TestGatewayService_GroupResolution_FallbackUsesLiteOnce
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
10
)
fallbackID
:=
int64
(
11
)
group
:=
&
Group
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
ClaudeCodeOnly
:
true
,
FallbackGroupID
:
&
fallbackID
,
Hydrated
:
true
,
}
fallbackGroup
:=
&
Group
{
ID
:
fallbackID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
}
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
Group
,
group
)
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
fallbackID
:
fallbackGroup
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cfg
:
testConfig
(),
}
account
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
&
groupID
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
account
)
require
.
Equal
(
t
,
0
,
groupRepo
.
getByIDCalls
)
require
.
Equal
(
t
,
1
,
groupRepo
.
getByIDLiteCalls
)
}
func
TestGatewayService_ResolveGatewayGroup_DetectsFallbackCycle
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
10
)
fallbackID
:=
int64
(
11
)
group
:=
&
Group
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
ClaudeCodeOnly
:
true
,
FallbackGroupID
:
&
fallbackID
,
}
fallbackGroup
:=
&
Group
{
ID
:
fallbackID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
ClaudeCodeOnly
:
true
,
FallbackGroupID
:
&
groupID
,
}
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
group
,
fallbackID
:
fallbackGroup
,
},
}
svc
:=
&
GatewayService
{
groupRepo
:
groupRepo
,
}
gotGroup
,
gotID
,
err
:=
svc
.
resolveGatewayGroup
(
ctx
,
&
groupID
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
gotGroup
)
require
.
Nil
(
t
,
gotID
)
require
.
Contains
(
t
,
err
.
Error
(),
"fallback group cycle"
)
}
backend/internal/service/gateway_service.go
View file @
1a641392
...
...
@@ -33,7 +33,7 @@ const (
claudeAPIURL
=
"https://api.anthropic.com/v1/messages?beta=true"
claudeAPICountTokensURL
=
"https://api.anthropic.com/v1/messages/count_tokens?beta=true"
stickySessionTTL
=
time
.
Hour
// 粘性会话TTL
defaultMaxLineSize
=
1
0
*
1024
*
1024
defaultMaxLineSize
=
4
0
*
1024
*
1024
claudeCodeSystemPrompt
=
"You are Claude Code, Anthropic's official CLI for Claude."
maxCacheControlBlocks
=
4
// Anthropic API 允许的最大 cache_control 块数量
)
...
...
@@ -361,27 +361,13 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
if
hasForcePlatform
&&
forcePlatform
!=
""
{
platform
=
forcePlatform
}
else
if
groupID
!=
nil
{
// 根据分组 platform 决定查询哪种账号
group
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
*
groupID
)
group
,
resolvedGroupID
,
err
:=
s
.
resolveGatewayGroup
(
ctx
,
groupID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get group failed: %w"
,
err
)
return
nil
,
err
}
groupID
=
resolvedGroupID
ctx
=
s
.
withGroupContext
(
ctx
,
group
)
platform
=
group
.
Platform
// 检查 Claude Code 客户端限制
if
group
.
ClaudeCodeOnly
{
isClaudeCode
:=
IsClaudeCodeClient
(
ctx
)
if
!
isClaudeCode
{
// 非 Claude Code 客户端,检查是否有降级分组
if
group
.
FallbackGroupID
!=
nil
{
// 使用降级分组重新调度
fallbackGroupID
:=
*
group
.
FallbackGroupID
return
s
.
SelectAccountForModelWithExclusions
(
ctx
,
&
fallbackGroupID
,
sessionHash
,
requestedModel
,
excludedIDs
)
}
// 无降级分组,拒绝访问
return
nil
,
ErrClaudeCodeOnly
}
}
}
else
{
// 无分组时只使用原生 anthropic 平台
platform
=
PlatformAnthropic
...
...
@@ -409,10 +395,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
// 检查 Claude Code 客户端限制(可能会替换 groupID 为降级分组)
groupID
,
err
:=
s
.
checkClaudeCodeRestriction
(
ctx
,
groupID
)
group
,
groupID
,
err
:=
s
.
checkClaudeCodeRestriction
(
ctx
,
groupID
)
if
err
!=
nil
{
return
nil
,
err
}
ctx
=
s
.
withGroupContext
(
ctx
,
group
)
if
s
.
concurrencyService
==
nil
||
!
cfg
.
LoadBatchEnabled
{
account
,
err
:=
s
.
SelectAccountForModelWithExclusions
(
ctx
,
groupID
,
sessionHash
,
requestedModel
,
excludedIDs
)
...
...
@@ -452,7 +439,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
},
nil
}
platform
,
hasForcePlatform
,
err
:=
s
.
resolvePlatform
(
ctx
,
groupID
)
platform
,
hasForcePlatform
,
err
:=
s
.
resolvePlatform
(
ctx
,
groupID
,
group
)
if
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -478,10 +465,15 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
if
err
==
nil
&&
accountID
>
0
&&
!
isExcluded
(
accountID
)
{
account
,
err
:=
s
.
accountRepo
.
GetByID
(
ctx
,
accountID
)
if
err
==
nil
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
// 粘性命中仅在当前可调度候选集中生效。
accountByID
:=
make
(
map
[
int64
]
*
Account
,
len
(
accounts
))
for
i
:=
range
accounts
{
accountByID
[
accounts
[
i
]
.
ID
]
=
&
accounts
[
i
]
}
account
,
ok
:=
accountByID
[
accountID
]
if
ok
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
s
.
isAccountAllowedForPlatform
(
account
,
platform
,
useMixed
)
&&
account
.
IsSchedulable
(
)
&&
account
.
IsSchedulable
ForModel
(
requestedModel
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
{
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
accountID
,
account
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
...
...
@@ -519,6 +511,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if
!
s
.
isAccountAllowedForPlatform
(
acc
,
platform
,
useMixed
)
{
continue
}
if
!
acc
.
IsSchedulableForModel
(
requestedModel
)
{
continue
}
if
requestedModel
!=
""
&&
!
s
.
isModelSupportedByAccount
(
acc
,
requestedModel
)
{
continue
}
...
...
@@ -652,51 +647,97 @@ func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig {
}
}
// checkClaudeCodeRestriction 检查分组的 Claude Code 客户端限制
// 如果分组启用了 claude_code_only 且请求不是来自 Claude Code 客户端:
// - 有降级分组:返回降级分组的 ID
// - 无降级分组:返回 ErrClaudeCodeOnly 错误
func
(
s
*
GatewayService
)
checkClaudeCodeRestriction
(
ctx
context
.
Context
,
groupID
*
int64
)
(
*
int64
,
error
)
{
if
groupID
==
nil
{
return
groupID
,
nil
func
(
s
*
GatewayService
)
withGroupContext
(
ctx
context
.
Context
,
group
*
Group
)
context
.
Context
{
if
!
IsGroupContextValid
(
group
)
{
return
ctx
}
if
existing
,
ok
:=
ctx
.
Value
(
ctxkey
.
Group
)
.
(
*
Group
);
ok
&&
existing
!=
nil
&&
existing
.
ID
==
group
.
ID
&&
IsGroupContextValid
(
existing
)
{
return
ctx
}
return
context
.
WithValue
(
ctx
,
ctxkey
.
Group
,
group
)
}
// 强制平台模式不检查 Claude Code 限制
if
_
,
hasForcePlatform
:=
ctx
.
Value
(
ctxkey
.
ForcePlatform
)
.
(
string
);
hasForcePlatform
{
return
group
ID
,
nil
func
(
s
*
GatewayService
)
groupFromContext
(
ctx
context
.
Context
,
groupID
int64
)
*
Group
{
if
group
,
ok
:=
ctx
.
Value
(
ctxkey
.
Group
)
.
(
*
Group
);
ok
&&
IsGroupContextValid
(
group
)
&&
group
.
ID
==
groupID
{
return
group
}
return
nil
}
group
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
*
groupID
)
func
(
s
*
GatewayService
)
resolveGroupByID
(
ctx
context
.
Context
,
groupID
int64
)
(
*
Group
,
error
)
{
if
group
:=
s
.
groupFromContext
(
ctx
,
groupID
);
group
!=
nil
{
return
group
,
nil
}
group
,
err
:=
s
.
groupRepo
.
GetByIDLite
(
ctx
,
groupID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get group failed: %w"
,
err
)
}
return
group
,
nil
}
if
!
group
.
ClaudeCodeOnly
{
return
groupID
,
nil
func
(
s
*
GatewayService
)
resolveGatewayGroup
(
ctx
context
.
Context
,
groupID
*
int64
)
(
*
Group
,
*
int64
,
error
)
{
if
groupID
==
nil
{
return
nil
,
nil
,
nil
}
// 分组启用了 Claude Code 限制
if
IsClaudeCodeClient
(
ctx
)
{
return
groupID
,
nil
currentID
:=
*
groupID
visited
:=
map
[
int64
]
struct
{}{}
for
{
if
_
,
seen
:=
visited
[
currentID
];
seen
{
return
nil
,
nil
,
fmt
.
Errorf
(
"fallback group cycle detected"
)
}
visited
[
currentID
]
=
struct
{}{}
// 非 Claude Code 客户端,检查降级分组
if
group
.
FallbackGroupID
!=
nil
{
return
group
.
FallbackGroupID
,
nil
group
,
err
:=
s
.
resolveGroupByID
(
ctx
,
currentID
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
return
nil
,
ErrClaudeCodeOnly
if
!
group
.
ClaudeCodeOnly
||
IsClaudeCodeClient
(
ctx
)
{
return
group
,
&
currentID
,
nil
}
if
group
.
FallbackGroupID
==
nil
{
return
nil
,
nil
,
ErrClaudeCodeOnly
}
currentID
=
*
group
.
FallbackGroupID
}
}
func
(
s
*
GatewayService
)
resolvePlatform
(
ctx
context
.
Context
,
groupID
*
int64
)
(
string
,
bool
,
error
)
{
// checkClaudeCodeRestriction 检查分组的 Claude Code 客户端限制
// 如果分组启用了 claude_code_only 且请求不是来自 Claude Code 客户端:
// - 有降级分组:返回降级分组的 ID
// - 无降级分组:返回 ErrClaudeCodeOnly 错误
func
(
s
*
GatewayService
)
checkClaudeCodeRestriction
(
ctx
context
.
Context
,
groupID
*
int64
)
(
*
Group
,
*
int64
,
error
)
{
if
groupID
==
nil
{
return
nil
,
groupID
,
nil
}
// 强制平台模式不检查 Claude Code 限制
if
_
,
hasForcePlatform
:=
ctx
.
Value
(
ctxkey
.
ForcePlatform
)
.
(
string
);
hasForcePlatform
{
return
nil
,
groupID
,
nil
}
group
,
resolvedID
,
err
:=
s
.
resolveGatewayGroup
(
ctx
,
groupID
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
return
group
,
resolvedID
,
nil
}
func
(
s
*
GatewayService
)
resolvePlatform
(
ctx
context
.
Context
,
groupID
*
int64
,
group
*
Group
)
(
string
,
bool
,
error
)
{
forcePlatform
,
hasForcePlatform
:=
ctx
.
Value
(
ctxkey
.
ForcePlatform
)
.
(
string
)
if
hasForcePlatform
&&
forcePlatform
!=
""
{
return
forcePlatform
,
true
,
nil
}
if
group
!=
nil
{
return
group
.
Platform
,
false
,
nil
}
if
groupID
!=
nil
{
group
,
err
:=
s
.
groupRepo
.
Get
ByID
(
ctx
,
*
groupID
)
group
,
err
:=
s
.
resolveGroup
ByID
(
ctx
,
*
groupID
)
if
err
!=
nil
{
return
""
,
false
,
fmt
.
Errorf
(
"get group failed: %w"
,
err
)
return
""
,
false
,
err
}
return
group
.
Platform
,
false
,
nil
}
...
...
@@ -812,7 +853,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if
_
,
excluded
:=
excludedIDs
[
accountID
];
!
excluded
{
account
,
err
:=
s
.
accountRepo
.
GetByID
(
ctx
,
accountID
)
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
if
err
==
nil
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
account
.
Platform
==
platform
&&
account
.
IsSchedulable
(
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
{
if
err
==
nil
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
account
.
Platform
==
platform
&&
account
.
IsSchedulable
ForModel
(
requestedModel
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
{
if
err
:=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
stickySessionTTL
);
err
!=
nil
{
log
.
Printf
(
"refresh session ttl failed: session=%s err=%v"
,
sessionHash
,
err
)
}
...
...
@@ -844,6 +885,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if
_
,
excluded
:=
excludedIDs
[
acc
.
ID
];
excluded
{
continue
}
if
!
acc
.
IsSchedulableForModel
(
requestedModel
)
{
continue
}
if
requestedModel
!=
""
&&
!
s
.
isModelSupportedByAccount
(
acc
,
requestedModel
)
{
continue
}
...
...
@@ -901,7 +945,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if
_
,
excluded
:=
excludedIDs
[
accountID
];
!
excluded
{
account
,
err
:=
s
.
accountRepo
.
GetByID
(
ctx
,
accountID
)
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
if
err
==
nil
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
account
.
IsSchedulable
(
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
{
if
err
==
nil
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
account
.
IsSchedulable
ForModel
(
requestedModel
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
{
if
account
.
Platform
==
nativePlatform
||
(
account
.
Platform
==
PlatformAntigravity
&&
account
.
IsMixedSchedulingEnabled
())
{
if
err
:=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
stickySessionTTL
);
err
!=
nil
{
log
.
Printf
(
"refresh session ttl failed: session=%s err=%v"
,
sessionHash
,
err
)
...
...
@@ -936,6 +980,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if
acc
.
Platform
==
PlatformAntigravity
&&
!
acc
.
IsMixedSchedulingEnabled
()
{
continue
}
if
!
acc
.
IsSchedulableForModel
(
requestedModel
)
{
continue
}
if
requestedModel
!=
""
&&
!
s
.
isModelSupportedByAccount
(
acc
,
requestedModel
)
{
continue
}
...
...
@@ -2247,6 +2294,7 @@ type RecordUsageInput struct {
Account
*
Account
Subscription
*
UserSubscription
// 可选:订阅信息
UserAgent
string
// 请求的 User-Agent
IPAddress
string
// 请求的客户端 IP 地址
}
// RecordUsage 记录使用量并扣费(或更新订阅用量)
...
...
@@ -2337,6 +2385,11 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
usageLog
.
UserAgent
=
&
input
.
UserAgent
}
// 添加 IPAddress
if
input
.
IPAddress
!=
""
{
usageLog
.
IPAddress
=
&
input
.
IPAddress
}
// 添加分组和订阅关联
if
apiKey
.
GroupID
!=
nil
{
usageLog
.
GroupID
=
apiKey
.
GroupID
...
...
backend/internal/service/gemini_messages_compat_service.go
View file @
1a641392
...
...
@@ -86,10 +86,16 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
platform
=
forcePlatform
}
else
if
groupID
!=
nil
{
// 根据分组 platform 决定查询哪种账号
group
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
*
groupID
)
var
group
*
Group
if
ctxGroup
,
ok
:=
ctx
.
Value
(
ctxkey
.
Group
)
.
(
*
Group
);
ok
&&
IsGroupContextValid
(
ctxGroup
)
&&
ctxGroup
.
ID
==
*
groupID
{
group
=
ctxGroup
}
else
{
var
err
error
group
,
err
=
s
.
groupRepo
.
GetByIDLite
(
ctx
,
*
groupID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get group failed: %w"
,
err
)
}
}
platform
=
group
.
Platform
}
else
{
// 无分组时只使用原生 gemini 平台
...
...
@@ -114,7 +120,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
if
_
,
excluded
:=
excludedIDs
[
accountID
];
!
excluded
{
account
,
err
:=
s
.
accountRepo
.
GetByID
(
ctx
,
accountID
)
// 检查账号是否有效:原生平台直接匹配,antigravity 需要启用混合调度
if
err
==
nil
&&
account
.
IsSchedulable
(
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
{
if
err
==
nil
&&
account
.
IsSchedulable
ForModel
(
requestedModel
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
{
valid
:=
false
if
account
.
Platform
==
platform
{
valid
=
true
...
...
@@ -172,6 +178,9 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
if
useMixedScheduling
&&
acc
.
Platform
==
PlatformAntigravity
&&
!
acc
.
IsMixedSchedulingEnabled
()
{
continue
}
if
!
acc
.
IsSchedulableForModel
(
requestedModel
)
{
continue
}
if
requestedModel
!=
""
&&
!
s
.
isModelSupportedByAccount
(
acc
,
requestedModel
)
{
continue
}
...
...
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment