Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
6cc7f997
Commit
6cc7f997
authored
Jan 02, 2026
by
song
Browse files
merge: 合并 upstream/main
parents
95d09f60
106e59b7
Changes
135
Show whitespace changes
Inline
Side-by-side
backend/internal/handler/admin/user_handler.go
View file @
6cc7f997
...
...
@@ -27,7 +27,6 @@ type CreateUserRequest struct {
Email
string
`json:"email" binding:"required,email"`
Password
string
`json:"password" binding:"required,min=6"`
Username
string
`json:"username"`
Wechat
string
`json:"wechat"`
Notes
string
`json:"notes"`
Balance
float64
`json:"balance"`
Concurrency
int
`json:"concurrency"`
...
...
@@ -40,7 +39,6 @@ type UpdateUserRequest struct {
Email
string
`json:"email" binding:"omitempty,email"`
Password
string
`json:"password" binding:"omitempty,min=6"`
Username
*
string
`json:"username"`
Wechat
*
string
`json:"wechat"`
Notes
*
string
`json:"notes"`
Balance
*
float64
`json:"balance"`
Concurrency
*
int
`json:"concurrency"`
...
...
@@ -57,13 +55,22 @@ type UpdateBalanceRequest struct {
// List handles listing all users with pagination
// GET /api/v1/admin/users
// Query params:
// - status: filter by user status
// - role: filter by user role
// - search: search in email, username
// - attr[{id}]: filter by custom attribute value, e.g. attr[1]=company
func
(
h
*
UserHandler
)
List
(
c
*
gin
.
Context
)
{
page
,
pageSize
:=
response
.
ParsePagination
(
c
)
status
:=
c
.
Query
(
"status"
)
role
:=
c
.
Query
(
"role"
)
search
:=
c
.
Query
(
"search"
)
users
,
total
,
err
:=
h
.
adminService
.
ListUsers
(
c
.
Request
.
Context
(),
page
,
pageSize
,
status
,
role
,
search
)
filters
:=
service
.
UserListFilters
{
Status
:
c
.
Query
(
"status"
),
Role
:
c
.
Query
(
"role"
),
Search
:
c
.
Query
(
"search"
),
Attributes
:
parseAttributeFilters
(
c
),
}
users
,
total
,
err
:=
h
.
adminService
.
ListUsers
(
c
.
Request
.
Context
(),
page
,
pageSize
,
filters
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
...
...
@@ -76,6 +83,29 @@ func (h *UserHandler) List(c *gin.Context) {
response
.
Paginated
(
c
,
out
,
total
,
page
,
pageSize
)
}
// parseAttributeFilters extracts attribute filters from query params
// Format: attr[{attributeID}]=value, e.g. attr[1]=company&attr[2]=developer
func
parseAttributeFilters
(
c
*
gin
.
Context
)
map
[
int64
]
string
{
result
:=
make
(
map
[
int64
]
string
)
// Get all query params and look for attr[*] pattern
for
key
,
values
:=
range
c
.
Request
.
URL
.
Query
()
{
if
len
(
values
)
==
0
||
values
[
0
]
==
""
{
continue
}
// Check if key matches pattern attr[{id}]
if
len
(
key
)
>
5
&&
key
[
:
5
]
==
"attr["
&&
key
[
len
(
key
)
-
1
]
==
']'
{
idStr
:=
key
[
5
:
len
(
key
)
-
1
]
id
,
err
:=
strconv
.
ParseInt
(
idStr
,
10
,
64
)
if
err
==
nil
&&
id
>
0
{
result
[
id
]
=
values
[
0
]
}
}
}
return
result
}
// GetByID handles getting a user by ID
// GET /api/v1/admin/users/:id
func
(
h
*
UserHandler
)
GetByID
(
c
*
gin
.
Context
)
{
...
...
@@ -107,7 +137,6 @@ func (h *UserHandler) Create(c *gin.Context) {
Email
:
req
.
Email
,
Password
:
req
.
Password
,
Username
:
req
.
Username
,
Wechat
:
req
.
Wechat
,
Notes
:
req
.
Notes
,
Balance
:
req
.
Balance
,
Concurrency
:
req
.
Concurrency
,
...
...
@@ -141,7 +170,6 @@ func (h *UserHandler) Update(c *gin.Context) {
Email
:
req
.
Email
,
Password
:
req
.
Password
,
Username
:
req
.
Username
,
Wechat
:
req
.
Wechat
,
Notes
:
req
.
Notes
,
Balance
:
req
.
Balance
,
Concurrency
:
req
.
Concurrency
,
...
...
backend/internal/handler/dto/mappers.go
View file @
6cc7f997
...
...
@@ -10,7 +10,6 @@ func UserFromServiceShallow(u *service.User) *User {
ID
:
u
.
ID
,
Email
:
u
.
Email
,
Username
:
u
.
Username
,
Wechat
:
u
.
Wechat
,
Notes
:
u
.
Notes
,
Role
:
u
.
Role
,
Balance
:
u
.
Balance
,
...
...
backend/internal/handler/dto/types.go
View file @
6cc7f997
...
...
@@ -6,7 +6,6 @@ type User struct {
ID
int64
`json:"id"`
Email
string
`json:"email"`
Username
string
`json:"username"`
Wechat
string
`json:"wechat"`
Notes
string
`json:"notes"`
Role
string
`json:"role"`
Balance
float64
`json:"balance"`
...
...
backend/internal/handler/gateway_handler.go
View file @
6cc7f997
...
...
@@ -142,6 +142,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
else
if
apiKey
.
Group
!=
nil
{
platform
=
apiKey
.
Group
.
Platform
}
sessionKey
:=
sessionHash
if
platform
==
service
.
PlatformGemini
&&
sessionHash
!=
""
{
sessionKey
=
"gemini:"
+
sessionHash
}
if
platform
==
service
.
PlatformGemini
{
const
maxAccountSwitches
=
3
...
...
@@ -150,7 +154,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
lastFailoverStatus
:=
0
for
{
account
,
err
:=
h
.
g
eminiCompat
Service
.
SelectAccount
ForModelWithExclusion
s
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
session
Hash
,
reqModel
,
failedAccountIDs
)
selection
,
err
:=
h
.
g
ateway
Service
.
SelectAccount
WithLoadAwarenes
s
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
session
Key
,
reqModel
,
failedAccountIDs
)
if
err
!=
nil
{
if
len
(
failedAccountIDs
)
==
0
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts: "
+
err
.
Error
(),
streamStarted
)
...
...
@@ -159,9 +163,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h
.
handleFailoverExhausted
(
c
,
lastFailoverStatus
,
streamStarted
)
return
}
account
:=
selection
.
Account
// 检查预热请求拦截(在账号选择后、转发前检查)
if
account
.
IsInterceptWarmupEnabled
()
&&
isWarmupRequest
(
body
)
{
if
selection
.
Acquired
&&
selection
.
ReleaseFunc
!=
nil
{
selection
.
ReleaseFunc
()
}
if
reqStream
{
sendMockWarmupStream
(
c
,
reqModel
)
}
else
{
...
...
@@ -171,12 +179,47 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 3. 获取账号并发槽位
accountReleaseFunc
,
err
:=
h
.
concurrencyHelper
.
AcquireAccountSlotWithWait
(
c
,
account
.
ID
,
account
.
Concurrency
,
reqStream
,
&
streamStarted
)
accountReleaseFunc
:=
selection
.
ReleaseFunc
var
accountWaitRelease
func
()
if
!
selection
.
Acquired
{
if
selection
.
WaitPlan
==
nil
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts"
,
streamStarted
)
return
}
canWait
,
err
:=
h
.
concurrencyHelper
.
IncrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
,
selection
.
WaitPlan
.
MaxWaiting
)
if
err
!=
nil
{
log
.
Printf
(
"Increment account wait count failed: %v"
,
err
)
}
else
if
!
canWait
{
log
.
Printf
(
"Account wait queue full: account=%d"
,
account
.
ID
)
h
.
handleStreamingAwareError
(
c
,
http
.
StatusTooManyRequests
,
"rate_limit_error"
,
"Too many pending requests, please retry later"
,
streamStarted
)
return
}
else
{
// Only set release function if increment succeeded
accountWaitRelease
=
func
()
{
h
.
concurrencyHelper
.
DecrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
)
}
}
accountReleaseFunc
,
err
=
h
.
concurrencyHelper
.
AcquireAccountSlotWithWaitTimeout
(
c
,
account
.
ID
,
selection
.
WaitPlan
.
MaxConcurrency
,
selection
.
WaitPlan
.
Timeout
,
reqStream
,
&
streamStarted
,
)
if
err
!=
nil
{
if
accountWaitRelease
!=
nil
{
accountWaitRelease
()
}
log
.
Printf
(
"Account concurrency acquire failed: %v"
,
err
)
h
.
handleConcurrencyError
(
c
,
err
,
"account"
,
streamStarted
)
return
}
if
err
:=
h
.
gatewayService
.
BindStickySession
(
c
.
Request
.
Context
(),
sessionKey
,
account
.
ID
);
err
!=
nil
{
log
.
Printf
(
"Bind sticky session failed: %v"
,
err
)
}
}
// 转发请求 - 根据账号平台分流
var
result
*
service
.
ForwardResult
...
...
@@ -188,6 +231,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if
accountReleaseFunc
!=
nil
{
accountReleaseFunc
()
}
if
accountWaitRelease
!=
nil
{
accountWaitRelease
()
}
if
err
!=
nil
{
var
failoverErr
*
service
.
UpstreamFailoverError
if
errors
.
As
(
err
,
&
failoverErr
)
{
...
...
@@ -232,7 +278,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
for
{
// 选择支持该模型的账号
account
,
err
:=
h
.
gatewayService
.
SelectAccount
ForModelWithExclusion
s
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
session
Hash
,
reqModel
,
failedAccountIDs
)
selection
,
err
:=
h
.
gatewayService
.
SelectAccount
WithLoadAwarenes
s
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
session
Key
,
reqModel
,
failedAccountIDs
)
if
err
!=
nil
{
if
len
(
failedAccountIDs
)
==
0
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts: "
+
err
.
Error
(),
streamStarted
)
...
...
@@ -241,9 +287,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h
.
handleFailoverExhausted
(
c
,
lastFailoverStatus
,
streamStarted
)
return
}
account
:=
selection
.
Account
// 检查预热请求拦截(在账号选择后、转发前检查)
if
account
.
IsInterceptWarmupEnabled
()
&&
isWarmupRequest
(
body
)
{
if
selection
.
Acquired
&&
selection
.
ReleaseFunc
!=
nil
{
selection
.
ReleaseFunc
()
}
if
reqStream
{
sendMockWarmupStream
(
c
,
reqModel
)
}
else
{
...
...
@@ -253,12 +303,47 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 3. 获取账号并发槽位
accountReleaseFunc
,
err
:=
h
.
concurrencyHelper
.
AcquireAccountSlotWithWait
(
c
,
account
.
ID
,
account
.
Concurrency
,
reqStream
,
&
streamStarted
)
accountReleaseFunc
:=
selection
.
ReleaseFunc
var
accountWaitRelease
func
()
if
!
selection
.
Acquired
{
if
selection
.
WaitPlan
==
nil
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts"
,
streamStarted
)
return
}
canWait
,
err
:=
h
.
concurrencyHelper
.
IncrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
,
selection
.
WaitPlan
.
MaxWaiting
)
if
err
!=
nil
{
log
.
Printf
(
"Increment account wait count failed: %v"
,
err
)
}
else
if
!
canWait
{
log
.
Printf
(
"Account wait queue full: account=%d"
,
account
.
ID
)
h
.
handleStreamingAwareError
(
c
,
http
.
StatusTooManyRequests
,
"rate_limit_error"
,
"Too many pending requests, please retry later"
,
streamStarted
)
return
}
else
{
// Only set release function if increment succeeded
accountWaitRelease
=
func
()
{
h
.
concurrencyHelper
.
DecrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
)
}
}
accountReleaseFunc
,
err
=
h
.
concurrencyHelper
.
AcquireAccountSlotWithWaitTimeout
(
c
,
account
.
ID
,
selection
.
WaitPlan
.
MaxConcurrency
,
selection
.
WaitPlan
.
Timeout
,
reqStream
,
&
streamStarted
,
)
if
err
!=
nil
{
if
accountWaitRelease
!=
nil
{
accountWaitRelease
()
}
log
.
Printf
(
"Account concurrency acquire failed: %v"
,
err
)
h
.
handleConcurrencyError
(
c
,
err
,
"account"
,
streamStarted
)
return
}
if
err
:=
h
.
gatewayService
.
BindStickySession
(
c
.
Request
.
Context
(),
sessionKey
,
account
.
ID
);
err
!=
nil
{
log
.
Printf
(
"Bind sticky session failed: %v"
,
err
)
}
}
// 转发请求 - 根据账号平台分流
var
result
*
service
.
ForwardResult
...
...
@@ -270,6 +355,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if
accountReleaseFunc
!=
nil
{
accountReleaseFunc
()
}
if
accountWaitRelease
!=
nil
{
accountWaitRelease
()
}
if
err
!=
nil
{
var
failoverErr
*
service
.
UpstreamFailoverError
if
errors
.
As
(
err
,
&
failoverErr
)
{
...
...
@@ -309,7 +397,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// Models handles listing available models
// GET /v1/models
// Returns different model lists based on the API key's group platform or forced platform
// Returns models based on account configurations (model_mapping whitelist)
// Falls back to default models if no whitelist is configured
func
(
h
*
GatewayHandler
)
Models
(
c
*
gin
.
Context
)
{
apiKey
,
_
:=
middleware2
.
GetApiKeyFromContext
(
c
)
...
...
@@ -324,8 +413,37 @@ func (h *GatewayHandler) Models(c *gin.Context) {
}
}
// Return OpenAI models for OpenAI platform groups
if
apiKey
!=
nil
&&
apiKey
.
Group
!=
nil
&&
apiKey
.
Group
.
Platform
==
"openai"
{
var
groupID
*
int64
var
platform
string
if
apiKey
!=
nil
&&
apiKey
.
Group
!=
nil
{
groupID
=
&
apiKey
.
Group
.
ID
platform
=
apiKey
.
Group
.
Platform
}
// Get available models from account configurations (without platform filter)
availableModels
:=
h
.
gatewayService
.
GetAvailableModels
(
c
.
Request
.
Context
(),
groupID
,
""
)
if
len
(
availableModels
)
>
0
{
// Build model list from whitelist
models
:=
make
([]
claude
.
Model
,
0
,
len
(
availableModels
))
for
_
,
modelID
:=
range
availableModels
{
models
=
append
(
models
,
claude
.
Model
{
ID
:
modelID
,
Type
:
"model"
,
DisplayName
:
modelID
,
CreatedAt
:
"2024-01-01T00:00:00Z"
,
})
}
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"object"
:
"list"
,
"data"
:
models
,
})
return
}
// Fallback to default models
if
platform
==
"openai"
{
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"object"
:
"list"
,
"data"
:
openai
.
DefaultModels
,
...
...
@@ -333,7 +451,6 @@ func (h *GatewayHandler) Models(c *gin.Context) {
return
}
// Default: Claude models
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"object"
:
"list"
,
"data"
:
claude
.
DefaultModels
,
...
...
backend/internal/handler/gateway_helper.go
View file @
6cc7f997
...
...
@@ -83,6 +83,16 @@ func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64
h
.
concurrencyService
.
DecrementWaitCount
(
ctx
,
userID
)
}
// IncrementAccountWaitCount increments the wait count for an account
func
(
h
*
ConcurrencyHelper
)
IncrementAccountWaitCount
(
ctx
context
.
Context
,
accountID
int64
,
maxWait
int
)
(
bool
,
error
)
{
return
h
.
concurrencyService
.
IncrementAccountWaitCount
(
ctx
,
accountID
,
maxWait
)
}
// DecrementAccountWaitCount decrements the wait count for an account
func
(
h
*
ConcurrencyHelper
)
DecrementAccountWaitCount
(
ctx
context
.
Context
,
accountID
int64
)
{
h
.
concurrencyService
.
DecrementAccountWaitCount
(
ctx
,
accountID
)
}
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
// For streaming requests, sends ping events during the wait.
// streamStarted is updated if streaming response has begun.
...
...
@@ -126,7 +136,12 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.
// streamStarted pointer is updated when streaming begins (for proper error handling by caller).
func
(
h
*
ConcurrencyHelper
)
waitForSlotWithPing
(
c
*
gin
.
Context
,
slotType
string
,
id
int64
,
maxConcurrency
int
,
isStream
bool
,
streamStarted
*
bool
)
(
func
(),
error
)
{
ctx
,
cancel
:=
context
.
WithTimeout
(
c
.
Request
.
Context
(),
maxConcurrencyWait
)
return
h
.
waitForSlotWithPingTimeout
(
c
,
slotType
,
id
,
maxConcurrency
,
maxConcurrencyWait
,
isStream
,
streamStarted
)
}
// waitForSlotWithPingTimeout waits for a concurrency slot with a custom timeout.
func
(
h
*
ConcurrencyHelper
)
waitForSlotWithPingTimeout
(
c
*
gin
.
Context
,
slotType
string
,
id
int64
,
maxConcurrency
int
,
timeout
time
.
Duration
,
isStream
bool
,
streamStarted
*
bool
)
(
func
(),
error
)
{
ctx
,
cancel
:=
context
.
WithTimeout
(
c
.
Request
.
Context
(),
timeout
)
defer
cancel
()
// Determine if ping is needed (streaming + ping format defined)
...
...
@@ -200,6 +215,11 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string,
}
}
// AcquireAccountSlotWithWaitTimeout acquires an account slot with a custom timeout (keeps SSE ping).
func
(
h
*
ConcurrencyHelper
)
AcquireAccountSlotWithWaitTimeout
(
c
*
gin
.
Context
,
accountID
int64
,
maxConcurrency
int
,
timeout
time
.
Duration
,
isStream
bool
,
streamStarted
*
bool
)
(
func
(),
error
)
{
return
h
.
waitForSlotWithPingTimeout
(
c
,
"account"
,
accountID
,
maxConcurrency
,
timeout
,
isStream
,
streamStarted
)
}
// nextBackoff 计算下一次退避时间
// 性能优化:使用指数退避 + 随机抖动,避免惊群效应
// current: 当前退避时间
...
...
backend/internal/handler/gemini_v1beta_handler.go
View file @
6cc7f997
...
...
@@ -198,13 +198,17 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 3) select account (sticky session based on request body)
parsedReq
,
_
:=
service
.
ParseGatewayRequest
(
body
)
sessionHash
:=
h
.
gatewayService
.
GenerateSessionHash
(
parsedReq
)
sessionKey
:=
sessionHash
if
sessionHash
!=
""
{
sessionKey
=
"gemini:"
+
sessionHash
}
const
maxAccountSwitches
=
3
switchCount
:=
0
failedAccountIDs
:=
make
(
map
[
int64
]
struct
{})
lastFailoverStatus
:=
0
for
{
account
,
err
:=
h
.
g
eminiCompat
Service
.
SelectAccount
ForModelWithExclusion
s
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
session
Hash
,
modelName
,
failedAccountIDs
)
selection
,
err
:=
h
.
g
ateway
Service
.
SelectAccount
WithLoadAwarenes
s
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
session
Key
,
modelName
,
failedAccountIDs
)
if
err
!=
nil
{
if
len
(
failedAccountIDs
)
==
0
{
googleError
(
c
,
http
.
StatusServiceUnavailable
,
"No available Gemini accounts: "
+
err
.
Error
())
...
...
@@ -213,13 +217,49 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
handleGeminiFailoverExhausted
(
c
,
lastFailoverStatus
)
return
}
account
:=
selection
.
Account
// 4) account concurrency slot
accountReleaseFunc
,
err
:=
geminiConcurrency
.
AcquireAccountSlotWithWait
(
c
,
account
.
ID
,
account
.
Concurrency
,
stream
,
&
streamStarted
)
accountReleaseFunc
:=
selection
.
ReleaseFunc
var
accountWaitRelease
func
()
if
!
selection
.
Acquired
{
if
selection
.
WaitPlan
==
nil
{
googleError
(
c
,
http
.
StatusServiceUnavailable
,
"No available Gemini accounts"
)
return
}
canWait
,
err
:=
geminiConcurrency
.
IncrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
,
selection
.
WaitPlan
.
MaxWaiting
)
if
err
!=
nil
{
log
.
Printf
(
"Increment account wait count failed: %v"
,
err
)
}
else
if
!
canWait
{
log
.
Printf
(
"Account wait queue full: account=%d"
,
account
.
ID
)
googleError
(
c
,
http
.
StatusTooManyRequests
,
"Too many pending requests, please retry later"
)
return
}
else
{
// Only set release function if increment succeeded
accountWaitRelease
=
func
()
{
geminiConcurrency
.
DecrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
)
}
}
accountReleaseFunc
,
err
=
geminiConcurrency
.
AcquireAccountSlotWithWaitTimeout
(
c
,
account
.
ID
,
selection
.
WaitPlan
.
MaxConcurrency
,
selection
.
WaitPlan
.
Timeout
,
stream
,
&
streamStarted
,
)
if
err
!=
nil
{
if
accountWaitRelease
!=
nil
{
accountWaitRelease
()
}
googleError
(
c
,
http
.
StatusTooManyRequests
,
err
.
Error
())
return
}
if
err
:=
h
.
gatewayService
.
BindStickySession
(
c
.
Request
.
Context
(),
sessionKey
,
account
.
ID
);
err
!=
nil
{
log
.
Printf
(
"Bind sticky session failed: %v"
,
err
)
}
}
// 5) forward (根据平台分流)
var
result
*
service
.
ForwardResult
...
...
@@ -231,6 +271,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
if
accountReleaseFunc
!=
nil
{
accountReleaseFunc
()
}
if
accountWaitRelease
!=
nil
{
accountWaitRelease
()
}
if
err
!=
nil
{
var
failoverErr
*
service
.
UpstreamFailoverError
if
errors
.
As
(
err
,
&
failoverErr
)
{
...
...
backend/internal/handler/handler.go
View file @
6cc7f997
...
...
@@ -20,6 +20,7 @@ type AdminHandlers struct {
System
*
admin
.
SystemHandler
Subscription
*
admin
.
SubscriptionHandler
Usage
*
admin
.
UsageHandler
UserAttribute
*
admin
.
UserAttributeHandler
}
// Handlers contains all HTTP handlers
...
...
backend/internal/handler/openai_gateway_handler.go
View file @
6cc7f997
...
...
@@ -146,7 +146,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
for
{
// Select account supporting the requested model
log
.
Printf
(
"[OpenAI Handler] Selecting account: groupID=%v model=%s"
,
apiKey
.
GroupID
,
reqModel
)
account
,
err
:=
h
.
gatewayService
.
SelectAccount
ForModelWithExclusion
s
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionHash
,
reqModel
,
failedAccountIDs
)
selection
,
err
:=
h
.
gatewayService
.
SelectAccount
WithLoadAwarenes
s
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionHash
,
reqModel
,
failedAccountIDs
)
if
err
!=
nil
{
log
.
Printf
(
"[OpenAI Handler] SelectAccount failed: %v"
,
err
)
if
len
(
failedAccountIDs
)
==
0
{
...
...
@@ -156,21 +156,60 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h
.
handleFailoverExhausted
(
c
,
lastFailoverStatus
,
streamStarted
)
return
}
account
:=
selection
.
Account
log
.
Printf
(
"[OpenAI Handler] Selected account: id=%d name=%s"
,
account
.
ID
,
account
.
Name
)
// 3. Acquire account concurrency slot
accountReleaseFunc
,
err
:=
h
.
concurrencyHelper
.
AcquireAccountSlotWithWait
(
c
,
account
.
ID
,
account
.
Concurrency
,
reqStream
,
&
streamStarted
)
accountReleaseFunc
:=
selection
.
ReleaseFunc
var
accountWaitRelease
func
()
if
!
selection
.
Acquired
{
if
selection
.
WaitPlan
==
nil
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts"
,
streamStarted
)
return
}
canWait
,
err
:=
h
.
concurrencyHelper
.
IncrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
,
selection
.
WaitPlan
.
MaxWaiting
)
if
err
!=
nil
{
log
.
Printf
(
"Increment account wait count failed: %v"
,
err
)
}
else
if
!
canWait
{
log
.
Printf
(
"Account wait queue full: account=%d"
,
account
.
ID
)
h
.
handleStreamingAwareError
(
c
,
http
.
StatusTooManyRequests
,
"rate_limit_error"
,
"Too many pending requests, please retry later"
,
streamStarted
)
return
}
else
{
// Only set release function if increment succeeded
accountWaitRelease
=
func
()
{
h
.
concurrencyHelper
.
DecrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
)
}
}
accountReleaseFunc
,
err
=
h
.
concurrencyHelper
.
AcquireAccountSlotWithWaitTimeout
(
c
,
account
.
ID
,
selection
.
WaitPlan
.
MaxConcurrency
,
selection
.
WaitPlan
.
Timeout
,
reqStream
,
&
streamStarted
,
)
if
err
!=
nil
{
if
accountWaitRelease
!=
nil
{
accountWaitRelease
()
}
log
.
Printf
(
"Account concurrency acquire failed: %v"
,
err
)
h
.
handleConcurrencyError
(
c
,
err
,
"account"
,
streamStarted
)
return
}
if
err
:=
h
.
gatewayService
.
BindStickySession
(
c
.
Request
.
Context
(),
sessionHash
,
account
.
ID
);
err
!=
nil
{
log
.
Printf
(
"Bind sticky session failed: %v"
,
err
)
}
}
// Forward request
result
,
err
:=
h
.
gatewayService
.
Forward
(
c
.
Request
.
Context
(),
c
,
account
,
body
)
if
accountReleaseFunc
!=
nil
{
accountReleaseFunc
()
}
if
accountWaitRelease
!=
nil
{
accountWaitRelease
()
}
if
err
!=
nil
{
var
failoverErr
*
service
.
UpstreamFailoverError
if
errors
.
As
(
err
,
&
failoverErr
)
{
...
...
backend/internal/handler/user_handler.go
View file @
6cc7f997
...
...
@@ -30,7 +30,6 @@ type ChangePasswordRequest struct {
// UpdateProfileRequest represents the update profile request payload
type
UpdateProfileRequest
struct
{
Username
*
string
`json:"username"`
Wechat
*
string
`json:"wechat"`
}
// GetProfile handles getting user profile
...
...
@@ -99,7 +98,6 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
svcReq
:=
service
.
UpdateProfileRequest
{
Username
:
req
.
Username
,
Wechat
:
req
.
Wechat
,
}
updatedUser
,
err
:=
h
.
userService
.
UpdateProfile
(
c
.
Request
.
Context
(),
subject
.
UserID
,
svcReq
)
if
err
!=
nil
{
...
...
backend/internal/handler/wire.go
View file @
6cc7f997
...
...
@@ -23,6 +23,7 @@ func ProvideAdminHandlers(
systemHandler
*
admin
.
SystemHandler
,
subscriptionHandler
*
admin
.
SubscriptionHandler
,
usageHandler
*
admin
.
UsageHandler
,
userAttributeHandler
*
admin
.
UserAttributeHandler
,
)
*
AdminHandlers
{
return
&
AdminHandlers
{
Dashboard
:
dashboardHandler
,
...
...
@@ -39,6 +40,7 @@ func ProvideAdminHandlers(
System
:
systemHandler
,
Subscription
:
subscriptionHandler
,
Usage
:
usageHandler
,
UserAttribute
:
userAttributeHandler
,
}
}
...
...
@@ -107,6 +109,7 @@ var ProviderSet = wire.NewSet(
ProvideSystemHandler
,
admin
.
NewSubscriptionHandler
,
admin
.
NewUsageHandler
,
admin
.
NewUserAttributeHandler
,
// AdminHandlers and Handlers constructors
ProvideAdminHandlers
,
...
...
backend/internal/integration/e2e_gateway_test.go
View file @
6cc7f997
...
...
@@ -57,6 +57,7 @@ var geminiModels = []string{
"gemini-2.5-flash-lite"
,
"gemini-3-flash"
,
"gemini-3-pro-low"
,
"gemini-3-pro-high"
,
}
func
TestMain
(
m
*
testing
.
M
)
{
...
...
@@ -641,6 +642,37 @@ func testClaudeThinkingWithToolHistory(t *testing.T, model string) {
t
.
Logf
(
"✅ thinking 模式工具调用测试通过, id=%v"
,
result
[
"id"
])
}
// TestClaudeMessagesWithGeminiModel 测试在 Claude 端点使用 Gemini 模型
// 验证:通过 /v1/messages 端点传入 gemini 模型名的场景(含前缀映射)
// 仅在 Antigravity 模式下运行(ENDPOINT_PREFIX="/antigravity")
func
TestClaudeMessagesWithGeminiModel
(
t
*
testing
.
T
)
{
if
endpointPrefix
!=
"/antigravity"
{
t
.
Skip
(
"仅在 Antigravity 模式下运行"
)
}
// 测试通过 Claude 端点调用 Gemini 模型
geminiViaClaude
:=
[]
string
{
"gemini-3-flash"
,
// 直接支持
"gemini-3-pro-low"
,
// 直接支持
"gemini-3-pro-high"
,
// 直接支持
"gemini-3-pro"
,
// 前缀映射 -> gemini-3-pro-high
"gemini-3-pro-preview"
,
// 前缀映射 -> gemini-3-pro-high
}
for
i
,
model
:=
range
geminiViaClaude
{
if
i
>
0
{
time
.
Sleep
(
testInterval
)
}
t
.
Run
(
model
+
"_通过Claude端点"
,
func
(
t
*
testing
.
T
)
{
testClaudeMessage
(
t
,
model
,
false
)
})
time
.
Sleep
(
testInterval
)
t
.
Run
(
model
+
"_通过Claude端点_流式"
,
func
(
t
*
testing
.
T
)
{
testClaudeMessage
(
t
,
model
,
true
)
})
}
}
// TestClaudeMessagesWithNoSignature 测试历史 thinking block 不带 signature 的场景
// 验证:Gemini 模型接受没有 signature 的 thinking block
func
TestClaudeMessagesWithNoSignature
(
t
*
testing
.
T
)
{
...
...
@@ -738,3 +770,30 @@ func testClaudeWithNoSignature(t *testing.T, model string) {
}
t
.
Logf
(
"✅ 无 signature thinking 处理测试通过, id=%v"
,
result
[
"id"
])
}
// TestGeminiEndpointWithClaudeModel 测试通过 Gemini 端点调用 Claude 模型
// 仅在 Antigravity 模式下运行(ENDPOINT_PREFIX="/antigravity")
func
TestGeminiEndpointWithClaudeModel
(
t
*
testing
.
T
)
{
if
endpointPrefix
!=
"/antigravity"
{
t
.
Skip
(
"仅在 Antigravity 模式下运行"
)
}
// 测试通过 Gemini 端点调用 Claude 模型
claudeViaGemini
:=
[]
string
{
"claude-sonnet-4-5"
,
"claude-opus-4-5-thinking"
,
}
for
i
,
model
:=
range
claudeViaGemini
{
if
i
>
0
{
time
.
Sleep
(
testInterval
)
}
t
.
Run
(
model
+
"_通过Gemini端点"
,
func
(
t
*
testing
.
T
)
{
testGeminiGenerate
(
t
,
model
,
false
)
})
time
.
Sleep
(
testInterval
)
t
.
Run
(
model
+
"_通过Gemini端点_流式"
,
func
(
t
*
testing
.
T
)
{
testGeminiGenerate
(
t
,
model
,
true
)
})
}
}
backend/internal/pkg/antigravity/claude_types.go
View file @
6cc7f997
...
...
@@ -54,6 +54,9 @@ type CustomToolSpec struct {
InputSchema
map
[
string
]
any
`json:"input_schema"`
}
// ClaudeCustomToolSpec 兼容旧命名(MCP custom 工具规格)
type
ClaudeCustomToolSpec
=
CustomToolSpec
// SystemBlock system prompt 数组形式的元素
type
SystemBlock
struct
{
Type
string
`json:"type"`
...
...
backend/internal/pkg/antigravity/request_transformer.go
View file @
6cc7f997
...
...
@@ -14,13 +14,16 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st
// 用于存储 tool_use id -> name 映射
toolIDToName
:=
make
(
map
[
string
]
string
)
// 检测是否启用 thinking
isThinkingEnabled
:=
claudeReq
.
Thinking
!=
nil
&&
claudeReq
.
Thinking
.
Type
==
"enabled"
// 只有 Gemini 模型支持 dummy thought workaround
// Claude 模型通过 Vertex/Google API 需要有效的 thought signatures
allowDummyThought
:=
strings
.
HasPrefix
(
mappedModel
,
"gemini-"
)
// 检测是否启用 thinking
requestedThinkingEnabled
:=
claudeReq
.
Thinking
!=
nil
&&
claudeReq
.
Thinking
.
Type
==
"enabled"
// 为避免 Claude 模型的 thought signature/消息块约束导致 400(上游要求 thinking 块开头等),
// 非 Gemini 模型默认不启用 thinking(除非未来支持完整签名链路)。
isThinkingEnabled
:=
requestedThinkingEnabled
&&
allowDummyThought
// 1. 构建 contents
contents
,
err
:=
buildContents
(
claudeReq
.
Messages
,
toolIDToName
,
isThinkingEnabled
,
allowDummyThought
)
if
err
!=
nil
{
...
...
@@ -31,7 +34,15 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st
systemInstruction
:=
buildSystemInstruction
(
claudeReq
.
System
,
claudeReq
.
Model
)
// 3. 构建 generationConfig
generationConfig
:=
buildGenerationConfig
(
claudeReq
)
reqForGen
:=
claudeReq
if
requestedThinkingEnabled
&&
!
allowDummyThought
{
log
.
Printf
(
"[Warning] Disabling thinking for non-Gemini model in antigravity transform: model=%s"
,
mappedModel
)
// shallow copy to avoid mutating caller's request
clone
:=
*
claudeReq
clone
.
Thinking
=
nil
reqForGen
=
&
clone
}
generationConfig
:=
buildGenerationConfig
(
reqForGen
)
// 4. 构建 tools
tools
:=
buildTools
(
claudeReq
.
Tools
)
...
...
@@ -150,6 +161,7 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
parts
=
append
([]
GeminiPart
{{
Text
:
"Thinking..."
,
Thought
:
true
,
ThoughtSignature
:
dummyThoughtSignature
,
}},
parts
...
)
}
}
...
...
@@ -171,6 +183,34 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
// 参考: https://ai.google.dev/gemini-api/docs/thought-signatures
const
dummyThoughtSignature
=
"skip_thought_signature_validator"
// isValidThoughtSignature 验证 thought signature 是否有效
// Claude API 要求 signature 必须是 base64 编码的字符串,长度至少 32 字节
func
isValidThoughtSignature
(
signature
string
)
bool
{
// 空字符串无效
if
signature
==
""
{
return
false
}
// signature 应该是 base64 编码,长度至少 40 个字符(约 30 字节)
// 参考 Claude API 文档和实际观察到的有效 signature
if
len
(
signature
)
<
40
{
log
.
Printf
(
"[Debug] Signature too short: len=%d"
,
len
(
signature
))
return
false
}
// 检查是否是有效的 base64 字符
// base64 字符集: A-Z, a-z, 0-9, +, /, =
for
i
,
c
:=
range
signature
{
if
(
c
<
'A'
||
c
>
'Z'
)
&&
(
c
<
'a'
||
c
>
'z'
)
&&
(
c
<
'0'
||
c
>
'9'
)
&&
c
!=
'+'
&&
c
!=
'/'
&&
c
!=
'='
{
log
.
Printf
(
"[Debug] Invalid base64 character at position %d: %c (code=%d)"
,
i
,
c
,
c
)
return
false
}
}
return
true
}
// buildParts 构建消息的 parts
// allowDummyThought: 只有 Gemini 模型支持 dummy thought signature
func
buildParts
(
content
json
.
RawMessage
,
toolIDToName
map
[
string
]
string
,
allowDummyThought
bool
)
([]
GeminiPart
,
error
)
{
...
...
@@ -199,22 +239,30 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
}
case
"thinking"
:
part
:=
GeminiPart
{
if
allowDummyThought
{
// Gemini 模型可以使用 dummy signature
parts
=
append
(
parts
,
GeminiPart
{
Text
:
block
.
Thinking
,
Thought
:
true
,
ThoughtSignature
:
dummyThoughtSignature
,
})
continue
}
// 保留原有 signature(Claude 模型需要有效的 signature)
if
block
.
Signature
!=
""
{
part
.
ThoughtSignature
=
block
.
Signature
}
else
if
!
allowDummyThought
{
// Claude 模型需要有效 signature,跳过无 signature 的 thinking block
log
.
Printf
(
"Warning: skipping thinking block without signature for Claude model"
)
// Claude 模型:仅在提供有效 signature 时保留 thinking block;否则跳过以避免上游校验失败。
signature
:=
strings
.
TrimSpace
(
block
.
Signature
)
if
signature
==
""
||
signature
==
dummyThoughtSignature
{
log
.
Printf
(
"[Warning] Skipping thinking block for Claude model (missing or dummy signature)"
)
continue
}
else
{
// Gemini 模型使用 dummy signature
part
.
ThoughtSignature
=
dummyThoughtSignature
}
parts
=
append
(
parts
,
part
)
if
!
isValidThoughtSignature
(
signature
)
{
log
.
Printf
(
"[Debug] Thinking signature may be invalid (passing through anyway): len=%d"
,
len
(
signature
))
}
parts
=
append
(
parts
,
GeminiPart
{
Text
:
block
.
Thinking
,
Thought
:
true
,
ThoughtSignature
:
signature
,
})
case
"image"
:
if
block
.
Source
!=
nil
&&
block
.
Source
.
Type
==
"base64"
{
...
...
@@ -239,10 +287,9 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
ID
:
block
.
ID
,
},
}
// 保留原有 signature,或对 Gemini 模型使用 dummy signature
if
block
.
Signature
!=
""
{
part
.
ThoughtSignature
=
block
.
Signature
}
else
if
allowDummyThought
{
// 只有 Gemini 模型使用 dummy signature
// Claude 模型不设置 signature(避免验证问题)
if
allowDummyThought
{
part
.
ThoughtSignature
=
dummyThoughtSignature
}
parts
=
append
(
parts
,
part
)
...
...
@@ -386,9 +433,9 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
// 普通工具
var
funcDecls
[]
GeminiFunctionDecl
for
_
,
tool
:=
range
tools
{
for
i
,
tool
:=
range
tools
{
// 跳过无效工具名称
if
tool
.
Name
==
""
{
if
strings
.
TrimSpace
(
tool
.
Name
)
==
""
{
log
.
Printf
(
"Warning: skipping tool with empty name"
)
continue
}
...
...
@@ -397,10 +444,18 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
var
inputSchema
map
[
string
]
any
// 检查是否为 custom 类型工具 (MCP)
if
tool
.
Type
==
"custom"
&&
tool
.
Custom
!=
nil
{
// Custom 格式: 从 custom 字段获取 description 和 input_schema
if
tool
.
Type
==
"custom"
{
if
tool
.
Custom
==
nil
||
tool
.
Custom
.
InputSchema
==
nil
{
log
.
Printf
(
"[Warning] Skipping invalid custom tool '%s': missing custom spec or input_schema"
,
tool
.
Name
)
continue
}
description
=
tool
.
Custom
.
Description
inputSchema
=
tool
.
Custom
.
InputSchema
// 调试日志:记录 custom 工具的 schema
if
schemaJSON
,
err
:=
json
.
Marshal
(
inputSchema
);
err
==
nil
{
log
.
Printf
(
"[Debug] Tool[%d] '%s' (custom) original schema: %s"
,
i
,
tool
.
Name
,
string
(
schemaJSON
))
}
}
else
{
// 标准格式: 从顶层字段获取
description
=
tool
.
Description
...
...
@@ -409,7 +464,6 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
// 清理 JSON Schema
params
:=
cleanJSONSchema
(
inputSchema
)
// 为 nil schema 提供默认值
if
params
==
nil
{
params
=
map
[
string
]
any
{
...
...
@@ -418,6 +472,11 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
}
}
// 调试日志:记录清理后的 schema
if
paramsJSON
,
err
:=
json
.
Marshal
(
params
);
err
==
nil
{
log
.
Printf
(
"[Debug] Tool[%d] '%s' cleaned schema: %s"
,
i
,
tool
.
Name
,
string
(
paramsJSON
))
}
funcDecls
=
append
(
funcDecls
,
GeminiFunctionDecl
{
Name
:
tool
.
Name
,
Description
:
description
,
...
...
@@ -479,24 +538,54 @@ func cleanJSONSchema(schema map[string]any) map[string]any {
}
// excludedSchemaKeys 不支持的 schema 字段
// 基于 Claude API (Vertex AI) 的实际支持情况
// 支持: type, description, enum, properties, required, additionalProperties, items
// 不支持: minItems, maxItems, minLength, maxLength, pattern, minimum, maximum 等验证字段
var
excludedSchemaKeys
=
map
[
string
]
bool
{
// 元 schema 字段
"$schema"
:
true
,
"$id"
:
true
,
"$ref"
:
true
,
"additionalProperties"
:
true
,
// 字符串验证(Gemini 不支持)
"minLength"
:
true
,
"maxLength"
:
true
,
"
minItems"
:
true
,
"maxItems"
:
true
,
"uniqueItems"
:
true
,
"
pattern"
:
true
,
// 数字验证(Claude API 通过 Vertex AI 不支持这些字段)
"minimum"
:
true
,
"maximum"
:
true
,
"exclusiveMinimum"
:
true
,
"exclusiveMaximum"
:
true
,
"pattern"
:
true
,
"format"
:
true
,
"multipleOf"
:
true
,
// 数组验证(Claude API 通过 Vertex AI 不支持这些字段)
"uniqueItems"
:
true
,
"minItems"
:
true
,
"maxItems"
:
true
,
// 组合 schema(Gemini 不支持)
"oneOf"
:
true
,
"anyOf"
:
true
,
"allOf"
:
true
,
"not"
:
true
,
"if"
:
true
,
"then"
:
true
,
"else"
:
true
,
"$defs"
:
true
,
"definitions"
:
true
,
// 对象验证(仅保留 properties/required/additionalProperties)
"minProperties"
:
true
,
"maxProperties"
:
true
,
"patternProperties"
:
true
,
"propertyNames"
:
true
,
"dependencies"
:
true
,
"dependentSchemas"
:
true
,
"dependentRequired"
:
true
,
// 其他不支持的字段
"default"
:
true
,
"strict"
:
true
,
"const"
:
true
,
"examples"
:
true
,
"deprecated"
:
true
,
...
...
@@ -504,6 +593,9 @@ var excludedSchemaKeys = map[string]bool{
"writeOnly"
:
true
,
"contentMediaType"
:
true
,
"contentEncoding"
:
true
,
// Claude 特有字段
"strict"
:
true
,
}
// cleanSchemaValue 递归清理 schema 值
...
...
@@ -523,6 +615,31 @@ func cleanSchemaValue(value any) any {
continue
}
// 特殊处理 format 字段:只保留 Gemini 支持的 format 值
if
k
==
"format"
{
if
formatStr
,
ok
:=
val
.
(
string
);
ok
{
// Gemini 只支持 date-time, date, time
if
formatStr
==
"date-time"
||
formatStr
==
"date"
||
formatStr
==
"time"
{
result
[
k
]
=
val
}
// 其他 format 值直接跳过
}
continue
}
// 特殊处理 additionalProperties:Claude API 只支持布尔值,不支持 schema 对象
if
k
==
"additionalProperties"
{
if
boolVal
,
ok
:=
val
.
(
bool
);
ok
{
result
[
k
]
=
boolVal
log
.
Printf
(
"[Debug] additionalProperties is bool: %v"
,
boolVal
)
}
else
{
// 如果是 schema 对象,转换为 false(更安全的默认值)
result
[
k
]
=
false
log
.
Printf
(
"[Debug] additionalProperties is not bool (type: %T), converting to false"
,
val
)
}
continue
}
// 递归清理所有值
result
[
k
]
=
cleanSchemaValue
(
val
)
}
...
...
backend/internal/pkg/antigravity/request_transformer_test.go
0 → 100644
View file @
6cc7f997
package
antigravity
import
(
"encoding/json"
"testing"
)
// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理
func
TestBuildParts_ThinkingBlockWithoutSignature
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
content
string
allowDummyThought
bool
expectedParts
int
description
string
}{
{
name
:
"Claude model - skip thinking block without signature"
,
content
:
`[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
{"type": "text", "text": "World"}
]`
,
allowDummyThought
:
false
,
expectedParts
:
2
,
// 只有两个text block
description
:
"Claude模型应该跳过无signature的thinking block"
,
},
{
name
:
"Claude model - keep thinking block with signature"
,
content
:
`[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": "valid_sig"},
{"type": "text", "text": "World"}
]`
,
allowDummyThought
:
false
,
expectedParts
:
3
,
// 三个block都保留
description
:
"Claude模型应该保留有signature的thinking block"
,
},
{
name
:
"Gemini model - use dummy signature"
,
content
:
`[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
{"type": "text", "text": "World"}
]`
,
allowDummyThought
:
true
,
expectedParts
:
3
,
// 三个block都保留,thinking使用dummy signature
description
:
"Gemini模型应该为无signature的thinking block使用dummy signature"
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
toolIDToName
:=
make
(
map
[
string
]
string
)
parts
,
err
:=
buildParts
(
json
.
RawMessage
(
tt
.
content
),
toolIDToName
,
tt
.
allowDummyThought
)
if
err
!=
nil
{
t
.
Fatalf
(
"buildParts() error = %v"
,
err
)
}
if
len
(
parts
)
!=
tt
.
expectedParts
{
t
.
Errorf
(
"%s: got %d parts, want %d parts"
,
tt
.
description
,
len
(
parts
),
tt
.
expectedParts
)
}
})
}
}
// TestBuildTools_CustomTypeTools 测试custom类型工具转换
func
TestBuildTools_CustomTypeTools
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
tools
[]
ClaudeTool
expectedLen
int
description
string
}{
{
name
:
"Standard tool format"
,
tools
:
[]
ClaudeTool
{
{
Name
:
"get_weather"
,
Description
:
"Get weather information"
,
InputSchema
:
map
[
string
]
any
{
"type"
:
"object"
,
"properties"
:
map
[
string
]
any
{
"location"
:
map
[
string
]
any
{
"type"
:
"string"
},
},
},
},
},
expectedLen
:
1
,
description
:
"标准工具格式应该正常转换"
,
},
{
name
:
"Custom type tool (MCP format)"
,
tools
:
[]
ClaudeTool
{
{
Type
:
"custom"
,
Name
:
"mcp_tool"
,
Custom
:
&
CustomToolSpec
{
Description
:
"MCP tool description"
,
InputSchema
:
map
[
string
]
any
{
"type"
:
"object"
,
"properties"
:
map
[
string
]
any
{
"param"
:
map
[
string
]
any
{
"type"
:
"string"
},
},
},
},
},
},
expectedLen
:
1
,
description
:
"Custom类型工具应该从Custom字段读取description和input_schema"
,
},
{
name
:
"Mixed standard and custom tools"
,
tools
:
[]
ClaudeTool
{
{
Name
:
"standard_tool"
,
Description
:
"Standard tool"
,
InputSchema
:
map
[
string
]
any
{
"type"
:
"object"
},
},
{
Type
:
"custom"
,
Name
:
"custom_tool"
,
Custom
:
&
CustomToolSpec
{
Description
:
"Custom tool"
,
InputSchema
:
map
[
string
]
any
{
"type"
:
"object"
},
},
},
},
expectedLen
:
1
,
// 返回一个GeminiToolDeclaration,包含2个function declarations
description
:
"混合标准和custom工具应该都能正确转换"
,
},
{
name
:
"Invalid custom tool - nil Custom field"
,
tools
:
[]
ClaudeTool
{
{
Type
:
"custom"
,
Name
:
"invalid_custom"
,
// Custom 为 nil
},
},
expectedLen
:
0
,
// 应该被跳过
description
:
"Custom字段为nil的custom工具应该被跳过"
,
},
{
name
:
"Invalid custom tool - nil InputSchema"
,
tools
:
[]
ClaudeTool
{
{
Type
:
"custom"
,
Name
:
"invalid_custom"
,
Custom
:
&
CustomToolSpec
{
Description
:
"Invalid"
,
// InputSchema 为 nil
},
},
},
expectedLen
:
0
,
// 应该被跳过
description
:
"InputSchema为nil的custom工具应该被跳过"
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
result
:=
buildTools
(
tt
.
tools
)
if
len
(
result
)
!=
tt
.
expectedLen
{
t
.
Errorf
(
"%s: got %d tool declarations, want %d"
,
tt
.
description
,
len
(
result
),
tt
.
expectedLen
)
}
// 验证function declarations存在
if
len
(
result
)
>
0
&&
result
[
0
]
.
FunctionDeclarations
!=
nil
{
if
len
(
result
[
0
]
.
FunctionDeclarations
)
!=
len
(
tt
.
tools
)
{
t
.
Errorf
(
"%s: got %d function declarations, want %d"
,
tt
.
description
,
len
(
result
[
0
]
.
FunctionDeclarations
),
len
(
tt
.
tools
))
}
}
})
}
}
backend/internal/pkg/claude/constants.go
View file @
6cc7f997
...
...
@@ -16,6 +16,12 @@ const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleav
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta)
const
HaikuBetaHeader
=
BetaOAuth
+
","
+
BetaInterleavedThinking
// ApiKeyBetaHeader API-key 账号建议使用的 anthropic-beta header(不包含 oauth)
const
ApiKeyBetaHeader
=
BetaClaudeCode
+
","
+
BetaInterleavedThinking
+
","
+
BetaFineGrainedToolStreaming
// ApiKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header(不包含 oauth / claude-code)
const
ApiKeyHaikuBetaHeader
=
BetaInterleavedThinking
// Claude Code 客户端默认请求头
var
DefaultHeaders
=
map
[
string
]
string
{
"User-Agent"
:
"claude-cli/2.0.62 (external, cli)"
,
...
...
backend/internal/pkg/geminicli/drive_client.go
0 → 100644
View file @
6cc7f997
package
geminicli
import
(
"context"
"encoding/json"
"fmt"
"math/rand"
"net/http"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
)
// DriveStorageInfo represents Google Drive storage quota information
type
DriveStorageInfo
struct
{
Limit
int64
`json:"limit"`
// Storage limit in bytes
Usage
int64
`json:"usage"`
// Current usage in bytes
}
// DriveClient interface for Google Drive API operations
type
DriveClient
interface
{
GetStorageQuota
(
ctx
context
.
Context
,
accessToken
,
proxyURL
string
)
(
*
DriveStorageInfo
,
error
)
}
type
driveClient
struct
{}
// NewDriveClient creates a new Drive API client
func
NewDriveClient
()
DriveClient
{
return
&
driveClient
{}
}
// GetStorageQuota fetches storage quota from Google Drive API
func
(
c
*
driveClient
)
GetStorageQuota
(
ctx
context
.
Context
,
accessToken
,
proxyURL
string
)
(
*
DriveStorageInfo
,
error
)
{
const
driveAPIURL
=
"https://www.googleapis.com/drive/v3/about?fields=storageQuota"
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
"GET"
,
driveAPIURL
,
nil
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to create request: %w"
,
err
)
}
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
accessToken
)
// Get HTTP client with proxy support
client
,
err
:=
httpclient
.
GetClient
(
httpclient
.
Options
{
ProxyURL
:
proxyURL
,
Timeout
:
10
*
time
.
Second
,
})
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to create HTTP client: %w"
,
err
)
}
sleepWithContext
:=
func
(
d
time
.
Duration
)
error
{
timer
:=
time
.
NewTimer
(
d
)
defer
timer
.
Stop
()
select
{
case
<-
ctx
.
Done
()
:
return
ctx
.
Err
()
case
<-
timer
.
C
:
return
nil
}
}
// Retry logic with exponential backoff (+ jitter) for rate limits and transient failures
var
resp
*
http
.
Response
maxRetries
:=
3
rng
:=
rand
.
New
(
rand
.
NewSource
(
time
.
Now
()
.
UnixNano
()))
for
attempt
:=
0
;
attempt
<
maxRetries
;
attempt
++
{
if
ctx
.
Err
()
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"request cancelled: %w"
,
ctx
.
Err
())
}
resp
,
err
=
client
.
Do
(
req
)
if
err
!=
nil
{
// Network error retry
if
attempt
<
maxRetries
-
1
{
backoff
:=
time
.
Duration
(
1
<<
uint
(
attempt
))
*
time
.
Second
jitter
:=
time
.
Duration
(
rng
.
Intn
(
1000
))
*
time
.
Millisecond
if
err
:=
sleepWithContext
(
backoff
+
jitter
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"request cancelled: %w"
,
err
)
}
continue
}
return
nil
,
fmt
.
Errorf
(
"network error after %d attempts: %w"
,
maxRetries
,
err
)
}
// Success
if
resp
.
StatusCode
==
http
.
StatusOK
{
break
}
// Retry 429, 500, 502, 503 with exponential backoff + jitter
if
(
resp
.
StatusCode
==
http
.
StatusTooManyRequests
||
resp
.
StatusCode
==
http
.
StatusInternalServerError
||
resp
.
StatusCode
==
http
.
StatusBadGateway
||
resp
.
StatusCode
==
http
.
StatusServiceUnavailable
)
&&
attempt
<
maxRetries
-
1
{
if
err
:=
func
()
error
{
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
backoff
:=
time
.
Duration
(
1
<<
uint
(
attempt
))
*
time
.
Second
jitter
:=
time
.
Duration
(
rng
.
Intn
(
1000
))
*
time
.
Millisecond
return
sleepWithContext
(
backoff
+
jitter
)
}();
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"request cancelled: %w"
,
err
)
}
continue
}
break
}
if
resp
==
nil
{
return
nil
,
fmt
.
Errorf
(
"request failed: no response received"
)
}
if
resp
.
StatusCode
!=
http
.
StatusOK
{
_
=
resp
.
Body
.
Close
()
statusText
:=
http
.
StatusText
(
resp
.
StatusCode
)
if
statusText
==
""
{
statusText
=
resp
.
Status
}
fmt
.
Printf
(
"[DriveClient] Drive API error: status=%d, msg=%s
\n
"
,
resp
.
StatusCode
,
statusText
)
// 只返回通用错误
return
nil
,
fmt
.
Errorf
(
"drive API error: status %d"
,
resp
.
StatusCode
)
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
// Parse response
var
result
struct
{
StorageQuota
struct
{
Limit
string
`json:"limit"`
// Can be string or number
Usage
string
`json:"usage"`
}
`json:"storageQuota"`
}
if
err
:=
json
.
NewDecoder
(
resp
.
Body
)
.
Decode
(
&
result
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to decode response: %w"
,
err
)
}
// Parse limit and usage (handle both string and number formats)
var
limit
,
usage
int64
if
result
.
StorageQuota
.
Limit
!=
""
{
if
val
,
err
:=
strconv
.
ParseInt
(
result
.
StorageQuota
.
Limit
,
10
,
64
);
err
==
nil
{
limit
=
val
}
}
if
result
.
StorageQuota
.
Usage
!=
""
{
if
val
,
err
:=
strconv
.
ParseInt
(
result
.
StorageQuota
.
Usage
,
10
,
64
);
err
==
nil
{
usage
=
val
}
}
return
&
DriveStorageInfo
{
Limit
:
limit
,
Usage
:
usage
,
},
nil
}
backend/internal/pkg/geminicli/drive_client_test.go
0 → 100644
View file @
6cc7f997
package
geminicli
import
"testing"
func
TestDriveStorageInfo
(
t
*
testing
.
T
)
{
// 测试 DriveStorageInfo 结构体
info
:=
&
DriveStorageInfo
{
Limit
:
100
*
1024
*
1024
*
1024
,
// 100GB
Usage
:
50
*
1024
*
1024
*
1024
,
// 50GB
}
if
info
.
Limit
!=
100
*
1024
*
1024
*
1024
{
t
.
Errorf
(
"Expected limit 100GB, got %d"
,
info
.
Limit
)
}
if
info
.
Usage
!=
50
*
1024
*
1024
*
1024
{
t
.
Errorf
(
"Expected usage 50GB, got %d"
,
info
.
Usage
)
}
}
backend/internal/repository/account_repo.go
View file @
6cc7f997
...
...
@@ -124,6 +124,90 @@ func (r *accountRepository) GetByID(ctx context.Context, id int64) (*service.Acc
return
&
accounts
[
0
],
nil
}
func
(
r
*
accountRepository
)
GetByIDs
(
ctx
context
.
Context
,
ids
[]
int64
)
([]
*
service
.
Account
,
error
)
{
if
len
(
ids
)
==
0
{
return
[]
*
service
.
Account
{},
nil
}
// De-duplicate while preserving order of first occurrence.
uniqueIDs
:=
make
([]
int64
,
0
,
len
(
ids
))
seen
:=
make
(
map
[
int64
]
struct
{},
len
(
ids
))
for
_
,
id
:=
range
ids
{
if
id
<=
0
{
continue
}
if
_
,
ok
:=
seen
[
id
];
ok
{
continue
}
seen
[
id
]
=
struct
{}{}
uniqueIDs
=
append
(
uniqueIDs
,
id
)
}
if
len
(
uniqueIDs
)
==
0
{
return
[]
*
service
.
Account
{},
nil
}
entAccounts
,
err
:=
r
.
client
.
Account
.
Query
()
.
Where
(
dbaccount
.
IDIn
(
uniqueIDs
...
))
.
WithProxy
()
.
All
(
ctx
)
if
err
!=
nil
{
return
nil
,
err
}
if
len
(
entAccounts
)
==
0
{
return
[]
*
service
.
Account
{},
nil
}
accountIDs
:=
make
([]
int64
,
0
,
len
(
entAccounts
))
entByID
:=
make
(
map
[
int64
]
*
dbent
.
Account
,
len
(
entAccounts
))
for
_
,
acc
:=
range
entAccounts
{
entByID
[
acc
.
ID
]
=
acc
accountIDs
=
append
(
accountIDs
,
acc
.
ID
)
}
groupsByAccount
,
groupIDsByAccount
,
accountGroupsByAccount
,
err
:=
r
.
loadAccountGroups
(
ctx
,
accountIDs
)
if
err
!=
nil
{
return
nil
,
err
}
outByID
:=
make
(
map
[
int64
]
*
service
.
Account
,
len
(
entAccounts
))
for
_
,
entAcc
:=
range
entAccounts
{
out
:=
accountEntityToService
(
entAcc
)
if
out
==
nil
{
continue
}
// Prefer the preloaded proxy edge when available.
if
entAcc
.
Edges
.
Proxy
!=
nil
{
out
.
Proxy
=
proxyEntityToService
(
entAcc
.
Edges
.
Proxy
)
}
if
groups
,
ok
:=
groupsByAccount
[
entAcc
.
ID
];
ok
{
out
.
Groups
=
groups
}
if
groupIDs
,
ok
:=
groupIDsByAccount
[
entAcc
.
ID
];
ok
{
out
.
GroupIDs
=
groupIDs
}
if
ags
,
ok
:=
accountGroupsByAccount
[
entAcc
.
ID
];
ok
{
out
.
AccountGroups
=
ags
}
outByID
[
entAcc
.
ID
]
=
out
}
// Preserve input order (first occurrence), and ignore missing IDs.
out
:=
make
([]
*
service
.
Account
,
0
,
len
(
uniqueIDs
))
for
_
,
id
:=
range
uniqueIDs
{
if
_
,
ok
:=
entByID
[
id
];
!
ok
{
continue
}
if
acc
,
ok
:=
outByID
[
id
];
ok
&&
acc
!=
nil
{
out
=
append
(
out
,
acc
)
}
}
return
out
,
nil
}
// ExistsByID 检查指定 ID 的账号是否存在。
// 相比 GetByID,此方法性能更优,因为:
// - 使用 Exist() 方法生成 SELECT EXISTS 查询,只返回布尔值
...
...
backend/internal/repository/api_key_repo.go
View file @
6cc7f997
...
...
@@ -294,7 +294,6 @@ func userEntityToService(u *dbent.User) *service.User {
ID
:
u
.
ID
,
Email
:
u
.
Email
,
Username
:
u
.
Username
,
Wechat
:
u
.
Wechat
,
Notes
:
u
.
Notes
,
PasswordHash
:
u
.
PasswordHash
,
Role
:
u
.
Role
,
...
...
backend/internal/repository/concurrency_cache.go
View file @
6cc7f997
...
...
@@ -2,7 +2,9 @@ package repository
import
(
"context"
"errors"
"fmt"
"strconv"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
...
...
@@ -27,6 +29,8 @@ const (
userSlotKeyPrefix
=
"concurrency:user:"
// 等待队列计数器格式: concurrency:wait:{userID}
waitQueueKeyPrefix
=
"concurrency:wait:"
// 账号级等待队列计数器格式: wait:account:{accountID}
accountWaitKeyPrefix
=
"wait:account:"
// 默认槽位过期时间(分钟),可通过配置覆盖
defaultSlotTTLMinutes
=
15
...
...
@@ -115,6 +119,29 @@ var (
return 1
`
)
// incrementAccountWaitScript - account-level wait queue count
incrementAccountWaitScript
=
redis
.
NewScript
(
`
local current = redis.call('GET', KEYS[1])
if current == false then
current = 0
else
current = tonumber(current)
end
if current >= tonumber(ARGV[1]) then
return 0
end
local newVal = redis.call('INCR', KEYS[1])
-- Only set TTL on first creation to avoid refreshing zombie data
if newVal == 1 then
redis.call('EXPIRE', KEYS[1], ARGV[2])
end
return 1
`
)
// decrementWaitScript - same as before
decrementWaitScript
=
redis
.
NewScript
(
`
local current = redis.call('GET', KEYS[1])
...
...
@@ -123,22 +150,78 @@ var (
end
return 1
`
)
// getAccountsLoadBatchScript - batch load query (read-only)
// ARGV[1] = slot TTL (seconds, retained for compatibility)
// ARGV[2..n] = accountID1, maxConcurrency1, accountID2, maxConcurrency2, ...
getAccountsLoadBatchScript
=
redis
.
NewScript
(
`
local result = {}
local i = 2
while i <= #ARGV do
local accountID = ARGV[i]
local maxConcurrency = tonumber(ARGV[i + 1])
local slotKey = 'concurrency:account:' .. accountID
local currentConcurrency = redis.call('ZCARD', slotKey)
local waitKey = 'wait:account:' .. accountID
local waitingCount = redis.call('GET', waitKey)
if waitingCount == false then
waitingCount = 0
else
waitingCount = tonumber(waitingCount)
end
local loadRate = 0
if maxConcurrency > 0 then
loadRate = math.floor((currentConcurrency + waitingCount) * 100 / maxConcurrency)
end
table.insert(result, accountID)
table.insert(result, currentConcurrency)
table.insert(result, waitingCount)
table.insert(result, loadRate)
i = i + 2
end
return result
`
)
// cleanupExpiredSlotsScript - remove expired slots
// KEYS[1] = concurrency:account:{accountID}
// ARGV[1] = TTL (seconds)
cleanupExpiredSlotsScript
=
redis
.
NewScript
(
`
local key = KEYS[1]
local ttl = tonumber(ARGV[1])
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - ttl
return redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
`
)
)
type
concurrencyCache
struct
{
rdb
*
redis
.
Client
slotTTLSeconds
int
// 槽位过期时间(秒)
waitQueueTTLSeconds
int
// 等待队列过期时间(秒)
}
// NewConcurrencyCache 创建并发控制缓存
// slotTTLMinutes: 槽位过期时间(分钟),0 或负数使用默认值 15 分钟
func
NewConcurrencyCache
(
rdb
*
redis
.
Client
,
slotTTLMinutes
int
)
service
.
ConcurrencyCache
{
// waitQueueTTLSeconds: 等待队列过期时间(秒),0 或负数使用 slot TTL
func
NewConcurrencyCache
(
rdb
*
redis
.
Client
,
slotTTLMinutes
int
,
waitQueueTTLSeconds
int
)
service
.
ConcurrencyCache
{
if
slotTTLMinutes
<=
0
{
slotTTLMinutes
=
defaultSlotTTLMinutes
}
if
waitQueueTTLSeconds
<=
0
{
waitQueueTTLSeconds
=
slotTTLMinutes
*
60
}
return
&
concurrencyCache
{
rdb
:
rdb
,
slotTTLSeconds
:
slotTTLMinutes
*
60
,
waitQueueTTLSeconds
:
waitQueueTTLSeconds
,
}
}
...
...
@@ -155,6 +238,10 @@ func waitQueueKey(userID int64) string {
return
fmt
.
Sprintf
(
"%s%d"
,
waitQueueKeyPrefix
,
userID
)
}
func
accountWaitKey
(
accountID
int64
)
string
{
return
fmt
.
Sprintf
(
"%s%d"
,
accountWaitKeyPrefix
,
accountID
)
}
// Account slot operations
func
(
c
*
concurrencyCache
)
AcquireAccountSlot
(
ctx
context
.
Context
,
accountID
int64
,
maxConcurrency
int
,
requestID
string
)
(
bool
,
error
)
{
...
...
@@ -225,3 +312,75 @@ func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64)
_
,
err
:=
decrementWaitScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
key
})
.
Result
()
return
err
}
// Account wait queue operations
func
(
c
*
concurrencyCache
)
IncrementAccountWaitCount
(
ctx
context
.
Context
,
accountID
int64
,
maxWait
int
)
(
bool
,
error
)
{
key
:=
accountWaitKey
(
accountID
)
result
,
err
:=
incrementAccountWaitScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
key
},
maxWait
,
c
.
waitQueueTTLSeconds
)
.
Int
()
if
err
!=
nil
{
return
false
,
err
}
return
result
==
1
,
nil
}
func
(
c
*
concurrencyCache
)
DecrementAccountWaitCount
(
ctx
context
.
Context
,
accountID
int64
)
error
{
key
:=
accountWaitKey
(
accountID
)
_
,
err
:=
decrementWaitScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
key
})
.
Result
()
return
err
}
func
(
c
*
concurrencyCache
)
GetAccountWaitingCount
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
{
key
:=
accountWaitKey
(
accountID
)
val
,
err
:=
c
.
rdb
.
Get
(
ctx
,
key
)
.
Int
()
if
err
!=
nil
&&
!
errors
.
Is
(
err
,
redis
.
Nil
)
{
return
0
,
err
}
if
errors
.
Is
(
err
,
redis
.
Nil
)
{
return
0
,
nil
}
return
val
,
nil
}
func
(
c
*
concurrencyCache
)
GetAccountsLoadBatch
(
ctx
context
.
Context
,
accounts
[]
service
.
AccountWithConcurrency
)
(
map
[
int64
]
*
service
.
AccountLoadInfo
,
error
)
{
if
len
(
accounts
)
==
0
{
return
map
[
int64
]
*
service
.
AccountLoadInfo
{},
nil
}
args
:=
[]
any
{
c
.
slotTTLSeconds
}
for
_
,
acc
:=
range
accounts
{
args
=
append
(
args
,
acc
.
ID
,
acc
.
MaxConcurrency
)
}
result
,
err
:=
getAccountsLoadBatchScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{},
args
...
)
.
Slice
()
if
err
!=
nil
{
return
nil
,
err
}
loadMap
:=
make
(
map
[
int64
]
*
service
.
AccountLoadInfo
)
for
i
:=
0
;
i
<
len
(
result
);
i
+=
4
{
if
i
+
3
>=
len
(
result
)
{
break
}
accountID
,
_
:=
strconv
.
ParseInt
(
fmt
.
Sprintf
(
"%v"
,
result
[
i
]),
10
,
64
)
currentConcurrency
,
_
:=
strconv
.
Atoi
(
fmt
.
Sprintf
(
"%v"
,
result
[
i
+
1
]))
waitingCount
,
_
:=
strconv
.
Atoi
(
fmt
.
Sprintf
(
"%v"
,
result
[
i
+
2
]))
loadRate
,
_
:=
strconv
.
Atoi
(
fmt
.
Sprintf
(
"%v"
,
result
[
i
+
3
]))
loadMap
[
accountID
]
=
&
service
.
AccountLoadInfo
{
AccountID
:
accountID
,
CurrentConcurrency
:
currentConcurrency
,
WaitingCount
:
waitingCount
,
LoadRate
:
loadRate
,
}
}
return
loadMap
,
nil
}
func
(
c
*
concurrencyCache
)
CleanupExpiredAccountSlots
(
ctx
context
.
Context
,
accountID
int64
)
error
{
key
:=
accountSlotKey
(
accountID
)
_
,
err
:=
cleanupExpiredSlotsScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
key
},
c
.
slotTTLSeconds
)
.
Result
()
return
err
}
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment