Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
bb664d9b
Commit
bb664d9b
authored
Feb 28, 2026
by
yangjianbo
Browse files
feat(sync): full code sync from release
parent
bfc7b339
Changes
244
Hide whitespace changes
Inline
Side-by-side
Too many changes to show.
To preserve performance only
244 of 244+
files are displayed.
Plain diff
Email patch
backend/internal/server/middleware/api_key_auth.go
View file @
bb664d9b
...
...
@@ -97,7 +97,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
// 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制
if
len
(
apiKey
.
IPWhitelist
)
>
0
||
len
(
apiKey
.
IPBlacklist
)
>
0
{
clientIP
:=
ip
.
GetTrustedClientIP
(
c
)
allowed
,
_
:=
ip
.
CheckIPRestriction
(
clientIP
,
apiKey
.
IPWhitelist
,
apiKey
.
IPBlacklist
)
allowed
,
_
:=
ip
.
CheckIPRestriction
WithCompiledRules
(
clientIP
,
apiKey
.
Compiled
IPWhitelist
,
apiKey
.
Compiled
IPBlacklist
)
if
!
allowed
{
AbortWithError
(
c
,
403
,
"ACCESS_DENIED"
,
"Access denied"
)
return
...
...
backend/internal/server/middleware/api_key_auth_google.go
View file @
bb664d9b
...
...
@@ -80,17 +80,25 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
abortWithGoogleError
(
c
,
403
,
"No active subscription found for this group"
)
return
}
if
err
:=
subscriptionService
.
ValidateSubscription
(
c
.
Request
.
Context
(),
subscription
);
err
!=
nil
{
abortWithGoogleError
(
c
,
403
,
err
.
Error
())
return
}
_
=
subscriptionService
.
CheckAndActivateWindow
(
c
.
Request
.
Context
(),
subscription
)
_
=
subscriptionService
.
CheckAndResetWindows
(
c
.
Request
.
Context
(),
subscription
)
if
err
:=
subscriptionService
.
CheckUsageLimits
(
c
.
Request
.
Context
(),
subscription
,
apiKey
.
Group
,
0
);
err
!=
nil
{
abortWithGoogleError
(
c
,
429
,
err
.
Error
())
needsMaintenance
,
err
:=
subscriptionService
.
ValidateAndCheckLimits
(
subscription
,
apiKey
.
Group
)
if
err
!=
nil
{
status
:=
403
if
errors
.
Is
(
err
,
service
.
ErrDailyLimitExceeded
)
||
errors
.
Is
(
err
,
service
.
ErrWeeklyLimitExceeded
)
||
errors
.
Is
(
err
,
service
.
ErrMonthlyLimitExceeded
)
{
status
=
429
}
abortWithGoogleError
(
c
,
status
,
err
.
Error
())
return
}
c
.
Set
(
string
(
ContextKeySubscription
),
subscription
)
if
needsMaintenance
{
maintenanceCopy
:=
*
subscription
subscriptionService
.
DoWindowMaintenance
(
&
maintenanceCopy
)
}
}
else
{
if
apiKey
.
User
.
Balance
<=
0
{
abortWithGoogleError
(
c
,
403
,
"Insufficient account balance"
)
...
...
backend/internal/server/middleware/api_key_auth_google_test.go
View file @
bb664d9b
...
...
@@ -23,6 +23,15 @@ type fakeAPIKeyRepo struct {
updateLastUsed
func
(
ctx
context
.
Context
,
id
int64
,
usedAt
time
.
Time
)
error
}
type
fakeGoogleSubscriptionRepo
struct
{
getActive
func
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
service
.
UserSubscription
,
error
)
updateStatus
func
(
ctx
context
.
Context
,
subscriptionID
int64
,
status
string
)
error
activateWindow
func
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
resetDaily
func
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
resetWeekly
func
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
resetMonthly
func
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
}
func
(
f
fakeAPIKeyRepo
)
Create
(
ctx
context
.
Context
,
key
*
service
.
APIKey
)
error
{
return
errors
.
New
(
"not implemented"
)
}
...
...
@@ -87,6 +96,85 @@ func (f fakeAPIKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt tim
return
nil
}
func
(
f
fakeGoogleSubscriptionRepo
)
Create
(
ctx
context
.
Context
,
sub
*
service
.
UserSubscription
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
f
fakeGoogleSubscriptionRepo
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
UserSubscription
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
f
fakeGoogleSubscriptionRepo
)
GetByUserIDAndGroupID
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
service
.
UserSubscription
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
f
fakeGoogleSubscriptionRepo
)
GetActiveByUserIDAndGroupID
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
service
.
UserSubscription
,
error
)
{
if
f
.
getActive
!=
nil
{
return
f
.
getActive
(
ctx
,
userID
,
groupID
)
}
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
f
fakeGoogleSubscriptionRepo
)
Update
(
ctx
context
.
Context
,
sub
*
service
.
UserSubscription
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
f
fakeGoogleSubscriptionRepo
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
f
fakeGoogleSubscriptionRepo
)
ListByUserID
(
ctx
context
.
Context
,
userID
int64
)
([]
service
.
UserSubscription
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
f
fakeGoogleSubscriptionRepo
)
ListActiveByUserID
(
ctx
context
.
Context
,
userID
int64
)
([]
service
.
UserSubscription
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
f
fakeGoogleSubscriptionRepo
)
ListByGroupID
(
ctx
context
.
Context
,
groupID
int64
,
params
pagination
.
PaginationParams
)
([]
service
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
f
fakeGoogleSubscriptionRepo
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
userID
,
groupID
*
int64
,
status
,
sortBy
,
sortOrder
string
)
([]
service
.
UserSubscription
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
f
fakeGoogleSubscriptionRepo
)
ExistsByUserIDAndGroupID
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
bool
,
error
)
{
return
false
,
errors
.
New
(
"not implemented"
)
}
func
(
f
fakeGoogleSubscriptionRepo
)
ExtendExpiry
(
ctx
context
.
Context
,
subscriptionID
int64
,
newExpiresAt
time
.
Time
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
f
fakeGoogleSubscriptionRepo
)
UpdateStatus
(
ctx
context
.
Context
,
subscriptionID
int64
,
status
string
)
error
{
if
f
.
updateStatus
!=
nil
{
return
f
.
updateStatus
(
ctx
,
subscriptionID
,
status
)
}
return
errors
.
New
(
"not implemented"
)
}
func
(
f
fakeGoogleSubscriptionRepo
)
UpdateNotes
(
ctx
context
.
Context
,
subscriptionID
int64
,
notes
string
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
f
fakeGoogleSubscriptionRepo
)
ActivateWindows
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
{
if
f
.
activateWindow
!=
nil
{
return
f
.
activateWindow
(
ctx
,
id
,
start
)
}
return
errors
.
New
(
"not implemented"
)
}
func
(
f
fakeGoogleSubscriptionRepo
)
ResetDailyUsage
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
{
if
f
.
resetDaily
!=
nil
{
return
f
.
resetDaily
(
ctx
,
id
,
start
)
}
return
errors
.
New
(
"not implemented"
)
}
func
(
f
fakeGoogleSubscriptionRepo
)
ResetWeeklyUsage
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
{
if
f
.
resetWeekly
!=
nil
{
return
f
.
resetWeekly
(
ctx
,
id
,
start
)
}
return
errors
.
New
(
"not implemented"
)
}
func
(
f
fakeGoogleSubscriptionRepo
)
ResetMonthlyUsage
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
{
if
f
.
resetMonthly
!=
nil
{
return
f
.
resetMonthly
(
ctx
,
id
,
start
)
}
return
errors
.
New
(
"not implemented"
)
}
func
(
f
fakeGoogleSubscriptionRepo
)
IncrementUsage
(
ctx
context
.
Context
,
id
int64
,
costUSD
float64
)
error
{
return
errors
.
New
(
"not implemented"
)
}
func
(
f
fakeGoogleSubscriptionRepo
)
BatchUpdateExpiredStatus
(
ctx
context
.
Context
)
(
int64
,
error
)
{
return
0
,
errors
.
New
(
"not implemented"
)
}
type
googleErrorResponse
struct
{
Error
struct
{
Code
int
`json:"code"`
...
...
@@ -505,3 +593,85 @@ func TestApiKeyAuthWithSubscriptionGoogle_TouchesLastUsedInStandardMode(t *testi
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
require
.
Equal
(
t
,
1
,
touchCalls
)
}
func
TestApiKeyAuthWithSubscriptionGoogle_SubscriptionLimitExceededReturns429
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
limit
:=
1.0
group
:=
&
service
.
Group
{
ID
:
77
,
Name
:
"gemini-sub"
,
Status
:
service
.
StatusActive
,
Platform
:
service
.
PlatformGemini
,
Hydrated
:
true
,
SubscriptionType
:
service
.
SubscriptionTypeSubscription
,
DailyLimitUSD
:
&
limit
,
}
user
:=
&
service
.
User
{
ID
:
999
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
Balance
:
10
,
Concurrency
:
3
,
}
apiKey
:=
&
service
.
APIKey
{
ID
:
501
,
UserID
:
user
.
ID
,
Key
:
"google-sub-limit"
,
Status
:
service
.
StatusActive
,
User
:
user
,
Group
:
group
,
}
apiKey
.
GroupID
=
&
group
.
ID
apiKeyService
:=
newTestAPIKeyService
(
fakeAPIKeyRepo
{
getByKey
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
if
key
!=
apiKey
.
Key
{
return
nil
,
service
.
ErrAPIKeyNotFound
}
clone
:=
*
apiKey
return
&
clone
,
nil
},
})
now
:=
time
.
Now
()
sub
:=
&
service
.
UserSubscription
{
ID
:
601
,
UserID
:
user
.
ID
,
GroupID
:
group
.
ID
,
Status
:
service
.
SubscriptionStatusActive
,
ExpiresAt
:
now
.
Add
(
24
*
time
.
Hour
),
DailyWindowStart
:
&
now
,
DailyUsageUSD
:
10
,
}
subscriptionService
:=
service
.
NewSubscriptionService
(
nil
,
fakeGoogleSubscriptionRepo
{
getActive
:
func
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
service
.
UserSubscription
,
error
)
{
if
userID
!=
user
.
ID
||
groupID
!=
group
.
ID
{
return
nil
,
service
.
ErrSubscriptionNotFound
}
clone
:=
*
sub
return
&
clone
,
nil
},
updateStatus
:
func
(
ctx
context
.
Context
,
subscriptionID
int64
,
status
string
)
error
{
return
nil
},
activateWindow
:
func
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
{
return
nil
},
resetDaily
:
func
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
{
return
nil
},
resetWeekly
:
func
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
{
return
nil
},
resetMonthly
:
func
(
ctx
context
.
Context
,
id
int64
,
start
time
.
Time
)
error
{
return
nil
},
},
nil
,
nil
,
&
config
.
Config
{
RunMode
:
config
.
RunModeStandard
})
r
:=
gin
.
New
()
r
.
Use
(
APIKeyAuthWithSubscriptionGoogle
(
apiKeyService
,
subscriptionService
,
&
config
.
Config
{
RunMode
:
config
.
RunModeStandard
}))
r
.
GET
(
"/v1beta/test"
,
func
(
c
*
gin
.
Context
)
{
c
.
JSON
(
200
,
gin
.
H
{
"ok"
:
true
})
})
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/v1beta/test"
,
nil
)
req
.
Header
.
Set
(
"x-goog-api-key"
,
apiKey
.
Key
)
rec
:=
httptest
.
NewRecorder
()
r
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusTooManyRequests
,
rec
.
Code
)
var
resp
googleErrorResponse
require
.
NoError
(
t
,
json
.
Unmarshal
(
rec
.
Body
.
Bytes
(),
&
resp
))
require
.
Equal
(
t
,
http
.
StatusTooManyRequests
,
resp
.
Error
.
Code
)
require
.
Equal
(
t
,
"RESOURCE_EXHAUSTED"
,
resp
.
Error
.
Status
)
require
.
Contains
(
t
,
resp
.
Error
.
Message
,
"daily usage limit exceeded"
)
}
backend/internal/server/middleware/security_headers.go
View file @
bb664d9b
...
...
@@ -54,6 +54,10 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
c
.
Header
(
"X-Content-Type-Options"
,
"nosniff"
)
c
.
Header
(
"X-Frame-Options"
,
"DENY"
)
c
.
Header
(
"Referrer-Policy"
,
"strict-origin-when-cross-origin"
)
if
isAPIRoutePath
(
c
)
{
c
.
Next
()
return
}
if
cfg
.
Enabled
{
// Generate nonce for this request
...
...
@@ -73,6 +77,18 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
}
}
func
isAPIRoutePath
(
c
*
gin
.
Context
)
bool
{
if
c
==
nil
||
c
.
Request
==
nil
||
c
.
Request
.
URL
==
nil
{
return
false
}
path
:=
c
.
Request
.
URL
.
Path
return
strings
.
HasPrefix
(
path
,
"/v1/"
)
||
strings
.
HasPrefix
(
path
,
"/v1beta/"
)
||
strings
.
HasPrefix
(
path
,
"/antigravity/"
)
||
strings
.
HasPrefix
(
path
,
"/sora/"
)
||
strings
.
HasPrefix
(
path
,
"/responses"
)
}
// enhanceCSPPolicy ensures the CSP policy includes nonce support and Cloudflare Insights domain.
// This allows the application to work correctly even if the config file has an older CSP policy.
func
enhanceCSPPolicy
(
policy
string
)
string
{
...
...
backend/internal/server/middleware/security_headers_test.go
View file @
bb664d9b
...
...
@@ -131,6 +131,26 @@ func TestSecurityHeaders(t *testing.T) {
assert
.
Contains
(
t
,
csp
,
CloudflareInsightsDomain
)
})
t
.
Run
(
"api_route_skips_csp_nonce_generation"
,
func
(
t
*
testing
.
T
)
{
cfg
:=
config
.
CSPConfig
{
Enabled
:
true
,
Policy
:
"default-src 'self'; script-src 'self' __CSP_NONCE__"
,
}
middleware
:=
SecurityHeaders
(
cfg
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/v1/messages"
,
nil
)
middleware
(
c
)
assert
.
Equal
(
t
,
"nosniff"
,
w
.
Header
()
.
Get
(
"X-Content-Type-Options"
))
assert
.
Equal
(
t
,
"DENY"
,
w
.
Header
()
.
Get
(
"X-Frame-Options"
))
assert
.
Equal
(
t
,
"strict-origin-when-cross-origin"
,
w
.
Header
()
.
Get
(
"Referrer-Policy"
))
assert
.
Empty
(
t
,
w
.
Header
()
.
Get
(
"Content-Security-Policy"
))
assert
.
Empty
(
t
,
GetNonceFromContext
(
c
))
})
t
.
Run
(
"csp_enabled_with_nonce_placeholder"
,
func
(
t
*
testing
.
T
)
{
cfg
:=
config
.
CSPConfig
{
Enabled
:
true
,
...
...
backend/internal/server/router.go
View file @
bb664d9b
...
...
@@ -75,6 +75,7 @@ func registerRoutes(
// 注册各模块路由
routes
.
RegisterAuthRoutes
(
v1
,
h
,
jwtAuth
,
redisClient
)
routes
.
RegisterUserRoutes
(
v1
,
h
,
jwtAuth
)
routes
.
RegisterSoraClientRoutes
(
v1
,
h
,
jwtAuth
)
routes
.
RegisterAdminRoutes
(
v1
,
h
,
adminAuth
)
routes
.
RegisterGatewayRoutes
(
r
,
h
,
apiKeyAuth
,
apiKeyService
,
subscriptionService
,
opsService
,
cfg
)
}
backend/internal/server/routes/admin.go
View file @
bb664d9b
...
...
@@ -55,6 +55,9 @@ func RegisterAdminRoutes(
// 系统设置
registerSettingsRoutes
(
admin
,
h
)
// 数据管理
registerDataManagementRoutes
(
admin
,
h
)
// 运维监控(Ops)
registerOpsRoutes
(
admin
,
h
)
...
...
@@ -231,6 +234,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts
.
POST
(
"/:id/clear-error"
,
h
.
Admin
.
Account
.
ClearError
)
accounts
.
GET
(
"/:id/usage"
,
h
.
Admin
.
Account
.
GetUsage
)
accounts
.
GET
(
"/:id/today-stats"
,
h
.
Admin
.
Account
.
GetTodayStats
)
accounts
.
POST
(
"/today-stats/batch"
,
h
.
Admin
.
Account
.
GetBatchTodayStats
)
accounts
.
POST
(
"/:id/clear-rate-limit"
,
h
.
Admin
.
Account
.
ClearRateLimit
)
accounts
.
GET
(
"/:id/temp-unschedulable"
,
h
.
Admin
.
Account
.
GetTempUnschedulable
)
accounts
.
DELETE
(
"/:id/temp-unschedulable"
,
h
.
Admin
.
Account
.
ClearTempUnschedulable
)
...
...
@@ -370,6 +374,38 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
// 流超时处理配置
adminSettings
.
GET
(
"/stream-timeout"
,
h
.
Admin
.
Setting
.
GetStreamTimeoutSettings
)
adminSettings
.
PUT
(
"/stream-timeout"
,
h
.
Admin
.
Setting
.
UpdateStreamTimeoutSettings
)
// Sora S3 存储配置
adminSettings
.
GET
(
"/sora-s3"
,
h
.
Admin
.
Setting
.
GetSoraS3Settings
)
adminSettings
.
PUT
(
"/sora-s3"
,
h
.
Admin
.
Setting
.
UpdateSoraS3Settings
)
adminSettings
.
POST
(
"/sora-s3/test"
,
h
.
Admin
.
Setting
.
TestSoraS3Connection
)
adminSettings
.
GET
(
"/sora-s3/profiles"
,
h
.
Admin
.
Setting
.
ListSoraS3Profiles
)
adminSettings
.
POST
(
"/sora-s3/profiles"
,
h
.
Admin
.
Setting
.
CreateSoraS3Profile
)
adminSettings
.
PUT
(
"/sora-s3/profiles/:profile_id"
,
h
.
Admin
.
Setting
.
UpdateSoraS3Profile
)
adminSettings
.
DELETE
(
"/sora-s3/profiles/:profile_id"
,
h
.
Admin
.
Setting
.
DeleteSoraS3Profile
)
adminSettings
.
POST
(
"/sora-s3/profiles/:profile_id/activate"
,
h
.
Admin
.
Setting
.
SetActiveSoraS3Profile
)
}
}
func
registerDataManagementRoutes
(
admin
*
gin
.
RouterGroup
,
h
*
handler
.
Handlers
)
{
dataManagement
:=
admin
.
Group
(
"/data-management"
)
{
dataManagement
.
GET
(
"/agent/health"
,
h
.
Admin
.
DataManagement
.
GetAgentHealth
)
dataManagement
.
GET
(
"/config"
,
h
.
Admin
.
DataManagement
.
GetConfig
)
dataManagement
.
PUT
(
"/config"
,
h
.
Admin
.
DataManagement
.
UpdateConfig
)
dataManagement
.
GET
(
"/sources/:source_type/profiles"
,
h
.
Admin
.
DataManagement
.
ListSourceProfiles
)
dataManagement
.
POST
(
"/sources/:source_type/profiles"
,
h
.
Admin
.
DataManagement
.
CreateSourceProfile
)
dataManagement
.
PUT
(
"/sources/:source_type/profiles/:profile_id"
,
h
.
Admin
.
DataManagement
.
UpdateSourceProfile
)
dataManagement
.
DELETE
(
"/sources/:source_type/profiles/:profile_id"
,
h
.
Admin
.
DataManagement
.
DeleteSourceProfile
)
dataManagement
.
POST
(
"/sources/:source_type/profiles/:profile_id/activate"
,
h
.
Admin
.
DataManagement
.
SetActiveSourceProfile
)
dataManagement
.
POST
(
"/s3/test"
,
h
.
Admin
.
DataManagement
.
TestS3
)
dataManagement
.
GET
(
"/s3/profiles"
,
h
.
Admin
.
DataManagement
.
ListS3Profiles
)
dataManagement
.
POST
(
"/s3/profiles"
,
h
.
Admin
.
DataManagement
.
CreateS3Profile
)
dataManagement
.
PUT
(
"/s3/profiles/:profile_id"
,
h
.
Admin
.
DataManagement
.
UpdateS3Profile
)
dataManagement
.
DELETE
(
"/s3/profiles/:profile_id"
,
h
.
Admin
.
DataManagement
.
DeleteS3Profile
)
dataManagement
.
POST
(
"/s3/profiles/:profile_id/activate"
,
h
.
Admin
.
DataManagement
.
SetActiveS3Profile
)
dataManagement
.
POST
(
"/backups"
,
h
.
Admin
.
DataManagement
.
CreateBackupJob
)
dataManagement
.
GET
(
"/backups"
,
h
.
Admin
.
DataManagement
.
ListBackupJobs
)
dataManagement
.
GET
(
"/backups/:job_id"
,
h
.
Admin
.
DataManagement
.
GetBackupJob
)
}
}
...
...
backend/internal/server/routes/gateway.go
View file @
bb664d9b
...
...
@@ -43,6 +43,7 @@ func RegisterGatewayRoutes(
gateway
.
GET
(
"/usage"
,
h
.
Gateway
.
Usage
)
// OpenAI Responses API
gateway
.
POST
(
"/responses"
,
h
.
OpenAIGateway
.
Responses
)
gateway
.
GET
(
"/responses"
,
h
.
OpenAIGateway
.
ResponsesWebSocket
)
// 明确阻止旧协议入口:OpenAI 仅支持 Responses API,避免客户端误解为会自动路由到其它平台。
gateway
.
POST
(
"/chat/completions"
,
func
(
c
*
gin
.
Context
)
{
c
.
JSON
(
http
.
StatusBadRequest
,
gin
.
H
{
...
...
@@ -69,6 +70,7 @@ func RegisterGatewayRoutes(
// OpenAI Responses API(不带v1前缀的别名)
r
.
POST
(
"/responses"
,
bodyLimit
,
clientRequestID
,
opsErrorLogger
,
gin
.
HandlerFunc
(
apiKeyAuth
),
h
.
OpenAIGateway
.
Responses
)
r
.
GET
(
"/responses"
,
bodyLimit
,
clientRequestID
,
opsErrorLogger
,
gin
.
HandlerFunc
(
apiKeyAuth
),
h
.
OpenAIGateway
.
ResponsesWebSocket
)
// Antigravity 模型列表
r
.
GET
(
"/antigravity/models"
,
gin
.
HandlerFunc
(
apiKeyAuth
),
h
.
Gateway
.
AntigravityModels
)
...
...
backend/internal/server/routes/sora_client.go
0 → 100644
View file @
bb664d9b
package
routes
import
(
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/gin-gonic/gin"
)
// RegisterSoraClientRoutes 注册 Sora 客户端 API 路由(需要用户认证)。
func
RegisterSoraClientRoutes
(
v1
*
gin
.
RouterGroup
,
h
*
handler
.
Handlers
,
jwtAuth
middleware
.
JWTAuthMiddleware
,
)
{
if
h
.
SoraClient
==
nil
{
return
}
authenticated
:=
v1
.
Group
(
"/sora"
)
authenticated
.
Use
(
gin
.
HandlerFunc
(
jwtAuth
))
{
authenticated
.
POST
(
"/generate"
,
h
.
SoraClient
.
Generate
)
authenticated
.
GET
(
"/generations"
,
h
.
SoraClient
.
ListGenerations
)
authenticated
.
GET
(
"/generations/:id"
,
h
.
SoraClient
.
GetGeneration
)
authenticated
.
DELETE
(
"/generations/:id"
,
h
.
SoraClient
.
DeleteGeneration
)
authenticated
.
POST
(
"/generations/:id/cancel"
,
h
.
SoraClient
.
CancelGeneration
)
authenticated
.
POST
(
"/generations/:id/save"
,
h
.
SoraClient
.
SaveToStorage
)
authenticated
.
GET
(
"/quota"
,
h
.
SoraClient
.
GetQuota
)
authenticated
.
GET
(
"/models"
,
h
.
SoraClient
.
GetModels
)
authenticated
.
GET
(
"/storage-status"
,
h
.
SoraClient
.
GetStorageStatus
)
}
}
backend/internal/service/account.go
View file @
bb664d9b
...
...
@@ -3,6 +3,8 @@ package service
import
(
"encoding/json"
"hash/fnv"
"reflect"
"sort"
"strconv"
"strings"
...
...
@@ -50,6 +52,14 @@ type Account struct {
AccountGroups
[]
AccountGroup
GroupIDs
[]
int64
Groups
[]
*
Group
// model_mapping 热路径缓存(非持久化字段)
modelMappingCache
map
[
string
]
string
modelMappingCacheReady
bool
modelMappingCacheCredentialsPtr
uintptr
modelMappingCacheRawPtr
uintptr
modelMappingCacheRawLen
int
modelMappingCacheRawSig
uint64
}
type
TempUnschedulableRule
struct
{
...
...
@@ -349,6 +359,39 @@ func parseTempUnschedInt(value any) int {
}
func
(
a
*
Account
)
GetModelMapping
()
map
[
string
]
string
{
credentialsPtr
:=
mapPtr
(
a
.
Credentials
)
rawMapping
,
_
:=
a
.
Credentials
[
"model_mapping"
]
.
(
map
[
string
]
any
)
rawPtr
:=
mapPtr
(
rawMapping
)
rawLen
:=
len
(
rawMapping
)
rawSig
:=
uint64
(
0
)
rawSigReady
:=
false
if
a
.
modelMappingCacheReady
&&
a
.
modelMappingCacheCredentialsPtr
==
credentialsPtr
&&
a
.
modelMappingCacheRawPtr
==
rawPtr
&&
a
.
modelMappingCacheRawLen
==
rawLen
{
rawSig
=
modelMappingSignature
(
rawMapping
)
rawSigReady
=
true
if
a
.
modelMappingCacheRawSig
==
rawSig
{
return
a
.
modelMappingCache
}
}
mapping
:=
a
.
resolveModelMapping
(
rawMapping
)
if
!
rawSigReady
{
rawSig
=
modelMappingSignature
(
rawMapping
)
}
a
.
modelMappingCache
=
mapping
a
.
modelMappingCacheReady
=
true
a
.
modelMappingCacheCredentialsPtr
=
credentialsPtr
a
.
modelMappingCacheRawPtr
=
rawPtr
a
.
modelMappingCacheRawLen
=
rawLen
a
.
modelMappingCacheRawSig
=
rawSig
return
mapping
}
func
(
a
*
Account
)
resolveModelMapping
(
rawMapping
map
[
string
]
any
)
map
[
string
]
string
{
if
a
.
Credentials
==
nil
{
// Antigravity 平台使用默认映射
if
a
.
Platform
==
domain
.
PlatformAntigravity
{
...
...
@@ -356,32 +399,31 @@ func (a *Account) GetModelMapping() map[string]string {
}
return
nil
}
raw
,
ok
:=
a
.
Credentials
[
"model_mapping"
]
if
!
ok
||
raw
==
nil
{
if
len
(
rawMapping
)
==
0
{
// Antigravity 平台使用默认映射
if
a
.
Platform
==
domain
.
PlatformAntigravity
{
return
domain
.
DefaultAntigravityModelMapping
}
return
nil
}
if
m
,
ok
:=
raw
.
(
map
[
string
]
any
);
ok
{
result
:=
make
(
map
[
string
]
string
)
for
k
,
v
:=
range
m
{
if
s
,
ok
:=
v
.
(
string
);
ok
{
result
[
k
]
=
s
}
result
:=
make
(
map
[
string
]
string
)
for
k
,
v
:=
range
rawMapping
{
if
s
,
ok
:=
v
.
(
string
);
ok
{
result
[
k
]
=
s
}
if
len
(
result
)
>
0
{
if
a
.
Platform
==
domain
.
PlatformAntigravity
{
ensureAntigravityDefaultPassthroughs
(
result
,
[]
string
{
"gemini-3-flash"
,
"gemini-3.1-pro-high"
,
"gemini-3.1-pro-low"
,
})
}
return
result
}
if
len
(
result
)
>
0
{
if
a
.
Platform
==
domain
.
PlatformAntigravity
{
ensureAntigravityDefaultPassthroughs
(
result
,
[]
string
{
"gemini-3-flash"
,
"gemini-3.1-pro-high"
,
"gemini-3.1-pro-low"
,
})
}
return
result
}
// Antigravity 平台使用默认映射
if
a
.
Platform
==
domain
.
PlatformAntigravity
{
return
domain
.
DefaultAntigravityModelMapping
...
...
@@ -389,6 +431,37 @@ func (a *Account) GetModelMapping() map[string]string {
return
nil
}
func
mapPtr
(
m
map
[
string
]
any
)
uintptr
{
if
m
==
nil
{
return
0
}
return
reflect
.
ValueOf
(
m
)
.
Pointer
()
}
func
modelMappingSignature
(
rawMapping
map
[
string
]
any
)
uint64
{
if
len
(
rawMapping
)
==
0
{
return
0
}
keys
:=
make
([]
string
,
0
,
len
(
rawMapping
))
for
k
:=
range
rawMapping
{
keys
=
append
(
keys
,
k
)
}
sort
.
Strings
(
keys
)
h
:=
fnv
.
New64a
()
for
_
,
k
:=
range
keys
{
_
,
_
=
h
.
Write
([]
byte
(
k
))
_
,
_
=
h
.
Write
([]
byte
{
0
})
if
v
,
ok
:=
rawMapping
[
k
]
.
(
string
);
ok
{
_
,
_
=
h
.
Write
([]
byte
(
v
))
}
else
{
_
,
_
=
h
.
Write
([]
byte
{
1
})
}
_
,
_
=
h
.
Write
([]
byte
{
0xff
})
}
return
h
.
Sum64
()
}
func
ensureAntigravityDefaultPassthrough
(
mapping
map
[
string
]
string
,
model
string
)
{
if
mapping
==
nil
||
model
==
""
{
return
...
...
@@ -742,6 +815,159 @@ func (a *Account) IsOpenAIPassthroughEnabled() bool {
return
false
}
// IsOpenAIResponsesWebSocketV2Enabled 返回 OpenAI 账号是否开启 Responses WebSocket v2。
//
// 分类型新字段:
// - OAuth 账号:accounts.extra.openai_oauth_responses_websockets_v2_enabled
// - API Key 账号:accounts.extra.openai_apikey_responses_websockets_v2_enabled
//
// 兼容字段:
// - accounts.extra.responses_websockets_v2_enabled
// - accounts.extra.openai_ws_enabled(历史开关)
//
// 优先级:
// 1. 按账号类型读取分类型字段
// 2. 分类型字段缺失时,回退兼容字段
func
(
a
*
Account
)
IsOpenAIResponsesWebSocketV2Enabled
()
bool
{
if
a
==
nil
||
!
a
.
IsOpenAI
()
||
a
.
Extra
==
nil
{
return
false
}
if
a
.
IsOpenAIOAuth
()
{
if
enabled
,
ok
:=
a
.
Extra
[
"openai_oauth_responses_websockets_v2_enabled"
]
.
(
bool
);
ok
{
return
enabled
}
}
if
a
.
IsOpenAIApiKey
()
{
if
enabled
,
ok
:=
a
.
Extra
[
"openai_apikey_responses_websockets_v2_enabled"
]
.
(
bool
);
ok
{
return
enabled
}
}
if
enabled
,
ok
:=
a
.
Extra
[
"responses_websockets_v2_enabled"
]
.
(
bool
);
ok
{
return
enabled
}
if
enabled
,
ok
:=
a
.
Extra
[
"openai_ws_enabled"
]
.
(
bool
);
ok
{
return
enabled
}
return
false
}
const
(
OpenAIWSIngressModeOff
=
"off"
OpenAIWSIngressModeShared
=
"shared"
OpenAIWSIngressModeDedicated
=
"dedicated"
)
func
normalizeOpenAIWSIngressMode
(
mode
string
)
string
{
switch
strings
.
ToLower
(
strings
.
TrimSpace
(
mode
))
{
case
OpenAIWSIngressModeOff
:
return
OpenAIWSIngressModeOff
case
OpenAIWSIngressModeShared
:
return
OpenAIWSIngressModeShared
case
OpenAIWSIngressModeDedicated
:
return
OpenAIWSIngressModeDedicated
default
:
return
""
}
}
func
normalizeOpenAIWSIngressDefaultMode
(
mode
string
)
string
{
if
normalized
:=
normalizeOpenAIWSIngressMode
(
mode
);
normalized
!=
""
{
return
normalized
}
return
OpenAIWSIngressModeShared
}
// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/shared/dedicated)。
//
// 优先级:
// 1. 分类型 mode 新字段(string)
// 2. 分类型 enabled 旧字段(bool)
// 3. 兼容 enabled 旧字段(bool)
// 4. defaultMode(非法时回退 shared)
func
(
a
*
Account
)
ResolveOpenAIResponsesWebSocketV2Mode
(
defaultMode
string
)
string
{
resolvedDefault
:=
normalizeOpenAIWSIngressDefaultMode
(
defaultMode
)
if
a
==
nil
||
!
a
.
IsOpenAI
()
{
return
OpenAIWSIngressModeOff
}
if
a
.
Extra
==
nil
{
return
resolvedDefault
}
resolveModeString
:=
func
(
key
string
)
(
string
,
bool
)
{
raw
,
ok
:=
a
.
Extra
[
key
]
if
!
ok
{
return
""
,
false
}
mode
,
ok
:=
raw
.
(
string
)
if
!
ok
{
return
""
,
false
}
normalized
:=
normalizeOpenAIWSIngressMode
(
mode
)
if
normalized
==
""
{
return
""
,
false
}
return
normalized
,
true
}
resolveBoolMode
:=
func
(
key
string
)
(
string
,
bool
)
{
raw
,
ok
:=
a
.
Extra
[
key
]
if
!
ok
{
return
""
,
false
}
enabled
,
ok
:=
raw
.
(
bool
)
if
!
ok
{
return
""
,
false
}
if
enabled
{
return
OpenAIWSIngressModeShared
,
true
}
return
OpenAIWSIngressModeOff
,
true
}
if
a
.
IsOpenAIOAuth
()
{
if
mode
,
ok
:=
resolveModeString
(
"openai_oauth_responses_websockets_v2_mode"
);
ok
{
return
mode
}
if
mode
,
ok
:=
resolveBoolMode
(
"openai_oauth_responses_websockets_v2_enabled"
);
ok
{
return
mode
}
}
if
a
.
IsOpenAIApiKey
()
{
if
mode
,
ok
:=
resolveModeString
(
"openai_apikey_responses_websockets_v2_mode"
);
ok
{
return
mode
}
if
mode
,
ok
:=
resolveBoolMode
(
"openai_apikey_responses_websockets_v2_enabled"
);
ok
{
return
mode
}
}
if
mode
,
ok
:=
resolveBoolMode
(
"responses_websockets_v2_enabled"
);
ok
{
return
mode
}
if
mode
,
ok
:=
resolveBoolMode
(
"openai_ws_enabled"
);
ok
{
return
mode
}
return
resolvedDefault
}
// IsOpenAIWSForceHTTPEnabled 返回账号级“强制 HTTP”开关。
// 字段:accounts.extra.openai_ws_force_http。
func
(
a
*
Account
)
IsOpenAIWSForceHTTPEnabled
()
bool
{
if
a
==
nil
||
!
a
.
IsOpenAI
()
||
a
.
Extra
==
nil
{
return
false
}
enabled
,
ok
:=
a
.
Extra
[
"openai_ws_force_http"
]
.
(
bool
)
return
ok
&&
enabled
}
// IsOpenAIWSAllowStoreRecoveryEnabled 返回账号级 store 恢复开关。
// 字段:accounts.extra.openai_ws_allow_store_recovery。
func
(
a
*
Account
)
IsOpenAIWSAllowStoreRecoveryEnabled
()
bool
{
if
a
==
nil
||
!
a
.
IsOpenAI
()
||
a
.
Extra
==
nil
{
return
false
}
enabled
,
ok
:=
a
.
Extra
[
"openai_ws_allow_store_recovery"
]
.
(
bool
)
return
ok
&&
enabled
}
// IsOpenAIOAuthPassthroughEnabled 兼容旧接口,等价于 OAuth 账号的 IsOpenAIPassthroughEnabled。
func
(
a
*
Account
)
IsOpenAIOAuthPassthroughEnabled
()
bool
{
return
a
!=
nil
&&
a
.
IsOpenAIOAuth
()
&&
a
.
IsOpenAIPassthroughEnabled
()
...
...
backend/internal/service/account_openai_passthrough_test.go
View file @
bb664d9b
...
...
@@ -134,3 +134,161 @@ func TestAccount_IsCodexCLIOnlyEnabled(t *testing.T) {
require
.
False
(
t
,
otherPlatform
.
IsCodexCLIOnlyEnabled
())
})
}
func
TestAccount_IsOpenAIResponsesWebSocketV2Enabled
(
t
*
testing
.
T
)
{
t
.
Run
(
"OAuth使用OAuth专用开关"
,
func
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{
"openai_oauth_responses_websockets_v2_enabled"
:
true
,
},
}
require
.
True
(
t
,
account
.
IsOpenAIResponsesWebSocketV2Enabled
())
})
t
.
Run
(
"API Key使用API Key专用开关"
,
func
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Extra
:
map
[
string
]
any
{
"openai_apikey_responses_websockets_v2_enabled"
:
true
,
},
}
require
.
True
(
t
,
account
.
IsOpenAIResponsesWebSocketV2Enabled
())
})
t
.
Run
(
"OAuth账号不会读取API Key专用开关"
,
func
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{
"openai_apikey_responses_websockets_v2_enabled"
:
true
,
},
}
require
.
False
(
t
,
account
.
IsOpenAIResponsesWebSocketV2Enabled
())
})
t
.
Run
(
"分类型新键优先于兼容键"
,
func
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{
"openai_oauth_responses_websockets_v2_enabled"
:
false
,
"responses_websockets_v2_enabled"
:
true
,
"openai_ws_enabled"
:
true
,
},
}
require
.
False
(
t
,
account
.
IsOpenAIResponsesWebSocketV2Enabled
())
})
t
.
Run
(
"分类型键缺失时回退兼容键"
,
func
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Extra
:
map
[
string
]
any
{
"responses_websockets_v2_enabled"
:
true
,
},
}
require
.
True
(
t
,
account
.
IsOpenAIResponsesWebSocketV2Enabled
())
})
t
.
Run
(
"非OpenAI账号默认关闭"
,
func
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
Extra
:
map
[
string
]
any
{
"responses_websockets_v2_enabled"
:
true
,
},
}
require
.
False
(
t
,
account
.
IsOpenAIResponsesWebSocketV2Enabled
())
})
}
func
TestAccount_ResolveOpenAIResponsesWebSocketV2Mode
(
t
*
testing
.
T
)
{
t
.
Run
(
"default fallback to shared"
,
func
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{},
}
require
.
Equal
(
t
,
OpenAIWSIngressModeShared
,
account
.
ResolveOpenAIResponsesWebSocketV2Mode
(
""
))
require
.
Equal
(
t
,
OpenAIWSIngressModeShared
,
account
.
ResolveOpenAIResponsesWebSocketV2Mode
(
"invalid"
))
})
t
.
Run
(
"oauth mode field has highest priority"
,
func
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{
"openai_oauth_responses_websockets_v2_mode"
:
OpenAIWSIngressModeDedicated
,
"openai_oauth_responses_websockets_v2_enabled"
:
false
,
"responses_websockets_v2_enabled"
:
false
,
},
}
require
.
Equal
(
t
,
OpenAIWSIngressModeDedicated
,
account
.
ResolveOpenAIResponsesWebSocketV2Mode
(
OpenAIWSIngressModeShared
))
})
t
.
Run
(
"legacy enabled maps to shared"
,
func
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Extra
:
map
[
string
]
any
{
"responses_websockets_v2_enabled"
:
true
,
},
}
require
.
Equal
(
t
,
OpenAIWSIngressModeShared
,
account
.
ResolveOpenAIResponsesWebSocketV2Mode
(
OpenAIWSIngressModeOff
))
})
t
.
Run
(
"legacy disabled maps to off"
,
func
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Extra
:
map
[
string
]
any
{
"openai_apikey_responses_websockets_v2_enabled"
:
false
,
"responses_websockets_v2_enabled"
:
true
,
},
}
require
.
Equal
(
t
,
OpenAIWSIngressModeOff
,
account
.
ResolveOpenAIResponsesWebSocketV2Mode
(
OpenAIWSIngressModeShared
))
})
t
.
Run
(
"non openai always off"
,
func
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{
"openai_oauth_responses_websockets_v2_mode"
:
OpenAIWSIngressModeDedicated
,
},
}
require
.
Equal
(
t
,
OpenAIWSIngressModeOff
,
account
.
ResolveOpenAIResponsesWebSocketV2Mode
(
OpenAIWSIngressModeDedicated
))
})
}
func
TestAccount_OpenAIWSExtraFlags
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{
"openai_ws_force_http"
:
true
,
"openai_ws_allow_store_recovery"
:
true
,
},
}
require
.
True
(
t
,
account
.
IsOpenAIWSForceHTTPEnabled
())
require
.
True
(
t
,
account
.
IsOpenAIWSAllowStoreRecoveryEnabled
())
off
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{}}
require
.
False
(
t
,
off
.
IsOpenAIWSForceHTTPEnabled
())
require
.
False
(
t
,
off
.
IsOpenAIWSAllowStoreRecoveryEnabled
())
var
nilAccount
*
Account
require
.
False
(
t
,
nilAccount
.
IsOpenAIWSAllowStoreRecoveryEnabled
())
nonOpenAI
:=
&
Account
{
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{
"openai_ws_allow_store_recovery"
:
true
,
},
}
require
.
False
(
t
,
nonOpenAI
.
IsOpenAIWSAllowStoreRecoveryEnabled
())
}
backend/internal/service/account_service.go
View file @
bb664d9b
...
...
@@ -119,6 +119,10 @@ type AccountService struct {
groupRepo
GroupRepository
}
type
groupExistenceBatchChecker
interface
{
ExistsByIDs
(
ctx
context
.
Context
,
ids
[]
int64
)
(
map
[
int64
]
bool
,
error
)
}
// NewAccountService 创建账号服务实例
func
NewAccountService
(
accountRepo
AccountRepository
,
groupRepo
GroupRepository
)
*
AccountService
{
return
&
AccountService
{
...
...
@@ -131,11 +135,8 @@ func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository)
func
(
s
*
AccountService
)
Create
(
ctx
context
.
Context
,
req
CreateAccountRequest
)
(
*
Account
,
error
)
{
// 验证分组是否存在(如果指定了分组)
if
len
(
req
.
GroupIDs
)
>
0
{
for
_
,
groupID
:=
range
req
.
GroupIDs
{
_
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
groupID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get group: %w"
,
err
)
}
if
err
:=
s
.
validateGroupIDsExist
(
ctx
,
req
.
GroupIDs
);
err
!=
nil
{
return
nil
,
err
}
}
...
...
@@ -256,11 +257,8 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount
// 先验证分组是否存在(在任何写操作之前)
if
req
.
GroupIDs
!=
nil
{
for
_
,
groupID
:=
range
*
req
.
GroupIDs
{
_
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
groupID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get group: %w"
,
err
)
}
if
err
:=
s
.
validateGroupIDsExist
(
ctx
,
*
req
.
GroupIDs
);
err
!=
nil
{
return
nil
,
err
}
}
...
...
@@ -300,6 +298,39 @@ func (s *AccountService) Delete(ctx context.Context, id int64) error {
return
nil
}
func
(
s
*
AccountService
)
validateGroupIDsExist
(
ctx
context
.
Context
,
groupIDs
[]
int64
)
error
{
if
len
(
groupIDs
)
==
0
{
return
nil
}
if
s
.
groupRepo
==
nil
{
return
fmt
.
Errorf
(
"group repository not configured"
)
}
if
batchChecker
,
ok
:=
s
.
groupRepo
.
(
groupExistenceBatchChecker
);
ok
{
existsByID
,
err
:=
batchChecker
.
ExistsByIDs
(
ctx
,
groupIDs
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"check groups exists: %w"
,
err
)
}
for
_
,
groupID
:=
range
groupIDs
{
if
groupID
<=
0
{
return
fmt
.
Errorf
(
"get group: %w"
,
ErrGroupNotFound
)
}
if
!
existsByID
[
groupID
]
{
return
fmt
.
Errorf
(
"get group: %w"
,
ErrGroupNotFound
)
}
}
return
nil
}
for
_
,
groupID
:=
range
groupIDs
{
_
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
groupID
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"get group: %w"
,
err
)
}
}
return
nil
}
// UpdateStatus 更新账号状态
func
(
s
*
AccountService
)
UpdateStatus
(
ctx
context
.
Context
,
id
int64
,
status
string
,
errorMessage
string
)
error
{
account
,
err
:=
s
.
accountRepo
.
GetByID
(
ctx
,
id
)
...
...
backend/internal/service/account_test_service.go
View file @
bb664d9b
...
...
@@ -598,9 +598,102 @@ func ceilSeconds(d time.Duration) int {
return
sec
}
// testSoraAPIKeyAccountConnection 测试 Sora apikey 类型账号的连通性。
// 向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性和 API Key 有效性。
func
(
s
*
AccountTestService
)
testSoraAPIKeyAccountConnection
(
c
*
gin
.
Context
,
account
*
Account
)
error
{
ctx
:=
c
.
Request
.
Context
()
apiKey
:=
account
.
GetCredential
(
"api_key"
)
if
apiKey
==
""
{
return
s
.
sendErrorAndEnd
(
c
,
"Sora apikey 账号缺少 api_key 凭证"
)
}
baseURL
:=
account
.
GetBaseURL
()
if
baseURL
==
""
{
return
s
.
sendErrorAndEnd
(
c
,
"Sora apikey 账号缺少 base_url"
)
}
// 验证 base_url 格式
normalizedBaseURL
,
err
:=
s
.
validateUpstreamBaseURL
(
baseURL
)
if
err
!=
nil
{
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"base_url 无效: %s"
,
err
.
Error
()))
}
upstreamURL
:=
strings
.
TrimSuffix
(
normalizedBaseURL
,
"/"
)
+
"/sora/v1/chat/completions"
// 设置 SSE 头
c
.
Writer
.
Header
()
.
Set
(
"Content-Type"
,
"text/event-stream"
)
c
.
Writer
.
Header
()
.
Set
(
"Cache-Control"
,
"no-cache"
)
c
.
Writer
.
Header
()
.
Set
(
"Connection"
,
"keep-alive"
)
c
.
Writer
.
Header
()
.
Set
(
"X-Accel-Buffering"
,
"no"
)
c
.
Writer
.
Flush
()
if
wait
,
ok
:=
s
.
acquireSoraTestPermit
(
account
.
ID
);
!
ok
{
msg
:=
fmt
.
Sprintf
(
"Sora 账号测试过于频繁,请 %d 秒后重试"
,
ceilSeconds
(
wait
))
return
s
.
sendErrorAndEnd
(
c
,
msg
)
}
s
.
sendEvent
(
c
,
TestEvent
{
Type
:
"test_start"
,
Model
:
"sora-upstream"
})
// 构建轻量级 prompt-enhance 请求作为连通性测试
testPayload
:=
map
[
string
]
any
{
"model"
:
"prompt-enhance-short-10s"
,
"messages"
:
[]
map
[
string
]
string
{{
"role"
:
"user"
,
"content"
:
"test"
}},
"stream"
:
false
,
}
payloadBytes
,
_
:=
json
.
Marshal
(
testPayload
)
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodPost
,
upstreamURL
,
bytes
.
NewReader
(
payloadBytes
))
if
err
!=
nil
{
return
s
.
sendErrorAndEnd
(
c
,
"构建测试请求失败"
)
}
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
apiKey
)
// 获取代理 URL
proxyURL
:=
""
if
account
.
ProxyID
!=
nil
&&
account
.
Proxy
!=
nil
{
proxyURL
=
account
.
Proxy
.
URL
()
}
resp
,
err
:=
s
.
httpUpstream
.
Do
(
req
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
err
!=
nil
{
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"上游连接失败: %s"
,
err
.
Error
()))
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
64
*
1024
))
if
resp
.
StatusCode
==
http
.
StatusOK
{
s
.
sendEvent
(
c
,
TestEvent
{
Type
:
"content"
,
Text
:
fmt
.
Sprintf
(
"上游连接成功 (%s)"
,
upstreamURL
)})
s
.
sendEvent
(
c
,
TestEvent
{
Type
:
"content"
,
Text
:
fmt
.
Sprintf
(
"API Key 有效 (HTTP %d)"
,
resp
.
StatusCode
)})
s
.
sendEvent
(
c
,
TestEvent
{
Type
:
"test_complete"
,
Success
:
true
})
return
nil
}
if
resp
.
StatusCode
==
http
.
StatusUnauthorized
||
resp
.
StatusCode
==
http
.
StatusForbidden
{
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"上游认证失败 (HTTP %d),请检查 API Key 是否正确"
,
resp
.
StatusCode
))
}
// 其他错误但能连通(如 400 参数错误)也算连通性测试通过
if
resp
.
StatusCode
==
http
.
StatusBadRequest
{
s
.
sendEvent
(
c
,
TestEvent
{
Type
:
"content"
,
Text
:
fmt
.
Sprintf
(
"上游连接成功 (%s)"
,
upstreamURL
)})
s
.
sendEvent
(
c
,
TestEvent
{
Type
:
"content"
,
Text
:
fmt
.
Sprintf
(
"API Key 有效(上游返回 %d,参数校验错误属正常)"
,
resp
.
StatusCode
)})
s
.
sendEvent
(
c
,
TestEvent
{
Type
:
"test_complete"
,
Success
:
true
})
return
nil
}
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"上游返回异常 HTTP %d: %s"
,
resp
.
StatusCode
,
truncateSoraErrorBody
(
respBody
,
256
)))
}
// testSoraAccountConnection 测试 Sora 账号的连接
// 调用 /backend/me 接口验证 access_token 有效性(不需要 Sentinel Token)
// OAuth 类型:调用 /backend/me 接口验证 access_token 有效性
// APIKey 类型:向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性
func
(
s
*
AccountTestService
)
testSoraAccountConnection
(
c
*
gin
.
Context
,
account
*
Account
)
error
{
// apikey 类型走独立测试流程
if
account
.
Type
==
AccountTypeAPIKey
{
return
s
.
testSoraAPIKeyAccountConnection
(
c
,
account
)
}
ctx
:=
c
.
Request
.
Context
()
recorder
:=
&
soraProbeRecorder
{}
...
...
backend/internal/service/account_usage_service.go
View file @
bb664d9b
...
...
@@ -9,7 +9,9 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"golang.org/x/sync/errgroup"
)
type
UsageLogRepository
interface
{
...
...
@@ -33,8 +35,8 @@ type UsageLogRepository interface {
// Admin dashboard stats
GetDashboardStats
(
ctx
context
.
Context
)
(
*
usagestats
.
DashboardStats
,
error
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
TrendDataPoint
,
error
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
ModelStat
,
error
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
requestType
*
int16
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
TrendDataPoint
,
error
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
requestType
*
int16
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
ModelStat
,
error
)
GetAPIKeyUsageTrend
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
limit
int
)
([]
usagestats
.
APIKeyUsageTrendPoint
,
error
)
GetUserUsageTrend
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
limit
int
)
([]
usagestats
.
UserUsageTrendPoint
,
error
)
GetBatchUserUsageStats
(
ctx
context
.
Context
,
userIDs
[]
int64
,
startTime
,
endTime
time
.
Time
)
(
map
[
int64
]
*
usagestats
.
BatchUserUsageStats
,
error
)
...
...
@@ -62,6 +64,10 @@ type UsageLogRepository interface {
GetDailyStatsAggregated
(
ctx
context
.
Context
,
userID
int64
,
startTime
,
endTime
time
.
Time
)
([]
map
[
string
]
any
,
error
)
}
type
accountWindowStatsBatchReader
interface
{
GetAccountWindowStatsBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
,
startTime
time
.
Time
)
(
map
[
int64
]
*
usagestats
.
AccountStats
,
error
)
}
// apiUsageCache 缓存从 Anthropic API 获取的使用率数据(utilization, resets_at)
type
apiUsageCache
struct
{
response
*
ClaudeUsageResponse
...
...
@@ -297,7 +303,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
}
dayStart
:=
geminiDailyWindowStart
(
now
)
stats
,
err
:=
s
.
usageLogRepo
.
GetModelStatsWithFilters
(
ctx
,
dayStart
,
now
,
0
,
0
,
account
.
ID
,
0
,
nil
,
nil
)
stats
,
err
:=
s
.
usageLogRepo
.
GetModelStatsWithFilters
(
ctx
,
dayStart
,
now
,
0
,
0
,
account
.
ID
,
0
,
nil
,
nil
,
nil
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get gemini usage stats failed: %w"
,
err
)
}
...
...
@@ -319,7 +325,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
// Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m)
minuteStart
:=
now
.
Truncate
(
time
.
Minute
)
minuteResetAt
:=
minuteStart
.
Add
(
time
.
Minute
)
minuteStats
,
err
:=
s
.
usageLogRepo
.
GetModelStatsWithFilters
(
ctx
,
minuteStart
,
now
,
0
,
0
,
account
.
ID
,
0
,
nil
,
nil
)
minuteStats
,
err
:=
s
.
usageLogRepo
.
GetModelStatsWithFilters
(
ctx
,
minuteStart
,
now
,
0
,
0
,
account
.
ID
,
0
,
nil
,
nil
,
nil
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get gemini minute usage stats failed: %w"
,
err
)
}
...
...
@@ -440,6 +446,78 @@ func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64
},
nil
}
// GetTodayStatsBatch 批量获取账号今日统计,优先走批量 SQL,失败时回退单账号查询。
func
(
s
*
AccountUsageService
)
GetTodayStatsBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
)
(
map
[
int64
]
*
WindowStats
,
error
)
{
uniqueIDs
:=
make
([]
int64
,
0
,
len
(
accountIDs
))
seen
:=
make
(
map
[
int64
]
struct
{},
len
(
accountIDs
))
for
_
,
accountID
:=
range
accountIDs
{
if
accountID
<=
0
{
continue
}
if
_
,
exists
:=
seen
[
accountID
];
exists
{
continue
}
seen
[
accountID
]
=
struct
{}{}
uniqueIDs
=
append
(
uniqueIDs
,
accountID
)
}
result
:=
make
(
map
[
int64
]
*
WindowStats
,
len
(
uniqueIDs
))
if
len
(
uniqueIDs
)
==
0
{
return
result
,
nil
}
startTime
:=
timezone
.
Today
()
if
batchReader
,
ok
:=
s
.
usageLogRepo
.
(
accountWindowStatsBatchReader
);
ok
{
statsByAccount
,
err
:=
batchReader
.
GetAccountWindowStatsBatch
(
ctx
,
uniqueIDs
,
startTime
)
if
err
==
nil
{
for
_
,
accountID
:=
range
uniqueIDs
{
result
[
accountID
]
=
windowStatsFromAccountStats
(
statsByAccount
[
accountID
])
}
return
result
,
nil
}
}
var
mu
sync
.
Mutex
g
,
gctx
:=
errgroup
.
WithContext
(
ctx
)
g
.
SetLimit
(
8
)
for
_
,
accountID
:=
range
uniqueIDs
{
id
:=
accountID
g
.
Go
(
func
()
error
{
stats
,
err
:=
s
.
usageLogRepo
.
GetAccountWindowStats
(
gctx
,
id
,
startTime
)
if
err
!=
nil
{
return
nil
}
mu
.
Lock
()
result
[
id
]
=
windowStatsFromAccountStats
(
stats
)
mu
.
Unlock
()
return
nil
})
}
_
=
g
.
Wait
()
for
_
,
accountID
:=
range
uniqueIDs
{
if
_
,
ok
:=
result
[
accountID
];
!
ok
{
result
[
accountID
]
=
&
WindowStats
{}
}
}
return
result
,
nil
}
func
windowStatsFromAccountStats
(
stats
*
usagestats
.
AccountStats
)
*
WindowStats
{
if
stats
==
nil
{
return
&
WindowStats
{}
}
return
&
WindowStats
{
Requests
:
stats
.
Requests
,
Tokens
:
stats
.
Tokens
,
Cost
:
stats
.
Cost
,
StandardCost
:
stats
.
StandardCost
,
UserCost
:
stats
.
UserCost
,
}
}
func
(
s
*
AccountUsageService
)
GetAccountUsageStats
(
ctx
context
.
Context
,
accountID
int64
,
startTime
,
endTime
time
.
Time
)
(
*
usagestats
.
AccountUsageStatsResponse
,
error
)
{
stats
,
err
:=
s
.
usageLogRepo
.
GetAccountUsageStats
(
ctx
,
accountID
,
startTime
,
endTime
)
if
err
!=
nil
{
...
...
backend/internal/service/account_wildcard_test.go
View file @
bb664d9b
...
...
@@ -314,3 +314,72 @@ func TestAccountGetModelMapping_AntigravityRespectsWildcardOverride(t *testing.T
t
.
Fatalf
(
"expected wildcard mapping to stay effective, got: %q"
,
mapped
)
}
}
func
TestAccountGetModelMapping_CacheInvalidatesOnCredentialsReplace
(
t
*
testing
.
T
)
{
account
:=
&
Account
{
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"claude-3-5-sonnet"
:
"upstream-a"
,
},
},
}
first
:=
account
.
GetModelMapping
()
if
first
[
"claude-3-5-sonnet"
]
!=
"upstream-a"
{
t
.
Fatalf
(
"unexpected first mapping: %v"
,
first
)
}
account
.
Credentials
=
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"claude-3-5-sonnet"
:
"upstream-b"
,
},
}
second
:=
account
.
GetModelMapping
()
if
second
[
"claude-3-5-sonnet"
]
!=
"upstream-b"
{
t
.
Fatalf
(
"expected cache invalidated after credentials replace, got: %v"
,
second
)
}
}
func
TestAccountGetModelMapping_CacheInvalidatesOnMappingLenChange
(
t
*
testing
.
T
)
{
rawMapping
:=
map
[
string
]
any
{
"claude-sonnet"
:
"sonnet-a"
,
}
account
:=
&
Account
{
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
rawMapping
,
},
}
first
:=
account
.
GetModelMapping
()
if
len
(
first
)
!=
1
{
t
.
Fatalf
(
"unexpected first mapping length: %d"
,
len
(
first
))
}
rawMapping
[
"claude-opus"
]
=
"opus-b"
second
:=
account
.
GetModelMapping
()
if
second
[
"claude-opus"
]
!=
"opus-b"
{
t
.
Fatalf
(
"expected cache invalidated after mapping len change, got: %v"
,
second
)
}
}
func
TestAccountGetModelMapping_CacheInvalidatesOnInPlaceValueChange
(
t
*
testing
.
T
)
{
rawMapping
:=
map
[
string
]
any
{
"claude-sonnet"
:
"sonnet-a"
,
}
account
:=
&
Account
{
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
rawMapping
,
},
}
first
:=
account
.
GetModelMapping
()
if
first
[
"claude-sonnet"
]
!=
"sonnet-a"
{
t
.
Fatalf
(
"unexpected first mapping: %v"
,
first
)
}
rawMapping
[
"claude-sonnet"
]
=
"sonnet-b"
second
:=
account
.
GetModelMapping
()
if
second
[
"claude-sonnet"
]
!=
"sonnet-b"
{
t
.
Fatalf
(
"expected cache invalidated after in-place value change, got: %v"
,
second
)
}
}
backend/internal/service/admin_service.go
View file @
bb664d9b
...
...
@@ -83,13 +83,14 @@ type AdminService interface {
// CreateUserInput represents input for creating a new user via admin operations.
type
CreateUserInput
struct
{
Email
string
Password
string
Username
string
Notes
string
Balance
float64
Concurrency
int
AllowedGroups
[]
int64
Email
string
Password
string
Username
string
Notes
string
Balance
float64
Concurrency
int
AllowedGroups
[]
int64
SoraStorageQuotaBytes
int64
}
type
UpdateUserInput
struct
{
...
...
@@ -103,7 +104,8 @@ type UpdateUserInput struct {
AllowedGroups
*
[]
int64
// 使用指针区分"未提供"和"设置为空数组"
// GroupRates 用户专属分组倍率配置
// map[groupID]*rate,nil 表示删除该分组的专属倍率
GroupRates
map
[
int64
]
*
float64
GroupRates
map
[
int64
]
*
float64
SoraStorageQuotaBytes
*
int64
}
type
CreateGroupInput
struct
{
...
...
@@ -135,6 +137,8 @@ type CreateGroupInput struct {
MCPXMLInject
*
bool
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes
[]
string
// Sora 存储配额
SoraStorageQuotaBytes
int64
// 从指定分组复制账号(创建分组后在同一事务内绑定)
CopyAccountsFromGroupIDs
[]
int64
}
...
...
@@ -169,6 +173,8 @@ type UpdateGroupInput struct {
MCPXMLInject
*
bool
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes
*
[]
string
// Sora 存储配额
SoraStorageQuotaBytes
*
int64
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs
[]
int64
}
...
...
@@ -402,6 +408,14 @@ type adminServiceImpl struct {
authCacheInvalidator
APIKeyAuthCacheInvalidator
}
type
userGroupRateBatchReader
interface
{
GetByUserIDs
(
ctx
context
.
Context
,
userIDs
[]
int64
)
(
map
[
int64
]
map
[
int64
]
float64
,
error
)
}
type
groupExistenceBatchReader
interface
{
ExistsByIDs
(
ctx
context
.
Context
,
ids
[]
int64
)
(
map
[
int64
]
bool
,
error
)
}
// NewAdminService creates a new AdminService
func
NewAdminService
(
userRepo
UserRepository
,
...
...
@@ -442,18 +456,43 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi
}
// 批量加载用户专属分组倍率
if
s
.
userGroupRateRepo
!=
nil
&&
len
(
users
)
>
0
{
for
i
:=
range
users
{
rates
,
err
:=
s
.
userGroupRateRepo
.
GetByUserID
(
ctx
,
users
[
i
]
.
ID
)
if
batchRepo
,
ok
:=
s
.
userGroupRateRepo
.
(
userGroupRateBatchReader
);
ok
{
userIDs
:=
make
([]
int64
,
0
,
len
(
users
))
for
i
:=
range
users
{
userIDs
=
append
(
userIDs
,
users
[
i
]
.
ID
)
}
ratesByUser
,
err
:=
batchRepo
.
GetByUserIDs
(
ctx
,
userIDs
)
if
err
!=
nil
{
logger
.
LegacyPrintf
(
"service.admin"
,
"failed to load user group rates: user_id=%d err=%v"
,
users
[
i
]
.
ID
,
err
)
continue
logger
.
LegacyPrintf
(
"service.admin"
,
"failed to load user group rates in batch: err=%v"
,
err
)
s
.
loadUserGroupRatesOneByOne
(
ctx
,
users
)
}
else
{
for
i
:=
range
users
{
if
rates
,
ok
:=
ratesByUser
[
users
[
i
]
.
ID
];
ok
{
users
[
i
]
.
GroupRates
=
rates
}
}
}
users
[
i
]
.
GroupRates
=
rates
}
else
{
s
.
loadUserGroupRatesOneByOne
(
ctx
,
users
)
}
}
return
users
,
result
.
Total
,
nil
}
func
(
s
*
adminServiceImpl
)
loadUserGroupRatesOneByOne
(
ctx
context
.
Context
,
users
[]
User
)
{
if
s
.
userGroupRateRepo
==
nil
{
return
}
for
i
:=
range
users
{
rates
,
err
:=
s
.
userGroupRateRepo
.
GetByUserID
(
ctx
,
users
[
i
]
.
ID
)
if
err
!=
nil
{
logger
.
LegacyPrintf
(
"service.admin"
,
"failed to load user group rates: user_id=%d err=%v"
,
users
[
i
]
.
ID
,
err
)
continue
}
users
[
i
]
.
GroupRates
=
rates
}
}
func
(
s
*
adminServiceImpl
)
GetUser
(
ctx
context
.
Context
,
id
int64
)
(
*
User
,
error
)
{
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
...
...
@@ -473,14 +512,15 @@ func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error)
func
(
s
*
adminServiceImpl
)
CreateUser
(
ctx
context
.
Context
,
input
*
CreateUserInput
)
(
*
User
,
error
)
{
user
:=
&
User
{
Email
:
input
.
Email
,
Username
:
input
.
Username
,
Notes
:
input
.
Notes
,
Role
:
RoleUser
,
// Always create as regular user, never admin
Balance
:
input
.
Balance
,
Concurrency
:
input
.
Concurrency
,
Status
:
StatusActive
,
AllowedGroups
:
input
.
AllowedGroups
,
Email
:
input
.
Email
,
Username
:
input
.
Username
,
Notes
:
input
.
Notes
,
Role
:
RoleUser
,
// Always create as regular user, never admin
Balance
:
input
.
Balance
,
Concurrency
:
input
.
Concurrency
,
Status
:
StatusActive
,
AllowedGroups
:
input
.
AllowedGroups
,
SoraStorageQuotaBytes
:
input
.
SoraStorageQuotaBytes
,
}
if
err
:=
user
.
SetPassword
(
input
.
Password
);
err
!=
nil
{
return
nil
,
err
...
...
@@ -534,6 +574,10 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
user
.
AllowedGroups
=
*
input
.
AllowedGroups
}
if
input
.
SoraStorageQuotaBytes
!=
nil
{
user
.
SoraStorageQuotaBytes
=
*
input
.
SoraStorageQuotaBytes
}
if
err
:=
s
.
userRepo
.
Update
(
ctx
,
user
);
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -820,6 +864,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
ModelRouting
:
input
.
ModelRouting
,
MCPXMLInject
:
mcpXMLInject
,
SupportedModelScopes
:
input
.
SupportedModelScopes
,
SoraStorageQuotaBytes
:
input
.
SoraStorageQuotaBytes
,
}
if
err
:=
s
.
groupRepo
.
Create
(
ctx
,
group
);
err
!=
nil
{
return
nil
,
err
...
...
@@ -982,6 +1027,9 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if
input
.
SoraVideoPricePerRequestHD
!=
nil
{
group
.
SoraVideoPricePerRequestHD
=
normalizePrice
(
input
.
SoraVideoPricePerRequestHD
)
}
if
input
.
SoraStorageQuotaBytes
!=
nil
{
group
.
SoraStorageQuotaBytes
=
*
input
.
SoraStorageQuotaBytes
}
// Claude Code 客户端限制
if
input
.
ClaudeCodeOnly
!=
nil
{
...
...
@@ -1188,6 +1236,18 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
}
}
// Sora apikey 账号的 base_url 必填校验
if
input
.
Platform
==
PlatformSora
&&
input
.
Type
==
AccountTypeAPIKey
{
baseURL
,
_
:=
input
.
Credentials
[
"base_url"
]
.
(
string
)
baseURL
=
strings
.
TrimSpace
(
baseURL
)
if
baseURL
==
""
{
return
nil
,
errors
.
New
(
"sora apikey 账号必须设置 base_url"
)
}
if
!
strings
.
HasPrefix
(
baseURL
,
"http://"
)
&&
!
strings
.
HasPrefix
(
baseURL
,
"https://"
)
{
return
nil
,
errors
.
New
(
"base_url 必须以 http:// 或 https:// 开头"
)
}
}
account
:=
&
Account
{
Name
:
input
.
Name
,
Notes
:
normalizeAccountNotes
(
input
.
Notes
),
...
...
@@ -1301,12 +1361,22 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
account
.
AutoPauseOnExpired
=
*
input
.
AutoPauseOnExpired
}
// Sora apikey 账号的 base_url 必填校验
if
account
.
Platform
==
PlatformSora
&&
account
.
Type
==
AccountTypeAPIKey
{
baseURL
,
_
:=
account
.
Credentials
[
"base_url"
]
.
(
string
)
baseURL
=
strings
.
TrimSpace
(
baseURL
)
if
baseURL
==
""
{
return
nil
,
errors
.
New
(
"sora apikey 账号必须设置 base_url"
)
}
if
!
strings
.
HasPrefix
(
baseURL
,
"http://"
)
&&
!
strings
.
HasPrefix
(
baseURL
,
"https://"
)
{
return
nil
,
errors
.
New
(
"base_url 必须以 http:// 或 https:// 开头"
)
}
}
// 先验证分组是否存在(在任何写操作之前)
if
input
.
GroupIDs
!=
nil
{
for
_
,
groupID
:=
range
*
input
.
GroupIDs
{
if
_
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
groupID
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get group: %w"
,
err
)
}
if
err
:=
s
.
validateGroupIDsExist
(
ctx
,
*
input
.
GroupIDs
);
err
!=
nil
{
return
nil
,
err
}
// 检查混合渠道风险(除非用户已确认)
...
...
@@ -1348,11 +1418,18 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
if
len
(
input
.
AccountIDs
)
==
0
{
return
result
,
nil
}
if
input
.
GroupIDs
!=
nil
{
if
err
:=
s
.
validateGroupIDsExist
(
ctx
,
*
input
.
GroupIDs
);
err
!=
nil
{
return
nil
,
err
}
}
needMixedChannelCheck
:=
input
.
GroupIDs
!=
nil
&&
!
input
.
SkipMixedChannelCheck
// 预加载账号平台信息(混合渠道检查或 Sora 同步需要)。
platformByID
:=
map
[
int64
]
string
{}
groupAccountsByID
:=
map
[
int64
][]
Account
{}
groupNameByID
:=
map
[
int64
]
string
{}
if
needMixedChannelCheck
{
accounts
,
err
:=
s
.
accountRepo
.
GetByIDs
(
ctx
,
input
.
AccountIDs
)
if
err
!=
nil
{
...
...
@@ -1366,6 +1443,13 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
}
}
}
loadedAccounts
,
loadedNames
,
err
:=
s
.
preloadMixedChannelRiskData
(
ctx
,
*
input
.
GroupIDs
)
if
err
!=
nil
{
return
nil
,
err
}
groupAccountsByID
=
loadedAccounts
groupNameByID
=
loadedNames
}
if
input
.
RateMultiplier
!=
nil
{
...
...
@@ -1409,11 +1493,12 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
// Handle group bindings per account (requires individual operations).
for
_
,
accountID
:=
range
input
.
AccountIDs
{
entry
:=
BulkUpdateAccountResult
{
AccountID
:
accountID
}
platform
:=
""
if
input
.
GroupIDs
!=
nil
{
// 检查混合渠道风险(除非用户已确认)
if
!
input
.
SkipMixedChannelCheck
{
platform
:
=
platformByID
[
accountID
]
platform
=
platformByID
[
accountID
]
if
platform
==
""
{
account
,
err
:=
s
.
accountRepo
.
GetByID
(
ctx
,
accountID
)
if
err
!=
nil
{
...
...
@@ -1426,7 +1511,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
}
platform
=
account
.
Platform
}
if
err
:=
s
.
checkMixedChannelRisk
(
ctx
,
accountID
,
platform
,
*
input
.
GroupIDs
);
err
!=
nil
{
if
err
:=
s
.
checkMixedChannelRisk
WithPreloaded
(
accountID
,
platform
,
*
input
.
GroupIDs
,
groupAccountsByID
,
groupNameByID
);
err
!=
nil
{
entry
.
Success
=
false
entry
.
Error
=
err
.
Error
()
result
.
Failed
++
...
...
@@ -1444,6 +1529,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
result
.
Results
=
append
(
result
.
Results
,
entry
)
continue
}
if
!
input
.
SkipMixedChannelCheck
&&
platform
!=
""
{
updateMixedChannelPreloadedAccounts
(
groupAccountsByID
,
*
input
.
GroupIDs
,
accountID
,
platform
)
}
}
entry
.
Success
=
true
...
...
@@ -2115,6 +2203,135 @@ func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAcc
return
nil
}
func
(
s
*
adminServiceImpl
)
preloadMixedChannelRiskData
(
ctx
context
.
Context
,
groupIDs
[]
int64
)
(
map
[
int64
][]
Account
,
map
[
int64
]
string
,
error
)
{
accountsByGroup
:=
make
(
map
[
int64
][]
Account
)
groupNameByID
:=
make
(
map
[
int64
]
string
)
if
len
(
groupIDs
)
==
0
{
return
accountsByGroup
,
groupNameByID
,
nil
}
seen
:=
make
(
map
[
int64
]
struct
{},
len
(
groupIDs
))
for
_
,
groupID
:=
range
groupIDs
{
if
groupID
<=
0
{
continue
}
if
_
,
ok
:=
seen
[
groupID
];
ok
{
continue
}
seen
[
groupID
]
=
struct
{}{}
accounts
,
err
:=
s
.
accountRepo
.
ListByGroup
(
ctx
,
groupID
)
if
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"get accounts in group %d: %w"
,
groupID
,
err
)
}
accountsByGroup
[
groupID
]
=
accounts
group
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
groupID
)
if
err
!=
nil
{
continue
}
if
group
!=
nil
{
groupNameByID
[
groupID
]
=
group
.
Name
}
}
return
accountsByGroup
,
groupNameByID
,
nil
}
func
(
s
*
adminServiceImpl
)
validateGroupIDsExist
(
ctx
context
.
Context
,
groupIDs
[]
int64
)
error
{
if
len
(
groupIDs
)
==
0
{
return
nil
}
if
s
.
groupRepo
==
nil
{
return
errors
.
New
(
"group repository not configured"
)
}
if
batchReader
,
ok
:=
s
.
groupRepo
.
(
groupExistenceBatchReader
);
ok
{
existsByID
,
err
:=
batchReader
.
ExistsByIDs
(
ctx
,
groupIDs
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"check groups exists: %w"
,
err
)
}
for
_
,
groupID
:=
range
groupIDs
{
if
groupID
<=
0
||
!
existsByID
[
groupID
]
{
return
fmt
.
Errorf
(
"get group: %w"
,
ErrGroupNotFound
)
}
}
return
nil
}
for
_
,
groupID
:=
range
groupIDs
{
if
_
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
groupID
);
err
!=
nil
{
return
fmt
.
Errorf
(
"get group: %w"
,
err
)
}
}
return
nil
}
func
(
s
*
adminServiceImpl
)
checkMixedChannelRiskWithPreloaded
(
currentAccountID
int64
,
currentAccountPlatform
string
,
groupIDs
[]
int64
,
accountsByGroup
map
[
int64
][]
Account
,
groupNameByID
map
[
int64
]
string
)
error
{
currentPlatform
:=
getAccountPlatform
(
currentAccountPlatform
)
if
currentPlatform
==
""
{
return
nil
}
for
_
,
groupID
:=
range
groupIDs
{
accounts
:=
accountsByGroup
[
groupID
]
for
_
,
account
:=
range
accounts
{
if
currentAccountID
>
0
&&
account
.
ID
==
currentAccountID
{
continue
}
otherPlatform
:=
getAccountPlatform
(
account
.
Platform
)
if
otherPlatform
==
""
{
continue
}
if
currentPlatform
!=
otherPlatform
{
groupName
:=
fmt
.
Sprintf
(
"Group %d"
,
groupID
)
if
name
:=
strings
.
TrimSpace
(
groupNameByID
[
groupID
]);
name
!=
""
{
groupName
=
name
}
return
&
MixedChannelError
{
GroupID
:
groupID
,
GroupName
:
groupName
,
CurrentPlatform
:
currentPlatform
,
OtherPlatform
:
otherPlatform
,
}
}
}
}
return
nil
}
func
updateMixedChannelPreloadedAccounts
(
accountsByGroup
map
[
int64
][]
Account
,
groupIDs
[]
int64
,
accountID
int64
,
platform
string
)
{
if
len
(
groupIDs
)
==
0
||
accountID
<=
0
||
platform
==
""
{
return
}
for
_
,
groupID
:=
range
groupIDs
{
if
groupID
<=
0
{
continue
}
accounts
:=
accountsByGroup
[
groupID
]
found
:=
false
for
i
:=
range
accounts
{
if
accounts
[
i
]
.
ID
!=
accountID
{
continue
}
accounts
[
i
]
.
Platform
=
platform
found
=
true
break
}
if
!
found
{
accounts
=
append
(
accounts
,
Account
{
ID
:
accountID
,
Platform
:
platform
,
})
}
accountsByGroup
[
groupID
]
=
accounts
}
}
// CheckMixedChannelRisk checks whether target groups contain mixed channels for the current account platform.
func
(
s
*
adminServiceImpl
)
CheckMixedChannelRisk
(
ctx
context
.
Context
,
currentAccountID
int64
,
currentAccountPlatform
string
,
groupIDs
[]
int64
)
error
{
return
s
.
checkMixedChannelRisk
(
ctx
,
currentAccountID
,
currentAccountPlatform
,
groupIDs
)
...
...
backend/internal/service/admin_service_bulk_update_test.go
View file @
bb664d9b
...
...
@@ -15,6 +15,7 @@ type accountRepoStubForBulkUpdate struct {
bulkUpdateErr
error
bulkUpdateIDs
[]
int64
bindGroupErrByID
map
[
int64
]
error
bindGroupsCalls
[]
int64
getByIDsAccounts
[]
*
Account
getByIDsErr
error
getByIDsCalled
bool
...
...
@@ -22,6 +23,8 @@ type accountRepoStubForBulkUpdate struct {
getByIDAccounts
map
[
int64
]
*
Account
getByIDErrByID
map
[
int64
]
error
getByIDCalled
[]
int64
listByGroupData
map
[
int64
][]
Account
listByGroupErr
map
[
int64
]
error
}
func
(
s
*
accountRepoStubForBulkUpdate
)
BulkUpdate
(
_
context
.
Context
,
ids
[]
int64
,
_
AccountBulkUpdate
)
(
int64
,
error
)
{
...
...
@@ -33,6 +36,7 @@ func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64
}
func
(
s
*
accountRepoStubForBulkUpdate
)
BindGroups
(
_
context
.
Context
,
accountID
int64
,
_
[]
int64
)
error
{
s
.
bindGroupsCalls
=
append
(
s
.
bindGroupsCalls
,
accountID
)
if
err
,
ok
:=
s
.
bindGroupErrByID
[
accountID
];
ok
{
return
err
}
...
...
@@ -59,6 +63,16 @@ func (s *accountRepoStubForBulkUpdate) GetByID(_ context.Context, id int64) (*Ac
return
nil
,
errors
.
New
(
"account not found"
)
}
func
(
s
*
accountRepoStubForBulkUpdate
)
ListByGroup
(
_
context
.
Context
,
groupID
int64
)
([]
Account
,
error
)
{
if
err
,
ok
:=
s
.
listByGroupErr
[
groupID
];
ok
{
return
nil
,
err
}
if
rows
,
ok
:=
s
.
listByGroupData
[
groupID
];
ok
{
return
rows
,
nil
}
return
nil
,
nil
}
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
func
TestAdminService_BulkUpdateAccounts_AllSuccessIDs
(
t
*
testing
.
T
)
{
repo
:=
&
accountRepoStubForBulkUpdate
{}
...
...
@@ -86,7 +100,10 @@ func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) {
2
:
errors
.
New
(
"bind failed"
),
},
}
svc
:=
&
adminServiceImpl
{
accountRepo
:
repo
}
svc
:=
&
adminServiceImpl
{
accountRepo
:
repo
,
groupRepo
:
&
groupRepoStubForAdmin
{
getByID
:
&
Group
{
ID
:
10
,
Name
:
"g10"
}},
}
groupIDs
:=
[]
int64
{
10
}
schedulable
:=
false
...
...
@@ -105,3 +122,51 @@ func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) {
require
.
ElementsMatch
(
t
,
[]
int64
{
2
},
result
.
FailedIDs
)
require
.
Len
(
t
,
result
.
Results
,
3
)
}
func
TestAdminService_BulkUpdateAccounts_NilGroupRepoReturnsError
(
t
*
testing
.
T
)
{
repo
:=
&
accountRepoStubForBulkUpdate
{}
svc
:=
&
adminServiceImpl
{
accountRepo
:
repo
}
groupIDs
:=
[]
int64
{
10
}
input
:=
&
BulkUpdateAccountsInput
{
AccountIDs
:
[]
int64
{
1
},
GroupIDs
:
&
groupIDs
,
}
result
,
err
:=
svc
.
BulkUpdateAccounts
(
context
.
Background
(),
input
)
require
.
Nil
(
t
,
result
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"group repository not configured"
)
}
func
TestAdminService_BulkUpdateAccounts_MixedChannelCheckUsesUpdatedSnapshot
(
t
*
testing
.
T
)
{
repo
:=
&
accountRepoStubForBulkUpdate
{
getByIDsAccounts
:
[]
*
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
},
{
ID
:
2
,
Platform
:
PlatformAntigravity
},
},
listByGroupData
:
map
[
int64
][]
Account
{
10
:
{},
},
}
svc
:=
&
adminServiceImpl
{
accountRepo
:
repo
,
groupRepo
:
&
groupRepoStubForAdmin
{
getByID
:
&
Group
{
ID
:
10
,
Name
:
"目标分组"
}},
}
groupIDs
:=
[]
int64
{
10
}
input
:=
&
BulkUpdateAccountsInput
{
AccountIDs
:
[]
int64
{
1
,
2
},
GroupIDs
:
&
groupIDs
,
}
result
,
err
:=
svc
.
BulkUpdateAccounts
(
context
.
Background
(),
input
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
result
.
Success
)
require
.
Equal
(
t
,
1
,
result
.
Failed
)
require
.
ElementsMatch
(
t
,
[]
int64
{
1
},
result
.
SuccessIDs
)
require
.
ElementsMatch
(
t
,
[]
int64
{
2
},
result
.
FailedIDs
)
require
.
Len
(
t
,
result
.
Results
,
2
)
require
.
Contains
(
t
,
result
.
Results
[
1
]
.
Error
,
"mixed channel"
)
require
.
Equal
(
t
,
[]
int64
{
1
},
repo
.
bindGroupsCalls
)
}
backend/internal/service/admin_service_list_users_test.go
0 → 100644
View file @
bb664d9b
//go:build unit
package
service
import
(
"context"
"errors"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type
userRepoStubForListUsers
struct
{
userRepoStub
users
[]
User
err
error
}
func
(
s
*
userRepoStubForListUsers
)
ListWithFilters
(
_
context
.
Context
,
params
pagination
.
PaginationParams
,
_
UserListFilters
)
([]
User
,
*
pagination
.
PaginationResult
,
error
)
{
if
s
.
err
!=
nil
{
return
nil
,
nil
,
s
.
err
}
out
:=
make
([]
User
,
len
(
s
.
users
))
copy
(
out
,
s
.
users
)
return
out
,
&
pagination
.
PaginationResult
{
Total
:
int64
(
len
(
out
)),
Page
:
params
.
Page
,
PageSize
:
params
.
PageSize
,
},
nil
}
type
userGroupRateRepoStubForListUsers
struct
{
batchCalls
int
singleCall
[]
int64
batchErr
error
batchData
map
[
int64
]
map
[
int64
]
float64
singleErr
map
[
int64
]
error
singleData
map
[
int64
]
map
[
int64
]
float64
}
func
(
s
*
userGroupRateRepoStubForListUsers
)
GetByUserIDs
(
_
context
.
Context
,
_
[]
int64
)
(
map
[
int64
]
map
[
int64
]
float64
,
error
)
{
s
.
batchCalls
++
if
s
.
batchErr
!=
nil
{
return
nil
,
s
.
batchErr
}
return
s
.
batchData
,
nil
}
func
(
s
*
userGroupRateRepoStubForListUsers
)
GetByUserID
(
_
context
.
Context
,
userID
int64
)
(
map
[
int64
]
float64
,
error
)
{
s
.
singleCall
=
append
(
s
.
singleCall
,
userID
)
if
err
,
ok
:=
s
.
singleErr
[
userID
];
ok
{
return
nil
,
err
}
if
rates
,
ok
:=
s
.
singleData
[
userID
];
ok
{
return
rates
,
nil
}
return
map
[
int64
]
float64
{},
nil
}
func
(
s
*
userGroupRateRepoStubForListUsers
)
GetByUserAndGroup
(
_
context
.
Context
,
userID
,
groupID
int64
)
(
*
float64
,
error
)
{
panic
(
"unexpected GetByUserAndGroup call"
)
}
func
(
s
*
userGroupRateRepoStubForListUsers
)
SyncUserGroupRates
(
_
context
.
Context
,
userID
int64
,
rates
map
[
int64
]
*
float64
)
error
{
panic
(
"unexpected SyncUserGroupRates call"
)
}
func
(
s
*
userGroupRateRepoStubForListUsers
)
DeleteByGroupID
(
_
context
.
Context
,
groupID
int64
)
error
{
panic
(
"unexpected DeleteByGroupID call"
)
}
func
(
s
*
userGroupRateRepoStubForListUsers
)
DeleteByUserID
(
_
context
.
Context
,
userID
int64
)
error
{
panic
(
"unexpected DeleteByUserID call"
)
}
func
TestAdminService_ListUsers_BatchRateFallbackToSingle
(
t
*
testing
.
T
)
{
userRepo
:=
&
userRepoStubForListUsers
{
users
:
[]
User
{
{
ID
:
101
,
Username
:
"u1"
},
{
ID
:
202
,
Username
:
"u2"
},
},
}
rateRepo
:=
&
userGroupRateRepoStubForListUsers
{
batchErr
:
errors
.
New
(
"batch unavailable"
),
singleData
:
map
[
int64
]
map
[
int64
]
float64
{
101
:
{
11
:
1.1
},
202
:
{
22
:
2.2
},
},
}
svc
:=
&
adminServiceImpl
{
userRepo
:
userRepo
,
userGroupRateRepo
:
rateRepo
,
}
users
,
total
,
err
:=
svc
.
ListUsers
(
context
.
Background
(),
1
,
20
,
UserListFilters
{})
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
2
),
total
)
require
.
Len
(
t
,
users
,
2
)
require
.
Equal
(
t
,
1
,
rateRepo
.
batchCalls
)
require
.
ElementsMatch
(
t
,
[]
int64
{
101
,
202
},
rateRepo
.
singleCall
)
require
.
Equal
(
t
,
1.1
,
users
[
0
]
.
GroupRates
[
11
])
require
.
Equal
(
t
,
2.2
,
users
[
1
]
.
GroupRates
[
22
])
}
backend/internal/service/antigravity_gateway_service.go
View file @
bb664d9b
...
...
@@ -21,7 +21,6 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
...
...
@@ -2291,7 +2290,7 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
// isSingleAccountRetry 检查 context 中是否设置了单账号退避重试标记
func
isSingleAccountRetry
(
ctx
context
.
Context
)
bool
{
v
,
_
:=
ctx
.
Value
(
ctxkey
.
SingleAccountRetry
)
.
(
bool
)
v
,
_
:=
SingleAccountRetry
FromContext
(
ctx
)
return
v
}
...
...
backend/internal/service/api_key.go
View file @
bb664d9b
package
service
import
"time"
import
(
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
)
// API Key status constants
const
(
...
...
@@ -19,11 +23,14 @@ type APIKey struct {
Status
string
IPWhitelist
[]
string
IPBlacklist
[]
string
LastUsedAt
*
time
.
Time
CreatedAt
time
.
Time
UpdatedAt
time
.
Time
User
*
User
Group
*
Group
// 预编译的 IP 规则,用于认证热路径避免重复 ParseIP/ParseCIDR。
CompiledIPWhitelist
*
ip
.
CompiledIPRules
`json:"-"`
CompiledIPBlacklist
*
ip
.
CompiledIPRules
`json:"-"`
LastUsedAt
*
time
.
Time
CreatedAt
time
.
Time
UpdatedAt
time
.
Time
User
*
User
Group
*
Group
// Quota fields
Quota
float64
// Quota limit in USD (0 = unlimited)
...
...
Prev
1
…
4
5
6
7
8
9
10
11
12
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment