Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
3d79773b
Commit
3d79773b
authored
Mar 04, 2026
by
kyx236
Browse files
Merge branch 'main' of
https://github.com/james-6-23/sub2api
parents
6aa8cbbf
742e73c9
Changes
249
Hide whitespace changes
Inline
Side-by-side
Too many changes to show.
To preserve performance only
249 of 249+
files are displayed.
Plain diff
Email patch
backend/internal/handler/admin/usage_handler.go
View file @
3d79773b
package
admin
import
(
"
log
"
"
context
"
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
...
...
@@ -50,6 +51,7 @@ type CreateUsageCleanupTaskRequest struct {
AccountID
*
int64
`json:"account_id"`
GroupID
*
int64
`json:"group_id"`
Model
*
string
`json:"model"`
RequestType
*
string
`json:"request_type"`
Stream
*
bool
`json:"stream"`
BillingType
*
int8
`json:"billing_type"`
Timezone
string
`json:"timezone"`
...
...
@@ -59,6 +61,15 @@ type CreateUsageCleanupTaskRequest struct {
// GET /api/v1/admin/usage
func
(
h
*
UsageHandler
)
List
(
c
*
gin
.
Context
)
{
page
,
pageSize
:=
response
.
ParsePagination
(
c
)
exactTotal
:=
false
if
exactTotalRaw
:=
strings
.
TrimSpace
(
c
.
Query
(
"exact_total"
));
exactTotalRaw
!=
""
{
parsed
,
err
:=
strconv
.
ParseBool
(
exactTotalRaw
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid exact_total value, use true or false"
)
return
}
exactTotal
=
parsed
}
// Parse filters
var
userID
,
apiKeyID
,
accountID
,
groupID
int64
...
...
@@ -100,8 +111,17 @@ func (h *UsageHandler) List(c *gin.Context) {
model
:=
c
.
Query
(
"model"
)
var
requestType
*
int16
var
stream
*
bool
if
streamStr
:=
c
.
Query
(
"stream"
);
streamStr
!=
""
{
if
requestTypeStr
:=
strings
.
TrimSpace
(
c
.
Query
(
"request_type"
));
requestTypeStr
!=
""
{
parsed
,
err
:=
service
.
ParseUsageRequestType
(
requestTypeStr
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
err
.
Error
())
return
}
value
:=
int16
(
parsed
)
requestType
=
&
value
}
else
if
streamStr
:=
c
.
Query
(
"stream"
);
streamStr
!=
""
{
val
,
err
:=
strconv
.
ParseBool
(
streamStr
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid stream value, use true or false"
)
...
...
@@ -151,10 +171,12 @@ func (h *UsageHandler) List(c *gin.Context) {
AccountID
:
accountID
,
GroupID
:
groupID
,
Model
:
model
,
RequestType
:
requestType
,
Stream
:
stream
,
BillingType
:
billingType
,
StartTime
:
startTime
,
EndTime
:
endTime
,
ExactTotal
:
exactTotal
,
}
records
,
result
,
err
:=
h
.
usageService
.
ListWithFilters
(
c
.
Request
.
Context
(),
params
,
filters
)
...
...
@@ -213,8 +235,17 @@ func (h *UsageHandler) Stats(c *gin.Context) {
model
:=
c
.
Query
(
"model"
)
var
requestType
*
int16
var
stream
*
bool
if
streamStr
:=
c
.
Query
(
"stream"
);
streamStr
!=
""
{
if
requestTypeStr
:=
strings
.
TrimSpace
(
c
.
Query
(
"request_type"
));
requestTypeStr
!=
""
{
parsed
,
err
:=
service
.
ParseUsageRequestType
(
requestTypeStr
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
err
.
Error
())
return
}
value
:=
int16
(
parsed
)
requestType
=
&
value
}
else
if
streamStr
:=
c
.
Query
(
"stream"
);
streamStr
!=
""
{
val
,
err
:=
strconv
.
ParseBool
(
streamStr
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid stream value, use true or false"
)
...
...
@@ -277,6 +308,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
AccountID
:
accountID
,
GroupID
:
groupID
,
Model
:
model
,
RequestType
:
requestType
,
Stream
:
stream
,
BillingType
:
billingType
,
StartTime
:
&
startTime
,
...
...
@@ -378,11 +410,11 @@ func (h *UsageHandler) ListCleanupTasks(c *gin.Context) {
operator
=
subject
.
UserID
}
page
,
pageSize
:=
response
.
ParsePagination
(
c
)
log
.
Printf
(
"[UsageCleanup] 请求清理任务列表: operator=%d page=%d page_size=%d"
,
operator
,
page
,
pageSize
)
log
ger
.
LegacyPrintf
(
"handler.admin.usage"
,
"[UsageCleanup] 请求清理任务列表: operator=%d page=%d page_size=%d"
,
operator
,
page
,
pageSize
)
params
:=
pagination
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
}
tasks
,
result
,
err
:=
h
.
cleanupService
.
ListTasks
(
c
.
Request
.
Context
(),
params
)
if
err
!=
nil
{
log
.
Printf
(
"[UsageCleanup] 查询清理任务列表失败: operator=%d page=%d page_size=%d err=%v"
,
operator
,
page
,
pageSize
,
err
)
log
ger
.
LegacyPrintf
(
"handler.admin.usage"
,
"[UsageCleanup] 查询清理任务列表失败: operator=%d page=%d page_size=%d err=%v"
,
operator
,
page
,
pageSize
,
err
)
response
.
ErrorFrom
(
c
,
err
)
return
}
...
...
@@ -390,7 +422,7 @@ func (h *UsageHandler) ListCleanupTasks(c *gin.Context) {
for
i
:=
range
tasks
{
out
=
append
(
out
,
*
dto
.
UsageCleanupTaskFromService
(
&
tasks
[
i
]))
}
log
.
Printf
(
"[UsageCleanup] 返回清理任务列表: operator=%d total=%d items=%d page=%d page_size=%d"
,
operator
,
result
.
Total
,
len
(
out
),
page
,
pageSize
)
log
ger
.
LegacyPrintf
(
"handler.admin.usage"
,
"[UsageCleanup] 返回清理任务列表: operator=%d total=%d items=%d page=%d page_size=%d"
,
operator
,
result
.
Total
,
len
(
out
),
page
,
pageSize
)
response
.
Paginated
(
c
,
out
,
result
.
Total
,
page
,
pageSize
)
}
...
...
@@ -431,6 +463,19 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
}
endTime
=
endTime
.
Add
(
24
*
time
.
Hour
-
time
.
Nanosecond
)
var
requestType
*
int16
stream
:=
req
.
Stream
if
req
.
RequestType
!=
nil
{
parsed
,
err
:=
service
.
ParseUsageRequestType
(
*
req
.
RequestType
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
err
.
Error
())
return
}
value
:=
int16
(
parsed
)
requestType
=
&
value
stream
=
nil
}
filters
:=
service
.
UsageCleanupFilters
{
StartTime
:
startTime
,
EndTime
:
endTime
,
...
...
@@ -439,7 +484,8 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
AccountID
:
req
.
AccountID
,
GroupID
:
req
.
GroupID
,
Model
:
req
.
Model
,
Stream
:
req
.
Stream
,
RequestType
:
requestType
,
Stream
:
stream
,
BillingType
:
req
.
BillingType
,
}
...
...
@@ -463,38 +509,50 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
if
filters
.
Model
!=
nil
{
model
=
*
filters
.
Model
}
var
stream
any
var
stream
Value
any
if
filters
.
Stream
!=
nil
{
stream
=
*
filters
.
Stream
streamValue
=
*
filters
.
Stream
}
var
requestTypeName
any
if
filters
.
RequestType
!=
nil
{
requestTypeName
=
service
.
RequestTypeFromInt16
(
*
filters
.
RequestType
)
.
String
()
}
var
billingType
any
if
filters
.
BillingType
!=
nil
{
billingType
=
*
filters
.
BillingType
}
log
.
Printf
(
"[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q"
,
subject
.
UserID
,
filters
.
StartTime
.
Format
(
time
.
RFC3339
),
filters
.
EndTime
.
Format
(
time
.
RFC3339
),
userID
,
apiKeyID
,
accountID
,
groupID
,
model
,
stream
,
billingType
,
req
.
Timezone
,
)
task
,
err
:=
h
.
cleanupService
.
CreateTask
(
c
.
Request
.
Context
(),
filters
,
subject
.
UserID
)
if
err
!=
nil
{
log
.
Printf
(
"[UsageCleanup] 创建清理任务失败: operator=%d err=%v"
,
subject
.
UserID
,
err
)
response
.
ErrorFrom
(
c
,
err
)
return
}
log
.
Printf
(
"[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s"
,
task
.
ID
,
subject
.
UserID
,
task
.
Status
)
response
.
Success
(
c
,
dto
.
UsageCleanupTaskFromService
(
task
))
idempotencyPayload
:=
struct
{
OperatorID
int64
`json:"operator_id"`
Body
CreateUsageCleanupTaskRequest
`json:"body"`
}{
OperatorID
:
subject
.
UserID
,
Body
:
req
,
}
executeAdminIdempotentJSON
(
c
,
"admin.usage.cleanup_tasks.create"
,
idempotencyPayload
,
service
.
DefaultWriteIdempotencyTTL
(),
func
(
ctx
context
.
Context
)
(
any
,
error
)
{
logger
.
LegacyPrintf
(
"handler.admin.usage"
,
"[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v request_type=%v stream=%v billing_type=%v tz=%q"
,
subject
.
UserID
,
filters
.
StartTime
.
Format
(
time
.
RFC3339
),
filters
.
EndTime
.
Format
(
time
.
RFC3339
),
userID
,
apiKeyID
,
accountID
,
groupID
,
model
,
requestTypeName
,
streamValue
,
billingType
,
req
.
Timezone
,
)
task
,
err
:=
h
.
cleanupService
.
CreateTask
(
ctx
,
filters
,
subject
.
UserID
)
if
err
!=
nil
{
logger
.
LegacyPrintf
(
"handler.admin.usage"
,
"[UsageCleanup] 创建清理任务失败: operator=%d err=%v"
,
subject
.
UserID
,
err
)
return
nil
,
err
}
logger
.
LegacyPrintf
(
"handler.admin.usage"
,
"[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s"
,
task
.
ID
,
subject
.
UserID
,
task
.
Status
)
return
dto
.
UsageCleanupTaskFromService
(
task
),
nil
})
}
// CancelCleanupTask handles canceling a usage cleanup task
...
...
@@ -515,12 +573,12 @@ func (h *UsageHandler) CancelCleanupTask(c *gin.Context) {
response
.
BadRequest
(
c
,
"Invalid task id"
)
return
}
log
.
Printf
(
"[UsageCleanup] 请求取消清理任务: task=%d operator=%d"
,
taskID
,
subject
.
UserID
)
log
ger
.
LegacyPrintf
(
"handler.admin.usage"
,
"[UsageCleanup] 请求取消清理任务: task=%d operator=%d"
,
taskID
,
subject
.
UserID
)
if
err
:=
h
.
cleanupService
.
CancelTask
(
c
.
Request
.
Context
(),
taskID
,
subject
.
UserID
);
err
!=
nil
{
log
.
Printf
(
"[UsageCleanup] 取消清理任务失败: task=%d operator=%d err=%v"
,
taskID
,
subject
.
UserID
,
err
)
log
ger
.
LegacyPrintf
(
"handler.admin.usage"
,
"[UsageCleanup] 取消清理任务失败: task=%d operator=%d err=%v"
,
taskID
,
subject
.
UserID
,
err
)
response
.
ErrorFrom
(
c
,
err
)
return
}
log
.
Printf
(
"[UsageCleanup] 清理任务已取消: task=%d operator=%d"
,
taskID
,
subject
.
UserID
)
log
ger
.
LegacyPrintf
(
"handler.admin.usage"
,
"[UsageCleanup] 清理任务已取消: task=%d operator=%d"
,
taskID
,
subject
.
UserID
)
response
.
Success
(
c
,
gin
.
H
{
"id"
:
taskID
,
"status"
:
service
.
UsageCleanupStatusCanceled
})
}
backend/internal/handler/admin/usage_handler_request_type_test.go
0 → 100644
View file @
3d79773b
package
admin
import
(
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type
adminUsageRepoCapture
struct
{
service
.
UsageLogRepository
listFilters
usagestats
.
UsageLogFilters
statsFilters
usagestats
.
UsageLogFilters
}
func
(
s
*
adminUsageRepoCapture
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
filters
usagestats
.
UsageLogFilters
)
([]
service
.
UsageLog
,
*
pagination
.
PaginationResult
,
error
)
{
s
.
listFilters
=
filters
return
[]
service
.
UsageLog
{},
&
pagination
.
PaginationResult
{
Total
:
0
,
Page
:
params
.
Page
,
PageSize
:
params
.
PageSize
,
Pages
:
0
,
},
nil
}
func
(
s
*
adminUsageRepoCapture
)
GetStatsWithFilters
(
ctx
context
.
Context
,
filters
usagestats
.
UsageLogFilters
)
(
*
usagestats
.
UsageStats
,
error
)
{
s
.
statsFilters
=
filters
return
&
usagestats
.
UsageStats
{},
nil
}
func
newAdminUsageRequestTypeTestRouter
(
repo
*
adminUsageRepoCapture
)
*
gin
.
Engine
{
gin
.
SetMode
(
gin
.
TestMode
)
usageSvc
:=
service
.
NewUsageService
(
repo
,
nil
,
nil
,
nil
)
handler
:=
NewUsageHandler
(
usageSvc
,
nil
,
nil
,
nil
)
router
:=
gin
.
New
()
router
.
GET
(
"/admin/usage"
,
handler
.
List
)
router
.
GET
(
"/admin/usage/stats"
,
handler
.
Stats
)
return
router
}
func
TestAdminUsageListRequestTypePriority
(
t
*
testing
.
T
)
{
repo
:=
&
adminUsageRepoCapture
{}
router
:=
newAdminUsageRequestTypeTestRouter
(
repo
)
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/admin/usage?request_type=ws_v2&stream=false"
,
nil
)
rec
:=
httptest
.
NewRecorder
()
router
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
require
.
NotNil
(
t
,
repo
.
listFilters
.
RequestType
)
require
.
Equal
(
t
,
int16
(
service
.
RequestTypeWSV2
),
*
repo
.
listFilters
.
RequestType
)
require
.
Nil
(
t
,
repo
.
listFilters
.
Stream
)
}
func
TestAdminUsageListInvalidRequestType
(
t
*
testing
.
T
)
{
repo
:=
&
adminUsageRepoCapture
{}
router
:=
newAdminUsageRequestTypeTestRouter
(
repo
)
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/admin/usage?request_type=bad"
,
nil
)
rec
:=
httptest
.
NewRecorder
()
router
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
}
func
TestAdminUsageListInvalidStream
(
t
*
testing
.
T
)
{
repo
:=
&
adminUsageRepoCapture
{}
router
:=
newAdminUsageRequestTypeTestRouter
(
repo
)
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/admin/usage?stream=bad"
,
nil
)
rec
:=
httptest
.
NewRecorder
()
router
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
}
func
TestAdminUsageListExactTotalTrue
(
t
*
testing
.
T
)
{
repo
:=
&
adminUsageRepoCapture
{}
router
:=
newAdminUsageRequestTypeTestRouter
(
repo
)
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/admin/usage?exact_total=true"
,
nil
)
rec
:=
httptest
.
NewRecorder
()
router
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
require
.
True
(
t
,
repo
.
listFilters
.
ExactTotal
)
}
func
TestAdminUsageListInvalidExactTotal
(
t
*
testing
.
T
)
{
repo
:=
&
adminUsageRepoCapture
{}
router
:=
newAdminUsageRequestTypeTestRouter
(
repo
)
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/admin/usage?exact_total=oops"
,
nil
)
rec
:=
httptest
.
NewRecorder
()
router
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
}
func
TestAdminUsageStatsRequestTypePriority
(
t
*
testing
.
T
)
{
repo
:=
&
adminUsageRepoCapture
{}
router
:=
newAdminUsageRequestTypeTestRouter
(
repo
)
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/admin/usage/stats?request_type=stream&stream=bad"
,
nil
)
rec
:=
httptest
.
NewRecorder
()
router
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
require
.
NotNil
(
t
,
repo
.
statsFilters
.
RequestType
)
require
.
Equal
(
t
,
int16
(
service
.
RequestTypeStream
),
*
repo
.
statsFilters
.
RequestType
)
require
.
Nil
(
t
,
repo
.
statsFilters
.
Stream
)
}
func
TestAdminUsageStatsInvalidRequestType
(
t
*
testing
.
T
)
{
repo
:=
&
adminUsageRepoCapture
{}
router
:=
newAdminUsageRequestTypeTestRouter
(
repo
)
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/admin/usage/stats?request_type=oops"
,
nil
)
rec
:=
httptest
.
NewRecorder
()
router
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
}
func
TestAdminUsageStatsInvalidStream
(
t
*
testing
.
T
)
{
repo
:=
&
adminUsageRepoCapture
{}
router
:=
newAdminUsageRequestTypeTestRouter
(
repo
)
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/admin/usage/stats?stream=oops"
,
nil
)
rec
:=
httptest
.
NewRecorder
()
router
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
}
backend/internal/handler/admin/user_attribute_handler.go
View file @
3d79773b
package
admin
import
(
"encoding/json"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
...
...
@@ -67,6 +69,8 @@ type BatchUserAttributesResponse struct {
Attributes
map
[
int64
]
map
[
int64
]
string
`json:"attributes"`
}
var
userAttributesBatchCache
=
newSnapshotCache
(
30
*
time
.
Second
)
// AttributeDefinitionResponse represents attribute definition response
type
AttributeDefinitionResponse
struct
{
ID
int64
`json:"id"`
...
...
@@ -327,16 +331,32 @@ func (h *UserAttributeHandler) GetBatchUserAttributes(c *gin.Context) {
return
}
if
len
(
req
.
UserIDs
)
==
0
{
userIDs
:=
normalizeInt64IDList
(
req
.
UserIDs
)
if
len
(
userIDs
)
==
0
{
response
.
Success
(
c
,
BatchUserAttributesResponse
{
Attributes
:
map
[
int64
]
map
[
int64
]
string
{}})
return
}
attrs
,
err
:=
h
.
attrService
.
GetBatchUserAttributes
(
c
.
Request
.
Context
(),
req
.
UserIDs
)
keyRaw
,
_
:=
json
.
Marshal
(
struct
{
UserIDs
[]
int64
`json:"user_ids"`
}{
UserIDs
:
userIDs
,
})
cacheKey
:=
string
(
keyRaw
)
if
cached
,
ok
:=
userAttributesBatchCache
.
Get
(
cacheKey
);
ok
{
c
.
Header
(
"X-Snapshot-Cache"
,
"hit"
)
response
.
Success
(
c
,
cached
.
Payload
)
return
}
attrs
,
err
:=
h
.
attrService
.
GetBatchUserAttributes
(
c
.
Request
.
Context
(),
userIDs
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
BatchUserAttributesResponse
{
Attributes
:
attrs
})
payload
:=
BatchUserAttributesResponse
{
Attributes
:
attrs
}
userAttributesBatchCache
.
Set
(
cacheKey
,
payload
)
c
.
Header
(
"X-Snapshot-Cache"
,
"miss"
)
response
.
Success
(
c
,
payload
)
}
backend/internal/handler/admin/user_handler.go
View file @
3d79773b
package
admin
import
(
"context"
"strconv"
"strings"
...
...
@@ -33,13 +34,14 @@ func NewUserHandler(adminService service.AdminService, concurrencyService *servi
// CreateUserRequest represents admin create user request
type
CreateUserRequest
struct
{
Email
string
`json:"email" binding:"required,email"`
Password
string
`json:"password" binding:"required,min=6"`
Username
string
`json:"username"`
Notes
string
`json:"notes"`
Balance
float64
`json:"balance"`
Concurrency
int
`json:"concurrency"`
AllowedGroups
[]
int64
`json:"allowed_groups"`
Email
string
`json:"email" binding:"required,email"`
Password
string
`json:"password" binding:"required,min=6"`
Username
string
`json:"username"`
Notes
string
`json:"notes"`
Balance
float64
`json:"balance"`
Concurrency
int
`json:"concurrency"`
AllowedGroups
[]
int64
`json:"allowed_groups"`
SoraStorageQuotaBytes
int64
`json:"sora_storage_quota_bytes"`
}
// UpdateUserRequest represents admin update user request
...
...
@@ -55,7 +57,8 @@ type UpdateUserRequest struct {
AllowedGroups
*
[]
int64
`json:"allowed_groups"`
// GroupRates 用户专属分组倍率配置
// map[groupID]*rate,nil 表示删除该分组的专属倍率
GroupRates
map
[
int64
]
*
float64
`json:"group_rates"`
GroupRates
map
[
int64
]
*
float64
`json:"group_rates"`
SoraStorageQuotaBytes
*
int64
`json:"sora_storage_quota_bytes"`
}
// UpdateBalanceRequest represents balance update request
...
...
@@ -78,8 +81,8 @@ func (h *UserHandler) List(c *gin.Context) {
search
:=
c
.
Query
(
"search"
)
// 标准化和验证 search 参数
search
=
strings
.
TrimSpace
(
search
)
if
len
(
search
)
>
100
{
search
=
s
earch
[
:
100
]
if
runes
:=
[]
rune
(
search
);
len
(
runes
)
>
100
{
search
=
s
tring
(
runes
[
:
100
]
)
}
filters
:=
service
.
UserListFilters
{
...
...
@@ -88,6 +91,10 @@ func (h *UserHandler) List(c *gin.Context) {
Search
:
search
,
Attributes
:
parseAttributeFilters
(
c
),
}
if
raw
,
ok
:=
c
.
GetQuery
(
"include_subscriptions"
);
ok
{
includeSubscriptions
:=
parseBoolQueryWithDefault
(
raw
,
true
)
filters
.
IncludeSubscriptions
=
&
includeSubscriptions
}
users
,
total
,
err
:=
h
.
adminService
.
ListUsers
(
c
.
Request
.
Context
(),
page
,
pageSize
,
filters
)
if
err
!=
nil
{
...
...
@@ -173,13 +180,14 @@ func (h *UserHandler) Create(c *gin.Context) {
}
user
,
err
:=
h
.
adminService
.
CreateUser
(
c
.
Request
.
Context
(),
&
service
.
CreateUserInput
{
Email
:
req
.
Email
,
Password
:
req
.
Password
,
Username
:
req
.
Username
,
Notes
:
req
.
Notes
,
Balance
:
req
.
Balance
,
Concurrency
:
req
.
Concurrency
,
AllowedGroups
:
req
.
AllowedGroups
,
Email
:
req
.
Email
,
Password
:
req
.
Password
,
Username
:
req
.
Username
,
Notes
:
req
.
Notes
,
Balance
:
req
.
Balance
,
Concurrency
:
req
.
Concurrency
,
AllowedGroups
:
req
.
AllowedGroups
,
SoraStorageQuotaBytes
:
req
.
SoraStorageQuotaBytes
,
})
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
...
...
@@ -206,15 +214,16 @@ func (h *UserHandler) Update(c *gin.Context) {
// 使用指针类型直接传递,nil 表示未提供该字段
user
,
err
:=
h
.
adminService
.
UpdateUser
(
c
.
Request
.
Context
(),
userID
,
&
service
.
UpdateUserInput
{
Email
:
req
.
Email
,
Password
:
req
.
Password
,
Username
:
req
.
Username
,
Notes
:
req
.
Notes
,
Balance
:
req
.
Balance
,
Concurrency
:
req
.
Concurrency
,
Status
:
req
.
Status
,
AllowedGroups
:
req
.
AllowedGroups
,
GroupRates
:
req
.
GroupRates
,
Email
:
req
.
Email
,
Password
:
req
.
Password
,
Username
:
req
.
Username
,
Notes
:
req
.
Notes
,
Balance
:
req
.
Balance
,
Concurrency
:
req
.
Concurrency
,
Status
:
req
.
Status
,
AllowedGroups
:
req
.
AllowedGroups
,
GroupRates
:
req
.
GroupRates
,
SoraStorageQuotaBytes
:
req
.
SoraStorageQuotaBytes
,
})
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
...
...
@@ -257,13 +266,20 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) {
return
}
user
,
err
:=
h
.
adminService
.
UpdateUserBalance
(
c
.
Request
.
Context
(),
userID
,
req
.
Balance
,
req
.
Operation
,
req
.
Notes
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
idempotencyPayload
:=
struct
{
UserID
int64
`json:"user_id"`
Body
UpdateBalanceRequest
`json:"body"`
}{
UserID
:
userID
,
Body
:
req
,
}
response
.
Success
(
c
,
dto
.
UserFromServiceAdmin
(
user
))
executeAdminIdempotentJSON
(
c
,
"admin.users.balance.update"
,
idempotencyPayload
,
service
.
DefaultWriteIdempotencyTTL
(),
func
(
ctx
context
.
Context
)
(
any
,
error
)
{
user
,
execErr
:=
h
.
adminService
.
UpdateUserBalance
(
ctx
,
userID
,
req
.
Balance
,
req
.
Operation
,
req
.
Notes
)
if
execErr
!=
nil
{
return
nil
,
execErr
}
return
dto
.
UserFromServiceAdmin
(
user
),
nil
})
}
// GetUserAPIKeys handles getting user's API keys
...
...
backend/internal/handler/api_key_handler.go
View file @
3d79773b
...
...
@@ -2,7 +2,9 @@
package
handler
import
(
"context"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
...
...
@@ -35,6 +37,11 @@ type CreateAPIKeyRequest struct {
IPBlacklist
[]
string
`json:"ip_blacklist"`
// IP 黑名单
Quota
*
float64
`json:"quota"`
// 配额限制 (USD)
ExpiresInDays
*
int
`json:"expires_in_days"`
// 过期天数
// Rate limit fields (0 = unlimited)
RateLimit5h
*
float64
`json:"rate_limit_5h"`
RateLimit1d
*
float64
`json:"rate_limit_1d"`
RateLimit7d
*
float64
`json:"rate_limit_7d"`
}
// UpdateAPIKeyRequest represents the update API key request payload
...
...
@@ -47,6 +54,12 @@ type UpdateAPIKeyRequest struct {
Quota
*
float64
`json:"quota"`
// 配额限制 (USD), 0=无限制
ExpiresAt
*
string
`json:"expires_at"`
// 过期时间 (ISO 8601)
ResetQuota
*
bool
`json:"reset_quota"`
// 重置已用配额
// Rate limit fields (nil = no change, 0 = unlimited)
RateLimit5h
*
float64
`json:"rate_limit_5h"`
RateLimit1d
*
float64
`json:"rate_limit_1d"`
RateLimit7d
*
float64
`json:"rate_limit_7d"`
ResetRateLimitUsage
*
bool
`json:"reset_rate_limit_usage"`
// 重置限速用量
}
// List handles listing user's API keys with pagination
...
...
@@ -61,7 +74,23 @@ func (h *APIKeyHandler) List(c *gin.Context) {
page
,
pageSize
:=
response
.
ParsePagination
(
c
)
params
:=
pagination
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
}
keys
,
result
,
err
:=
h
.
apiKeyService
.
List
(
c
.
Request
.
Context
(),
subject
.
UserID
,
params
)
// Parse filter parameters
var
filters
service
.
APIKeyListFilters
if
search
:=
strings
.
TrimSpace
(
c
.
Query
(
"search"
));
search
!=
""
{
if
len
(
search
)
>
100
{
search
=
search
[
:
100
]
}
filters
.
Search
=
search
}
filters
.
Status
=
c
.
Query
(
"status"
)
if
groupIDStr
:=
c
.
Query
(
"group_id"
);
groupIDStr
!=
""
{
gid
,
err
:=
strconv
.
ParseInt
(
groupIDStr
,
10
,
64
)
if
err
==
nil
{
filters
.
GroupID
=
&
gid
}
}
keys
,
result
,
err
:=
h
.
apiKeyService
.
List
(
c
.
Request
.
Context
(),
subject
.
UserID
,
params
,
filters
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
...
...
@@ -130,13 +159,23 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
if
req
.
Quota
!=
nil
{
svcReq
.
Quota
=
*
req
.
Quota
}
key
,
err
:=
h
.
apiKeyService
.
Create
(
c
.
Request
.
Context
(),
subject
.
UserID
,
svcReq
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
if
req
.
RateLimit5h
!=
nil
{
svcReq
.
RateLimit5h
=
*
req
.
RateLimit5h
}
if
req
.
RateLimit1d
!=
nil
{
svcReq
.
RateLimit1d
=
*
req
.
RateLimit1d
}
if
req
.
RateLimit7d
!=
nil
{
svcReq
.
RateLimit7d
=
*
req
.
RateLimit7d
}
response
.
Success
(
c
,
dto
.
APIKeyFromService
(
key
))
executeUserIdempotentJSON
(
c
,
"user.api_keys.create"
,
req
,
service
.
DefaultWriteIdempotencyTTL
(),
func
(
ctx
context
.
Context
)
(
any
,
error
)
{
key
,
err
:=
h
.
apiKeyService
.
Create
(
ctx
,
subject
.
UserID
,
svcReq
)
if
err
!=
nil
{
return
nil
,
err
}
return
dto
.
APIKeyFromService
(
key
),
nil
})
}
// Update handles updating an API key
...
...
@@ -161,10 +200,14 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
}
svcReq
:=
service
.
UpdateAPIKeyRequest
{
IPWhitelist
:
req
.
IPWhitelist
,
IPBlacklist
:
req
.
IPBlacklist
,
Quota
:
req
.
Quota
,
ResetQuota
:
req
.
ResetQuota
,
IPWhitelist
:
req
.
IPWhitelist
,
IPBlacklist
:
req
.
IPBlacklist
,
Quota
:
req
.
Quota
,
ResetQuota
:
req
.
ResetQuota
,
RateLimit5h
:
req
.
RateLimit5h
,
RateLimit1d
:
req
.
RateLimit1d
,
RateLimit7d
:
req
.
RateLimit7d
,
ResetRateLimitUsage
:
req
.
ResetRateLimitUsage
,
}
if
req
.
Name
!=
""
{
svcReq
.
Name
=
&
req
.
Name
...
...
backend/internal/handler/auth_handler.go
View file @
3d79773b
...
...
@@ -2,6 +2,7 @@ package handler
import
(
"log/slog"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
...
...
@@ -112,12 +113,10 @@ func (h *AuthHandler) Register(c *gin.Context) {
return
}
// Turnstile 验证(当提供了邮箱验证码时跳过,因为发送验证码时已验证过)
if
req
.
VerifyCode
==
""
{
if
err
:=
h
.
authService
.
VerifyTurnstile
(
c
.
Request
.
Context
(),
req
.
TurnstileToken
,
ip
.
GetClientIP
(
c
));
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
// Turnstile 验证(邮箱验证码注册场景避免重复校验一次性 token)
if
err
:=
h
.
authService
.
VerifyTurnstileForRegister
(
c
.
Request
.
Context
(),
req
.
TurnstileToken
,
ip
.
GetClientIP
(
c
),
req
.
VerifyCode
);
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
_
,
user
,
err
:=
h
.
authService
.
RegisterWithVerification
(
c
.
Request
.
Context
(),
req
.
Email
,
req
.
Password
,
req
.
VerifyCode
,
req
.
PromoCode
,
req
.
InvitationCode
)
...
...
@@ -448,17 +447,12 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) {
return
}
// Build frontend base URL from request
scheme
:=
"https"
if
c
.
Request
.
TLS
==
nil
{
// Check X-Forwarded-Proto header (common in reverse proxy setups)
if
proto
:=
c
.
GetHeader
(
"X-Forwarded-Proto"
);
proto
!=
""
{
scheme
=
proto
}
else
{
scheme
=
"http"
}
frontendBaseURL
:=
strings
.
TrimSpace
(
h
.
cfg
.
Server
.
FrontendURL
)
if
frontendBaseURL
==
""
{
slog
.
Error
(
"server.frontend_url not configured; cannot build password reset link"
)
response
.
InternalError
(
c
,
"Password reset is not configured"
)
return
}
frontendBaseURL
:=
scheme
+
"://"
+
c
.
Request
.
Host
// Request password reset (async)
// Note: This returns success even if email doesn't exist (to prevent enumeration)
...
...
backend/internal/handler/dto/api_key_mapper_last_used_test.go
0 → 100644
View file @
3d79773b
package
dto
import
(
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func
TestAPIKeyFromService_MapsLastUsedAt
(
t
*
testing
.
T
)
{
lastUsed
:=
time
.
Now
()
.
UTC
()
.
Truncate
(
time
.
Second
)
src
:=
&
service
.
APIKey
{
ID
:
1
,
UserID
:
2
,
Key
:
"sk-map-last-used"
,
Name
:
"Mapper"
,
Status
:
service
.
StatusActive
,
LastUsedAt
:
&
lastUsed
,
}
out
:=
APIKeyFromService
(
src
)
require
.
NotNil
(
t
,
out
)
require
.
NotNil
(
t
,
out
.
LastUsedAt
)
require
.
WithinDuration
(
t
,
lastUsed
,
*
out
.
LastUsedAt
,
time
.
Second
)
}
func
TestAPIKeyFromService_MapsNilLastUsedAt
(
t
*
testing
.
T
)
{
src
:=
&
service
.
APIKey
{
ID
:
1
,
UserID
:
2
,
Key
:
"sk-map-last-used-nil"
,
Name
:
"MapperNil"
,
Status
:
service
.
StatusActive
,
}
out
:=
APIKeyFromService
(
src
)
require
.
NotNil
(
t
,
out
)
require
.
Nil
(
t
,
out
.
LastUsedAt
)
}
backend/internal/handler/dto/mappers.go
View file @
3d79773b
...
...
@@ -2,6 +2,7 @@
package
dto
import
(
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
...
...
@@ -58,9 +59,11 @@ func UserFromServiceAdmin(u *service.User) *AdminUser {
return
nil
}
return
&
AdminUser
{
User
:
*
base
,
Notes
:
u
.
Notes
,
GroupRates
:
u
.
GroupRates
,
User
:
*
base
,
Notes
:
u
.
Notes
,
GroupRates
:
u
.
GroupRates
,
SoraStorageQuotaBytes
:
u
.
SoraStorageQuotaBytes
,
SoraStorageUsedBytes
:
u
.
SoraStorageUsedBytes
,
}
}
...
...
@@ -69,21 +72,31 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
return
nil
}
return
&
APIKey
{
ID
:
k
.
ID
,
UserID
:
k
.
UserID
,
Key
:
k
.
Key
,
Name
:
k
.
Name
,
GroupID
:
k
.
GroupID
,
Status
:
k
.
Status
,
IPWhitelist
:
k
.
IPWhitelist
,
IPBlacklist
:
k
.
IPBlacklist
,
Quota
:
k
.
Quota
,
QuotaUsed
:
k
.
QuotaUsed
,
ExpiresAt
:
k
.
ExpiresAt
,
CreatedAt
:
k
.
CreatedAt
,
UpdatedAt
:
k
.
UpdatedAt
,
User
:
UserFromServiceShallow
(
k
.
User
),
Group
:
GroupFromServiceShallow
(
k
.
Group
),
ID
:
k
.
ID
,
UserID
:
k
.
UserID
,
Key
:
k
.
Key
,
Name
:
k
.
Name
,
GroupID
:
k
.
GroupID
,
Status
:
k
.
Status
,
IPWhitelist
:
k
.
IPWhitelist
,
IPBlacklist
:
k
.
IPBlacklist
,
LastUsedAt
:
k
.
LastUsedAt
,
Quota
:
k
.
Quota
,
QuotaUsed
:
k
.
QuotaUsed
,
ExpiresAt
:
k
.
ExpiresAt
,
CreatedAt
:
k
.
CreatedAt
,
UpdatedAt
:
k
.
UpdatedAt
,
RateLimit5h
:
k
.
RateLimit5h
,
RateLimit1d
:
k
.
RateLimit1d
,
RateLimit7d
:
k
.
RateLimit7d
,
Usage5h
:
k
.
Usage5h
,
Usage1d
:
k
.
Usage1d
,
Usage7d
:
k
.
Usage7d
,
Window5hStart
:
k
.
Window5hStart
,
Window1dStart
:
k
.
Window1dStart
,
Window7dStart
:
k
.
Window7dStart
,
User
:
UserFromServiceShallow
(
k
.
User
),
Group
:
GroupFromServiceShallow
(
k
.
Group
),
}
}
...
...
@@ -129,24 +142,28 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
func
groupFromServiceBase
(
g
*
service
.
Group
)
Group
{
return
Group
{
ID
:
g
.
ID
,
Name
:
g
.
Name
,
Description
:
g
.
Description
,
Platform
:
g
.
Platform
,
RateMultiplier
:
g
.
RateMultiplier
,
IsExclusive
:
g
.
IsExclusive
,
Status
:
g
.
Status
,
SubscriptionType
:
g
.
SubscriptionType
,
DailyLimitUSD
:
g
.
DailyLimitUSD
,
WeeklyLimitUSD
:
g
.
WeeklyLimitUSD
,
MonthlyLimitUSD
:
g
.
MonthlyLimitUSD
,
ImagePrice1K
:
g
.
ImagePrice1K
,
ImagePrice2K
:
g
.
ImagePrice2K
,
ImagePrice4K
:
g
.
ImagePrice4K
,
ClaudeCodeOnly
:
g
.
ClaudeCodeOnly
,
FallbackGroupID
:
g
.
FallbackGroupID
,
// 无效请求兜底分组
ID
:
g
.
ID
,
Name
:
g
.
Name
,
Description
:
g
.
Description
,
Platform
:
g
.
Platform
,
RateMultiplier
:
g
.
RateMultiplier
,
IsExclusive
:
g
.
IsExclusive
,
Status
:
g
.
Status
,
SubscriptionType
:
g
.
SubscriptionType
,
DailyLimitUSD
:
g
.
DailyLimitUSD
,
WeeklyLimitUSD
:
g
.
WeeklyLimitUSD
,
MonthlyLimitUSD
:
g
.
MonthlyLimitUSD
,
ImagePrice1K
:
g
.
ImagePrice1K
,
ImagePrice2K
:
g
.
ImagePrice2K
,
ImagePrice4K
:
g
.
ImagePrice4K
,
SoraImagePrice360
:
g
.
SoraImagePrice360
,
SoraImagePrice540
:
g
.
SoraImagePrice540
,
SoraVideoPricePerRequest
:
g
.
SoraVideoPricePerRequest
,
SoraVideoPricePerRequestHD
:
g
.
SoraVideoPricePerRequestHD
,
ClaudeCodeOnly
:
g
.
ClaudeCodeOnly
,
FallbackGroupID
:
g
.
FallbackGroupID
,
FallbackGroupIDOnInvalidRequest
:
g
.
FallbackGroupIDOnInvalidRequest
,
SoraStorageQuotaBytes
:
g
.
SoraStorageQuotaBytes
,
CreatedAt
:
g
.
CreatedAt
,
UpdatedAt
:
g
.
UpdatedAt
,
}
...
...
@@ -201,6 +218,17 @@ func AccountFromServiceShallow(a *service.Account) *Account {
if
idleTimeout
:=
a
.
GetSessionIdleTimeoutMinutes
();
idleTimeout
>
0
{
out
.
SessionIdleTimeoutMin
=
&
idleTimeout
}
if
rpm
:=
a
.
GetBaseRPM
();
rpm
>
0
{
out
.
BaseRPM
=
&
rpm
strategy
:=
a
.
GetRPMStrategy
()
out
.
RPMStrategy
=
&
strategy
buffer
:=
a
.
GetRPMStickyBuffer
()
out
.
RPMStickyBuffer
=
&
buffer
}
// 用户消息队列模式
if
mode
:=
a
.
GetUserMsgQueueMode
();
mode
!=
""
{
out
.
UserMsgQueueMode
=
&
mode
}
// TLS指纹伪装开关
if
a
.
IsTLSFingerprintEnabled
()
{
enabled
:=
true
...
...
@@ -211,6 +239,13 @@ func AccountFromServiceShallow(a *service.Account) *Account {
enabled
:=
true
out
.
EnableSessionIDMasking
=
&
enabled
}
// 缓存 TTL 强制替换
if
a
.
IsCacheTTLOverrideEnabled
()
{
enabled
:=
true
out
.
CacheTTLOverrideEnabled
=
&
enabled
target
:=
a
.
GetCacheTTLOverrideTarget
()
out
.
CacheTTLOverrideTarget
=
&
target
}
}
return
out
...
...
@@ -271,7 +306,6 @@ func ProxyFromService(p *service.Proxy) *Proxy {
Host
:
p
.
Host
,
Port
:
p
.
Port
,
Username
:
p
.
Username
,
Password
:
p
.
Password
,
Status
:
p
.
Status
,
CreatedAt
:
p
.
CreatedAt
,
UpdatedAt
:
p
.
UpdatedAt
,
...
...
@@ -293,6 +327,56 @@ func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWi
CountryCode
:
p
.
CountryCode
,
Region
:
p
.
Region
,
City
:
p
.
City
,
QualityStatus
:
p
.
QualityStatus
,
QualityScore
:
p
.
QualityScore
,
QualityGrade
:
p
.
QualityGrade
,
QualitySummary
:
p
.
QualitySummary
,
QualityChecked
:
p
.
QualityChecked
,
}
}
// ProxyFromServiceAdmin converts a service Proxy to AdminProxy DTO for admin users.
// It includes the password field - user-facing endpoints must not use this.
func
ProxyFromServiceAdmin
(
p
*
service
.
Proxy
)
*
AdminProxy
{
if
p
==
nil
{
return
nil
}
base
:=
ProxyFromService
(
p
)
if
base
==
nil
{
return
nil
}
return
&
AdminProxy
{
Proxy
:
*
base
,
Password
:
p
.
Password
,
}
}
// ProxyWithAccountCountFromServiceAdmin converts a service ProxyWithAccountCount to AdminProxyWithAccountCount DTO.
// It includes the password field - user-facing endpoints must not use this.
func
ProxyWithAccountCountFromServiceAdmin
(
p
*
service
.
ProxyWithAccountCount
)
*
AdminProxyWithAccountCount
{
if
p
==
nil
{
return
nil
}
admin
:=
ProxyFromServiceAdmin
(
&
p
.
Proxy
)
if
admin
==
nil
{
return
nil
}
return
&
AdminProxyWithAccountCount
{
AdminProxy
:
*
admin
,
AccountCount
:
p
.
AccountCount
,
LatencyMs
:
p
.
LatencyMs
,
LatencyStatus
:
p
.
LatencyStatus
,
LatencyMessage
:
p
.
LatencyMessage
,
IPAddress
:
p
.
IPAddress
,
Country
:
p
.
Country
,
CountryCode
:
p
.
CountryCode
,
Region
:
p
.
Region
,
City
:
p
.
City
,
QualityStatus
:
p
.
QualityStatus
,
QualityScore
:
p
.
QualityScore
,
QualityGrade
:
p
.
QualityGrade
,
QualitySummary
:
p
.
QualitySummary
,
QualityChecked
:
p
.
QualityChecked
,
}
}
...
...
@@ -368,6 +452,8 @@ func AccountSummaryFromService(a *service.Account) *AccountSummary {
func
usageLogFromServiceUser
(
l
*
service
.
UsageLog
)
UsageLog
{
// 普通用户 DTO:严禁包含管理员字段(例如 account_rate_multiplier、ip_address、account)。
requestType
:=
l
.
EffectiveRequestType
()
stream
,
openAIWSMode
:=
service
.
ApplyLegacyRequestFields
(
requestType
,
l
.
Stream
,
l
.
OpenAIWSMode
)
return
UsageLog
{
ID
:
l
.
ID
,
UserID
:
l
.
UserID
,
...
...
@@ -392,12 +478,16 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
ActualCost
:
l
.
ActualCost
,
RateMultiplier
:
l
.
RateMultiplier
,
BillingType
:
l
.
BillingType
,
Stream
:
l
.
Stream
,
RequestType
:
requestType
.
String
(),
Stream
:
stream
,
OpenAIWSMode
:
openAIWSMode
,
DurationMs
:
l
.
DurationMs
,
FirstTokenMs
:
l
.
FirstTokenMs
,
ImageCount
:
l
.
ImageCount
,
ImageSize
:
l
.
ImageSize
,
MediaType
:
l
.
MediaType
,
UserAgent
:
l
.
UserAgent
,
CacheTTLOverridden
:
l
.
CacheTTLOverridden
,
CreatedAt
:
l
.
CreatedAt
,
User
:
UserFromServiceShallow
(
l
.
User
),
APIKey
:
APIKeyFromService
(
l
.
APIKey
),
...
...
@@ -445,6 +535,7 @@ func UsageCleanupTaskFromService(task *service.UsageCleanupTask) *UsageCleanupTa
AccountID
:
task
.
Filters
.
AccountID
,
GroupID
:
task
.
Filters
.
GroupID
,
Model
:
task
.
Filters
.
Model
,
RequestType
:
requestTypeStringPtr
(
task
.
Filters
.
RequestType
),
Stream
:
task
.
Filters
.
Stream
,
BillingType
:
task
.
Filters
.
BillingType
,
},
...
...
@@ -460,6 +551,14 @@ func UsageCleanupTaskFromService(task *service.UsageCleanupTask) *UsageCleanupTa
}
}
func
requestTypeStringPtr
(
requestType
*
int16
)
*
string
{
if
requestType
==
nil
{
return
nil
}
value
:=
service
.
RequestTypeFromInt16
(
*
requestType
)
.
String
()
return
&
value
}
func
SettingFromService
(
s
*
service
.
Setting
)
*
Setting
{
if
s
==
nil
{
return
nil
...
...
@@ -524,11 +623,18 @@ func BulkAssignResultFromService(r *service.BulkAssignResult) *BulkAssignResult
for
i
:=
range
r
.
Subscriptions
{
subs
=
append
(
subs
,
*
UserSubscriptionFromServiceAdmin
(
&
r
.
Subscriptions
[
i
]))
}
statuses
:=
make
(
map
[
string
]
string
,
len
(
r
.
Statuses
))
for
userID
,
status
:=
range
r
.
Statuses
{
statuses
[
strconv
.
FormatInt
(
userID
,
10
)]
=
status
}
return
&
BulkAssignResult
{
SuccessCount
:
r
.
SuccessCount
,
CreatedCount
:
r
.
CreatedCount
,
ReusedCount
:
r
.
ReusedCount
,
FailedCount
:
r
.
FailedCount
,
Subscriptions
:
subs
,
Errors
:
r
.
Errors
,
Statuses
:
statuses
,
}
}
...
...
backend/internal/handler/dto/mappers_usage_test.go
0 → 100644
View file @
3d79773b
package
dto
import
(
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func
TestUsageLogFromService_IncludesOpenAIWSMode
(
t
*
testing
.
T
)
{
t
.
Parallel
()
wsLog
:=
&
service
.
UsageLog
{
RequestID
:
"req_1"
,
Model
:
"gpt-5.3-codex"
,
OpenAIWSMode
:
true
,
}
httpLog
:=
&
service
.
UsageLog
{
RequestID
:
"resp_1"
,
Model
:
"gpt-5.3-codex"
,
OpenAIWSMode
:
false
,
}
require
.
True
(
t
,
UsageLogFromService
(
wsLog
)
.
OpenAIWSMode
)
require
.
False
(
t
,
UsageLogFromService
(
httpLog
)
.
OpenAIWSMode
)
require
.
True
(
t
,
UsageLogFromServiceAdmin
(
wsLog
)
.
OpenAIWSMode
)
require
.
False
(
t
,
UsageLogFromServiceAdmin
(
httpLog
)
.
OpenAIWSMode
)
}
func
TestUsageLogFromService_PrefersRequestTypeForLegacyFields
(
t
*
testing
.
T
)
{
t
.
Parallel
()
log
:=
&
service
.
UsageLog
{
RequestID
:
"req_2"
,
Model
:
"gpt-5.3-codex"
,
RequestType
:
service
.
RequestTypeWSV2
,
Stream
:
false
,
OpenAIWSMode
:
false
,
}
userDTO
:=
UsageLogFromService
(
log
)
adminDTO
:=
UsageLogFromServiceAdmin
(
log
)
require
.
Equal
(
t
,
"ws_v2"
,
userDTO
.
RequestType
)
require
.
True
(
t
,
userDTO
.
Stream
)
require
.
True
(
t
,
userDTO
.
OpenAIWSMode
)
require
.
Equal
(
t
,
"ws_v2"
,
adminDTO
.
RequestType
)
require
.
True
(
t
,
adminDTO
.
Stream
)
require
.
True
(
t
,
adminDTO
.
OpenAIWSMode
)
}
func
TestUsageCleanupTaskFromService_RequestTypeMapping
(
t
*
testing
.
T
)
{
t
.
Parallel
()
requestType
:=
int16
(
service
.
RequestTypeStream
)
task
:=
&
service
.
UsageCleanupTask
{
ID
:
1
,
Status
:
service
.
UsageCleanupStatusPending
,
Filters
:
service
.
UsageCleanupFilters
{
RequestType
:
&
requestType
,
},
}
dtoTask
:=
UsageCleanupTaskFromService
(
task
)
require
.
NotNil
(
t
,
dtoTask
)
require
.
NotNil
(
t
,
dtoTask
.
Filters
.
RequestType
)
require
.
Equal
(
t
,
"stream"
,
*
dtoTask
.
Filters
.
RequestType
)
}
func
TestRequestTypeStringPtrNil
(
t
*
testing
.
T
)
{
t
.
Parallel
()
require
.
Nil
(
t
,
requestTypeStringPtr
(
nil
))
}
backend/internal/handler/dto/settings.go
View file @
3d79773b
package
dto
import
(
"encoding/json"
"strings"
)
// CustomMenuItem represents a user-configured custom menu entry.
type
CustomMenuItem
struct
{
ID
string
`json:"id"`
Label
string
`json:"label"`
IconSVG
string
`json:"icon_svg"`
URL
string
`json:"url"`
Visibility
string
`json:"visibility"`
// "user" or "admin"
SortOrder
int
`json:"sort_order"`
}
// SystemSettings represents the admin settings API response payload.
type
SystemSettings
struct
{
RegistrationEnabled
bool
`json:"registration_enabled"`
EmailVerifyEnabled
bool
`json:"email_verify_enabled"`
PromoCodeEnabled
bool
`json:"promo_code_enabled"`
PasswordResetEnabled
bool
`json:"password_reset_enabled"`
InvitationCodeEnabled
bool
`json:"invitation_code_enabled"`
TotpEnabled
bool
`json:"totp_enabled"`
// TOTP 双因素认证
TotpEncryptionKeyConfigured
bool
`json:"totp_encryption_key_configured"`
// TOTP 加密密钥是否已配置
RegistrationEnabled
bool
`json:"registration_enabled"`
EmailVerifyEnabled
bool
`json:"email_verify_enabled"`
RegistrationEmailSuffixWhitelist
[]
string
`json:"registration_email_suffix_whitelist"`
PromoCodeEnabled
bool
`json:"promo_code_enabled"`
PasswordResetEnabled
bool
`json:"password_reset_enabled"`
InvitationCodeEnabled
bool
`json:"invitation_code_enabled"`
TotpEnabled
bool
`json:"totp_enabled"`
// TOTP 双因素认证
TotpEncryptionKeyConfigured
bool
`json:"totp_encryption_key_configured"`
// TOTP 加密密钥是否已配置
SMTPHost
string
`json:"smtp_host"`
SMTPPort
int
`json:"smtp_port"`
...
...
@@ -27,19 +43,22 @@ type SystemSettings struct {
LinuxDoConnectClientSecretConfigured
bool
`json:"linuxdo_connect_client_secret_configured"`
LinuxDoConnectRedirectURL
string
`json:"linuxdo_connect_redirect_url"`
SiteName
string
`json:"site_name"`
SiteLogo
string
`json:"site_logo"`
SiteSubtitle
string
`json:"site_subtitle"`
APIBaseURL
string
`json:"api_base_url"`
ContactInfo
string
`json:"contact_info"`
DocURL
string
`json:"doc_url"`
HomeContent
string
`json:"home_content"`
HideCcsImportButton
bool
`json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled
bool
`json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL
string
`json:"purchase_subscription_url"`
SiteName
string
`json:"site_name"`
SiteLogo
string
`json:"site_logo"`
SiteSubtitle
string
`json:"site_subtitle"`
APIBaseURL
string
`json:"api_base_url"`
ContactInfo
string
`json:"contact_info"`
DocURL
string
`json:"doc_url"`
HomeContent
string
`json:"home_content"`
HideCcsImportButton
bool
`json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled
bool
`json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL
string
`json:"purchase_subscription_url"`
SoraClientEnabled
bool
`json:"sora_client_enabled"`
CustomMenuItems
[]
CustomMenuItem
`json:"custom_menu_items"`
DefaultConcurrency
int
`json:"default_concurrency"`
DefaultBalance
float64
`json:"default_balance"`
DefaultConcurrency
int
`json:"default_concurrency"`
DefaultBalance
float64
`json:"default_balance"`
DefaultSubscriptions
[]
DefaultSubscriptionSetting
`json:"default_subscriptions"`
// Model fallback configuration
EnableModelFallback
bool
`json:"enable_model_fallback"`
...
...
@@ -57,29 +76,80 @@ type SystemSettings struct {
OpsRealtimeMonitoringEnabled
bool
`json:"ops_realtime_monitoring_enabled"`
OpsQueryModeDefault
string
`json:"ops_query_mode_default"`
OpsMetricsIntervalSeconds
int
`json:"ops_metrics_interval_seconds"`
MinClaudeCodeVersion
string
`json:"min_claude_code_version"`
// 分组隔离
AllowUngroupedKeyScheduling
bool
`json:"allow_ungrouped_key_scheduling"`
}
type
DefaultSubscriptionSetting
struct
{
GroupID
int64
`json:"group_id"`
ValidityDays
int
`json:"validity_days"`
}
type
PublicSettings
struct
{
RegistrationEnabled
bool
`json:"registration_enabled"`
EmailVerifyEnabled
bool
`json:"email_verify_enabled"`
PromoCodeEnabled
bool
`json:"promo_code_enabled"`
PasswordResetEnabled
bool
`json:"password_reset_enabled"`
InvitationCodeEnabled
bool
`json:"invitation_code_enabled"`
TotpEnabled
bool
`json:"totp_enabled"`
// TOTP 双因素认证
TurnstileEnabled
bool
`json:"turnstile_enabled"`
TurnstileSiteKey
string
`json:"turnstile_site_key"`
SiteName
string
`json:"site_name"`
SiteLogo
string
`json:"site_logo"`
SiteSubtitle
string
`json:"site_subtitle"`
APIBaseURL
string
`json:"api_base_url"`
ContactInfo
string
`json:"contact_info"`
DocURL
string
`json:"doc_url"`
HomeContent
string
`json:"home_content"`
HideCcsImportButton
bool
`json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled
bool
`json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL
string
`json:"purchase_subscription_url"`
LinuxDoOAuthEnabled
bool
`json:"linuxdo_oauth_enabled"`
Version
string
`json:"version"`
RegistrationEnabled
bool
`json:"registration_enabled"`
EmailVerifyEnabled
bool
`json:"email_verify_enabled"`
RegistrationEmailSuffixWhitelist
[]
string
`json:"registration_email_suffix_whitelist"`
PromoCodeEnabled
bool
`json:"promo_code_enabled"`
PasswordResetEnabled
bool
`json:"password_reset_enabled"`
InvitationCodeEnabled
bool
`json:"invitation_code_enabled"`
TotpEnabled
bool
`json:"totp_enabled"`
// TOTP 双因素认证
TurnstileEnabled
bool
`json:"turnstile_enabled"`
TurnstileSiteKey
string
`json:"turnstile_site_key"`
SiteName
string
`json:"site_name"`
SiteLogo
string
`json:"site_logo"`
SiteSubtitle
string
`json:"site_subtitle"`
APIBaseURL
string
`json:"api_base_url"`
ContactInfo
string
`json:"contact_info"`
DocURL
string
`json:"doc_url"`
HomeContent
string
`json:"home_content"`
HideCcsImportButton
bool
`json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled
bool
`json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL
string
`json:"purchase_subscription_url"`
CustomMenuItems
[]
CustomMenuItem
`json:"custom_menu_items"`
LinuxDoOAuthEnabled
bool
`json:"linuxdo_oauth_enabled"`
SoraClientEnabled
bool
`json:"sora_client_enabled"`
Version
string
`json:"version"`
}
// SoraS3Settings Sora S3 存储配置 DTO(响应用,不含敏感字段)
type
SoraS3Settings
struct
{
Enabled
bool
`json:"enabled"`
Endpoint
string
`json:"endpoint"`
Region
string
`json:"region"`
Bucket
string
`json:"bucket"`
AccessKeyID
string
`json:"access_key_id"`
SecretAccessKeyConfigured
bool
`json:"secret_access_key_configured"`
Prefix
string
`json:"prefix"`
ForcePathStyle
bool
`json:"force_path_style"`
CDNURL
string
`json:"cdn_url"`
DefaultStorageQuotaBytes
int64
`json:"default_storage_quota_bytes"`
}
// SoraS3Profile Sora S3 存储配置项 DTO(响应用,不含敏感字段)
type
SoraS3Profile
struct
{
ProfileID
string
`json:"profile_id"`
Name
string
`json:"name"`
IsActive
bool
`json:"is_active"`
Enabled
bool
`json:"enabled"`
Endpoint
string
`json:"endpoint"`
Region
string
`json:"region"`
Bucket
string
`json:"bucket"`
AccessKeyID
string
`json:"access_key_id"`
SecretAccessKeyConfigured
bool
`json:"secret_access_key_configured"`
Prefix
string
`json:"prefix"`
ForcePathStyle
bool
`json:"force_path_style"`
CDNURL
string
`json:"cdn_url"`
DefaultStorageQuotaBytes
int64
`json:"default_storage_quota_bytes"`
UpdatedAt
string
`json:"updated_at"`
}
// ListSoraS3ProfilesResponse Sora S3 配置列表响应
type
ListSoraS3ProfilesResponse
struct
{
ActiveProfileID
string
`json:"active_profile_id"`
Items
[]
SoraS3Profile
`json:"items"`
}
// StreamTimeoutSettings 流超时处理配置 DTO
...
...
@@ -90,3 +160,29 @@ type StreamTimeoutSettings struct {
ThresholdCount
int
`json:"threshold_count"`
ThresholdWindowMinutes
int
`json:"threshold_window_minutes"`
}
// ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem.
// Returns empty slice on empty/invalid input.
func
ParseCustomMenuItems
(
raw
string
)
[]
CustomMenuItem
{
raw
=
strings
.
TrimSpace
(
raw
)
if
raw
==
""
||
raw
==
"[]"
{
return
[]
CustomMenuItem
{}
}
var
items
[]
CustomMenuItem
if
err
:=
json
.
Unmarshal
([]
byte
(
raw
),
&
items
);
err
!=
nil
{
return
[]
CustomMenuItem
{}
}
return
items
}
// ParseUserVisibleMenuItems parses custom menu items and filters out admin-only entries.
func
ParseUserVisibleMenuItems
(
raw
string
)
[]
CustomMenuItem
{
items
:=
ParseCustomMenuItems
(
raw
)
filtered
:=
make
([]
CustomMenuItem
,
0
,
len
(
items
))
for
_
,
item
:=
range
items
{
if
item
.
Visibility
!=
"admin"
{
filtered
=
append
(
filtered
,
item
)
}
}
return
filtered
}
backend/internal/handler/dto/types.go
View file @
3d79773b
...
...
@@ -26,7 +26,9 @@ type AdminUser struct {
Notes
string
`json:"notes"`
// GroupRates 用户专属分组倍率配置
// map[groupID]rateMultiplier
GroupRates
map
[
int64
]
float64
`json:"group_rates,omitempty"`
GroupRates
map
[
int64
]
float64
`json:"group_rates,omitempty"`
SoraStorageQuotaBytes
int64
`json:"sora_storage_quota_bytes"`
SoraStorageUsedBytes
int64
`json:"sora_storage_used_bytes"`
}
type
APIKey
struct
{
...
...
@@ -38,12 +40,24 @@ type APIKey struct {
Status
string
`json:"status"`
IPWhitelist
[]
string
`json:"ip_whitelist"`
IPBlacklist
[]
string
`json:"ip_blacklist"`
LastUsedAt
*
time
.
Time
`json:"last_used_at"`
Quota
float64
`json:"quota"`
// Quota limit in USD (0 = unlimited)
QuotaUsed
float64
`json:"quota_used"`
// Used quota amount in USD
ExpiresAt
*
time
.
Time
`json:"expires_at"`
// Expiration time (nil = never expires)
CreatedAt
time
.
Time
`json:"created_at"`
UpdatedAt
time
.
Time
`json:"updated_at"`
// Rate limit fields
RateLimit5h
float64
`json:"rate_limit_5h"`
RateLimit1d
float64
`json:"rate_limit_1d"`
RateLimit7d
float64
`json:"rate_limit_7d"`
Usage5h
float64
`json:"usage_5h"`
Usage1d
float64
`json:"usage_1d"`
Usage7d
float64
`json:"usage_7d"`
Window5hStart
*
time
.
Time
`json:"window_5h_start"`
Window1dStart
*
time
.
Time
`json:"window_1d_start"`
Window7dStart
*
time
.
Time
`json:"window_7d_start"`
User
*
User
`json:"user,omitempty"`
Group
*
Group
`json:"group,omitempty"`
}
...
...
@@ -67,12 +81,21 @@ type Group struct {
ImagePrice2K
*
float64
`json:"image_price_2k"`
ImagePrice4K
*
float64
`json:"image_price_4k"`
// Sora 按次计费配置
SoraImagePrice360
*
float64
`json:"sora_image_price_360"`
SoraImagePrice540
*
float64
`json:"sora_image_price_540"`
SoraVideoPricePerRequest
*
float64
`json:"sora_video_price_per_request"`
SoraVideoPricePerRequestHD
*
float64
`json:"sora_video_price_per_request_hd"`
// Claude Code 客户端限制
ClaudeCodeOnly
bool
`json:"claude_code_only"`
FallbackGroupID
*
int64
`json:"fallback_group_id"`
// 无效请求兜底分组
FallbackGroupIDOnInvalidRequest
*
int64
`json:"fallback_group_id_on_invalid_request"`
// Sora 存储配额
SoraStorageQuotaBytes
int64
`json:"sora_storage_quota_bytes"`
CreatedAt
time
.
Time
`json:"created_at"`
UpdatedAt
time
.
Time
`json:"updated_at"`
}
...
...
@@ -141,6 +164,13 @@ type Account struct {
MaxSessions
*
int
`json:"max_sessions,omitempty"`
SessionIdleTimeoutMin
*
int
`json:"session_idle_timeout_minutes,omitempty"`
// RPM 限制(仅 Anthropic OAuth/SetupToken 账号有效)
// 从 extra 字段提取,方便前端显示和编辑
BaseRPM
*
int
`json:"base_rpm,omitempty"`
RPMStrategy
*
string
`json:"rpm_strategy,omitempty"`
RPMStickyBuffer
*
int
`json:"rpm_sticky_buffer,omitempty"`
UserMsgQueueMode
*
string
`json:"user_msg_queue_mode,omitempty"`
// TLS指纹伪装(仅 Anthropic OAuth/SetupToken 账号有效)
// 从 extra 字段提取,方便前端显示和编辑
EnableTLSFingerprint
*
bool
`json:"enable_tls_fingerprint,omitempty"`
...
...
@@ -150,6 +180,11 @@ type Account struct {
// 从 extra 字段提取,方便前端显示和编辑
EnableSessionIDMasking
*
bool
`json:"session_id_masking_enabled,omitempty"`
// 缓存 TTL 强制替换(仅 Anthropic OAuth/SetupToken 账号有效)
// 启用后将所有 cache creation tokens 归入指定的 TTL 类型计费
CacheTTLOverrideEnabled
*
bool
`json:"cache_ttl_override_enabled,omitempty"`
CacheTTLOverrideTarget
*
string
`json:"cache_ttl_override_target,omitempty"`
Proxy
*
Proxy
`json:"proxy,omitempty"`
AccountGroups
[]
AccountGroup
`json:"account_groups,omitempty"`
...
...
@@ -191,6 +226,37 @@ type ProxyWithAccountCount struct {
CountryCode
string
`json:"country_code,omitempty"`
Region
string
`json:"region,omitempty"`
City
string
`json:"city,omitempty"`
QualityStatus
string
`json:"quality_status,omitempty"`
QualityScore
*
int
`json:"quality_score,omitempty"`
QualityGrade
string
`json:"quality_grade,omitempty"`
QualitySummary
string
`json:"quality_summary,omitempty"`
QualityChecked
*
int64
`json:"quality_checked,omitempty"`
}
// AdminProxy 是管理员接口使用的 proxy DTO(包含密码等敏感字段)。
// 注意:普通接口不得使用此 DTO。
type
AdminProxy
struct
{
Proxy
Password
string
`json:"password,omitempty"`
}
// AdminProxyWithAccountCount 是管理员接口使用的带账号统计的 proxy DTO。
type
AdminProxyWithAccountCount
struct
{
AdminProxy
AccountCount
int64
`json:"account_count"`
LatencyMs
*
int64
`json:"latency_ms,omitempty"`
LatencyStatus
string
`json:"latency_status,omitempty"`
LatencyMessage
string
`json:"latency_message,omitempty"`
IPAddress
string
`json:"ip_address,omitempty"`
Country
string
`json:"country,omitempty"`
CountryCode
string
`json:"country_code,omitempty"`
Region
string
`json:"region,omitempty"`
City
string
`json:"city,omitempty"`
QualityStatus
string
`json:"quality_status,omitempty"`
QualityScore
*
int
`json:"quality_score,omitempty"`
QualityGrade
string
`json:"quality_grade,omitempty"`
QualitySummary
string
`json:"quality_summary,omitempty"`
QualityChecked
*
int64
`json:"quality_checked,omitempty"`
}
type
ProxyAccountSummary
struct
{
...
...
@@ -261,18 +327,24 @@ type UsageLog struct {
ActualCost
float64
`json:"actual_cost"`
RateMultiplier
float64
`json:"rate_multiplier"`
BillingType
int8
`json:"billing_type"`
Stream
bool
`json:"stream"`
DurationMs
*
int
`json:"duration_ms"`
FirstTokenMs
*
int
`json:"first_token_ms"`
BillingType
int8
`json:"billing_type"`
RequestType
string
`json:"request_type"`
Stream
bool
`json:"stream"`
OpenAIWSMode
bool
`json:"openai_ws_mode"`
DurationMs
*
int
`json:"duration_ms"`
FirstTokenMs
*
int
`json:"first_token_ms"`
// 图片生成字段
ImageCount
int
`json:"image_count"`
ImageSize
*
string
`json:"image_size"`
MediaType
*
string
`json:"media_type"`
// User-Agent
UserAgent
*
string
`json:"user_agent"`
// Cache TTL Override 标记
CacheTTLOverridden
bool
`json:"cache_ttl_overridden"`
CreatedAt
time
.
Time
`json:"created_at"`
User
*
User
`json:"user,omitempty"`
...
...
@@ -303,6 +375,7 @@ type UsageCleanupFilters struct {
AccountID
*
int64
`json:"account_id,omitempty"`
GroupID
*
int64
`json:"group_id,omitempty"`
Model
*
string
`json:"model,omitempty"`
RequestType
*
string
`json:"request_type,omitempty"`
Stream
*
bool
`json:"stream,omitempty"`
BillingType
*
int8
`json:"billing_type,omitempty"`
}
...
...
@@ -374,9 +447,12 @@ type AdminUserSubscription struct {
type
BulkAssignResult
struct
{
SuccessCount
int
`json:"success_count"`
CreatedCount
int
`json:"created_count"`
ReusedCount
int
`json:"reused_count"`
FailedCount
int
`json:"failed_count"`
Subscriptions
[]
AdminUserSubscription
`json:"subscriptions"`
Errors
[]
string
`json:"errors"`
Statuses
map
[
string
]
string
`json:"statuses,omitempty"`
}
// PromoCode 注册优惠码
...
...
backend/internal/handler/failover_loop.go
0 → 100644
View file @
3d79773b
package
handler
import
(
"context"
"net/http"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/service"
"go.uber.org/zap"
)
// TempUnscheduler 用于 HandleFailoverError 中同账号重试耗尽后的临时封禁。
// GatewayService 隐式实现此接口。
type
TempUnscheduler
interface
{
TempUnscheduleRetryableError
(
ctx
context
.
Context
,
accountID
int64
,
failoverErr
*
service
.
UpstreamFailoverError
)
}
// FailoverAction 表示 failover 错误处理后的下一步动作
type
FailoverAction
int
const
(
// FailoverContinue 继续循环(同账号重试或切换账号,调用方统一 continue)
FailoverContinue
FailoverAction
=
iota
// FailoverExhausted 切换次数耗尽(调用方应返回错误响应)
FailoverExhausted
// FailoverCanceled context 已取消(调用方应直接 return)
FailoverCanceled
)
const
(
// maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误)
maxSameAccountRetries
=
2
// sameAccountRetryDelay 同账号重试间隔
sameAccountRetryDelay
=
500
*
time
.
Millisecond
// singleAccountBackoffDelay 单账号分组 503 退避重试固定延时。
// Service 层在 SingleAccountRetry 模式下已做充分原地重试(最多 3 次、总等待 30s),
// Handler 层只需短暂间隔后重新进入 Service 层即可。
singleAccountBackoffDelay
=
2
*
time
.
Second
)
// FailoverState 跨循环迭代共享的 failover 状态
type
FailoverState
struct
{
SwitchCount
int
MaxSwitches
int
FailedAccountIDs
map
[
int64
]
struct
{}
SameAccountRetryCount
map
[
int64
]
int
LastFailoverErr
*
service
.
UpstreamFailoverError
ForceCacheBilling
bool
hasBoundSession
bool
}
// NewFailoverState 创建 failover 状态
func
NewFailoverState
(
maxSwitches
int
,
hasBoundSession
bool
)
*
FailoverState
{
return
&
FailoverState
{
MaxSwitches
:
maxSwitches
,
FailedAccountIDs
:
make
(
map
[
int64
]
struct
{}),
SameAccountRetryCount
:
make
(
map
[
int64
]
int
),
hasBoundSession
:
hasBoundSession
,
}
}
// HandleFailoverError 处理 UpstreamFailoverError,返回下一步动作。
// 包含:缓存计费判断、同账号重试、临时封禁、切换计数、Antigravity 延时。
func
(
s
*
FailoverState
)
HandleFailoverError
(
ctx
context
.
Context
,
gatewayService
TempUnscheduler
,
accountID
int64
,
platform
string
,
failoverErr
*
service
.
UpstreamFailoverError
,
)
FailoverAction
{
s
.
LastFailoverErr
=
failoverErr
// 缓存计费判断
if
needForceCacheBilling
(
s
.
hasBoundSession
,
failoverErr
)
{
s
.
ForceCacheBilling
=
true
}
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
if
failoverErr
.
RetryableOnSameAccount
&&
s
.
SameAccountRetryCount
[
accountID
]
<
maxSameAccountRetries
{
s
.
SameAccountRetryCount
[
accountID
]
++
logger
.
FromContext
(
ctx
)
.
Warn
(
"gateway.failover_same_account_retry"
,
zap
.
Int64
(
"account_id"
,
accountID
),
zap
.
Int
(
"upstream_status"
,
failoverErr
.
StatusCode
),
zap
.
Int
(
"same_account_retry_count"
,
s
.
SameAccountRetryCount
[
accountID
]),
zap
.
Int
(
"same_account_retry_max"
,
maxSameAccountRetries
),
)
if
!
sleepWithContext
(
ctx
,
sameAccountRetryDelay
)
{
return
FailoverCanceled
}
return
FailoverContinue
}
// 同账号重试用尽,执行临时封禁
if
failoverErr
.
RetryableOnSameAccount
{
gatewayService
.
TempUnscheduleRetryableError
(
ctx
,
accountID
,
failoverErr
)
}
// 加入失败列表
s
.
FailedAccountIDs
[
accountID
]
=
struct
{}{}
// 检查是否耗尽
if
s
.
SwitchCount
>=
s
.
MaxSwitches
{
return
FailoverExhausted
}
// 递增切换计数
s
.
SwitchCount
++
logger
.
FromContext
(
ctx
)
.
Warn
(
"gateway.failover_switch_account"
,
zap
.
Int64
(
"account_id"
,
accountID
),
zap
.
Int
(
"upstream_status"
,
failoverErr
.
StatusCode
),
zap
.
Int
(
"switch_count"
,
s
.
SwitchCount
),
zap
.
Int
(
"max_switches"
,
s
.
MaxSwitches
),
)
// Antigravity 平台换号线性递增延时
if
platform
==
service
.
PlatformAntigravity
{
delay
:=
time
.
Duration
(
s
.
SwitchCount
-
1
)
*
time
.
Second
if
!
sleepWithContext
(
ctx
,
delay
)
{
return
FailoverCanceled
}
}
return
FailoverContinue
}
// HandleSelectionExhausted 处理选号失败(所有候选账号都在排除列表中)时的退避重试决策。
// 针对 Antigravity 单账号分组的 503 (MODEL_CAPACITY_EXHAUSTED) 场景:
// 清除排除列表、等待退避后重新选号。
//
// 返回 FailoverContinue 时,调用方应设置 SingleAccountRetry context 并 continue。
// 返回 FailoverExhausted 时,调用方应返回错误响应。
// 返回 FailoverCanceled 时,调用方应直接 return。
func
(
s
*
FailoverState
)
HandleSelectionExhausted
(
ctx
context
.
Context
)
FailoverAction
{
if
s
.
LastFailoverErr
!=
nil
&&
s
.
LastFailoverErr
.
StatusCode
==
http
.
StatusServiceUnavailable
&&
s
.
SwitchCount
<=
s
.
MaxSwitches
{
logger
.
FromContext
(
ctx
)
.
Warn
(
"gateway.failover_single_account_backoff"
,
zap
.
Duration
(
"backoff_delay"
,
singleAccountBackoffDelay
),
zap
.
Int
(
"switch_count"
,
s
.
SwitchCount
),
zap
.
Int
(
"max_switches"
,
s
.
MaxSwitches
),
)
if
!
sleepWithContext
(
ctx
,
singleAccountBackoffDelay
)
{
return
FailoverCanceled
}
logger
.
FromContext
(
ctx
)
.
Warn
(
"gateway.failover_single_account_retry"
,
zap
.
Int
(
"switch_count"
,
s
.
SwitchCount
),
zap
.
Int
(
"max_switches"
,
s
.
MaxSwitches
),
)
s
.
FailedAccountIDs
=
make
(
map
[
int64
]
struct
{})
return
FailoverContinue
}
return
FailoverExhausted
}
// needForceCacheBilling 判断 failover 时是否需要强制缓存计费。
// 粘性会话切换账号、或上游明确标记时,将 input_tokens 转为 cache_read 计费。
func
needForceCacheBilling
(
hasBoundSession
bool
,
failoverErr
*
service
.
UpstreamFailoverError
)
bool
{
return
hasBoundSession
||
(
failoverErr
!=
nil
&&
failoverErr
.
ForceCacheBilling
)
}
// sleepWithContext 等待指定时长,返回 false 表示 context 已取消。
func
sleepWithContext
(
ctx
context
.
Context
,
d
time
.
Duration
)
bool
{
if
d
<=
0
{
return
true
}
select
{
case
<-
ctx
.
Done
()
:
return
false
case
<-
time
.
After
(
d
)
:
return
true
}
}
backend/internal/handler/failover_loop_test.go
0 → 100644
View file @
3d79773b
package
handler
import
(
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// Mock
// ---------------------------------------------------------------------------
// mockTempUnscheduler 记录 TempUnscheduleRetryableError 的调用信息。
type
mockTempUnscheduler
struct
{
calls
[]
tempUnscheduleCall
}
type
tempUnscheduleCall
struct
{
accountID
int64
failoverErr
*
service
.
UpstreamFailoverError
}
func
(
m
*
mockTempUnscheduler
)
TempUnscheduleRetryableError
(
_
context
.
Context
,
accountID
int64
,
failoverErr
*
service
.
UpstreamFailoverError
)
{
m
.
calls
=
append
(
m
.
calls
,
tempUnscheduleCall
{
accountID
:
accountID
,
failoverErr
:
failoverErr
})
}
// ---------------------------------------------------------------------------
// Helper
// ---------------------------------------------------------------------------
func
newTestFailoverErr
(
statusCode
int
,
retryable
,
forceBilling
bool
)
*
service
.
UpstreamFailoverError
{
return
&
service
.
UpstreamFailoverError
{
StatusCode
:
statusCode
,
RetryableOnSameAccount
:
retryable
,
ForceCacheBilling
:
forceBilling
,
}
}
// ---------------------------------------------------------------------------
// NewFailoverState 测试
// ---------------------------------------------------------------------------
func
TestNewFailoverState
(
t
*
testing
.
T
)
{
t
.
Run
(
"初始化字段正确"
,
func
(
t
*
testing
.
T
)
{
fs
:=
NewFailoverState
(
5
,
true
)
require
.
Equal
(
t
,
5
,
fs
.
MaxSwitches
)
require
.
Equal
(
t
,
0
,
fs
.
SwitchCount
)
require
.
NotNil
(
t
,
fs
.
FailedAccountIDs
)
require
.
Empty
(
t
,
fs
.
FailedAccountIDs
)
require
.
NotNil
(
t
,
fs
.
SameAccountRetryCount
)
require
.
Empty
(
t
,
fs
.
SameAccountRetryCount
)
require
.
Nil
(
t
,
fs
.
LastFailoverErr
)
require
.
False
(
t
,
fs
.
ForceCacheBilling
)
require
.
True
(
t
,
fs
.
hasBoundSession
)
})
t
.
Run
(
"无绑定会话"
,
func
(
t
*
testing
.
T
)
{
fs
:=
NewFailoverState
(
3
,
false
)
require
.
Equal
(
t
,
3
,
fs
.
MaxSwitches
)
require
.
False
(
t
,
fs
.
hasBoundSession
)
})
t
.
Run
(
"零最大切换次数"
,
func
(
t
*
testing
.
T
)
{
fs
:=
NewFailoverState
(
0
,
false
)
require
.
Equal
(
t
,
0
,
fs
.
MaxSwitches
)
})
}
// ---------------------------------------------------------------------------
// sleepWithContext 测试
// ---------------------------------------------------------------------------
func
TestSleepWithContext
(
t
*
testing
.
T
)
{
t
.
Run
(
"零时长立即返回true"
,
func
(
t
*
testing
.
T
)
{
start
:=
time
.
Now
()
ok
:=
sleepWithContext
(
context
.
Background
(),
0
)
require
.
True
(
t
,
ok
)
require
.
Less
(
t
,
time
.
Since
(
start
),
50
*
time
.
Millisecond
)
})
t
.
Run
(
"负时长立即返回true"
,
func
(
t
*
testing
.
T
)
{
start
:=
time
.
Now
()
ok
:=
sleepWithContext
(
context
.
Background
(),
-
1
*
time
.
Second
)
require
.
True
(
t
,
ok
)
require
.
Less
(
t
,
time
.
Since
(
start
),
50
*
time
.
Millisecond
)
})
t
.
Run
(
"正常等待后返回true"
,
func
(
t
*
testing
.
T
)
{
start
:=
time
.
Now
()
ok
:=
sleepWithContext
(
context
.
Background
(),
50
*
time
.
Millisecond
)
elapsed
:=
time
.
Since
(
start
)
require
.
True
(
t
,
ok
)
require
.
GreaterOrEqual
(
t
,
elapsed
,
40
*
time
.
Millisecond
)
require
.
Less
(
t
,
elapsed
,
500
*
time
.
Millisecond
)
})
t
.
Run
(
"已取消context立即返回false"
,
func
(
t
*
testing
.
T
)
{
ctx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
cancel
()
start
:=
time
.
Now
()
ok
:=
sleepWithContext
(
ctx
,
5
*
time
.
Second
)
require
.
False
(
t
,
ok
)
require
.
Less
(
t
,
time
.
Since
(
start
),
50
*
time
.
Millisecond
)
})
t
.
Run
(
"等待期间context取消返回false"
,
func
(
t
*
testing
.
T
)
{
ctx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
go
func
()
{
time
.
Sleep
(
30
*
time
.
Millisecond
)
cancel
()
}()
start
:=
time
.
Now
()
ok
:=
sleepWithContext
(
ctx
,
5
*
time
.
Second
)
elapsed
:=
time
.
Since
(
start
)
require
.
False
(
t
,
ok
)
require
.
Less
(
t
,
elapsed
,
500
*
time
.
Millisecond
)
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — 基本切换流程
// ---------------------------------------------------------------------------
func
TestHandleFailoverError_BasicSwitch
(
t
*
testing
.
T
)
{
t
.
Run
(
"非重试错误_非Antigravity_直接切换"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
err
:=
newTestFailoverErr
(
500
,
false
,
false
)
action
:=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
Equal
(
t
,
1
,
fs
.
SwitchCount
)
require
.
Contains
(
t
,
fs
.
FailedAccountIDs
,
int64
(
100
))
require
.
Equal
(
t
,
err
,
fs
.
LastFailoverErr
)
require
.
False
(
t
,
fs
.
ForceCacheBilling
)
require
.
Empty
(
t
,
mock
.
calls
,
"不应调用 TempUnschedule"
)
})
t
.
Run
(
"非重试错误_Antigravity_第一次切换无延迟"
,
func
(
t
*
testing
.
T
)
{
// switchCount 从 0→1 时,sleepFailoverDelay(ctx, 1) 的延时 = (1-1)*1s = 0
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
err
:=
newTestFailoverErr
(
500
,
false
,
false
)
start
:=
time
.
Now
()
action
:=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
service
.
PlatformAntigravity
,
err
)
elapsed
:=
time
.
Since
(
start
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
Equal
(
t
,
1
,
fs
.
SwitchCount
)
require
.
Less
(
t
,
elapsed
,
200
*
time
.
Millisecond
,
"第一次切换延迟应为 0"
)
})
t
.
Run
(
"非重试错误_Antigravity_第二次切换有1秒延迟"
,
func
(
t
*
testing
.
T
)
{
// switchCount 从 1→2 时,sleepFailoverDelay(ctx, 2) 的延时 = (2-1)*1s = 1s
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
fs
.
SwitchCount
=
1
// 模拟已切换一次
err
:=
newTestFailoverErr
(
500
,
false
,
false
)
start
:=
time
.
Now
()
action
:=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
200
,
service
.
PlatformAntigravity
,
err
)
elapsed
:=
time
.
Since
(
start
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
Equal
(
t
,
2
,
fs
.
SwitchCount
)
require
.
GreaterOrEqual
(
t
,
elapsed
,
800
*
time
.
Millisecond
,
"第二次切换延迟应约 1s"
)
require
.
Less
(
t
,
elapsed
,
3
*
time
.
Second
)
})
t
.
Run
(
"连续切换直到耗尽"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
2
,
false
)
// 第一次切换:0→1
err1
:=
newTestFailoverErr
(
500
,
false
,
false
)
action
:=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err1
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
Equal
(
t
,
1
,
fs
.
SwitchCount
)
// 第二次切换:1→2
err2
:=
newTestFailoverErr
(
502
,
false
,
false
)
action
=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
200
,
"openai"
,
err2
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
Equal
(
t
,
2
,
fs
.
SwitchCount
)
// 第三次已耗尽:SwitchCount(2) >= MaxSwitches(2)
err3
:=
newTestFailoverErr
(
503
,
false
,
false
)
action
=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
300
,
"openai"
,
err3
)
require
.
Equal
(
t
,
FailoverExhausted
,
action
)
require
.
Equal
(
t
,
2
,
fs
.
SwitchCount
,
"耗尽时不应继续递增"
)
// 验证失败账号列表
require
.
Len
(
t
,
fs
.
FailedAccountIDs
,
3
)
require
.
Contains
(
t
,
fs
.
FailedAccountIDs
,
int64
(
100
))
require
.
Contains
(
t
,
fs
.
FailedAccountIDs
,
int64
(
200
))
require
.
Contains
(
t
,
fs
.
FailedAccountIDs
,
int64
(
300
))
// LastFailoverErr 应为最后一次的错误
require
.
Equal
(
t
,
err3
,
fs
.
LastFailoverErr
)
})
t
.
Run
(
"MaxSwitches为0时首次即耗尽"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
0
,
false
)
err
:=
newTestFailoverErr
(
500
,
false
,
false
)
action
:=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err
)
require
.
Equal
(
t
,
FailoverExhausted
,
action
)
require
.
Equal
(
t
,
0
,
fs
.
SwitchCount
)
require
.
Contains
(
t
,
fs
.
FailedAccountIDs
,
int64
(
100
))
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — 缓存计费 (ForceCacheBilling)
// ---------------------------------------------------------------------------
func
TestHandleFailoverError_CacheBilling
(
t
*
testing
.
T
)
{
t
.
Run
(
"hasBoundSession为true时设置ForceCacheBilling"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
true
)
// hasBoundSession=true
err
:=
newTestFailoverErr
(
500
,
false
,
false
)
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err
)
require
.
True
(
t
,
fs
.
ForceCacheBilling
)
})
t
.
Run
(
"failoverErr.ForceCacheBilling为true时设置"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
err
:=
newTestFailoverErr
(
500
,
false
,
true
)
// ForceCacheBilling=true
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err
)
require
.
True
(
t
,
fs
.
ForceCacheBilling
)
})
t
.
Run
(
"两者均为false时不设置"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
err
:=
newTestFailoverErr
(
500
,
false
,
false
)
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err
)
require
.
False
(
t
,
fs
.
ForceCacheBilling
)
})
t
.
Run
(
"一旦设置不会被后续错误重置"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
// 第一次:ForceCacheBilling=true → 设置
err1
:=
newTestFailoverErr
(
500
,
false
,
true
)
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err1
)
require
.
True
(
t
,
fs
.
ForceCacheBilling
)
// 第二次:ForceCacheBilling=false → 仍然保持 true
err2
:=
newTestFailoverErr
(
502
,
false
,
false
)
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
200
,
"openai"
,
err2
)
require
.
True
(
t
,
fs
.
ForceCacheBilling
,
"ForceCacheBilling 一旦设置不应被重置"
)
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — 同账号重试 (RetryableOnSameAccount)
// ---------------------------------------------------------------------------
func
TestHandleFailoverError_SameAccountRetry
(
t
*
testing
.
T
)
{
t
.
Run
(
"第一次重试返回FailoverContinue"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
err
:=
newTestFailoverErr
(
400
,
true
,
false
)
start
:=
time
.
Now
()
action
:=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err
)
elapsed
:=
time
.
Since
(
start
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
Equal
(
t
,
1
,
fs
.
SameAccountRetryCount
[
100
])
require
.
Equal
(
t
,
0
,
fs
.
SwitchCount
,
"同账号重试不应增加切换计数"
)
require
.
NotContains
(
t
,
fs
.
FailedAccountIDs
,
int64
(
100
),
"同账号重试不应加入失败列表"
)
require
.
Empty
(
t
,
mock
.
calls
,
"同账号重试期间不应调用 TempUnschedule"
)
// 验证等待了 sameAccountRetryDelay (500ms)
require
.
GreaterOrEqual
(
t
,
elapsed
,
400
*
time
.
Millisecond
)
require
.
Less
(
t
,
elapsed
,
2
*
time
.
Second
)
})
t
.
Run
(
"第二次重试仍返回FailoverContinue"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
err
:=
newTestFailoverErr
(
400
,
true
,
false
)
// 第一次
action
:=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
Equal
(
t
,
1
,
fs
.
SameAccountRetryCount
[
100
])
// 第二次
action
=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
Equal
(
t
,
2
,
fs
.
SameAccountRetryCount
[
100
])
require
.
Empty
(
t
,
mock
.
calls
,
"两次重试期间均不应调用 TempUnschedule"
)
})
t
.
Run
(
"第三次重试耗尽_触发TempUnschedule并切换"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
err
:=
newTestFailoverErr
(
400
,
true
,
false
)
// 第一次、第二次重试
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err
)
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err
)
require
.
Equal
(
t
,
2
,
fs
.
SameAccountRetryCount
[
100
])
// 第三次:重试已达到 maxSameAccountRetries(2),应切换账号
action
:=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
Equal
(
t
,
1
,
fs
.
SwitchCount
)
require
.
Contains
(
t
,
fs
.
FailedAccountIDs
,
int64
(
100
))
// 验证 TempUnschedule 被调用
require
.
Len
(
t
,
mock
.
calls
,
1
)
require
.
Equal
(
t
,
int64
(
100
),
mock
.
calls
[
0
]
.
accountID
)
require
.
Equal
(
t
,
err
,
mock
.
calls
[
0
]
.
failoverErr
)
})
t
.
Run
(
"不同账号独立跟踪重试次数"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
5
,
false
)
err
:=
newTestFailoverErr
(
400
,
true
,
false
)
// 账号 100 第一次重试
action
:=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
Equal
(
t
,
1
,
fs
.
SameAccountRetryCount
[
100
])
// 账号 200 第一次重试(独立计数)
action
=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
200
,
"openai"
,
err
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
Equal
(
t
,
1
,
fs
.
SameAccountRetryCount
[
200
])
require
.
Equal
(
t
,
1
,
fs
.
SameAccountRetryCount
[
100
],
"账号 100 的计数不应受影响"
)
})
t
.
Run
(
"重试耗尽后再次遇到同账号_直接切换"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
5
,
false
)
err
:=
newTestFailoverErr
(
400
,
true
,
false
)
// 耗尽账号 100 的重试
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err
)
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err
)
// 第三次: 重试耗尽 → 切换
action
:=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
// 再次遇到账号 100,计数仍为 2,条件不满足 → 直接切换
action
=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
Len
(
t
,
mock
.
calls
,
2
,
"第二次耗尽也应调用 TempUnschedule"
)
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — TempUnschedule 调用验证
// ---------------------------------------------------------------------------
func
TestHandleFailoverError_TempUnschedule
(
t
*
testing
.
T
)
{
t
.
Run
(
"非重试错误不调用TempUnschedule"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
err
:=
newTestFailoverErr
(
500
,
false
,
false
)
// RetryableOnSameAccount=false
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err
)
require
.
Empty
(
t
,
mock
.
calls
)
})
t
.
Run
(
"重试错误耗尽后调用TempUnschedule_传入正确参数"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
err
:=
newTestFailoverErr
(
502
,
true
,
false
)
// 耗尽重试
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
42
,
"openai"
,
err
)
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
42
,
"openai"
,
err
)
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
42
,
"openai"
,
err
)
require
.
Len
(
t
,
mock
.
calls
,
1
)
require
.
Equal
(
t
,
int64
(
42
),
mock
.
calls
[
0
]
.
accountID
)
require
.
Equal
(
t
,
502
,
mock
.
calls
[
0
]
.
failoverErr
.
StatusCode
)
require
.
True
(
t
,
mock
.
calls
[
0
]
.
failoverErr
.
RetryableOnSameAccount
)
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — Context 取消
// ---------------------------------------------------------------------------
func
TestHandleFailoverError_ContextCanceled
(
t
*
testing
.
T
)
{
t
.
Run
(
"同账号重试sleep期间context取消"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
err
:=
newTestFailoverErr
(
400
,
true
,
false
)
ctx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
cancel
()
// 立即取消
start
:=
time
.
Now
()
action
:=
fs
.
HandleFailoverError
(
ctx
,
mock
,
100
,
"openai"
,
err
)
elapsed
:=
time
.
Since
(
start
)
require
.
Equal
(
t
,
FailoverCanceled
,
action
)
require
.
Less
(
t
,
elapsed
,
100
*
time
.
Millisecond
,
"应立即返回"
)
// 重试计数仍应递增
require
.
Equal
(
t
,
1
,
fs
.
SameAccountRetryCount
[
100
])
})
t
.
Run
(
"Antigravity延迟期间context取消"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
fs
.
SwitchCount
=
1
// 下一次 switchCount=2 → delay = 1s
err
:=
newTestFailoverErr
(
500
,
false
,
false
)
ctx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
cancel
()
// 立即取消
start
:=
time
.
Now
()
action
:=
fs
.
HandleFailoverError
(
ctx
,
mock
,
100
,
service
.
PlatformAntigravity
,
err
)
elapsed
:=
time
.
Since
(
start
)
require
.
Equal
(
t
,
FailoverCanceled
,
action
)
require
.
Less
(
t
,
elapsed
,
100
*
time
.
Millisecond
,
"应立即返回而非等待 1s"
)
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — FailedAccountIDs 跟踪
// ---------------------------------------------------------------------------
func
TestHandleFailoverError_FailedAccountIDs
(
t
*
testing
.
T
)
{
t
.
Run
(
"切换时添加到失败列表"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
newTestFailoverErr
(
500
,
false
,
false
))
require
.
Contains
(
t
,
fs
.
FailedAccountIDs
,
int64
(
100
))
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
200
,
"openai"
,
newTestFailoverErr
(
502
,
false
,
false
))
require
.
Contains
(
t
,
fs
.
FailedAccountIDs
,
int64
(
200
))
require
.
Len
(
t
,
fs
.
FailedAccountIDs
,
2
)
})
t
.
Run
(
"耗尽时也添加到失败列表"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
0
,
false
)
action
:=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
newTestFailoverErr
(
500
,
false
,
false
))
require
.
Equal
(
t
,
FailoverExhausted
,
action
)
require
.
Contains
(
t
,
fs
.
FailedAccountIDs
,
int64
(
100
))
})
t
.
Run
(
"同账号重试期间不添加到失败列表"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
action
:=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
newTestFailoverErr
(
400
,
true
,
false
))
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
NotContains
(
t
,
fs
.
FailedAccountIDs
,
int64
(
100
))
})
t
.
Run
(
"同一账号多次切换不重复添加"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
5
,
false
)
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
newTestFailoverErr
(
500
,
false
,
false
))
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
newTestFailoverErr
(
500
,
false
,
false
))
require
.
Len
(
t
,
fs
.
FailedAccountIDs
,
1
,
"map 天然去重"
)
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — LastFailoverErr 更新
// ---------------------------------------------------------------------------
func
TestHandleFailoverError_LastFailoverErr
(
t
*
testing
.
T
)
{
t
.
Run
(
"每次调用都更新LastFailoverErr"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
err1
:=
newTestFailoverErr
(
500
,
false
,
false
)
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err1
)
require
.
Equal
(
t
,
err1
,
fs
.
LastFailoverErr
)
err2
:=
newTestFailoverErr
(
502
,
false
,
false
)
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
200
,
"openai"
,
err2
)
require
.
Equal
(
t
,
err2
,
fs
.
LastFailoverErr
)
})
t
.
Run
(
"同账号重试时也更新LastFailoverErr"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
err
:=
newTestFailoverErr
(
400
,
true
,
false
)
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err
)
require
.
Equal
(
t
,
err
,
fs
.
LastFailoverErr
)
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — 综合集成场景
// ---------------------------------------------------------------------------
func
TestHandleFailoverError_IntegrationScenario
(
t
*
testing
.
T
)
{
t
.
Run
(
"模拟完整failover流程_多账号混合重试与切换"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
true
)
// hasBoundSession=true
// 1. 账号 100 遇到可重试错误,同账号重试 2 次
retryErr
:=
newTestFailoverErr
(
400
,
true
,
false
)
action
:=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
retryErr
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
True
(
t
,
fs
.
ForceCacheBilling
,
"hasBoundSession=true 应设置 ForceCacheBilling"
)
action
=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
retryErr
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
// 2. 账号 100 重试耗尽 → TempUnschedule + 切换
action
=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
retryErr
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
Equal
(
t
,
1
,
fs
.
SwitchCount
)
require
.
Len
(
t
,
mock
.
calls
,
1
)
// 3. 账号 200 遇到不可重试错误 → 直接切换
switchErr
:=
newTestFailoverErr
(
500
,
false
,
false
)
action
=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
200
,
"openai"
,
switchErr
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
Equal
(
t
,
2
,
fs
.
SwitchCount
)
// 4. 账号 300 遇到不可重试错误 → 再切换
action
=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
300
,
"openai"
,
switchErr
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
Equal
(
t
,
3
,
fs
.
SwitchCount
)
// 5. 账号 400 → 已耗尽 (SwitchCount=3 >= MaxSwitches=3)
action
=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
400
,
"openai"
,
switchErr
)
require
.
Equal
(
t
,
FailoverExhausted
,
action
)
// 最终状态验证
require
.
Equal
(
t
,
3
,
fs
.
SwitchCount
,
"耗尽时不再递增"
)
require
.
Len
(
t
,
fs
.
FailedAccountIDs
,
4
,
"4个不同账号都在失败列表中"
)
require
.
True
(
t
,
fs
.
ForceCacheBilling
)
require
.
Len
(
t
,
mock
.
calls
,
1
,
"只有账号 100 触发了 TempUnschedule"
)
})
t
.
Run
(
"模拟Antigravity平台完整流程"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
2
,
false
)
err
:=
newTestFailoverErr
(
500
,
false
,
false
)
// 第一次切换:delay = 0s
start
:=
time
.
Now
()
action
:=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
service
.
PlatformAntigravity
,
err
)
elapsed
:=
time
.
Since
(
start
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
Less
(
t
,
elapsed
,
200
*
time
.
Millisecond
,
"第一次切换延迟为 0"
)
// 第二次切换:delay = 1s
start
=
time
.
Now
()
action
=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
200
,
service
.
PlatformAntigravity
,
err
)
elapsed
=
time
.
Since
(
start
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
GreaterOrEqual
(
t
,
elapsed
,
800
*
time
.
Millisecond
,
"第二次切换延迟约 1s"
)
// 第三次:耗尽(无延迟,因为在检查延迟之前就返回了)
start
=
time
.
Now
()
action
=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
300
,
service
.
PlatformAntigravity
,
err
)
elapsed
=
time
.
Since
(
start
)
require
.
Equal
(
t
,
FailoverExhausted
,
action
)
require
.
Less
(
t
,
elapsed
,
200
*
time
.
Millisecond
,
"耗尽时不应有延迟"
)
})
t
.
Run
(
"ForceCacheBilling通过错误标志设置"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
// hasBoundSession=false
// 第一次:ForceCacheBilling=false
err1
:=
newTestFailoverErr
(
500
,
false
,
false
)
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err1
)
require
.
False
(
t
,
fs
.
ForceCacheBilling
)
// 第二次:ForceCacheBilling=true(Antigravity 粘性会话切换)
err2
:=
newTestFailoverErr
(
500
,
false
,
true
)
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
200
,
"openai"
,
err2
)
require
.
True
(
t
,
fs
.
ForceCacheBilling
,
"错误标志应触发 ForceCacheBilling"
)
// 第三次:ForceCacheBilling=false,但状态仍保持 true
err3
:=
newTestFailoverErr
(
500
,
false
,
false
)
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
300
,
"openai"
,
err3
)
require
.
True
(
t
,
fs
.
ForceCacheBilling
,
"不应重置"
)
})
}
// ---------------------------------------------------------------------------
// HandleFailoverError — 边界条件
// ---------------------------------------------------------------------------
func
TestHandleFailoverError_EdgeCases
(
t
*
testing
.
T
)
{
t
.
Run
(
"StatusCode为0的错误也能正常处理"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
err
:=
newTestFailoverErr
(
0
,
false
,
false
)
action
:=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
"openai"
,
err
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
})
t
.
Run
(
"AccountID为0也能正常跟踪"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
err
:=
newTestFailoverErr
(
500
,
true
,
false
)
action
:=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
0
,
"openai"
,
err
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
Equal
(
t
,
1
,
fs
.
SameAccountRetryCount
[
0
])
})
t
.
Run
(
"负AccountID也能正常跟踪"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
err
:=
newTestFailoverErr
(
500
,
true
,
false
)
action
:=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
-
1
,
"openai"
,
err
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
Equal
(
t
,
1
,
fs
.
SameAccountRetryCount
[
-
1
])
})
t
.
Run
(
"空平台名称不触发Antigravity延迟"
,
func
(
t
*
testing
.
T
)
{
mock
:=
&
mockTempUnscheduler
{}
fs
:=
NewFailoverState
(
3
,
false
)
fs
.
SwitchCount
=
1
err
:=
newTestFailoverErr
(
500
,
false
,
false
)
start
:=
time
.
Now
()
action
:=
fs
.
HandleFailoverError
(
context
.
Background
(),
mock
,
100
,
""
,
err
)
elapsed
:=
time
.
Since
(
start
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
Less
(
t
,
elapsed
,
200
*
time
.
Millisecond
,
"空平台不应触发 Antigravity 延迟"
)
})
}
// ---------------------------------------------------------------------------
// HandleSelectionExhausted 测试
// ---------------------------------------------------------------------------
func
TestHandleSelectionExhausted
(
t
*
testing
.
T
)
{
t
.
Run
(
"无LastFailoverErr时返回Exhausted"
,
func
(
t
*
testing
.
T
)
{
fs
:=
NewFailoverState
(
3
,
false
)
// LastFailoverErr 为 nil
action
:=
fs
.
HandleSelectionExhausted
(
context
.
Background
())
require
.
Equal
(
t
,
FailoverExhausted
,
action
)
})
t
.
Run
(
"非503错误返回Exhausted"
,
func
(
t
*
testing
.
T
)
{
fs
:=
NewFailoverState
(
3
,
false
)
fs
.
LastFailoverErr
=
newTestFailoverErr
(
500
,
false
,
false
)
action
:=
fs
.
HandleSelectionExhausted
(
context
.
Background
())
require
.
Equal
(
t
,
FailoverExhausted
,
action
)
})
t
.
Run
(
"503且未耗尽_等待后返回Continue并清除失败列表"
,
func
(
t
*
testing
.
T
)
{
fs
:=
NewFailoverState
(
3
,
false
)
fs
.
LastFailoverErr
=
newTestFailoverErr
(
503
,
false
,
false
)
fs
.
FailedAccountIDs
[
100
]
=
struct
{}{}
fs
.
SwitchCount
=
1
start
:=
time
.
Now
()
action
:=
fs
.
HandleSelectionExhausted
(
context
.
Background
())
elapsed
:=
time
.
Since
(
start
)
require
.
Equal
(
t
,
FailoverContinue
,
action
)
require
.
Empty
(
t
,
fs
.
FailedAccountIDs
,
"应清除失败账号列表"
)
require
.
GreaterOrEqual
(
t
,
elapsed
,
1500
*
time
.
Millisecond
,
"应等待约 2s"
)
require
.
Less
(
t
,
elapsed
,
5
*
time
.
Second
)
})
t
.
Run
(
"503但SwitchCount已超过MaxSwitches_返回Exhausted"
,
func
(
t
*
testing
.
T
)
{
fs
:=
NewFailoverState
(
2
,
false
)
fs
.
LastFailoverErr
=
newTestFailoverErr
(
503
,
false
,
false
)
fs
.
SwitchCount
=
3
// > MaxSwitches(2)
start
:=
time
.
Now
()
action
:=
fs
.
HandleSelectionExhausted
(
context
.
Background
())
elapsed
:=
time
.
Since
(
start
)
require
.
Equal
(
t
,
FailoverExhausted
,
action
)
require
.
Less
(
t
,
elapsed
,
100
*
time
.
Millisecond
,
"不应等待"
)
})
t
.
Run
(
"503但context已取消_返回Canceled"
,
func
(
t
*
testing
.
T
)
{
fs
:=
NewFailoverState
(
3
,
false
)
fs
.
LastFailoverErr
=
newTestFailoverErr
(
503
,
false
,
false
)
ctx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
cancel
()
start
:=
time
.
Now
()
action
:=
fs
.
HandleSelectionExhausted
(
ctx
)
elapsed
:=
time
.
Since
(
start
)
require
.
Equal
(
t
,
FailoverCanceled
,
action
)
require
.
Less
(
t
,
elapsed
,
100
*
time
.
Millisecond
,
"应立即返回"
)
})
t
.
Run
(
"503且SwitchCount等于MaxSwitches_仍可重试"
,
func
(
t
*
testing
.
T
)
{
fs
:=
NewFailoverState
(
2
,
false
)
fs
.
LastFailoverErr
=
newTestFailoverErr
(
503
,
false
,
false
)
fs
.
SwitchCount
=
2
// == MaxSwitches,条件是 <=,仍可重试
action
:=
fs
.
HandleSelectionExhausted
(
context
.
Background
())
require
.
Equal
(
t
,
FailoverContinue
,
action
)
})
}
backend/internal/handler/gateway_handler.go
View file @
3d79773b
...
...
@@ -6,10 +6,10 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"strconv"
"strings"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
...
...
@@ -18,14 +18,22 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
pkgerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
pkghttputil
"github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
middleware2
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
const
gatewayCompatibilityMetricsLogInterval
=
1024
var
gatewayCompatibilityMetricsLogCounter
atomic
.
Uint64
// GatewayHandler handles API gateway requests
type
GatewayHandler
struct
{
gatewayService
*
service
.
GatewayService
...
...
@@ -35,10 +43,14 @@ type GatewayHandler struct {
billingCacheService
*
service
.
BillingCacheService
usageService
*
service
.
UsageService
apiKeyService
*
service
.
APIKeyService
usageRecordWorkerPool
*
service
.
UsageRecordWorkerPool
errorPassthroughService
*
service
.
ErrorPassthroughService
concurrencyHelper
*
ConcurrencyHelper
userMsgQueueHelper
*
UserMsgQueueHelper
maxAccountSwitches
int
maxAccountSwitchesGemini
int
cfg
*
config
.
Config
settingService
*
service
.
SettingService
}
// NewGatewayHandler creates a new GatewayHandler
...
...
@@ -51,8 +63,11 @@ func NewGatewayHandler(
billingCacheService
*
service
.
BillingCacheService
,
usageService
*
service
.
UsageService
,
apiKeyService
*
service
.
APIKeyService
,
usageRecordWorkerPool
*
service
.
UsageRecordWorkerPool
,
errorPassthroughService
*
service
.
ErrorPassthroughService
,
userMsgQueueService
*
service
.
UserMessageQueueService
,
cfg
*
config
.
Config
,
settingService
*
service
.
SettingService
,
)
*
GatewayHandler
{
pingInterval
:=
time
.
Duration
(
0
)
maxAccountSwitches
:=
10
...
...
@@ -66,6 +81,13 @@ func NewGatewayHandler(
maxAccountSwitchesGemini
=
cfg
.
Gateway
.
MaxAccountSwitchesGemini
}
}
// 初始化用户消息串行队列 helper
var
umqHelper
*
UserMsgQueueHelper
if
userMsgQueueService
!=
nil
&&
cfg
!=
nil
{
umqHelper
=
NewUserMsgQueueHelper
(
userMsgQueueService
,
SSEPingFormatClaude
,
pingInterval
)
}
return
&
GatewayHandler
{
gatewayService
:
gatewayService
,
geminiCompatService
:
geminiCompatService
,
...
...
@@ -74,10 +96,14 @@ func NewGatewayHandler(
billingCacheService
:
billingCacheService
,
usageService
:
usageService
,
apiKeyService
:
apiKeyService
,
usageRecordWorkerPool
:
usageRecordWorkerPool
,
errorPassthroughService
:
errorPassthroughService
,
concurrencyHelper
:
NewConcurrencyHelper
(
concurrencyService
,
SSEPingFormatClaude
,
pingInterval
),
userMsgQueueHelper
:
umqHelper
,
maxAccountSwitches
:
maxAccountSwitches
,
maxAccountSwitchesGemini
:
maxAccountSwitchesGemini
,
cfg
:
cfg
,
settingService
:
settingService
,
}
}
...
...
@@ -96,9 +122,17 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h
.
errorResponse
(
c
,
http
.
StatusInternalServerError
,
"api_error"
,
"User context not found"
)
return
}
reqLog
:=
requestLogger
(
c
,
"handler.gateway.messages"
,
zap
.
Int64
(
"user_id"
,
subject
.
UserID
),
zap
.
Int64
(
"api_key_id"
,
apiKey
.
ID
),
zap
.
Any
(
"group_id"
,
apiKey
.
GroupID
),
)
defer
h
.
maybeLogCompatibilityFallbackMetrics
(
reqLog
)
// 读取请求体
body
,
err
:=
io
.
ReadAll
(
c
.
Request
.
Body
)
body
,
err
:=
pkghttputil
.
ReadRequestBodyWithPrealloc
(
c
.
Request
)
if
err
!=
nil
{
if
maxErr
,
ok
:=
extractMaxBytesError
(
err
);
ok
{
h
.
errorResponse
(
c
,
http
.
StatusRequestEntityTooLarge
,
"invalid_request_error"
,
buildBodyTooLargeMessage
(
maxErr
.
Limit
))
...
...
@@ -122,20 +156,26 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
reqModel
:=
parsedReq
.
Model
reqStream
:=
parsedReq
.
Stream
reqLog
=
reqLog
.
With
(
zap
.
String
(
"model"
,
reqModel
),
zap
.
Bool
(
"stream"
,
reqStream
))
// 设置 max_tokens=1 + haiku 探测请求标识到 context 中
// 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断
if
isMaxTokensOneHaikuRequest
(
reqModel
,
parsedReq
.
MaxTokens
,
reqStream
)
{
ctx
:=
context
.
WithValue
(
c
.
Request
.
Context
(),
ctxkey
.
IsMaxTokensOneHaikuRequest
,
true
)
ctx
:=
service
.
WithIsMaxTokensOneHaikuRequest
(
c
.
Request
.
Context
(),
true
,
h
.
metadataBridgeEnabled
()
)
c
.
Request
=
c
.
Request
.
WithContext
(
ctx
)
}
// 检查是否为 Claude Code 客户端,设置到 context 中
SetClaudeCodeClientContext
(
c
,
body
)
// 检查是否为 Claude Code 客户端,设置到 context 中
(复用已解析请求,避免二次反序列化)。
SetClaudeCodeClientContext
(
c
,
body
,
parsedReq
)
isClaudeCodeClient
:=
service
.
IsClaudeCodeClient
(
c
.
Request
.
Context
())
// 版本检查:仅对 Claude Code 客户端,拒绝低于最低版本的请求
if
!
h
.
checkClaudeCodeVersion
(
c
)
{
return
}
// 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用
c
.
Request
=
c
.
Request
.
WithContext
(
context
.
WithValue
(
c
.
Request
.
Context
(),
ctxkey
.
ThinkingEnabled
,
parsedReq
.
Thinking
Enabled
))
c
.
Request
=
c
.
Request
.
WithContext
(
service
.
WithThinkingEnabled
(
c
.
Request
.
Context
(),
parsedReq
.
ThinkingEnabled
,
h
.
metadataBridge
Enabled
()
))
setOpsRequestContext
(
c
,
reqModel
,
reqStream
,
body
)
...
...
@@ -161,9 +201,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
canWait
,
err
:=
h
.
concurrencyHelper
.
IncrementWaitCount
(
c
.
Request
.
Context
(),
subject
.
UserID
,
maxWait
)
waitCounted
:=
false
if
err
!=
nil
{
log
.
Printf
(
"Increment
wait
count
failed: %v"
,
err
)
reqLog
.
Warn
(
"gateway.user_
wait
_
count
er_increment_failed"
,
zap
.
Error
(
err
)
)
// On error, allow request to proceed
}
else
if
!
canWait
{
reqLog
.
Info
(
"gateway.user_wait_queue_full"
,
zap
.
Int
(
"max_wait"
,
maxWait
))
h
.
errorResponse
(
c
,
http
.
StatusTooManyRequests
,
"rate_limit_error"
,
"Too many pending requests, please retry later"
)
return
}
...
...
@@ -180,7 +221,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 1. 首先获取用户并发槽位
userReleaseFunc
,
err
:=
h
.
concurrencyHelper
.
AcquireUserSlotWithWait
(
c
,
subject
.
UserID
,
subject
.
Concurrency
,
reqStream
,
&
streamStarted
)
if
err
!=
nil
{
log
.
Printf
(
"User concurrency
acquire
failed
: %v"
,
err
)
reqLog
.
Warn
(
"gateway.user_slot_
acquire
_
failed
"
,
zap
.
Error
(
err
)
)
h
.
handleConcurrencyError
(
c
,
err
,
"user"
,
streamStarted
)
return
}
...
...
@@ -197,7 +238,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 2. 【新增】Wait后二次检查余额/订阅
if
err
:=
h
.
billingCacheService
.
CheckBillingEligibility
(
c
.
Request
.
Context
(),
apiKey
.
User
,
apiKey
,
apiKey
.
Group
,
subscription
);
err
!=
nil
{
log
.
Printf
(
"B
illing
eligibility
check
failed
after wait: %v"
,
err
)
reqLog
.
Info
(
"gateway.b
illing
_
eligibility
_
check
_
failed
"
,
zap
.
Error
(
err
)
)
status
,
code
,
message
:=
billingErrorDetails
(
err
)
h
.
handleStreamingAwareError
(
c
,
status
,
code
,
message
,
streamStarted
)
return
...
...
@@ -227,54 +268,54 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
var
sessionBoundAccountID
int64
if
sessionKey
!=
""
{
sessionBoundAccountID
,
_
=
h
.
gatewayService
.
GetCachedSessionAccountID
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
)
if
sessionBoundAccountID
>
0
{
prefetchedGroupID
:=
int64
(
0
)
if
apiKey
.
GroupID
!=
nil
{
prefetchedGroupID
=
*
apiKey
.
GroupID
}
ctx
:=
service
.
WithPrefetchedStickySession
(
c
.
Request
.
Context
(),
sessionBoundAccountID
,
prefetchedGroupID
,
h
.
metadataBridgeEnabled
())
c
.
Request
=
c
.
Request
.
WithContext
(
ctx
)
}
}
// 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号
hasBoundSession
:=
sessionKey
!=
""
&&
sessionBoundAccountID
>
0
if
platform
==
service
.
PlatformGemini
{
maxAccountSwitches
:=
h
.
maxAccountSwitchesGemini
switchCount
:=
0
failedAccountIDs
:=
make
(
map
[
int64
]
struct
{})
sameAccountRetryCount
:=
make
(
map
[
int64
]
int
)
// 同账号重试计数
var
lastFailoverErr
*
service
.
UpstreamFailoverError
var
forceCacheBilling
bool
// 粘性会话切换时的缓存计费标记
fs
:=
NewFailoverState
(
h
.
maxAccountSwitchesGemini
,
hasBoundSession
)
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
if
h
.
gatewayService
.
IsSingleAntigravityAccountGroup
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
)
{
ctx
:=
context
.
WithValue
(
c
.
Request
.
Context
(),
ctxkey
.
SingleAccountRetry
,
true
)
ctx
:=
service
.
WithSingleAccountRetry
(
c
.
Request
.
Context
(),
true
,
h
.
metadataBridgeEnabled
()
)
c
.
Request
=
c
.
Request
.
WithContext
(
ctx
)
}
for
{
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
reqModel
,
failedAccountIDs
,
""
)
// Gemini 不使用会话限制
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
reqModel
,
f
s
.
F
ailedAccountIDs
,
""
)
// Gemini 不使用会话限制
if
err
!=
nil
{
if
len
(
failedAccountIDs
)
==
0
{
if
len
(
f
s
.
F
ailedAccountIDs
)
==
0
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts: "
+
err
.
Error
(),
streamStarted
)
return
}
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
if
lastFailoverErr
!=
nil
&&
lastFailoverErr
.
StatusCode
==
http
.
StatusServiceUnavailable
&&
switchCount
<=
maxAccountSwitches
{
if
sleepAntigravitySingleAccountBackoff
(
c
.
Request
.
Context
(),
switchCount
)
{
log
.
Printf
(
"Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d"
,
switchCount
,
maxAccountSwitches
)
failedAccountIDs
=
make
(
map
[
int64
]
struct
{})
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
ctx
:=
context
.
WithValue
(
c
.
Request
.
Context
(),
ctxkey
.
SingleAccountRetry
,
true
)
c
.
Request
=
c
.
Request
.
WithContext
(
ctx
)
continue
action
:=
fs
.
HandleSelectionExhausted
(
c
.
Request
.
Context
())
switch
action
{
case
FailoverContinue
:
ctx
:=
service
.
WithSingleAccountRetry
(
c
.
Request
.
Context
(),
true
,
h
.
metadataBridgeEnabled
())
c
.
Request
=
c
.
Request
.
WithContext
(
ctx
)
continue
case
FailoverCanceled
:
return
default
:
// FailoverExhausted
if
fs
.
LastFailoverErr
!=
nil
{
h
.
handleFailoverExhausted
(
c
,
fs
.
LastFailoverErr
,
service
.
PlatformGemini
,
streamStarted
)
}
else
{
h
.
handleFailoverExhaustedSimple
(
c
,
502
,
streamStarted
)
}
return
}
if
lastFailoverErr
!=
nil
{
h
.
handleFailoverExhausted
(
c
,
lastFailoverErr
,
service
.
PlatformGemini
,
streamStarted
)
}
else
{
h
.
handleFailoverExhaustedSimple
(
c
,
502
,
streamStarted
)
}
return
}
account
:=
selection
.
Account
setOpsSelectedAccount
(
c
,
account
.
ID
)
setOpsSelectedAccount
(
c
,
account
.
ID
,
account
.
Platform
)
// 检查请求拦截(预热请求、SUGGESTION MODE等)
if
account
.
IsInterceptWarmupEnabled
()
{
...
...
@@ -302,21 +343,24 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
accountWaitCounted
:=
false
canWait
,
err
:=
h
.
concurrencyHelper
.
IncrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
,
selection
.
WaitPlan
.
MaxWaiting
)
if
err
!=
nil
{
log
.
Printf
(
"Increment
account
wait
count
failed: %v"
,
err
)
reqLog
.
Warn
(
"gateway.
account
_
wait
_
count
er_increment_failed"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Error
(
err
)
)
}
else
if
!
canWait
{
log
.
Printf
(
"Account wait queue full: account=%d"
,
account
.
ID
)
reqLog
.
Info
(
"gateway.account_wait_queue_full"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Int
(
"max_waiting"
,
selection
.
WaitPlan
.
MaxWaiting
),
)
h
.
handleStreamingAwareError
(
c
,
http
.
StatusTooManyRequests
,
"rate_limit_error"
,
"Too many pending requests, please retry later"
,
streamStarted
)
return
}
if
err
==
nil
&&
canWait
{
accountWaitCounted
=
true
}
// Ensure the wait counter is decremented if we exit before acquiring the slot.
defer
func
()
{
releaseWait
:=
func
()
{
if
accountWaitCounted
{
h
.
concurrencyHelper
.
DecrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
)
accountWaitCounted
=
false
}
}
()
}
accountReleaseFunc
,
err
=
h
.
concurrencyHelper
.
AcquireAccountSlotWithWaitTimeout
(
c
,
...
...
@@ -327,17 +371,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
&
streamStarted
,
)
if
err
!=
nil
{
log
.
Printf
(
"Account concurrency acquire failed: %v"
,
err
)
reqLog
.
Warn
(
"gateway.account_slot_acquire_failed"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Error
(
err
))
releaseWait
()
h
.
handleConcurrencyError
(
c
,
err
,
"account"
,
streamStarted
)
return
}
// Slot acquired: no longer waiting in queue.
if
accountWaitCounted
{
h
.
concurrencyHelper
.
DecrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
)
accountWaitCounted
=
false
}
releaseWait
()
if
err
:=
h
.
gatewayService
.
BindStickySession
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
account
.
ID
);
err
!=
nil
{
log
.
Printf
(
"B
ind
sticky
session
failed
: %v"
,
err
)
reqLog
.
Warn
(
"gateway.b
ind
_
sticky
_
session
_
failed
"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Error
(
err
)
)
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
...
...
@@ -346,8 +388,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 转发请求 - 根据账号平台分流
var
result
*
service
.
ForwardResult
requestCtx
:=
c
.
Request
.
Context
()
if
s
witchCount
>
0
{
requestCtx
=
context
.
WithValue
(
requestCtx
,
ctxkey
.
AccountSwitchCount
,
switchCount
)
if
fs
.
S
witchCount
>
0
{
requestCtx
=
service
.
WithAccountSwitchCount
(
requestCtx
,
fs
.
SwitchCount
,
h
.
metadataBridgeEnabled
()
)
}
if
account
.
Platform
==
service
.
PlatformAntigravity
{
result
,
err
=
h
.
antigravityGatewayService
.
ForwardGemini
(
requestCtx
,
c
,
account
,
reqModel
,
"generateContent"
,
reqStream
,
body
,
hasBoundSession
)
...
...
@@ -360,68 +402,62 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if
err
!=
nil
{
var
failoverErr
*
service
.
UpstreamFailoverError
if
errors
.
As
(
err
,
&
failoverErr
)
{
lastFailoverErr
=
failoverErr
if
needForceCacheBilling
(
hasBoundSession
,
failoverErr
)
{
forceCacheBilling
=
true
}
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
if
failoverErr
.
RetryableOnSameAccount
&&
sameAccountRetryCount
[
account
.
ID
]
<
maxSameAccountRetries
{
sameAccountRetryCount
[
account
.
ID
]
++
log
.
Printf
(
"Account %d: retryable error %d, same-account retry %d/%d"
,
account
.
ID
,
failoverErr
.
StatusCode
,
sameAccountRetryCount
[
account
.
ID
],
maxSameAccountRetries
)
if
!
sleepSameAccountRetryDelay
(
c
.
Request
.
Context
())
{
return
}
action
:=
fs
.
HandleFailoverError
(
c
.
Request
.
Context
(),
h
.
gatewayService
,
account
.
ID
,
account
.
Platform
,
failoverErr
)
switch
action
{
case
FailoverContinue
:
continue
}
// 同账号重试用尽,执行临时封禁并切换账号
if
failoverErr
.
RetryableOnSameAccount
{
h
.
gatewayService
.
TempUnscheduleRetryableError
(
c
.
Request
.
Context
(),
account
.
ID
,
failoverErr
)
}
failedAccountIDs
[
account
.
ID
]
=
struct
{}{}
if
switchCount
>=
maxAccountSwitches
{
h
.
handleFailoverExhausted
(
c
,
failoverErr
,
service
.
PlatformGemini
,
streamStarted
)
case
FailoverExhausted
:
h
.
handleFailoverExhausted
(
c
,
fs
.
LastFailoverErr
,
service
.
PlatformGemini
,
streamStarted
)
return
case
FailoverCanceled
:
return
}
switchCount
++
log
.
Printf
(
"Account %d: upstream error %d, switching account %d/%d"
,
account
.
ID
,
failoverErr
.
StatusCode
,
switchCount
,
maxAccountSwitches
)
if
account
.
Platform
==
service
.
PlatformAntigravity
{
if
!
sleepFailoverDelay
(
c
.
Request
.
Context
(),
switchCount
)
{
return
}
}
continue
}
// 错误响应已在Forward中处理,这里只记录日志
log
.
Printf
(
"Forward request failed: %v"
,
err
)
wroteFallback
:=
h
.
ensureForwardErrorResponse
(
c
,
streamStarted
)
reqLog
.
Error
(
"gateway.forward_failed"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Bool
(
"fallback_error_response_written"
,
wroteFallback
),
zap
.
Error
(
err
),
)
return
}
// RPM 计数递增(Forward 成功后)
// 注意:TOCTOU 竞态是已知且可接受的设计权衡,与 WindowCost 一致的 soft-limit 模式。
// 在高并发下可能短暂超出 RPM 限制,但不会导致请求失败。
if
account
.
IsAnthropicOAuthOrSetupToken
()
&&
account
.
GetBaseRPM
()
>
0
{
if
err
:=
h
.
gatewayService
.
IncrementAccountRPM
(
c
.
Request
.
Context
(),
account
.
ID
);
err
!=
nil
{
reqLog
.
Warn
(
"gateway.rpm_increment_failed"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Error
(
err
))
}
}
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
userAgent
:=
c
.
GetHeader
(
"User-Agent"
)
clientIP
:=
ip
.
GetClientIP
(
c
)
// 异步记录使用量(subscription已在函数开头获取)
go
func
(
result
*
service
.
ForwardResult
,
usedAccount
*
service
.
Account
,
ua
,
clientIP
string
,
fcb
bool
)
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
10
*
time
.
Second
)
defer
cancel
()
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h
.
submitUsageRecordTask
(
func
(
ctx
context
.
Context
)
{
if
err
:=
h
.
gatewayService
.
RecordUsage
(
ctx
,
&
service
.
RecordUsageInput
{
Result
:
result
,
APIKey
:
apiKey
,
User
:
apiKey
.
User
,
Account
:
usedA
ccount
,
Account
:
a
ccount
,
Subscription
:
subscription
,
UserAgent
:
u
a
,
UserAgent
:
u
serAgent
,
IPAddress
:
clientIP
,
ForceCacheBilling
:
f
cb
,
ForceCacheBilling
:
f
s
.
ForceCacheBilling
,
APIKeyService
:
h
.
apiKeyService
,
});
err
!=
nil
{
log
.
Printf
(
"Record usage failed: %v"
,
err
)
logger
.
L
()
.
With
(
zap
.
String
(
"component"
,
"handler.gateway.messages"
),
zap
.
Int64
(
"user_id"
,
subject
.
UserID
),
zap
.
Int64
(
"api_key_id"
,
apiKey
.
ID
),
zap
.
Any
(
"group_id"
,
apiKey
.
GroupID
),
zap
.
String
(
"model"
,
reqModel
),
zap
.
Int64
(
"account_id"
,
account
.
ID
),
)
.
Error
(
"gateway.record_usage_failed"
,
zap
.
Error
(
err
))
}
}
(
result
,
account
,
userAgent
,
clientIP
,
forceCacheBilling
)
})
return
}
}
...
...
@@ -437,49 +473,41 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
if
h
.
gatewayService
.
IsSingleAntigravityAccountGroup
(
c
.
Request
.
Context
(),
currentAPIKey
.
GroupID
)
{
ctx
:=
context
.
WithValue
(
c
.
Request
.
Context
(),
ctxkey
.
SingleAccountRetry
,
true
)
ctx
:=
service
.
WithSingleAccountRetry
(
c
.
Request
.
Context
(),
true
,
h
.
metadataBridgeEnabled
()
)
c
.
Request
=
c
.
Request
.
WithContext
(
ctx
)
}
for
{
maxAccountSwitches
:=
h
.
maxAccountSwitches
switchCount
:=
0
failedAccountIDs
:=
make
(
map
[
int64
]
struct
{})
sameAccountRetryCount
:=
make
(
map
[
int64
]
int
)
// 同账号重试计数
var
lastFailoverErr
*
service
.
UpstreamFailoverError
fs
:=
NewFailoverState
(
h
.
maxAccountSwitches
,
hasBoundSession
)
retryWithFallback
:=
false
var
forceCacheBilling
bool
// 粘性会话切换时的缓存计费标记
for
{
// 选择支持该模型的账号
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
currentAPIKey
.
GroupID
,
sessionKey
,
reqModel
,
failedAccountIDs
,
parsedReq
.
MetadataUserID
)
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
currentAPIKey
.
GroupID
,
sessionKey
,
reqModel
,
f
s
.
F
ailedAccountIDs
,
parsedReq
.
MetadataUserID
)
if
err
!=
nil
{
if
len
(
failedAccountIDs
)
==
0
{
if
len
(
f
s
.
F
ailedAccountIDs
)
==
0
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts: "
+
err
.
Error
(),
streamStarted
)
return
}
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
if
lastFailoverErr
!=
nil
&&
lastFailoverErr
.
StatusCode
==
http
.
StatusServiceUnavailable
&&
switchCount
<=
maxAccountSwitches
{
if
sleepAntigravitySingleAccountBackoff
(
c
.
Request
.
Context
(),
switchCount
)
{
log
.
Printf
(
"Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d"
,
switchCount
,
maxAccountSwitches
)
failedAccountIDs
=
make
(
map
[
int64
]
struct
{})
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
ctx
:=
context
.
WithValue
(
c
.
Request
.
Context
(),
ctxkey
.
SingleAccountRetry
,
true
)
c
.
Request
=
c
.
Request
.
WithContext
(
ctx
)
continue
action
:=
fs
.
HandleSelectionExhausted
(
c
.
Request
.
Context
())
switch
action
{
case
FailoverContinue
:
ctx
:=
service
.
WithSingleAccountRetry
(
c
.
Request
.
Context
(),
true
,
h
.
metadataBridgeEnabled
())
c
.
Request
=
c
.
Request
.
WithContext
(
ctx
)
continue
case
FailoverCanceled
:
return
default
:
// FailoverExhausted
if
fs
.
LastFailoverErr
!=
nil
{
h
.
handleFailoverExhausted
(
c
,
fs
.
LastFailoverErr
,
platform
,
streamStarted
)
}
else
{
h
.
handleFailoverExhaustedSimple
(
c
,
502
,
streamStarted
)
}
return
}
if
lastFailoverErr
!=
nil
{
h
.
handleFailoverExhausted
(
c
,
lastFailoverErr
,
platform
,
streamStarted
)
}
else
{
h
.
handleFailoverExhaustedSimple
(
c
,
502
,
streamStarted
)
}
return
}
account
:=
selection
.
Account
setOpsSelectedAccount
(
c
,
account
.
ID
)
setOpsSelectedAccount
(
c
,
account
.
ID
,
account
.
Platform
)
// 检查请求拦截(预热请求、SUGGESTION MODE等)
if
account
.
IsInterceptWarmupEnabled
()
{
...
...
@@ -507,20 +535,24 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
accountWaitCounted
:=
false
canWait
,
err
:=
h
.
concurrencyHelper
.
IncrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
,
selection
.
WaitPlan
.
MaxWaiting
)
if
err
!=
nil
{
log
.
Printf
(
"Increment
account
wait
count
failed: %v"
,
err
)
reqLog
.
Warn
(
"gateway.
account
_
wait
_
count
er_increment_failed"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Error
(
err
)
)
}
else
if
!
canWait
{
log
.
Printf
(
"Account wait queue full: account=%d"
,
account
.
ID
)
reqLog
.
Info
(
"gateway.account_wait_queue_full"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Int
(
"max_waiting"
,
selection
.
WaitPlan
.
MaxWaiting
),
)
h
.
handleStreamingAwareError
(
c
,
http
.
StatusTooManyRequests
,
"rate_limit_error"
,
"Too many pending requests, please retry later"
,
streamStarted
)
return
}
if
err
==
nil
&&
canWait
{
accountWaitCounted
=
true
}
defer
func
()
{
releaseWait
:=
func
()
{
if
accountWaitCounted
{
h
.
concurrencyHelper
.
DecrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
)
accountWaitCounted
=
false
}
}
()
}
accountReleaseFunc
,
err
=
h
.
concurrencyHelper
.
AcquireAccountSlotWithWaitTimeout
(
c
,
...
...
@@ -531,50 +563,117 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
&
streamStarted
,
)
if
err
!=
nil
{
log
.
Printf
(
"Account concurrency acquire failed: %v"
,
err
)
reqLog
.
Warn
(
"gateway.account_slot_acquire_failed"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Error
(
err
))
releaseWait
()
h
.
handleConcurrencyError
(
c
,
err
,
"account"
,
streamStarted
)
return
}
if
accountWaitCounted
{
h
.
concurrencyHelper
.
DecrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
)
accountWaitCounted
=
false
}
// Slot acquired: no longer waiting in queue.
releaseWait
()
if
err
:=
h
.
gatewayService
.
BindStickySession
(
c
.
Request
.
Context
(),
currentAPIKey
.
GroupID
,
sessionKey
,
account
.
ID
);
err
!=
nil
{
log
.
Printf
(
"B
ind
sticky
session
failed
: %v"
,
err
)
reqLog
.
Warn
(
"gateway.b
ind
_
sticky
_
session
_
failed
"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Error
(
err
)
)
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc
=
wrapReleaseOnDone
(
c
.
Request
.
Context
(),
accountReleaseFunc
)
// ===== 用户消息串行队列 START =====
var
queueRelease
func
()
umqMode
:=
h
.
getUserMsgQueueMode
(
account
,
parsedReq
)
switch
umqMode
{
case
config
.
UMQModeSerialize
:
// 串行模式:获取锁 + RPM 延迟 + 释放(当前行为不变)
baseRPM
:=
account
.
GetBaseRPM
()
release
,
qErr
:=
h
.
userMsgQueueHelper
.
AcquireWithWait
(
c
,
account
.
ID
,
baseRPM
,
reqStream
,
&
streamStarted
,
h
.
cfg
.
Gateway
.
UserMessageQueue
.
WaitTimeout
(),
reqLog
,
)
if
qErr
!=
nil
{
// fail-open: 记录 warn,不阻止请求
reqLog
.
Warn
(
"gateway.umq_acquire_failed"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Error
(
qErr
),
)
}
else
{
queueRelease
=
release
}
case
config
.
UMQModeThrottle
:
// 软性限速:仅施加 RPM 自适应延迟,不阻塞并发
baseRPM
:=
account
.
GetBaseRPM
()
if
tErr
:=
h
.
userMsgQueueHelper
.
ThrottleWithPing
(
c
,
account
.
ID
,
baseRPM
,
reqStream
,
&
streamStarted
,
h
.
cfg
.
Gateway
.
UserMessageQueue
.
WaitTimeout
(),
reqLog
,
);
tErr
!=
nil
{
reqLog
.
Warn
(
"gateway.umq_throttle_failed"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Error
(
tErr
),
)
}
default
:
if
umqMode
!=
""
{
reqLog
.
Warn
(
"gateway.umq_unknown_mode"
,
zap
.
String
(
"mode"
,
umqMode
),
zap
.
Int64
(
"account_id"
,
account
.
ID
),
)
}
}
// 用 wrapReleaseOnDone 确保 context 取消时自动释放(仅 serialize 模式有 queueRelease)
queueRelease
=
wrapReleaseOnDone
(
c
.
Request
.
Context
(),
queueRelease
)
// 注入回调到 ParsedRequest:使用外层 wrapper 以便提前清理 AfterFunc
parsedReq
.
OnUpstreamAccepted
=
queueRelease
// ===== 用户消息串行队列 END =====
// 转发请求 - 根据账号平台分流
var
result
*
service
.
ForwardResult
requestCtx
:=
c
.
Request
.
Context
()
if
s
witchCount
>
0
{
requestCtx
=
context
.
WithValue
(
requestCtx
,
ctxkey
.
AccountSwitchCount
,
switchCount
)
if
fs
.
S
witchCount
>
0
{
requestCtx
=
service
.
WithAccountSwitchCount
(
requestCtx
,
fs
.
SwitchCount
,
h
.
metadataBridgeEnabled
()
)
}
if
account
.
Platform
==
service
.
PlatformAntigravity
&&
account
.
Type
!=
service
.
AccountTypeAPIKey
{
result
,
err
=
h
.
antigravityGatewayService
.
Forward
(
requestCtx
,
c
,
account
,
body
,
hasBoundSession
)
}
else
{
result
,
err
=
h
.
gatewayService
.
Forward
(
requestCtx
,
c
,
account
,
parsedReq
)
}
// 兜底释放串行锁(正常情况已通过回调提前释放)
if
queueRelease
!=
nil
{
queueRelease
()
}
// 清理回调引用,防止 failover 重试时旧回调被错误调用
parsedReq
.
OnUpstreamAccepted
=
nil
if
accountReleaseFunc
!=
nil
{
accountReleaseFunc
()
}
if
err
!=
nil
{
var
promptTooLongErr
*
service
.
PromptTooLongError
if
errors
.
As
(
err
,
&
promptTooLongErr
)
{
log
.
Printf
(
"Prompt too long from antigravity: group=%d fallback_group_id=%v fallback_used=%v"
,
currentAPIKey
.
GroupID
,
fallbackGroupID
,
fallbackUsed
)
reqLog
.
Warn
(
"gateway.prompt_too_long_from_antigravity"
,
zap
.
Any
(
"current_group_id"
,
currentAPIKey
.
GroupID
),
zap
.
Any
(
"fallback_group_id"
,
fallbackGroupID
),
zap
.
Bool
(
"fallback_used"
,
fallbackUsed
),
)
if
!
fallbackUsed
&&
fallbackGroupID
!=
nil
&&
*
fallbackGroupID
>
0
{
fallbackGroup
,
err
:=
h
.
gatewayService
.
ResolveGroupByID
(
c
.
Request
.
Context
(),
*
fallbackGroupID
)
if
err
!=
nil
{
log
.
Printf
(
"R
esolve
fallback
group
failed
: %v"
,
err
)
reqLog
.
Warn
(
"gateway.r
esolve
_
fallback
_
group
_
failed
"
,
zap
.
Int64
(
"fallback_group_id"
,
*
fallbackGroupID
),
zap
.
Error
(
err
)
)
_
=
h
.
antigravityGatewayService
.
WriteMappedClaudeError
(
c
,
account
,
promptTooLongErr
.
StatusCode
,
promptTooLongErr
.
RequestID
,
promptTooLongErr
.
Body
)
return
}
if
fallbackGroup
.
Platform
!=
service
.
PlatformAnthropic
||
fallbackGroup
.
SubscriptionType
==
service
.
SubscriptionTypeSubscription
||
fallbackGroup
.
FallbackGroupIDOnInvalidRequest
!=
nil
{
log
.
Printf
(
"Fallback group invalid: group=%d platform=%s subscription=%s"
,
fallbackGroup
.
ID
,
fallbackGroup
.
Platform
,
fallbackGroup
.
SubscriptionType
)
reqLog
.
Warn
(
"gateway.fallback_group_invalid"
,
zap
.
Int64
(
"fallback_group_id"
,
fallbackGroup
.
ID
),
zap
.
String
(
"fallback_platform"
,
fallbackGroup
.
Platform
),
zap
.
String
(
"fallback_subscription_type"
,
fallbackGroup
.
SubscriptionType
),
)
_
=
h
.
antigravityGatewayService
.
WriteMappedClaudeError
(
c
,
account
,
promptTooLongErr
.
StatusCode
,
promptTooLongErr
.
RequestID
,
promptTooLongErr
.
Body
)
return
}
...
...
@@ -584,7 +683,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h
.
handleStreamingAwareError
(
c
,
status
,
code
,
message
,
streamStarted
)
return
}
// 兜底重试按
“
直接请求兜底分组
”
处理:清除强制平台,允许按分组平台调度
// 兜底重试按
"
直接请求兜底分组
"
处理:清除强制平台,允许按分组平台调度
ctx
:=
context
.
WithValue
(
c
.
Request
.
Context
(),
ctxkey
.
ForcePlatform
,
""
)
c
.
Request
=
c
.
Request
.
WithContext
(
ctx
)
currentAPIKey
=
fallbackAPIKey
...
...
@@ -598,68 +697,62 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
var
failoverErr
*
service
.
UpstreamFailoverError
if
errors
.
As
(
err
,
&
failoverErr
)
{
lastFailoverErr
=
failoverErr
if
needForceCacheBilling
(
hasBoundSession
,
failoverErr
)
{
forceCacheBilling
=
true
}
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
if
failoverErr
.
RetryableOnSameAccount
&&
sameAccountRetryCount
[
account
.
ID
]
<
maxSameAccountRetries
{
sameAccountRetryCount
[
account
.
ID
]
++
log
.
Printf
(
"Account %d: retryable error %d, same-account retry %d/%d"
,
account
.
ID
,
failoverErr
.
StatusCode
,
sameAccountRetryCount
[
account
.
ID
],
maxSameAccountRetries
)
if
!
sleepSameAccountRetryDelay
(
c
.
Request
.
Context
())
{
return
}
action
:=
fs
.
HandleFailoverError
(
c
.
Request
.
Context
(),
h
.
gatewayService
,
account
.
ID
,
account
.
Platform
,
failoverErr
)
switch
action
{
case
FailoverContinue
:
continue
}
// 同账号重试用尽,执行临时封禁并切换账号
if
failoverErr
.
RetryableOnSameAccount
{
h
.
gatewayService
.
TempUnscheduleRetryableError
(
c
.
Request
.
Context
(),
account
.
ID
,
failoverErr
)
}
failedAccountIDs
[
account
.
ID
]
=
struct
{}{}
if
switchCount
>=
maxAccountSwitches
{
h
.
handleFailoverExhausted
(
c
,
failoverErr
,
account
.
Platform
,
streamStarted
)
case
FailoverExhausted
:
h
.
handleFailoverExhausted
(
c
,
fs
.
LastFailoverErr
,
account
.
Platform
,
streamStarted
)
return
case
FailoverCanceled
:
return
}
switchCount
++
log
.
Printf
(
"Account %d: upstream error %d, switching account %d/%d"
,
account
.
ID
,
failoverErr
.
StatusCode
,
switchCount
,
maxAccountSwitches
)
if
account
.
Platform
==
service
.
PlatformAntigravity
{
if
!
sleepFailoverDelay
(
c
.
Request
.
Context
(),
switchCount
)
{
return
}
}
continue
}
// 错误响应已在Forward中处理,这里只记录日志
log
.
Printf
(
"Account %d: Forward request failed: %v"
,
account
.
ID
,
err
)
wroteFallback
:=
h
.
ensureForwardErrorResponse
(
c
,
streamStarted
)
reqLog
.
Error
(
"gateway.forward_failed"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Bool
(
"fallback_error_response_written"
,
wroteFallback
),
zap
.
Error
(
err
),
)
return
}
// RPM 计数递增(Forward 成功后)
// 注意:TOCTOU 竞态是已知且可接受的设计权衡,与 WindowCost 一致的 soft-limit 模式。
// 在高并发下可能短暂超出 RPM 限制,但不会导致请求失败。
if
account
.
IsAnthropicOAuthOrSetupToken
()
&&
account
.
GetBaseRPM
()
>
0
{
if
err
:=
h
.
gatewayService
.
IncrementAccountRPM
(
c
.
Request
.
Context
(),
account
.
ID
);
err
!=
nil
{
reqLog
.
Warn
(
"gateway.rpm_increment_failed"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Error
(
err
))
}
}
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
userAgent
:=
c
.
GetHeader
(
"User-Agent"
)
clientIP
:=
ip
.
GetClientIP
(
c
)
// 异步记录使用量(subscription已在函数开头获取)
go
func
(
result
*
service
.
ForwardResult
,
usedAccount
*
service
.
Account
,
ua
,
clientIP
string
,
fcb
bool
)
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
10
*
time
.
Second
)
defer
cancel
()
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h
.
submitUsageRecordTask
(
func
(
ctx
context
.
Context
)
{
if
err
:=
h
.
gatewayService
.
RecordUsage
(
ctx
,
&
service
.
RecordUsageInput
{
Result
:
result
,
APIKey
:
currentAPIKey
,
User
:
currentAPIKey
.
User
,
Account
:
usedA
ccount
,
Account
:
a
ccount
,
Subscription
:
currentSubscription
,
UserAgent
:
u
a
,
UserAgent
:
u
serAgent
,
IPAddress
:
clientIP
,
ForceCacheBilling
:
f
cb
,
ForceCacheBilling
:
f
s
.
ForceCacheBilling
,
APIKeyService
:
h
.
apiKeyService
,
});
err
!=
nil
{
log
.
Printf
(
"Record usage failed: %v"
,
err
)
logger
.
L
()
.
With
(
zap
.
String
(
"component"
,
"handler.gateway.messages"
),
zap
.
Int64
(
"user_id"
,
subject
.
UserID
),
zap
.
Int64
(
"api_key_id"
,
currentAPIKey
.
ID
),
zap
.
Any
(
"group_id"
,
currentAPIKey
.
GroupID
),
zap
.
String
(
"model"
,
reqModel
),
zap
.
Int64
(
"account_id"
,
account
.
ID
),
)
.
Error
(
"gateway.record_usage_failed"
,
zap
.
Error
(
err
))
}
}
(
result
,
account
,
userAgent
,
clientIP
,
forceCacheBilling
)
})
return
}
if
!
retryWithFallback
{
...
...
@@ -682,6 +775,17 @@ func (h *GatewayHandler) Models(c *gin.Context) {
groupID
=
&
apiKey
.
Group
.
ID
platform
=
apiKey
.
Group
.
Platform
}
if
forcedPlatform
,
ok
:=
middleware2
.
GetForcePlatformFromContext
(
c
);
ok
&&
strings
.
TrimSpace
(
forcedPlatform
)
!=
""
{
platform
=
forcedPlatform
}
if
platform
==
service
.
PlatformSora
{
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"object"
:
"list"
,
"data"
:
service
.
DefaultSoraModels
(
h
.
cfg
),
})
return
}
// Get available models from account configurations (without platform filter)
availableModels
:=
h
.
gatewayService
.
GetAvailableModels
(
c
.
Request
.
Context
(),
groupID
,
""
)
...
...
@@ -741,6 +845,10 @@ func cloneAPIKeyWithGroup(apiKey *service.APIKey, group *service.Group) *service
// Usage handles getting account balance and usage statistics for CC Switch integration
// GET /v1/usage
//
// Two modes:
// - quota_limited: API Key has quota or rate limits configured. Returns key-level limits/usage.
// - unrestricted: No key-level limits. Returns subscription or wallet balance info.
func
(
h
*
GatewayHandler
)
Usage
(
c
*
gin
.
Context
)
{
apiKey
,
ok
:=
middleware2
.
GetAPIKeyFromContext
(
c
)
if
!
ok
{
...
...
@@ -754,54 +862,183 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
return
}
ctx
:=
c
.
Request
.
Context
()
// 解析可选的日期范围参数(用于 model_stats 查询)
startTime
,
endTime
:=
h
.
parseUsageDateRange
(
c
)
// Best-effort: 获取用量统计(按当前 API Key 过滤),失败不影响基础响应
var
usageData
gin
.
H
usageData
:=
h
.
buildUsageData
(
ctx
,
apiKey
.
ID
)
// Best-effort: 获取模型统计
var
modelStats
any
if
h
.
usageService
!=
nil
{
dashStats
,
err
:=
h
.
usageService
.
GetAPIKeyDashboardStats
(
c
.
Request
.
Context
(),
apiKey
.
ID
)
if
err
==
nil
&&
dashStats
!=
nil
{
usageData
=
gin
.
H
{
"today"
:
gin
.
H
{
"requests"
:
dashStats
.
TodayRequests
,
"input_tokens"
:
dashStats
.
TodayInputTokens
,
"output_tokens"
:
dashStats
.
TodayOutputTokens
,
"cache_creation_tokens"
:
dashStats
.
TodayCacheCreationTokens
,
"cache_read_tokens"
:
dashStats
.
TodayCacheReadTokens
,
"total_tokens"
:
dashStats
.
TodayTokens
,
"cost"
:
dashStats
.
TodayCost
,
"actual_cost"
:
dashStats
.
TodayActualCost
,
},
"total"
:
gin
.
H
{
"requests"
:
dashStats
.
TotalRequests
,
"input_tokens"
:
dashStats
.
TotalInputTokens
,
"output_tokens"
:
dashStats
.
TotalOutputTokens
,
"cache_creation_tokens"
:
dashStats
.
TotalCacheCreationTokens
,
"cache_read_tokens"
:
dashStats
.
TotalCacheReadTokens
,
"total_tokens"
:
dashStats
.
TotalTokens
,
"cost"
:
dashStats
.
TotalCost
,
"actual_cost"
:
dashStats
.
TotalActualCost
,
},
"average_duration_ms"
:
dashStats
.
AverageDurationMs
,
"rpm"
:
dashStats
.
Rpm
,
"tpm"
:
dashStats
.
Tpm
,
}
if
stats
,
err
:=
h
.
usageService
.
GetAPIKeyModelStats
(
ctx
,
apiKey
.
ID
,
startTime
,
endTime
);
err
==
nil
&&
len
(
stats
)
>
0
{
modelStats
=
stats
}
}
// 订阅模式:返回订阅限额信息 + 用量统计
if
apiKey
.
Group
!=
nil
&&
apiKey
.
Group
.
IsSubscriptionType
()
{
subscription
,
ok
:=
middleware2
.
GetSubscriptionFromContext
(
c
)
if
!
ok
{
h
.
errorResponse
(
c
,
http
.
StatusForbidden
,
"subscription_error"
,
"No active subscription"
)
return
// 判断模式: key 有总额度或速率限制 → quota_limited,否则 → unrestricted
isQuotaLimited
:=
apiKey
.
Quota
>
0
||
apiKey
.
HasRateLimits
()
if
isQuotaLimited
{
h
.
usageQuotaLimited
(
c
,
ctx
,
apiKey
,
usageData
,
modelStats
)
return
}
h
.
usageUnrestricted
(
c
,
ctx
,
apiKey
,
subject
,
usageData
,
modelStats
)
}
// parseUsageDateRange 解析 start_date / end_date query params,默认返回近 30 天范围
func
(
h
*
GatewayHandler
)
parseUsageDateRange
(
c
*
gin
.
Context
)
(
time
.
Time
,
time
.
Time
)
{
now
:=
timezone
.
Now
()
endTime
:=
now
startTime
:=
now
.
AddDate
(
0
,
0
,
-
30
)
if
s
:=
c
.
Query
(
"start_date"
);
s
!=
""
{
if
t
,
err
:=
timezone
.
ParseInLocation
(
"2006-01-02"
,
s
);
err
==
nil
{
startTime
=
t
}
}
if
s
:=
c
.
Query
(
"end_date"
);
s
!=
""
{
if
t
,
err
:=
timezone
.
ParseInLocation
(
"2006-01-02"
,
s
);
err
==
nil
{
endTime
=
t
.
Add
(
24
*
time
.
Hour
-
time
.
Second
)
// end of day
}
}
return
startTime
,
endTime
}
remaining
:=
h
.
calculateSubscriptionRemaining
(
apiKey
.
Group
,
subscription
)
resp
:=
gin
.
H
{
"isValid"
:
true
,
"planName"
:
apiKey
.
Group
.
Name
,
// buildUsageData 构建 today/total 用量摘要
func
(
h
*
GatewayHandler
)
buildUsageData
(
ctx
context
.
Context
,
apiKeyID
int64
)
gin
.
H
{
if
h
.
usageService
==
nil
{
return
nil
}
dashStats
,
err
:=
h
.
usageService
.
GetAPIKeyDashboardStats
(
ctx
,
apiKeyID
)
if
err
!=
nil
||
dashStats
==
nil
{
return
nil
}
return
gin
.
H
{
"today"
:
gin
.
H
{
"requests"
:
dashStats
.
TodayRequests
,
"input_tokens"
:
dashStats
.
TodayInputTokens
,
"output_tokens"
:
dashStats
.
TodayOutputTokens
,
"cache_creation_tokens"
:
dashStats
.
TodayCacheCreationTokens
,
"cache_read_tokens"
:
dashStats
.
TodayCacheReadTokens
,
"total_tokens"
:
dashStats
.
TodayTokens
,
"cost"
:
dashStats
.
TodayCost
,
"actual_cost"
:
dashStats
.
TodayActualCost
,
},
"total"
:
gin
.
H
{
"requests"
:
dashStats
.
TotalRequests
,
"input_tokens"
:
dashStats
.
TotalInputTokens
,
"output_tokens"
:
dashStats
.
TotalOutputTokens
,
"cache_creation_tokens"
:
dashStats
.
TotalCacheCreationTokens
,
"cache_read_tokens"
:
dashStats
.
TotalCacheReadTokens
,
"total_tokens"
:
dashStats
.
TotalTokens
,
"cost"
:
dashStats
.
TotalCost
,
"actual_cost"
:
dashStats
.
TotalActualCost
,
},
"average_duration_ms"
:
dashStats
.
AverageDurationMs
,
"rpm"
:
dashStats
.
Rpm
,
"tpm"
:
dashStats
.
Tpm
,
}
}
// usageQuotaLimited 处理 quota_limited 模式的响应
func
(
h
*
GatewayHandler
)
usageQuotaLimited
(
c
*
gin
.
Context
,
ctx
context
.
Context
,
apiKey
*
service
.
APIKey
,
usageData
gin
.
H
,
modelStats
any
)
{
resp
:=
gin
.
H
{
"mode"
:
"quota_limited"
,
"isValid"
:
apiKey
.
Status
==
service
.
StatusAPIKeyActive
||
apiKey
.
Status
==
service
.
StatusAPIKeyQuotaExhausted
||
apiKey
.
Status
==
service
.
StatusAPIKeyExpired
,
"status"
:
apiKey
.
Status
,
}
// 总额度信息
if
apiKey
.
Quota
>
0
{
remaining
:=
apiKey
.
GetQuotaRemaining
()
resp
[
"quota"
]
=
gin
.
H
{
"limit"
:
apiKey
.
Quota
,
"used"
:
apiKey
.
QuotaUsed
,
"remaining"
:
remaining
,
"unit"
:
"USD"
,
"subscription"
:
gin
.
H
{
}
resp
[
"remaining"
]
=
remaining
resp
[
"unit"
]
=
"USD"
}
// 速率限制信息(从 DB 获取实时用量)
if
apiKey
.
HasRateLimits
()
&&
h
.
apiKeyService
!=
nil
{
rateLimitData
,
err
:=
h
.
apiKeyService
.
GetRateLimitData
(
ctx
,
apiKey
.
ID
)
if
err
==
nil
&&
rateLimitData
!=
nil
{
var
rateLimits
[]
gin
.
H
if
apiKey
.
RateLimit5h
>
0
{
used
:=
rateLimitData
.
Usage5h
rateLimits
=
append
(
rateLimits
,
gin
.
H
{
"window"
:
"5h"
,
"limit"
:
apiKey
.
RateLimit5h
,
"used"
:
used
,
"remaining"
:
max
(
0
,
apiKey
.
RateLimit5h
-
used
),
"window_start"
:
rateLimitData
.
Window5hStart
,
})
}
if
apiKey
.
RateLimit1d
>
0
{
used
:=
rateLimitData
.
Usage1d
rateLimits
=
append
(
rateLimits
,
gin
.
H
{
"window"
:
"1d"
,
"limit"
:
apiKey
.
RateLimit1d
,
"used"
:
used
,
"remaining"
:
max
(
0
,
apiKey
.
RateLimit1d
-
used
),
"window_start"
:
rateLimitData
.
Window1dStart
,
})
}
if
apiKey
.
RateLimit7d
>
0
{
used
:=
rateLimitData
.
Usage7d
rateLimits
=
append
(
rateLimits
,
gin
.
H
{
"window"
:
"7d"
,
"limit"
:
apiKey
.
RateLimit7d
,
"used"
:
used
,
"remaining"
:
max
(
0
,
apiKey
.
RateLimit7d
-
used
),
"window_start"
:
rateLimitData
.
Window7dStart
,
})
}
if
len
(
rateLimits
)
>
0
{
resp
[
"rate_limits"
]
=
rateLimits
}
}
}
// 过期时间
if
apiKey
.
ExpiresAt
!=
nil
{
resp
[
"expires_at"
]
=
apiKey
.
ExpiresAt
resp
[
"days_until_expiry"
]
=
apiKey
.
GetDaysUntilExpiry
()
}
if
usageData
!=
nil
{
resp
[
"usage"
]
=
usageData
}
if
modelStats
!=
nil
{
resp
[
"model_stats"
]
=
modelStats
}
c
.
JSON
(
http
.
StatusOK
,
resp
)
}
// usageUnrestricted 处理 unrestricted 模式的响应(向后兼容)
func
(
h
*
GatewayHandler
)
usageUnrestricted
(
c
*
gin
.
Context
,
ctx
context
.
Context
,
apiKey
*
service
.
APIKey
,
subject
middleware2
.
AuthSubject
,
usageData
gin
.
H
,
modelStats
any
)
{
// 订阅模式
if
apiKey
.
Group
!=
nil
&&
apiKey
.
Group
.
IsSubscriptionType
()
{
resp
:=
gin
.
H
{
"mode"
:
"unrestricted"
,
"isValid"
:
true
,
"planName"
:
apiKey
.
Group
.
Name
,
"unit"
:
"USD"
,
}
// 订阅信息可能不在 context 中(/v1/usage 路径跳过了中间件的计费检查)
subscription
,
ok
:=
middleware2
.
GetSubscriptionFromContext
(
c
)
if
ok
{
remaining
:=
h
.
calculateSubscriptionRemaining
(
apiKey
.
Group
,
subscription
)
resp
[
"remaining"
]
=
remaining
resp
[
"subscription"
]
=
gin
.
H
{
"daily_usage_usd"
:
subscription
.
DailyUsageUSD
,
"weekly_usage_usd"
:
subscription
.
WeeklyUsageUSD
,
"monthly_usage_usd"
:
subscription
.
MonthlyUsageUSD
,
...
...
@@ -809,23 +1046,28 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
"weekly_limit_usd"
:
apiKey
.
Group
.
WeeklyLimitUSD
,
"monthly_limit_usd"
:
apiKey
.
Group
.
MonthlyLimitUSD
,
"expires_at"
:
subscription
.
ExpiresAt
,
}
,
}
}
if
usageData
!=
nil
{
resp
[
"usage"
]
=
usageData
}
if
modelStats
!=
nil
{
resp
[
"model_stats"
]
=
modelStats
}
c
.
JSON
(
http
.
StatusOK
,
resp
)
return
}
// 余额模式
:返回钱包余额 + 用量统计
latestUser
,
err
:=
h
.
userService
.
GetByID
(
c
.
Request
.
Context
()
,
subject
.
UserID
)
// 余额模式
latestUser
,
err
:=
h
.
userService
.
GetByID
(
c
tx
,
subject
.
UserID
)
if
err
!=
nil
{
h
.
errorResponse
(
c
,
http
.
StatusInternalServerError
,
"api_error"
,
"Failed to get user info"
)
return
}
resp
:=
gin
.
H
{
"mode"
:
"unrestricted"
,
"isValid"
:
true
,
"planName"
:
"钱包余额"
,
"remaining"
:
latestUser
.
Balance
,
...
...
@@ -835,6 +1077,9 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
if
usageData
!=
nil
{
resp
[
"usage"
]
=
usageData
}
if
modelStats
!=
nil
{
resp
[
"model_stats"
]
=
modelStats
}
c
.
JSON
(
http
.
StatusOK
,
resp
)
}
...
...
@@ -893,65 +1138,6 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT
fmt
.
Sprintf
(
"Concurrency limit exceeded for %s, please retry later"
,
slotType
),
streamStarted
)
}
// needForceCacheBilling 判断 failover 时是否需要强制缓存计费
// 粘性会话切换账号、或上游明确标记时,将 input_tokens 转为 cache_read 计费
func
needForceCacheBilling
(
hasBoundSession
bool
,
failoverErr
*
service
.
UpstreamFailoverError
)
bool
{
return
hasBoundSession
||
(
failoverErr
!=
nil
&&
failoverErr
.
ForceCacheBilling
)
}
const
(
// maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误)
maxSameAccountRetries
=
2
// sameAccountRetryDelay 同账号重试间隔
sameAccountRetryDelay
=
500
*
time
.
Millisecond
)
// sleepSameAccountRetryDelay 同账号重试固定延时,返回 false 表示 context 已取消。
func
sleepSameAccountRetryDelay
(
ctx
context
.
Context
)
bool
{
select
{
case
<-
ctx
.
Done
()
:
return
false
case
<-
time
.
After
(
sameAccountRetryDelay
)
:
return
true
}
}
// sleepFailoverDelay 账号切换线性递增延时:第1次0s、第2次1s、第3次2s…
// 返回 false 表示 context 已取消。
func
sleepFailoverDelay
(
ctx
context
.
Context
,
switchCount
int
)
bool
{
delay
:=
time
.
Duration
(
switchCount
-
1
)
*
time
.
Second
if
delay
<=
0
{
return
true
}
select
{
case
<-
ctx
.
Done
()
:
return
false
case
<-
time
.
After
(
delay
)
:
return
true
}
}
// sleepAntigravitySingleAccountBackoff Antigravity 平台单账号分组的 503 退避重试延时。
// 当分组内只有一个可用账号且上游返回 503(MODEL_CAPACITY_EXHAUSTED)时使用,
// 采用短固定延时策略。Service 层在 SingleAccountRetry 模式下已经做了充分的原地重试
// (最多 3 次、总等待 30s),所以 Handler 层的退避只需短暂等待即可。
// 返回 false 表示 context 已取消。
func
sleepAntigravitySingleAccountBackoff
(
ctx
context
.
Context
,
retryCount
int
)
bool
{
// 固定短延时:2s
// Service 层已经在原地等待了足够长的时间(retryDelay × 重试次数),
// Handler 层只需短暂间隔后重新进入 Service 层即可。
const
delay
=
2
*
time
.
Second
log
.
Printf
(
"Antigravity single-account 503 backoff: waiting %v before retry (attempt %d)"
,
delay
,
retryCount
)
select
{
case
<-
ctx
.
Done
()
:
return
false
case
<-
time
.
After
(
delay
)
:
return
true
}
}
func
(
h
*
GatewayHandler
)
handleFailoverExhausted
(
c
*
gin
.
Context
,
failoverErr
*
service
.
UpstreamFailoverError
,
platform
string
,
streamStarted
bool
)
{
statusCode
:=
failoverErr
.
StatusCode
responseBody
:=
failoverErr
.
ResponseBody
...
...
@@ -1014,20 +1200,8 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
// Stream already started, send error as SSE event then close
flusher
,
ok
:=
c
.
Writer
.
(
http
.
Flusher
)
if
ok
{
// Send error event in SSE format with proper JSON marshaling
errorData
:=
map
[
string
]
any
{
"type"
:
"error"
,
"error"
:
map
[
string
]
string
{
"type"
:
errType
,
"message"
:
message
,
},
}
jsonBytes
,
err
:=
json
.
Marshal
(
errorData
)
if
err
!=
nil
{
_
=
c
.
Error
(
err
)
return
}
errorEvent
:=
fmt
.
Sprintf
(
"data: %s
\n\n
"
,
string
(
jsonBytes
))
// SSE 错误事件固定 schema,使用 Quote 直拼可避免额外 Marshal 分配。
errorEvent
:=
`data: {"type":"error","error":{"type":`
+
strconv
.
Quote
(
errType
)
+
`,"message":`
+
strconv
.
Quote
(
message
)
+
`}}`
+
"
\n\n
"
if
_
,
err
:=
fmt
.
Fprint
(
c
.
Writer
,
errorEvent
);
err
!=
nil
{
_
=
c
.
Error
(
err
)
}
...
...
@@ -1040,6 +1214,50 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
h
.
errorResponse
(
c
,
status
,
errType
,
message
)
}
// ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。
func
(
h
*
GatewayHandler
)
ensureForwardErrorResponse
(
c
*
gin
.
Context
,
streamStarted
bool
)
bool
{
if
c
==
nil
||
c
.
Writer
==
nil
||
c
.
Writer
.
Written
()
{
return
false
}
h
.
handleStreamingAwareError
(
c
,
http
.
StatusBadGateway
,
"upstream_error"
,
"Upstream request failed"
,
streamStarted
)
return
true
}
// checkClaudeCodeVersion 检查 Claude Code 客户端版本是否满足最低要求
// 仅对已识别的 Claude Code 客户端执行,count_tokens 路径除外
func
(
h
*
GatewayHandler
)
checkClaudeCodeVersion
(
c
*
gin
.
Context
)
bool
{
ctx
:=
c
.
Request
.
Context
()
if
!
service
.
IsClaudeCodeClient
(
ctx
)
{
return
true
}
// 排除 count_tokens 子路径
if
strings
.
HasSuffix
(
c
.
Request
.
URL
.
Path
,
"/count_tokens"
)
{
return
true
}
minVersion
:=
h
.
settingService
.
GetMinClaudeCodeVersion
(
ctx
)
if
minVersion
==
""
{
return
true
// 未设置,不检查
}
clientVersion
:=
service
.
GetClaudeCodeVersion
(
ctx
)
if
clientVersion
==
""
{
h
.
errorResponse
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"Unable to determine Claude Code version. Please update Claude Code: npm update -g @anthropic-ai/claude-code"
)
return
false
}
if
service
.
CompareVersions
(
clientVersion
,
minVersion
)
<
0
{
h
.
errorResponse
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
fmt
.
Sprintf
(
"Your Claude Code version (%s) is below the minimum required version (%s). Please update: npm update -g @anthropic-ai/claude-code"
,
clientVersion
,
minVersion
))
return
false
}
return
true
}
// errorResponse 返回Claude API格式的错误响应
func
(
h
*
GatewayHandler
)
errorResponse
(
c
*
gin
.
Context
,
status
int
,
errType
,
message
string
)
{
c
.
JSON
(
status
,
gin
.
H
{
...
...
@@ -1067,9 +1285,16 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
h
.
errorResponse
(
c
,
http
.
StatusInternalServerError
,
"api_error"
,
"User context not found"
)
return
}
reqLog
:=
requestLogger
(
c
,
"handler.gateway.count_tokens"
,
zap
.
Int64
(
"api_key_id"
,
apiKey
.
ID
),
zap
.
Any
(
"group_id"
,
apiKey
.
GroupID
),
)
defer
h
.
maybeLogCompatibilityFallbackMetrics
(
reqLog
)
// 读取请求体
body
,
err
:=
io
.
ReadAll
(
c
.
Request
.
Body
)
body
,
err
:=
pkghttputil
.
ReadRequestBodyWithPrealloc
(
c
.
Request
)
if
err
!=
nil
{
if
maxErr
,
ok
:=
extractMaxBytesError
(
err
);
ok
{
h
.
errorResponse
(
c
,
http
.
StatusRequestEntityTooLarge
,
"invalid_request_error"
,
buildBodyTooLargeMessage
(
maxErr
.
Limit
))
...
...
@@ -1084,9 +1309,6 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
return
}
// 检查是否为 Claude Code 客户端,设置到 context 中
SetClaudeCodeClientContext
(
c
,
body
)
setOpsRequestContext
(
c
,
""
,
false
,
body
)
parsedReq
,
err
:=
service
.
ParseGatewayRequest
(
body
,
domain
.
PlatformAnthropic
)
...
...
@@ -1094,8 +1316,11 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
h
.
errorResponse
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"Failed to parse request body"
)
return
}
// count_tokens 走 messages 严格校验时,复用已解析请求,避免二次反序列化。
SetClaudeCodeClientContext
(
c
,
body
,
parsedReq
)
reqLog
=
reqLog
.
With
(
zap
.
String
(
"model"
,
parsedReq
.
Model
),
zap
.
Bool
(
"stream"
,
parsedReq
.
Stream
))
// 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用
c
.
Request
=
c
.
Request
.
WithContext
(
context
.
WithValue
(
c
.
Request
.
Context
(),
ctxkey
.
ThinkingEnabled
,
parsedReq
.
Thinking
Enabled
))
c
.
Request
=
c
.
Request
.
WithContext
(
service
.
WithThinkingEnabled
(
c
.
Request
.
Context
(),
parsedReq
.
ThinkingEnabled
,
h
.
metadataBridge
Enabled
()
))
// 验证 model 必填
if
parsedReq
.
Model
==
""
{
...
...
@@ -1127,14 +1352,15 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 选择支持该模型的账号
account
,
err
:=
h
.
gatewayService
.
SelectAccountForModel
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionHash
,
parsedReq
.
Model
)
if
err
!=
nil
{
h
.
errorResponse
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts: "
+
err
.
Error
())
reqLog
.
Warn
(
"gateway.count_tokens_select_account_failed"
,
zap
.
Error
(
err
))
h
.
errorResponse
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"Service temporarily unavailable"
)
return
}
setOpsSelectedAccount
(
c
,
account
.
ID
)
setOpsSelectedAccount
(
c
,
account
.
ID
,
account
.
Platform
)
// 转发请求(不记录使用量)
if
err
:=
h
.
gatewayService
.
ForwardCountTokens
(
c
.
Request
.
Context
(),
c
,
account
,
parsedReq
);
err
!=
nil
{
log
.
Printf
(
"Forward count_tokens request failed: %v"
,
err
)
reqLog
.
Error
(
"gateway.count_tokens_forward_failed"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
zap
.
Error
(
err
)
)
// 错误响应已在 ForwardCountTokens 中处理
return
}
...
...
@@ -1258,24 +1484,8 @@ func sendMockInterceptStream(c *gin.Context, model string, interceptType Interce
textDeltas
=
[]
string
{
"New"
,
" Conversation"
}
}
// Build message_start event with proper JSON marshaling
messageStart
:=
map
[
string
]
any
{
"type"
:
"message_start"
,
"message"
:
map
[
string
]
any
{
"id"
:
msgID
,
"type"
:
"message"
,
"role"
:
"assistant"
,
"model"
:
model
,
"content"
:
[]
any
{},
"stop_reason"
:
nil
,
"stop_sequence"
:
nil
,
"usage"
:
map
[
string
]
int
{
"input_tokens"
:
10
,
"output_tokens"
:
0
,
},
},
}
messageStartJSON
,
_
:=
json
.
Marshal
(
messageStart
)
// Build message_start event with fixed schema.
messageStartJSON
:=
`{"type":"message_start","message":{"id":`
+
strconv
.
Quote
(
msgID
)
+
`,"type":"message","role":"assistant","model":`
+
strconv
.
Quote
(
model
)
+
`,"content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":10,"output_tokens":0}}}`
// Build events
events
:=
[]
string
{
...
...
@@ -1285,31 +1495,12 @@ func sendMockInterceptStream(c *gin.Context, model string, interceptType Interce
// Add text deltas
for
_
,
text
:=
range
textDeltas
{
delta
:=
map
[
string
]
any
{
"type"
:
"content_block_delta"
,
"index"
:
0
,
"delta"
:
map
[
string
]
string
{
"type"
:
"text_delta"
,
"text"
:
text
,
},
}
deltaJSON
,
_
:=
json
.
Marshal
(
delta
)
deltaJSON
:=
`{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":`
+
strconv
.
Quote
(
text
)
+
`}}`
events
=
append
(
events
,
`event: content_block_delta`
+
"
\n
"
+
`data: `
+
string
(
deltaJSON
))
}
// Add final events
messageDelta
:=
map
[
string
]
any
{
"type"
:
"message_delta"
,
"delta"
:
map
[
string
]
any
{
"stop_reason"
:
"end_turn"
,
"stop_sequence"
:
nil
,
},
"usage"
:
map
[
string
]
int
{
"input_tokens"
:
10
,
"output_tokens"
:
outputTokens
,
},
}
messageDeltaJSON
,
_
:=
json
.
Marshal
(
messageDelta
)
messageDeltaJSON
:=
`{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":10,"output_tokens":`
+
strconv
.
Itoa
(
outputTokens
)
+
`}}`
events
=
append
(
events
,
`event: content_block_stop`
+
"
\n
"
+
`data: {"index":0,"type":"content_block_stop"}`
,
...
...
@@ -1396,9 +1587,92 @@ func billingErrorDetails(err error) (status int, code, message string) {
}
return
http
.
StatusServiceUnavailable
,
"billing_service_error"
,
msg
}
if
errors
.
Is
(
err
,
service
.
ErrAPIKeyRateLimit5hExceeded
)
{
msg
:=
pkgerrors
.
Message
(
err
)
return
http
.
StatusTooManyRequests
,
"rate_limit_exceeded"
,
msg
}
if
errors
.
Is
(
err
,
service
.
ErrAPIKeyRateLimit1dExceeded
)
{
msg
:=
pkgerrors
.
Message
(
err
)
return
http
.
StatusTooManyRequests
,
"rate_limit_exceeded"
,
msg
}
if
errors
.
Is
(
err
,
service
.
ErrAPIKeyRateLimit7dExceeded
)
{
msg
:=
pkgerrors
.
Message
(
err
)
return
http
.
StatusTooManyRequests
,
"rate_limit_exceeded"
,
msg
}
msg
:=
pkgerrors
.
Message
(
err
)
if
msg
==
""
{
msg
=
err
.
Error
()
logger
.
L
()
.
With
(
zap
.
String
(
"component"
,
"handler.gateway.billing"
),
zap
.
Error
(
err
),
)
.
Warn
(
"gateway.billing_error_missing_message"
)
msg
=
"Billing error"
}
return
http
.
StatusForbidden
,
"billing_error"
,
msg
}
func
(
h
*
GatewayHandler
)
metadataBridgeEnabled
()
bool
{
if
h
==
nil
||
h
.
cfg
==
nil
{
return
true
}
return
h
.
cfg
.
Gateway
.
OpenAIWS
.
MetadataBridgeEnabled
}
func
(
h
*
GatewayHandler
)
maybeLogCompatibilityFallbackMetrics
(
reqLog
*
zap
.
Logger
)
{
if
reqLog
==
nil
{
return
}
if
gatewayCompatibilityMetricsLogCounter
.
Add
(
1
)
%
gatewayCompatibilityMetricsLogInterval
!=
0
{
return
}
metrics
:=
service
.
SnapshotOpenAICompatibilityFallbackMetrics
()
reqLog
.
Info
(
"gateway.compatibility_fallback_metrics"
,
zap
.
Int64
(
"session_hash_legacy_read_fallback_total"
,
metrics
.
SessionHashLegacyReadFallbackTotal
),
zap
.
Int64
(
"session_hash_legacy_read_fallback_hit"
,
metrics
.
SessionHashLegacyReadFallbackHit
),
zap
.
Int64
(
"session_hash_legacy_dual_write_total"
,
metrics
.
SessionHashLegacyDualWriteTotal
),
zap
.
Float64
(
"session_hash_legacy_read_hit_rate"
,
metrics
.
SessionHashLegacyReadHitRate
),
zap
.
Int64
(
"metadata_legacy_fallback_total"
,
metrics
.
MetadataLegacyFallbackTotal
),
)
}
func
(
h
*
GatewayHandler
)
submitUsageRecordTask
(
task
service
.
UsageRecordTask
)
{
if
task
==
nil
{
return
}
if
h
.
usageRecordWorkerPool
!=
nil
{
h
.
usageRecordWorkerPool
.
Submit
(
task
)
return
}
// 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
10
*
time
.
Second
)
defer
cancel
()
defer
func
()
{
if
recovered
:=
recover
();
recovered
!=
nil
{
logger
.
L
()
.
With
(
zap
.
String
(
"component"
,
"handler.gateway.messages"
),
zap
.
Any
(
"panic"
,
recovered
),
)
.
Error
(
"gateway.usage_record_task_panic_recovered"
)
}
}()
task
(
ctx
)
}
// getUserMsgQueueMode 获取当前请求的 UMQ 模式
// 返回 "serialize" | "throttle" | ""
func
(
h
*
GatewayHandler
)
getUserMsgQueueMode
(
account
*
service
.
Account
,
parsed
*
service
.
ParsedRequest
)
string
{
if
h
.
userMsgQueueHelper
==
nil
{
return
""
}
// 仅适用于 Anthropic OAuth/SetupToken 账号
if
!
account
.
IsAnthropicOAuthOrSetupToken
()
{
return
""
}
if
!
service
.
IsRealUserMessage
(
parsed
)
{
return
""
}
// 账号级模式优先,fallback 到全局配置
mode
:=
account
.
GetUserMsgQueueMode
()
if
mode
==
""
{
mode
=
h
.
cfg
.
Gateway
.
UserMessageQueue
.
GetEffectiveMode
()
}
return
mode
}
backend/internal/handler/gateway_handler_error_fallback_test.go
0 → 100644
View file @
3d79773b
package
handler
import
(
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func
TestGatewayEnsureForwardErrorResponse_WritesFallbackWhenNotWritten
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
h
:=
&
GatewayHandler
{}
wrote
:=
h
.
ensureForwardErrorResponse
(
c
,
false
)
require
.
True
(
t
,
wrote
)
require
.
Equal
(
t
,
http
.
StatusBadGateway
,
w
.
Code
)
var
parsed
map
[
string
]
any
err
:=
json
.
Unmarshal
(
w
.
Body
.
Bytes
(),
&
parsed
)
require
.
NoError
(
t
,
err
)
assert
.
Equal
(
t
,
"error"
,
parsed
[
"type"
])
errorObj
,
ok
:=
parsed
[
"error"
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
assert
.
Equal
(
t
,
"upstream_error"
,
errorObj
[
"type"
])
assert
.
Equal
(
t
,
"Upstream request failed"
,
errorObj
[
"message"
])
}
func
TestGatewayEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
c
.
String
(
http
.
StatusTeapot
,
"already written"
)
h
:=
&
GatewayHandler
{}
wrote
:=
h
.
ensureForwardErrorResponse
(
c
,
false
)
require
.
False
(
t
,
wrote
)
require
.
Equal
(
t
,
http
.
StatusTeapot
,
w
.
Code
)
assert
.
Equal
(
t
,
"already written"
,
w
.
Body
.
String
())
}
backend/internal/handler/gateway_handler_single_account_retry_test.go
deleted
100644 → 0
View file @
6aa8cbbf
package
handler
import
(
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// sleepAntigravitySingleAccountBackoff 测试
// ---------------------------------------------------------------------------
func
TestSleepAntigravitySingleAccountBackoff_ReturnsTrue
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
start
:=
time
.
Now
()
ok
:=
sleepAntigravitySingleAccountBackoff
(
ctx
,
1
)
elapsed
:=
time
.
Since
(
start
)
require
.
True
(
t
,
ok
,
"should return true when context is not canceled"
)
// 固定延迟 2s
require
.
GreaterOrEqual
(
t
,
elapsed
,
1500
*
time
.
Millisecond
,
"should wait approximately 2s"
)
require
.
Less
(
t
,
elapsed
,
5
*
time
.
Second
,
"should not wait too long"
)
}
func
TestSleepAntigravitySingleAccountBackoff_ContextCanceled
(
t
*
testing
.
T
)
{
ctx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
cancel
()
// 立即取消
start
:=
time
.
Now
()
ok
:=
sleepAntigravitySingleAccountBackoff
(
ctx
,
1
)
elapsed
:=
time
.
Since
(
start
)
require
.
False
(
t
,
ok
,
"should return false when context is canceled"
)
require
.
Less
(
t
,
elapsed
,
500
*
time
.
Millisecond
,
"should return immediately on cancel"
)
}
func
TestSleepAntigravitySingleAccountBackoff_FixedDelay
(
t
*
testing
.
T
)
{
// 验证不同 retryCount 都使用固定 2s 延迟
ctx
:=
context
.
Background
()
start
:=
time
.
Now
()
ok
:=
sleepAntigravitySingleAccountBackoff
(
ctx
,
5
)
elapsed
:=
time
.
Since
(
start
)
require
.
True
(
t
,
ok
)
// 即使 retryCount=5,延迟仍然是固定的 2s
require
.
GreaterOrEqual
(
t
,
elapsed
,
1500
*
time
.
Millisecond
)
require
.
Less
(
t
,
elapsed
,
5
*
time
.
Second
)
}
backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go
0 → 100644
View file @
3d79773b
//go:build unit
package
handler
import
(
"bytes"
"context"
"encoding/json"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
middleware
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
// 目标:严格验证“antigravity 账号通过 /v1/messages 提供 Claude 服务时”,
// 当账号 credentials.intercept_warmup_requests=true 且请求为 Warmup 时,
// 后端会在转发上游前直接拦截并返回 mock 响应(不依赖上游)。
type
fakeSchedulerCache
struct
{
accounts
[]
*
service
.
Account
}
func
(
f
*
fakeSchedulerCache
)
GetSnapshot
(
_
context
.
Context
,
_
service
.
SchedulerBucket
)
([]
*
service
.
Account
,
bool
,
error
)
{
return
f
.
accounts
,
true
,
nil
}
func
(
f
*
fakeSchedulerCache
)
SetSnapshot
(
_
context
.
Context
,
_
service
.
SchedulerBucket
,
_
[]
service
.
Account
)
error
{
return
nil
}
func
(
f
*
fakeSchedulerCache
)
GetAccount
(
_
context
.
Context
,
_
int64
)
(
*
service
.
Account
,
error
)
{
return
nil
,
nil
}
func
(
f
*
fakeSchedulerCache
)
SetAccount
(
_
context
.
Context
,
_
*
service
.
Account
)
error
{
return
nil
}
func
(
f
*
fakeSchedulerCache
)
DeleteAccount
(
_
context
.
Context
,
_
int64
)
error
{
return
nil
}
func
(
f
*
fakeSchedulerCache
)
UpdateLastUsed
(
_
context
.
Context
,
_
map
[
int64
]
time
.
Time
)
error
{
return
nil
}
func
(
f
*
fakeSchedulerCache
)
TryLockBucket
(
_
context
.
Context
,
_
service
.
SchedulerBucket
,
_
time
.
Duration
)
(
bool
,
error
)
{
return
true
,
nil
}
func
(
f
*
fakeSchedulerCache
)
ListBuckets
(
_
context
.
Context
)
([]
service
.
SchedulerBucket
,
error
)
{
return
nil
,
nil
}
func
(
f
*
fakeSchedulerCache
)
GetOutboxWatermark
(
_
context
.
Context
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
f
*
fakeSchedulerCache
)
SetOutboxWatermark
(
_
context
.
Context
,
_
int64
)
error
{
return
nil
}
type
fakeGroupRepo
struct
{
group
*
service
.
Group
}
func
(
f
*
fakeGroupRepo
)
Create
(
context
.
Context
,
*
service
.
Group
)
error
{
return
nil
}
func
(
f
*
fakeGroupRepo
)
GetByID
(
context
.
Context
,
int64
)
(
*
service
.
Group
,
error
)
{
return
f
.
group
,
nil
}
func
(
f
*
fakeGroupRepo
)
GetByIDLite
(
context
.
Context
,
int64
)
(
*
service
.
Group
,
error
)
{
return
f
.
group
,
nil
}
func
(
f
*
fakeGroupRepo
)
Update
(
context
.
Context
,
*
service
.
Group
)
error
{
return
nil
}
func
(
f
*
fakeGroupRepo
)
Delete
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
f
*
fakeGroupRepo
)
DeleteCascade
(
context
.
Context
,
int64
)
([]
int64
,
error
)
{
return
nil
,
nil
}
func
(
f
*
fakeGroupRepo
)
List
(
context
.
Context
,
pagination
.
PaginationParams
)
([]
service
.
Group
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
f
*
fakeGroupRepo
)
ListWithFilters
(
context
.
Context
,
pagination
.
PaginationParams
,
string
,
string
,
string
,
*
bool
)
([]
service
.
Group
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
f
*
fakeGroupRepo
)
ListActive
(
context
.
Context
)
([]
service
.
Group
,
error
)
{
return
nil
,
nil
}
func
(
f
*
fakeGroupRepo
)
ListActiveByPlatform
(
context
.
Context
,
string
)
([]
service
.
Group
,
error
)
{
return
nil
,
nil
}
func
(
f
*
fakeGroupRepo
)
ExistsByName
(
context
.
Context
,
string
)
(
bool
,
error
)
{
return
false
,
nil
}
func
(
f
*
fakeGroupRepo
)
GetAccountCount
(
context
.
Context
,
int64
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
f
*
fakeGroupRepo
)
DeleteAccountGroupsByGroupID
(
context
.
Context
,
int64
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
f
*
fakeGroupRepo
)
GetAccountIDsByGroupIDs
(
context
.
Context
,
[]
int64
)
([]
int64
,
error
)
{
return
nil
,
nil
}
func
(
f
*
fakeGroupRepo
)
BindAccountsToGroup
(
context
.
Context
,
int64
,
[]
int64
)
error
{
return
nil
}
func
(
f
*
fakeGroupRepo
)
UpdateSortOrders
(
context
.
Context
,
[]
service
.
GroupSortOrderUpdate
)
error
{
return
nil
}
type
fakeConcurrencyCache
struct
{}
func
(
f
*
fakeConcurrencyCache
)
AcquireAccountSlot
(
context
.
Context
,
int64
,
int
,
string
)
(
bool
,
error
)
{
return
true
,
nil
}
func
(
f
*
fakeConcurrencyCache
)
ReleaseAccountSlot
(
context
.
Context
,
int64
,
string
)
error
{
return
nil
}
func
(
f
*
fakeConcurrencyCache
)
GetAccountConcurrency
(
context
.
Context
,
int64
)
(
int
,
error
)
{
return
0
,
nil
}
func
(
f
*
fakeConcurrencyCache
)
IncrementAccountWaitCount
(
context
.
Context
,
int64
,
int
)
(
bool
,
error
)
{
return
true
,
nil
}
func
(
f
*
fakeConcurrencyCache
)
DecrementAccountWaitCount
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
f
*
fakeConcurrencyCache
)
GetAccountWaitingCount
(
context
.
Context
,
int64
)
(
int
,
error
)
{
return
0
,
nil
}
func
(
f
*
fakeConcurrencyCache
)
AcquireUserSlot
(
context
.
Context
,
int64
,
int
,
string
)
(
bool
,
error
)
{
return
true
,
nil
}
func
(
f
*
fakeConcurrencyCache
)
ReleaseUserSlot
(
context
.
Context
,
int64
,
string
)
error
{
return
nil
}
func
(
f
*
fakeConcurrencyCache
)
GetUserConcurrency
(
context
.
Context
,
int64
)
(
int
,
error
)
{
return
0
,
nil
}
func
(
f
*
fakeConcurrencyCache
)
IncrementWaitCount
(
context
.
Context
,
int64
,
int
)
(
bool
,
error
)
{
return
true
,
nil
}
func
(
f
*
fakeConcurrencyCache
)
DecrementWaitCount
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
f
*
fakeConcurrencyCache
)
GetAccountsLoadBatch
(
context
.
Context
,
[]
service
.
AccountWithConcurrency
)
(
map
[
int64
]
*
service
.
AccountLoadInfo
,
error
)
{
return
map
[
int64
]
*
service
.
AccountLoadInfo
{},
nil
}
func
(
f
*
fakeConcurrencyCache
)
GetUsersLoadBatch
(
context
.
Context
,
[]
service
.
UserWithConcurrency
)
(
map
[
int64
]
*
service
.
UserLoadInfo
,
error
)
{
return
map
[
int64
]
*
service
.
UserLoadInfo
{},
nil
}
func
(
f
*
fakeConcurrencyCache
)
GetAccountConcurrencyBatch
(
_
context
.
Context
,
accountIDs
[]
int64
)
(
map
[
int64
]
int
,
error
)
{
result
:=
make
(
map
[
int64
]
int
,
len
(
accountIDs
))
for
_
,
id
:=
range
accountIDs
{
result
[
id
]
=
0
}
return
result
,
nil
}
func
(
f
*
fakeConcurrencyCache
)
CleanupExpiredAccountSlots
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
newTestGatewayHandler
(
t
*
testing
.
T
,
group
*
service
.
Group
,
accounts
[]
*
service
.
Account
)
(
*
GatewayHandler
,
func
())
{
t
.
Helper
()
schedulerCache
:=
&
fakeSchedulerCache
{
accounts
:
accounts
}
schedulerSnapshot
:=
service
.
NewSchedulerSnapshotService
(
schedulerCache
,
nil
,
nil
,
nil
,
nil
)
gwSvc
:=
service
.
NewGatewayService
(
nil
,
// accountRepo (not used: scheduler snapshot hit)
&
fakeGroupRepo
{
group
:
group
},
nil
,
// usageLogRepo
nil
,
// userRepo
nil
,
// userSubRepo
nil
,
// userGroupRateRepo
nil
,
// cache (disable sticky)
nil
,
// cfg
schedulerSnapshot
,
nil
,
// concurrencyService (disable load-aware; tryAcquire always acquired)
nil
,
// billingService
nil
,
// rateLimitService
nil
,
// billingCacheService
nil
,
// identityService
nil
,
// httpUpstream
nil
,
// deferredService
nil
,
// claudeTokenProvider
nil
,
// sessionLimitCache
nil
,
// rpmCache
nil
,
// digestStore
)
// RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
billingCacheSvc
:=
service
.
NewBillingCacheService
(
nil
,
nil
,
nil
,
nil
,
cfg
)
concurrencySvc
:=
service
.
NewConcurrencyService
(
&
fakeConcurrencyCache
{})
concurrencyHelper
:=
NewConcurrencyHelper
(
concurrencySvc
,
SSEPingFormatClaude
,
0
)
h
:=
&
GatewayHandler
{
gatewayService
:
gwSvc
,
billingCacheService
:
billingCacheSvc
,
concurrencyHelper
:
concurrencyHelper
,
// 这些字段对本测试不敏感,保持较小即可
maxAccountSwitches
:
1
,
maxAccountSwitchesGemini
:
1
,
}
cleanup
:=
func
()
{
billingCacheSvc
.
Stop
()
}
return
h
,
cleanup
}
func
TestGatewayHandlerMessages_InterceptWarmup_AntigravityAccount_MixedSchedulingV1
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
groupID
:=
int64
(
2001
)
accountID
:=
int64
(
1001
)
group
:=
&
service
.
Group
{
ID
:
groupID
,
Hydrated
:
true
,
Platform
:
service
.
PlatformAnthropic
,
// /v1/messages(Claude兼容)入口
Status
:
service
.
StatusActive
,
}
account
:=
&
service
.
Account
{
ID
:
accountID
,
Name
:
"ag-1"
,
Platform
:
service
.
PlatformAntigravity
,
Type
:
service
.
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"tok_xxx"
,
"intercept_warmup_requests"
:
true
,
},
Extra
:
map
[
string
]
any
{
"mixed_scheduling"
:
true
,
// 关键:允许被 anthropic 分组混合调度选中
},
Concurrency
:
1
,
Priority
:
1
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
AccountGroups
:
[]
service
.
AccountGroup
{{
AccountID
:
accountID
,
GroupID
:
groupID
}},
}
h
,
cleanup
:=
newTestGatewayHandler
(
t
,
group
,
[]
*
service
.
Account
{
account
})
defer
cleanup
()
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
body
:=
[]
byte
(
`{
"model": "claude-sonnet-4-5",
"max_tokens": 256,
"messages": [{"role":"user","content":[{"type":"text","text":"Warmup"}]}]
}`
)
req
:=
httptest
.
NewRequest
(
"POST"
,
"/v1/messages"
,
bytes
.
NewReader
(
body
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
=
req
.
WithContext
(
context
.
WithValue
(
req
.
Context
(),
ctxkey
.
Group
,
group
))
c
.
Request
=
req
apiKey
:=
&
service
.
APIKey
{
ID
:
3001
,
UserID
:
4001
,
GroupID
:
&
groupID
,
Status
:
service
.
StatusActive
,
User
:
&
service
.
User
{
ID
:
4001
,
Concurrency
:
10
,
Balance
:
100
,
},
Group
:
group
,
}
c
.
Set
(
string
(
middleware
.
ContextKeyAPIKey
),
apiKey
)
c
.
Set
(
string
(
middleware
.
ContextKeyUser
),
middleware
.
AuthSubject
{
UserID
:
apiKey
.
UserID
,
Concurrency
:
10
})
h
.
Messages
(
c
)
require
.
Equal
(
t
,
200
,
rec
.
Code
)
// 断言:确实选中了 antigravity 账号(不是纯函数测试,而是从 Handler 里验证调度结果)
selected
,
ok
:=
c
.
Get
(
opsAccountIDKey
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
accountID
,
selected
)
var
resp
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
rec
.
Body
.
Bytes
(),
&
resp
))
require
.
Equal
(
t
,
"msg_mock_warmup"
,
resp
[
"id"
])
require
.
Equal
(
t
,
"claude-sonnet-4-5"
,
resp
[
"model"
])
content
,
ok
:=
resp
[
"content"
]
.
([]
any
)
require
.
True
(
t
,
ok
)
require
.
Len
(
t
,
content
,
1
)
first
,
ok
:=
content
[
0
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"New Conversation"
,
first
[
"text"
])
}
func
TestGatewayHandlerMessages_InterceptWarmup_AntigravityAccount_ForcePlatform
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
groupID
:=
int64
(
2002
)
accountID
:=
int64
(
1002
)
group
:=
&
service
.
Group
{
ID
:
groupID
,
Hydrated
:
true
,
Platform
:
service
.
PlatformAntigravity
,
Status
:
service
.
StatusActive
,
}
account
:=
&
service
.
Account
{
ID
:
accountID
,
Name
:
"ag-2"
,
Platform
:
service
.
PlatformAntigravity
,
Type
:
service
.
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"tok_xxx"
,
"intercept_warmup_requests"
:
true
,
},
Concurrency
:
1
,
Priority
:
1
,
Status
:
service
.
StatusActive
,
Schedulable
:
true
,
AccountGroups
:
[]
service
.
AccountGroup
{{
AccountID
:
accountID
,
GroupID
:
groupID
}},
}
h
,
cleanup
:=
newTestGatewayHandler
(
t
,
group
,
[]
*
service
.
Account
{
account
})
defer
cleanup
()
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
body
:=
[]
byte
(
`{
"model": "claude-sonnet-4-5",
"max_tokens": 256,
"messages": [{"role":"user","content":[{"type":"text","text":"Warmup"}]}]
}`
)
req
:=
httptest
.
NewRequest
(
"POST"
,
"/antigravity/v1/messages"
,
bytes
.
NewReader
(
body
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
// 模拟 routes/gateway.go 里的 ForcePlatform 中间件效果:
// - 写入 request.Context(Service读取)
// - 写入 gin.Context(Handler快速读取)
ctx
:=
context
.
WithValue
(
req
.
Context
(),
ctxkey
.
Group
,
group
)
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
ForcePlatform
,
service
.
PlatformAntigravity
)
req
=
req
.
WithContext
(
ctx
)
c
.
Request
=
req
c
.
Set
(
string
(
middleware
.
ContextKeyForcePlatform
),
service
.
PlatformAntigravity
)
apiKey
:=
&
service
.
APIKey
{
ID
:
3002
,
UserID
:
4002
,
GroupID
:
&
groupID
,
Status
:
service
.
StatusActive
,
User
:
&
service
.
User
{
ID
:
4002
,
Concurrency
:
10
,
Balance
:
100
,
},
Group
:
group
,
}
c
.
Set
(
string
(
middleware
.
ContextKeyAPIKey
),
apiKey
)
c
.
Set
(
string
(
middleware
.
ContextKeyUser
),
middleware
.
AuthSubject
{
UserID
:
apiKey
.
UserID
,
Concurrency
:
10
})
h
.
Messages
(
c
)
require
.
Equal
(
t
,
200
,
rec
.
Code
)
selected
,
ok
:=
c
.
Get
(
opsAccountIDKey
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
accountID
,
selected
)
var
resp
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
rec
.
Body
.
Bytes
(),
&
resp
))
require
.
Equal
(
t
,
"msg_mock_warmup"
,
resp
[
"id"
])
require
.
Equal
(
t
,
"claude-sonnet-4-5"
,
resp
[
"model"
])
}
backend/internal/handler/gateway_helper.go
View file @
3d79773b
...
...
@@ -4,8 +4,9 @@ import (
"context"
"encoding/json"
"fmt"
"math/rand"
"math/rand
/v2
"
"net/http"
"strings"
"sync"
"time"
...
...
@@ -17,23 +18,91 @@ import (
// claudeCodeValidator is a singleton validator for Claude Code client detection
var
claudeCodeValidator
=
service
.
NewClaudeCodeValidator
()
const
claudeCodeParsedRequestContextKey
=
"claude_code_parsed_request"
// SetClaudeCodeClientContext 检查请求是否来自 Claude Code 客户端,并设置到 context 中
// 返回更新后的 context
func
SetClaudeCodeClientContext
(
c
*
gin
.
Context
,
body
[]
byte
)
{
// 解析请求体为 map
var
bodyMap
map
[
string
]
any
if
len
(
body
)
>
0
{
_
=
json
.
Unmarshal
(
body
,
&
bodyMap
)
func
SetClaudeCodeClientContext
(
c
*
gin
.
Context
,
body
[]
byte
,
parsedReq
*
service
.
ParsedRequest
)
{
if
c
==
nil
||
c
.
Request
==
nil
{
return
}
if
parsedReq
!=
nil
{
c
.
Set
(
claudeCodeParsedRequestContextKey
,
parsedReq
)
}
ua
:=
c
.
GetHeader
(
"User-Agent"
)
// Fast path:非 Claude CLI UA 直接判定 false,避免热路径二次 JSON 反序列化。
if
!
claudeCodeValidator
.
ValidateUserAgent
(
ua
)
{
ctx
:=
service
.
SetClaudeCodeClient
(
c
.
Request
.
Context
(),
false
)
c
.
Request
=
c
.
Request
.
WithContext
(
ctx
)
return
}
// 验证是否为 Claude Code 客户端
isClaudeCode
:=
claudeCodeValidator
.
Validate
(
c
.
Request
,
bodyMap
)
isClaudeCode
:=
false
if
!
strings
.
Contains
(
c
.
Request
.
URL
.
Path
,
"messages"
)
{
// 与 Validate 行为一致:非 messages 路径 UA 命中即可视为 Claude Code 客户端。
isClaudeCode
=
true
}
else
{
// 仅在确认为 Claude CLI 且 messages 路径时再做 body 解析。
bodyMap
:=
claudeCodeBodyMapFromParsedRequest
(
parsedReq
)
if
bodyMap
==
nil
{
bodyMap
=
claudeCodeBodyMapFromContextCache
(
c
)
}
if
bodyMap
==
nil
&&
len
(
body
)
>
0
{
_
=
json
.
Unmarshal
(
body
,
&
bodyMap
)
}
isClaudeCode
=
claudeCodeValidator
.
Validate
(
c
.
Request
,
bodyMap
)
}
// 更新 request context
ctx
:=
service
.
SetClaudeCodeClient
(
c
.
Request
.
Context
(),
isClaudeCode
)
// 仅在确认为 Claude Code 客户端时提取版本号写入 context
if
isClaudeCode
{
if
version
:=
claudeCodeValidator
.
ExtractVersion
(
ua
);
version
!=
""
{
ctx
=
service
.
SetClaudeCodeVersion
(
ctx
,
version
)
}
}
c
.
Request
=
c
.
Request
.
WithContext
(
ctx
)
}
func
claudeCodeBodyMapFromParsedRequest
(
parsedReq
*
service
.
ParsedRequest
)
map
[
string
]
any
{
if
parsedReq
==
nil
{
return
nil
}
bodyMap
:=
map
[
string
]
any
{
"model"
:
parsedReq
.
Model
,
}
if
parsedReq
.
System
!=
nil
||
parsedReq
.
HasSystem
{
bodyMap
[
"system"
]
=
parsedReq
.
System
}
if
parsedReq
.
MetadataUserID
!=
""
{
bodyMap
[
"metadata"
]
=
map
[
string
]
any
{
"user_id"
:
parsedReq
.
MetadataUserID
}
}
return
bodyMap
}
func
claudeCodeBodyMapFromContextCache
(
c
*
gin
.
Context
)
map
[
string
]
any
{
if
c
==
nil
{
return
nil
}
if
cached
,
ok
:=
c
.
Get
(
service
.
OpenAIParsedRequestBodyKey
);
ok
{
if
bodyMap
,
ok
:=
cached
.
(
map
[
string
]
any
);
ok
{
return
bodyMap
}
}
if
cached
,
ok
:=
c
.
Get
(
claudeCodeParsedRequestContextKey
);
ok
{
switch
v
:=
cached
.
(
type
)
{
case
*
service
.
ParsedRequest
:
return
claudeCodeBodyMapFromParsedRequest
(
v
)
case
service
.
ParsedRequest
:
return
claudeCodeBodyMapFromParsedRequest
(
&
v
)
}
}
return
nil
}
// 并发槽位等待相关常量
//
// 性能优化说明:
...
...
@@ -104,31 +173,24 @@ func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFo
// wrapReleaseOnDone ensures release runs at most once and still triggers on context cancellation.
// 用于避免客户端断开或上游超时导致的并发槽位泄漏。
//
修复:添加 quit channel 确保 goroutine 及时退出,避免泄露
//
优化:基于 context.AfterFunc 注册回调,避免每请求额外守护 goroutine。
func
wrapReleaseOnDone
(
ctx
context
.
Context
,
releaseFunc
func
())
func
()
{
if
releaseFunc
==
nil
{
return
nil
}
var
once
sync
.
Once
quit
:=
make
(
chan
struct
{})
var
stop
func
()
bool
release
:=
func
()
{
once
.
Do
(
func
()
{
if
stop
!=
nil
{
_
=
stop
()
}
releaseFunc
()
close
(
quit
)
// 通知监听 goroutine 退出
})
}
go
func
()
{
select
{
case
<-
ctx
.
Done
()
:
// Context 取消时释放资源
release
()
case
<-
quit
:
// 正常释放已完成,goroutine 退出
return
}
}()
stop
=
context
.
AfterFunc
(
ctx
,
release
)
return
release
}
...
...
@@ -153,6 +215,32 @@ func (h *ConcurrencyHelper) DecrementAccountWaitCount(ctx context.Context, accou
h
.
concurrencyService
.
DecrementAccountWaitCount
(
ctx
,
accountID
)
}
// TryAcquireUserSlot 尝试立即获取用户并发槽位。
// 返回值: (releaseFunc, acquired, error)
func
(
h
*
ConcurrencyHelper
)
TryAcquireUserSlot
(
ctx
context
.
Context
,
userID
int64
,
maxConcurrency
int
)
(
func
(),
bool
,
error
)
{
result
,
err
:=
h
.
concurrencyService
.
AcquireUserSlot
(
ctx
,
userID
,
maxConcurrency
)
if
err
!=
nil
{
return
nil
,
false
,
err
}
if
!
result
.
Acquired
{
return
nil
,
false
,
nil
}
return
result
.
ReleaseFunc
,
true
,
nil
}
// TryAcquireAccountSlot 尝试立即获取账号并发槽位。
// 返回值: (releaseFunc, acquired, error)
func
(
h
*
ConcurrencyHelper
)
TryAcquireAccountSlot
(
ctx
context
.
Context
,
accountID
int64
,
maxConcurrency
int
)
(
func
(),
bool
,
error
)
{
result
,
err
:=
h
.
concurrencyService
.
AcquireAccountSlot
(
ctx
,
accountID
,
maxConcurrency
)
if
err
!=
nil
{
return
nil
,
false
,
err
}
if
!
result
.
Acquired
{
return
nil
,
false
,
nil
}
return
result
.
ReleaseFunc
,
true
,
nil
}
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
// For streaming requests, sends ping events during the wait.
// streamStarted is updated if streaming response has begun.
...
...
@@ -160,13 +248,13 @@ func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, userID int64
ctx
:=
c
.
Request
.
Context
()
// Try to acquire immediately
re
sult
,
err
:=
h
.
concurrencyService
.
AcquireUserSlot
(
ctx
,
userID
,
maxConcurrency
)
re
leaseFunc
,
acquired
,
err
:=
h
.
Try
AcquireUserSlot
(
ctx
,
userID
,
maxConcurrency
)
if
err
!=
nil
{
return
nil
,
err
}
if
result
.
A
cquired
{
return
re
sult
.
Re
leaseFunc
,
nil
if
a
cquired
{
return
releaseFunc
,
nil
}
// Need to wait - handle streaming ping if needed
...
...
@@ -180,13 +268,13 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID
ctx
:=
c
.
Request
.
Context
()
// Try to acquire immediately
re
sult
,
err
:=
h
.
concurrencyService
.
AcquireAccountSlot
(
ctx
,
accountID
,
maxConcurrency
)
re
leaseFunc
,
acquired
,
err
:=
h
.
Try
AcquireAccountSlot
(
ctx
,
accountID
,
maxConcurrency
)
if
err
!=
nil
{
return
nil
,
err
}
if
result
.
A
cquired
{
return
re
sult
.
Re
leaseFunc
,
nil
if
a
cquired
{
return
releaseFunc
,
nil
}
// Need to wait - handle streaming ping if needed
...
...
@@ -196,27 +284,29 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.
// streamStarted pointer is updated when streaming begins (for proper error handling by caller).
func
(
h
*
ConcurrencyHelper
)
waitForSlotWithPing
(
c
*
gin
.
Context
,
slotType
string
,
id
int64
,
maxConcurrency
int
,
isStream
bool
,
streamStarted
*
bool
)
(
func
(),
error
)
{
return
h
.
waitForSlotWithPingTimeout
(
c
,
slotType
,
id
,
maxConcurrency
,
maxConcurrencyWait
,
isStream
,
streamStarted
)
return
h
.
waitForSlotWithPingTimeout
(
c
,
slotType
,
id
,
maxConcurrency
,
maxConcurrencyWait
,
isStream
,
streamStarted
,
false
)
}
// waitForSlotWithPingTimeout waits for a concurrency slot with a custom timeout.
func
(
h
*
ConcurrencyHelper
)
waitForSlotWithPingTimeout
(
c
*
gin
.
Context
,
slotType
string
,
id
int64
,
maxConcurrency
int
,
timeout
time
.
Duration
,
isStream
bool
,
streamStarted
*
bool
)
(
func
(),
error
)
{
func
(
h
*
ConcurrencyHelper
)
waitForSlotWithPingTimeout
(
c
*
gin
.
Context
,
slotType
string
,
id
int64
,
maxConcurrency
int
,
timeout
time
.
Duration
,
isStream
bool
,
streamStarted
*
bool
,
tryImmediate
bool
)
(
func
(),
error
)
{
ctx
,
cancel
:=
context
.
WithTimeout
(
c
.
Request
.
Context
(),
timeout
)
defer
cancel
()
// Try immediate acquire first (avoid unnecessary wait)
var
result
*
service
.
AcquireResult
var
err
error
if
slotType
==
"user"
{
result
,
err
=
h
.
concurrencyService
.
AcquireUserSlot
(
ctx
,
id
,
maxConcurrency
)
}
else
{
result
,
err
=
h
.
concurrencyService
.
AcquireAccountSlot
(
ctx
,
id
,
maxConcurrency
)
}
if
err
!=
nil
{
return
nil
,
err
acquireSlot
:=
func
()
(
*
service
.
AcquireResult
,
error
)
{
if
slotType
==
"user"
{
return
h
.
concurrencyService
.
AcquireUserSlot
(
ctx
,
id
,
maxConcurrency
)
}
return
h
.
concurrencyService
.
AcquireAccountSlot
(
ctx
,
id
,
maxConcurrency
)
}
if
result
.
Acquired
{
return
result
.
ReleaseFunc
,
nil
if
tryImmediate
{
result
,
err
:=
acquireSlot
()
if
err
!=
nil
{
return
nil
,
err
}
if
result
.
Acquired
{
return
result
.
ReleaseFunc
,
nil
}
}
// Determine if ping is needed (streaming + ping format defined)
...
...
@@ -242,7 +332,6 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
backoff
:=
initialBackoff
timer
:=
time
.
NewTimer
(
backoff
)
defer
timer
.
Stop
()
rng
:=
rand
.
New
(
rand
.
NewSource
(
time
.
Now
()
.
UnixNano
()))
for
{
select
{
...
...
@@ -268,15 +357,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
case
<-
timer
.
C
:
// Try to acquire slot
var
result
*
service
.
AcquireResult
var
err
error
if
slotType
==
"user"
{
result
,
err
=
h
.
concurrencyService
.
AcquireUserSlot
(
ctx
,
id
,
maxConcurrency
)
}
else
{
result
,
err
=
h
.
concurrencyService
.
AcquireAccountSlot
(
ctx
,
id
,
maxConcurrency
)
}
result
,
err
:=
acquireSlot
()
if
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -284,7 +365,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
if
result
.
Acquired
{
return
result
.
ReleaseFunc
,
nil
}
backoff
=
nextBackoff
(
backoff
,
rng
)
backoff
=
nextBackoff
(
backoff
)
timer
.
Reset
(
backoff
)
}
}
...
...
@@ -292,26 +373,22 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
// AcquireAccountSlotWithWaitTimeout acquires an account slot with a custom timeout (keeps SSE ping).
func
(
h
*
ConcurrencyHelper
)
AcquireAccountSlotWithWaitTimeout
(
c
*
gin
.
Context
,
accountID
int64
,
maxConcurrency
int
,
timeout
time
.
Duration
,
isStream
bool
,
streamStarted
*
bool
)
(
func
(),
error
)
{
return
h
.
waitForSlotWithPingTimeout
(
c
,
"account"
,
accountID
,
maxConcurrency
,
timeout
,
isStream
,
streamStarted
)
return
h
.
waitForSlotWithPingTimeout
(
c
,
"account"
,
accountID
,
maxConcurrency
,
timeout
,
isStream
,
streamStarted
,
true
)
}
// nextBackoff 计算下一次退避时间
// 性能优化:使用指数退避 + 随机抖动,避免惊群效应
// current: 当前退避时间
// rng: 随机数生成器(可为 nil,此时不添加抖动)
// 返回值:下一次退避时间(100ms ~ 2s 之间)
func
nextBackoff
(
current
time
.
Duration
,
rng
*
rand
.
Rand
)
time
.
Duration
{
func
nextBackoff
(
current
time
.
Duration
)
time
.
Duration
{
// 指数退避:当前时间 * 1.5
next
:=
time
.
Duration
(
float64
(
current
)
*
backoffMultiplier
)
if
next
>
maxBackoff
{
next
=
maxBackoff
}
if
rng
==
nil
{
return
next
}
// 添加 ±20% 的随机抖动(jitter 范围 0.8 ~ 1.2)
// 抖动可以分散多个请求的重试时间点,避免同时冲击 Redis
jitter
:=
0.8
+
r
ng
.
Float64
()
*
0.4
jitter
:=
0.8
+
r
and
.
Float64
()
*
0.4
jittered
:=
time
.
Duration
(
float64
(
next
)
*
jitter
)
if
jittered
<
initialBackoff
{
return
initialBackoff
...
...
backend/internal/handler/gateway_helper_backoff_test.go
0 → 100644
View file @
3d79773b
package
handler
import
(
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// --- Task 6.2 验证: math/rand/v2 迁移后 nextBackoff 行为正确 ---
func
TestNextBackoff_ExponentialGrowth
(
t
*
testing
.
T
)
{
// 验证退避时间指数增长(乘数 1.5)
// 由于有随机抖动(±20%),需要验证范围
current
:=
initialBackoff
// 100ms
for
i
:=
0
;
i
<
10
;
i
++
{
next
:=
nextBackoff
(
current
)
// 退避结果应在 [initialBackoff, maxBackoff] 范围内
assert
.
GreaterOrEqual
(
t
,
int64
(
next
),
int64
(
initialBackoff
),
"第 %d 次退避不应低于初始值 %v"
,
i
,
initialBackoff
)
assert
.
LessOrEqual
(
t
,
int64
(
next
),
int64
(
maxBackoff
),
"第 %d 次退避不应超过最大值 %v"
,
i
,
maxBackoff
)
// 为下一轮提供当前退避值
current
=
next
}
}
func
TestNextBackoff_BoundedByMaxBackoff
(
t
*
testing
.
T
)
{
// 即使输入非常大,输出也不超过 maxBackoff
for
i
:=
0
;
i
<
100
;
i
++
{
result
:=
nextBackoff
(
10
*
time
.
Second
)
assert
.
LessOrEqual
(
t
,
int64
(
result
),
int64
(
maxBackoff
),
"退避值不应超过 maxBackoff"
)
}
}
func
TestNextBackoff_BoundedByInitialBackoff
(
t
*
testing
.
T
)
{
// 即使输入非常小,输出也不低于 initialBackoff
for
i
:=
0
;
i
<
100
;
i
++
{
result
:=
nextBackoff
(
1
*
time
.
Millisecond
)
assert
.
GreaterOrEqual
(
t
,
int64
(
result
),
int64
(
initialBackoff
),
"退避值不应低于 initialBackoff"
)
}
}
func
TestNextBackoff_HasJitter
(
t
*
testing
.
T
)
{
// 验证多次调用会产生不同的值(随机抖动生效)
// 使用相同的输入调用 50 次,收集结果
results
:=
make
(
map
[
time
.
Duration
]
bool
)
current
:=
500
*
time
.
Millisecond
for
i
:=
0
;
i
<
50
;
i
++
{
result
:=
nextBackoff
(
current
)
results
[
result
]
=
true
}
// 50 次调用应该至少有 2 个不同的值(抖动存在)
require
.
Greater
(
t
,
len
(
results
),
1
,
"nextBackoff 应产生随机抖动,但所有 50 次调用结果相同"
)
}
func
TestNextBackoff_InitialValueGrows
(
t
*
testing
.
T
)
{
// 验证从初始值开始,退避趋势是增长的
current
:=
initialBackoff
var
sum
time
.
Duration
runs
:=
100
for
i
:=
0
;
i
<
runs
;
i
++
{
next
:=
nextBackoff
(
current
)
sum
+=
next
current
=
next
}
avg
:=
sum
/
time
.
Duration
(
runs
)
// 平均退避时间应大于初始值(因为指数增长 + 上限)
assert
.
Greater
(
t
,
int64
(
avg
),
int64
(
initialBackoff
),
"平均退避时间应大于初始退避值"
)
}
func
TestNextBackoff_ConvergesToMaxBackoff
(
t
*
testing
.
T
)
{
// 从初始值开始,经过多次退避后应收敛到 maxBackoff 附近
current
:=
initialBackoff
for
i
:=
0
;
i
<
20
;
i
++
{
current
=
nextBackoff
(
current
)
}
// 经过 20 次迭代后,应该已经到达 maxBackoff 区间
// 由于抖动,允许 ±20% 的范围
lowerBound
:=
time
.
Duration
(
float64
(
maxBackoff
)
*
0.8
)
assert
.
GreaterOrEqual
(
t
,
int64
(
current
),
int64
(
lowerBound
),
"经过多次退避后应收敛到 maxBackoff 附近"
)
}
func
BenchmarkNextBackoff
(
b
*
testing
.
B
)
{
current
:=
initialBackoff
for
i
:=
0
;
i
<
b
.
N
;
i
++
{
current
=
nextBackoff
(
current
)
if
current
>
maxBackoff
{
current
=
initialBackoff
}
}
}
backend/internal/handler/gateway_helper_fastpath_test.go
0 → 100644
View file @
3d79773b
package
handler
import
(
"context"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
type
concurrencyCacheMock
struct
{
acquireUserSlotFn
func
(
ctx
context
.
Context
,
userID
int64
,
maxConcurrency
int
,
requestID
string
)
(
bool
,
error
)
acquireAccountSlotFn
func
(
ctx
context
.
Context
,
accountID
int64
,
maxConcurrency
int
,
requestID
string
)
(
bool
,
error
)
releaseUserCalled
int32
releaseAccountCalled
int32
}
func
(
m
*
concurrencyCacheMock
)
AcquireAccountSlot
(
ctx
context
.
Context
,
accountID
int64
,
maxConcurrency
int
,
requestID
string
)
(
bool
,
error
)
{
if
m
.
acquireAccountSlotFn
!=
nil
{
return
m
.
acquireAccountSlotFn
(
ctx
,
accountID
,
maxConcurrency
,
requestID
)
}
return
false
,
nil
}
func
(
m
*
concurrencyCacheMock
)
ReleaseAccountSlot
(
ctx
context
.
Context
,
accountID
int64
,
requestID
string
)
error
{
atomic
.
AddInt32
(
&
m
.
releaseAccountCalled
,
1
)
return
nil
}
func
(
m
*
concurrencyCacheMock
)
GetAccountConcurrency
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
{
return
0
,
nil
}
func
(
m
*
concurrencyCacheMock
)
GetAccountConcurrencyBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
)
(
map
[
int64
]
int
,
error
)
{
result
:=
make
(
map
[
int64
]
int
,
len
(
accountIDs
))
for
_
,
accountID
:=
range
accountIDs
{
result
[
accountID
]
=
0
}
return
result
,
nil
}
func
(
m
*
concurrencyCacheMock
)
IncrementAccountWaitCount
(
ctx
context
.
Context
,
accountID
int64
,
maxWait
int
)
(
bool
,
error
)
{
return
true
,
nil
}
func
(
m
*
concurrencyCacheMock
)
DecrementAccountWaitCount
(
ctx
context
.
Context
,
accountID
int64
)
error
{
return
nil
}
func
(
m
*
concurrencyCacheMock
)
GetAccountWaitingCount
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
{
return
0
,
nil
}
func
(
m
*
concurrencyCacheMock
)
AcquireUserSlot
(
ctx
context
.
Context
,
userID
int64
,
maxConcurrency
int
,
requestID
string
)
(
bool
,
error
)
{
if
m
.
acquireUserSlotFn
!=
nil
{
return
m
.
acquireUserSlotFn
(
ctx
,
userID
,
maxConcurrency
,
requestID
)
}
return
false
,
nil
}
func
(
m
*
concurrencyCacheMock
)
ReleaseUserSlot
(
ctx
context
.
Context
,
userID
int64
,
requestID
string
)
error
{
atomic
.
AddInt32
(
&
m
.
releaseUserCalled
,
1
)
return
nil
}
func
(
m
*
concurrencyCacheMock
)
GetUserConcurrency
(
ctx
context
.
Context
,
userID
int64
)
(
int
,
error
)
{
return
0
,
nil
}
func
(
m
*
concurrencyCacheMock
)
IncrementWaitCount
(
ctx
context
.
Context
,
userID
int64
,
maxWait
int
)
(
bool
,
error
)
{
return
true
,
nil
}
func
(
m
*
concurrencyCacheMock
)
DecrementWaitCount
(
ctx
context
.
Context
,
userID
int64
)
error
{
return
nil
}
func
(
m
*
concurrencyCacheMock
)
GetAccountsLoadBatch
(
ctx
context
.
Context
,
accounts
[]
service
.
AccountWithConcurrency
)
(
map
[
int64
]
*
service
.
AccountLoadInfo
,
error
)
{
return
map
[
int64
]
*
service
.
AccountLoadInfo
{},
nil
}
func
(
m
*
concurrencyCacheMock
)
GetUsersLoadBatch
(
ctx
context
.
Context
,
users
[]
service
.
UserWithConcurrency
)
(
map
[
int64
]
*
service
.
UserLoadInfo
,
error
)
{
return
map
[
int64
]
*
service
.
UserLoadInfo
{},
nil
}
func
(
m
*
concurrencyCacheMock
)
CleanupExpiredAccountSlots
(
ctx
context
.
Context
,
accountID
int64
)
error
{
return
nil
}
func
TestConcurrencyHelper_TryAcquireUserSlot
(
t
*
testing
.
T
)
{
cache
:=
&
concurrencyCacheMock
{
acquireUserSlotFn
:
func
(
ctx
context
.
Context
,
userID
int64
,
maxConcurrency
int
,
requestID
string
)
(
bool
,
error
)
{
return
true
,
nil
},
}
helper
:=
NewConcurrencyHelper
(
service
.
NewConcurrencyService
(
cache
),
SSEPingFormatNone
,
time
.
Second
)
release
,
acquired
,
err
:=
helper
.
TryAcquireUserSlot
(
context
.
Background
(),
101
,
2
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
acquired
)
require
.
NotNil
(
t
,
release
)
release
()
require
.
Equal
(
t
,
int32
(
1
),
atomic
.
LoadInt32
(
&
cache
.
releaseUserCalled
))
}
func
TestConcurrencyHelper_TryAcquireAccountSlot_NotAcquired
(
t
*
testing
.
T
)
{
cache
:=
&
concurrencyCacheMock
{
acquireAccountSlotFn
:
func
(
ctx
context
.
Context
,
accountID
int64
,
maxConcurrency
int
,
requestID
string
)
(
bool
,
error
)
{
return
false
,
nil
},
}
helper
:=
NewConcurrencyHelper
(
service
.
NewConcurrencyService
(
cache
),
SSEPingFormatNone
,
time
.
Second
)
release
,
acquired
,
err
:=
helper
.
TryAcquireAccountSlot
(
context
.
Background
(),
201
,
1
)
require
.
NoError
(
t
,
err
)
require
.
False
(
t
,
acquired
)
require
.
Nil
(
t
,
release
)
require
.
Equal
(
t
,
int32
(
0
),
atomic
.
LoadInt32
(
&
cache
.
releaseAccountCalled
))
}
Prev
1
…
3
4
5
6
7
8
9
10
11
…
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment