Commit 6cc7f997 authored by song's avatar song
Browse files

merge: 合并 upstream/main

parents 95d09f60 106e59b7
......@@ -27,7 +27,6 @@ type CreateUserRequest struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=6"`
Username string `json:"username"`
Wechat string `json:"wechat"`
Notes string `json:"notes"`
Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"`
......@@ -40,7 +39,6 @@ type UpdateUserRequest struct {
Email string `json:"email" binding:"omitempty,email"`
Password string `json:"password" binding:"omitempty,min=6"`
Username *string `json:"username"`
Wechat *string `json:"wechat"`
Notes *string `json:"notes"`
Balance *float64 `json:"balance"`
Concurrency *int `json:"concurrency"`
......@@ -57,13 +55,22 @@ type UpdateBalanceRequest struct {
// List handles listing all users with pagination
// GET /api/v1/admin/users
// Query params:
// - status: filter by user status
// - role: filter by user role
// - search: search in email, username
// - attr[{id}]: filter by custom attribute value, e.g. attr[1]=company
func (h *UserHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
status := c.Query("status")
role := c.Query("role")
search := c.Query("search")
users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, status, role, search)
filters := service.UserListFilters{
Status: c.Query("status"),
Role: c.Query("role"),
Search: c.Query("search"),
Attributes: parseAttributeFilters(c),
}
users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, filters)
if err != nil {
response.ErrorFrom(c, err)
return
......@@ -76,6 +83,29 @@ func (h *UserHandler) List(c *gin.Context) {
response.Paginated(c, out, total, page, pageSize)
}
// parseAttributeFilters extracts attribute filters from query params
// Format: attr[{attributeID}]=value, e.g. attr[1]=company&attr[2]=developer
func parseAttributeFilters(c *gin.Context) map[int64]string {
result := make(map[int64]string)
// Get all query params and look for attr[*] pattern
for key, values := range c.Request.URL.Query() {
if len(values) == 0 || values[0] == "" {
continue
}
// Check if key matches pattern attr[{id}]
if len(key) > 5 && key[:5] == "attr[" && key[len(key)-1] == ']' {
idStr := key[5 : len(key)-1]
id, err := strconv.ParseInt(idStr, 10, 64)
if err == nil && id > 0 {
result[id] = values[0]
}
}
}
return result
}
// GetByID handles getting a user by ID
// GET /api/v1/admin/users/:id
func (h *UserHandler) GetByID(c *gin.Context) {
......@@ -107,7 +137,6 @@ func (h *UserHandler) Create(c *gin.Context) {
Email: req.Email,
Password: req.Password,
Username: req.Username,
Wechat: req.Wechat,
Notes: req.Notes,
Balance: req.Balance,
Concurrency: req.Concurrency,
......@@ -141,7 +170,6 @@ func (h *UserHandler) Update(c *gin.Context) {
Email: req.Email,
Password: req.Password,
Username: req.Username,
Wechat: req.Wechat,
Notes: req.Notes,
Balance: req.Balance,
Concurrency: req.Concurrency,
......
......@@ -10,7 +10,6 @@ func UserFromServiceShallow(u *service.User) *User {
ID: u.ID,
Email: u.Email,
Username: u.Username,
Wechat: u.Wechat,
Notes: u.Notes,
Role: u.Role,
Balance: u.Balance,
......
......@@ -6,7 +6,6 @@ type User struct {
ID int64 `json:"id"`
Email string `json:"email"`
Username string `json:"username"`
Wechat string `json:"wechat"`
Notes string `json:"notes"`
Role string `json:"role"`
Balance float64 `json:"balance"`
......
......@@ -142,6 +142,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} else if apiKey.Group != nil {
platform = apiKey.Group.Platform
}
sessionKey := sessionHash
if platform == service.PlatformGemini && sessionHash != "" {
sessionKey = "gemini:" + sessionHash
}
if platform == service.PlatformGemini {
const maxAccountSwitches = 3
......@@ -150,7 +154,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
lastFailoverStatus := 0
for {
account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs)
if err != nil {
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
......@@ -159,9 +163,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
account := selection.Account
// 检查预热请求拦截(在账号选择后、转发前检查)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
if reqStream {
sendMockWarmupStream(c, reqModel)
} else {
......@@ -171,11 +179,46 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 3. 获取账号并发槽位
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
accountReleaseFunc := selection.ReleaseFunc
var accountWaitRelease func()
if !selection.Acquired {
if selection.WaitPlan == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
} else {
// Only set release function if increment succeeded
accountWaitRelease = func() {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
if accountWaitRelease != nil {
accountWaitRelease()
}
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 转发请求 - 根据账号平台分流
......@@ -188,6 +231,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if accountWaitRelease != nil {
accountWaitRelease()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
......@@ -232,7 +278,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
for {
// 选择支持该模型的账号
account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs)
if err != nil {
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
......@@ -241,9 +287,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
account := selection.Account
// 检查预热请求拦截(在账号选择后、转发前检查)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
if reqStream {
sendMockWarmupStream(c, reqModel)
} else {
......@@ -253,11 +303,46 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 3. 获取账号并发槽位
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
accountReleaseFunc := selection.ReleaseFunc
var accountWaitRelease func()
if !selection.Acquired {
if selection.WaitPlan == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
} else {
// Only set release function if increment succeeded
accountWaitRelease = func() {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
if accountWaitRelease != nil {
accountWaitRelease()
}
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 转发请求 - 根据账号平台分流
......@@ -270,6 +355,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if accountWaitRelease != nil {
accountWaitRelease()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
......@@ -309,7 +397,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// Models handles listing available models
// GET /v1/models
// Returns different model lists based on the API key's group platform or forced platform
// Returns models based on account configurations (model_mapping whitelist)
// Falls back to default models if no whitelist is configured
func (h *GatewayHandler) Models(c *gin.Context) {
apiKey, _ := middleware2.GetApiKeyFromContext(c)
......@@ -324,8 +413,37 @@ func (h *GatewayHandler) Models(c *gin.Context) {
}
}
// Return OpenAI models for OpenAI platform groups
if apiKey != nil && apiKey.Group != nil && apiKey.Group.Platform == "openai" {
var groupID *int64
var platform string
if apiKey != nil && apiKey.Group != nil {
groupID = &apiKey.Group.ID
platform = apiKey.Group.Platform
}
// Get available models from account configurations (without platform filter)
availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, "")
if len(availableModels) > 0 {
// Build model list from whitelist
models := make([]claude.Model, 0, len(availableModels))
for _, modelID := range availableModels {
models = append(models, claude.Model{
ID: modelID,
Type: "model",
DisplayName: modelID,
CreatedAt: "2024-01-01T00:00:00Z",
})
}
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": models,
})
return
}
// Fallback to default models
if platform == "openai" {
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": openai.DefaultModels,
......@@ -333,7 +451,6 @@ func (h *GatewayHandler) Models(c *gin.Context) {
return
}
// Default: Claude models
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": claude.DefaultModels,
......
......@@ -83,6 +83,16 @@ func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64
h.concurrencyService.DecrementWaitCount(ctx, userID)
}
// IncrementAccountWaitCount increments the wait count for an account
func (h *ConcurrencyHelper) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
return h.concurrencyService.IncrementAccountWaitCount(ctx, accountID, maxWait)
}
// DecrementAccountWaitCount decrements the wait count for an account
func (h *ConcurrencyHelper) DecrementAccountWaitCount(ctx context.Context, accountID int64) {
h.concurrencyService.DecrementAccountWaitCount(ctx, accountID)
}
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
// For streaming requests, sends ping events during the wait.
// streamStarted is updated if streaming response has begun.
......@@ -126,7 +136,12 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.
// streamStarted pointer is updated when streaming begins (for proper error handling by caller).
func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
ctx, cancel := context.WithTimeout(c.Request.Context(), maxConcurrencyWait)
return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted)
}
// waitForSlotWithPingTimeout waits for a concurrency slot with a custom timeout.
func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) {
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel()
// Determine if ping is needed (streaming + ping format defined)
......@@ -200,6 +215,11 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string,
}
}
// AcquireAccountSlotWithWaitTimeout acquires an account slot with a custom timeout (keeps SSE ping).
func (h *ConcurrencyHelper) AcquireAccountSlotWithWaitTimeout(c *gin.Context, accountID int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) {
return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted)
}
// nextBackoff 计算下一次退避时间
// 性能优化:使用指数退避 + 随机抖动,避免惊群效应
// current: 当前退避时间
......
......@@ -198,13 +198,17 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 3) select account (sticky session based on request body)
parsedReq, _ := service.ParseGatewayRequest(body)
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
sessionKey := sessionHash
if sessionHash != "" {
sessionKey = "gemini:" + sessionHash
}
const maxAccountSwitches = 3
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
for {
account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, modelName, failedAccountIDs)
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs)
if err != nil {
if len(failedAccountIDs) == 0 {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
......@@ -213,12 +217,48 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
handleGeminiFailoverExhausted(c, lastFailoverStatus)
return
}
account := selection.Account
// 4) account concurrency slot
accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, stream, &streamStarted)
if err != nil {
googleError(c, http.StatusTooManyRequests, err.Error())
return
accountReleaseFunc := selection.ReleaseFunc
var accountWaitRelease func()
if !selection.Acquired {
if selection.WaitPlan == nil {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts")
return
}
canWait, err := geminiConcurrency.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
return
} else {
// Only set release function if increment succeeded
accountWaitRelease = func() {
geminiConcurrency.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}
accountReleaseFunc, err = geminiConcurrency.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
stream,
&streamStarted,
)
if err != nil {
if accountWaitRelease != nil {
accountWaitRelease()
}
googleError(c, http.StatusTooManyRequests, err.Error())
return
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 5) forward (根据平台分流)
......@@ -231,6 +271,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if accountWaitRelease != nil {
accountWaitRelease()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
......
......@@ -20,6 +20,7 @@ type AdminHandlers struct {
System *admin.SystemHandler
Subscription *admin.SubscriptionHandler
Usage *admin.UsageHandler
UserAttribute *admin.UserAttributeHandler
}
// Handlers contains all HTTP handlers
......
......@@ -146,7 +146,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
for {
// Select account supporting the requested model
log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel)
account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
if err != nil {
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
if len(failedAccountIDs) == 0 {
......@@ -156,14 +156,50 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
account := selection.Account
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
// 3. Acquire account concurrency slot
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
accountReleaseFunc := selection.ReleaseFunc
var accountWaitRelease func()
if !selection.Acquired {
if selection.WaitPlan == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
} else {
// Only set release function if increment succeeded
accountWaitRelease = func() {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
if accountWaitRelease != nil {
accountWaitRelease()
}
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionHash, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
}
// Forward request
......@@ -171,6 +207,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if accountWaitRelease != nil {
accountWaitRelease()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
......
......@@ -30,7 +30,6 @@ type ChangePasswordRequest struct {
// UpdateProfileRequest represents the update profile request payload
type UpdateProfileRequest struct {
Username *string `json:"username"`
Wechat *string `json:"wechat"`
}
// GetProfile handles getting user profile
......@@ -99,7 +98,6 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
svcReq := service.UpdateProfileRequest{
Username: req.Username,
Wechat: req.Wechat,
}
updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), subject.UserID, svcReq)
if err != nil {
......
......@@ -23,6 +23,7 @@ func ProvideAdminHandlers(
systemHandler *admin.SystemHandler,
subscriptionHandler *admin.SubscriptionHandler,
usageHandler *admin.UsageHandler,
userAttributeHandler *admin.UserAttributeHandler,
) *AdminHandlers {
return &AdminHandlers{
Dashboard: dashboardHandler,
......@@ -39,6 +40,7 @@ func ProvideAdminHandlers(
System: systemHandler,
Subscription: subscriptionHandler,
Usage: usageHandler,
UserAttribute: userAttributeHandler,
}
}
......@@ -107,6 +109,7 @@ var ProviderSet = wire.NewSet(
ProvideSystemHandler,
admin.NewSubscriptionHandler,
admin.NewUsageHandler,
admin.NewUserAttributeHandler,
// AdminHandlers and Handlers constructors
ProvideAdminHandlers,
......
......@@ -57,6 +57,7 @@ var geminiModels = []string{
"gemini-2.5-flash-lite",
"gemini-3-flash",
"gemini-3-pro-low",
"gemini-3-pro-high",
}
func TestMain(m *testing.M) {
......@@ -641,6 +642,37 @@ func testClaudeThinkingWithToolHistory(t *testing.T, model string) {
t.Logf("✅ thinking 模式工具调用测试通过, id=%v", result["id"])
}
// TestClaudeMessagesWithGeminiModel 测试在 Claude 端点使用 Gemini 模型
// 验证:通过 /v1/messages 端点传入 gemini 模型名的场景(含前缀映射)
// 仅在 Antigravity 模式下运行(ENDPOINT_PREFIX="/antigravity")
func TestClaudeMessagesWithGeminiModel(t *testing.T) {
if endpointPrefix != "/antigravity" {
t.Skip("仅在 Antigravity 模式下运行")
}
// 测试通过 Claude 端点调用 Gemini 模型
geminiViaClaude := []string{
"gemini-3-flash", // 直接支持
"gemini-3-pro-low", // 直接支持
"gemini-3-pro-high", // 直接支持
"gemini-3-pro", // 前缀映射 -> gemini-3-pro-high
"gemini-3-pro-preview", // 前缀映射 -> gemini-3-pro-high
}
for i, model := range geminiViaClaude {
if i > 0 {
time.Sleep(testInterval)
}
t.Run(model+"_通过Claude端点", func(t *testing.T) {
testClaudeMessage(t, model, false)
})
time.Sleep(testInterval)
t.Run(model+"_通过Claude端点_流式", func(t *testing.T) {
testClaudeMessage(t, model, true)
})
}
}
// TestClaudeMessagesWithNoSignature 测试历史 thinking block 不带 signature 的场景
// 验证:Gemini 模型接受没有 signature 的 thinking block
func TestClaudeMessagesWithNoSignature(t *testing.T) {
......@@ -738,3 +770,30 @@ func testClaudeWithNoSignature(t *testing.T, model string) {
}
t.Logf("✅ 无 signature thinking 处理测试通过, id=%v", result["id"])
}
// TestGeminiEndpointWithClaudeModel 测试通过 Gemini 端点调用 Claude 模型
// 仅在 Antigravity 模式下运行(ENDPOINT_PREFIX="/antigravity")
func TestGeminiEndpointWithClaudeModel(t *testing.T) {
if endpointPrefix != "/antigravity" {
t.Skip("仅在 Antigravity 模式下运行")
}
// 测试通过 Gemini 端点调用 Claude 模型
claudeViaGemini := []string{
"claude-sonnet-4-5",
"claude-opus-4-5-thinking",
}
for i, model := range claudeViaGemini {
if i > 0 {
time.Sleep(testInterval)
}
t.Run(model+"_通过Gemini端点", func(t *testing.T) {
testGeminiGenerate(t, model, false)
})
time.Sleep(testInterval)
t.Run(model+"_通过Gemini端点_流式", func(t *testing.T) {
testGeminiGenerate(t, model, true)
})
}
}
......@@ -54,6 +54,9 @@ type CustomToolSpec struct {
InputSchema map[string]any `json:"input_schema"`
}
// ClaudeCustomToolSpec 兼容旧命名(MCP custom 工具规格)
type ClaudeCustomToolSpec = CustomToolSpec
// SystemBlock system prompt 数组形式的元素
type SystemBlock struct {
Type string `json:"type"`
......
......@@ -14,13 +14,16 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st
// 用于存储 tool_use id -> name 映射
toolIDToName := make(map[string]string)
// 检测是否启用 thinking
isThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
// 只有 Gemini 模型支持 dummy thought workaround
// Claude 模型通过 Vertex/Google API 需要有效的 thought signatures
allowDummyThought := strings.HasPrefix(mappedModel, "gemini-")
// 检测是否启用 thinking
requestedThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
// 为避免 Claude 模型的 thought signature/消息块约束导致 400(上游要求 thinking 块开头等),
// 非 Gemini 模型默认不启用 thinking(除非未来支持完整签名链路)。
isThinkingEnabled := requestedThinkingEnabled && allowDummyThought
// 1. 构建 contents
contents, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought)
if err != nil {
......@@ -31,7 +34,15 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st
systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model)
// 3. 构建 generationConfig
generationConfig := buildGenerationConfig(claudeReq)
reqForGen := claudeReq
if requestedThinkingEnabled && !allowDummyThought {
log.Printf("[Warning] Disabling thinking for non-Gemini model in antigravity transform: model=%s", mappedModel)
// shallow copy to avoid mutating caller's request
clone := *claudeReq
clone.Thinking = nil
reqForGen = &clone
}
generationConfig := buildGenerationConfig(reqForGen)
// 4. 构建 tools
tools := buildTools(claudeReq.Tools)
......@@ -148,8 +159,9 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
if !hasThoughtPart && len(parts) > 0 {
// 在开头添加 dummy thinking block
parts = append([]GeminiPart{{
Text: "Thinking...",
Thought: true,
Text: "Thinking...",
Thought: true,
ThoughtSignature: dummyThoughtSignature,
}}, parts...)
}
}
......@@ -171,6 +183,34 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
// 参考: https://ai.google.dev/gemini-api/docs/thought-signatures
const dummyThoughtSignature = "skip_thought_signature_validator"
// isValidThoughtSignature 验证 thought signature 是否有效
// Claude API 要求 signature 必须是 base64 编码的字符串,长度至少 32 字节
func isValidThoughtSignature(signature string) bool {
// 空字符串无效
if signature == "" {
return false
}
// signature 应该是 base64 编码,长度至少 40 个字符(约 30 字节)
// 参考 Claude API 文档和实际观察到的有效 signature
if len(signature) < 40 {
log.Printf("[Debug] Signature too short: len=%d", len(signature))
return false
}
// 检查是否是有效的 base64 字符
// base64 字符集: A-Z, a-z, 0-9, +, /, =
for i, c := range signature {
if (c < 'A' || c > 'Z') && (c < 'a' || c > 'z') &&
(c < '0' || c > '9') && c != '+' && c != '/' && c != '=' {
log.Printf("[Debug] Invalid base64 character at position %d: %c (code=%d)", i, c, c)
return false
}
}
return true
}
// buildParts 构建消息的 parts
// allowDummyThought: 只有 Gemini 模型支持 dummy thought signature
func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDummyThought bool) ([]GeminiPart, error) {
......@@ -199,22 +239,30 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
}
case "thinking":
part := GeminiPart{
Text: block.Thinking,
Thought: true,
if allowDummyThought {
// Gemini 模型可以使用 dummy signature
parts = append(parts, GeminiPart{
Text: block.Thinking,
Thought: true,
ThoughtSignature: dummyThoughtSignature,
})
continue
}
// 保留原有 signature(Claude 模型需要有效的 signature)
if block.Signature != "" {
part.ThoughtSignature = block.Signature
} else if !allowDummyThought {
// Claude 模型需要有效 signature,跳过无 signature 的 thinking block
log.Printf("Warning: skipping thinking block without signature for Claude model")
// Claude 模型:仅在提供有效 signature 时保留 thinking block;否则跳过以避免上游校验失败。
signature := strings.TrimSpace(block.Signature)
if signature == "" || signature == dummyThoughtSignature {
log.Printf("[Warning] Skipping thinking block for Claude model (missing or dummy signature)")
continue
} else {
// Gemini 模型使用 dummy signature
part.ThoughtSignature = dummyThoughtSignature
}
parts = append(parts, part)
if !isValidThoughtSignature(signature) {
log.Printf("[Debug] Thinking signature may be invalid (passing through anyway): len=%d", len(signature))
}
parts = append(parts, GeminiPart{
Text: block.Thinking,
Thought: true,
ThoughtSignature: signature,
})
case "image":
if block.Source != nil && block.Source.Type == "base64" {
......@@ -239,10 +287,9 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
ID: block.ID,
},
}
// 保留原有 signature,或对 Gemini 模型使用 dummy signature
if block.Signature != "" {
part.ThoughtSignature = block.Signature
} else if allowDummyThought {
// 只有 Gemini 模型使用 dummy signature
// Claude 模型不设置 signature(避免验证问题)
if allowDummyThought {
part.ThoughtSignature = dummyThoughtSignature
}
parts = append(parts, part)
......@@ -386,9 +433,9 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
// 普通工具
var funcDecls []GeminiFunctionDecl
for _, tool := range tools {
for i, tool := range tools {
// 跳过无效工具名称
if tool.Name == "" {
if strings.TrimSpace(tool.Name) == "" {
log.Printf("Warning: skipping tool with empty name")
continue
}
......@@ -397,10 +444,18 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
var inputSchema map[string]any
// 检查是否为 custom 类型工具 (MCP)
if tool.Type == "custom" && tool.Custom != nil {
// Custom 格式: 从 custom 字段获取 description 和 input_schema
if tool.Type == "custom" {
if tool.Custom == nil || tool.Custom.InputSchema == nil {
log.Printf("[Warning] Skipping invalid custom tool '%s': missing custom spec or input_schema", tool.Name)
continue
}
description = tool.Custom.Description
inputSchema = tool.Custom.InputSchema
// 调试日志:记录 custom 工具的 schema
if schemaJSON, err := json.Marshal(inputSchema); err == nil {
log.Printf("[Debug] Tool[%d] '%s' (custom) original schema: %s", i, tool.Name, string(schemaJSON))
}
} else {
// 标准格式: 从顶层字段获取
description = tool.Description
......@@ -409,7 +464,6 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
// 清理 JSON Schema
params := cleanJSONSchema(inputSchema)
// 为 nil schema 提供默认值
if params == nil {
params = map[string]any{
......@@ -418,6 +472,11 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
}
}
// 调试日志:记录清理后的 schema
if paramsJSON, err := json.Marshal(params); err == nil {
log.Printf("[Debug] Tool[%d] '%s' cleaned schema: %s", i, tool.Name, string(paramsJSON))
}
funcDecls = append(funcDecls, GeminiFunctionDecl{
Name: tool.Name,
Description: description,
......@@ -479,31 +538,64 @@ func cleanJSONSchema(schema map[string]any) map[string]any {
}
// excludedSchemaKeys 不支持的 schema 字段
// 基于 Claude API (Vertex AI) 的实际支持情况
// 支持: type, description, enum, properties, required, additionalProperties, items
// 不支持: minItems, maxItems, minLength, maxLength, pattern, minimum, maximum 等验证字段
var excludedSchemaKeys = map[string]bool{
"$schema": true,
"$id": true,
"$ref": true,
"additionalProperties": true,
"minLength": true,
"maxLength": true,
"minItems": true,
"maxItems": true,
"uniqueItems": true,
"minimum": true,
"maximum": true,
"exclusiveMinimum": true,
"exclusiveMaximum": true,
"pattern": true,
"format": true,
"default": true,
"strict": true,
"const": true,
"examples": true,
"deprecated": true,
"readOnly": true,
"writeOnly": true,
"contentMediaType": true,
"contentEncoding": true,
// 元 schema 字段
"$schema": true,
"$id": true,
"$ref": true,
// 字符串验证(Gemini 不支持)
"minLength": true,
"maxLength": true,
"pattern": true,
// 数字验证(Claude API 通过 Vertex AI 不支持这些字段)
"minimum": true,
"maximum": true,
"exclusiveMinimum": true,
"exclusiveMaximum": true,
"multipleOf": true,
// 数组验证(Claude API 通过 Vertex AI 不支持这些字段)
"uniqueItems": true,
"minItems": true,
"maxItems": true,
// 组合 schema(Gemini 不支持)
"oneOf": true,
"anyOf": true,
"allOf": true,
"not": true,
"if": true,
"then": true,
"else": true,
"$defs": true,
"definitions": true,
// 对象验证(仅保留 properties/required/additionalProperties)
"minProperties": true,
"maxProperties": true,
"patternProperties": true,
"propertyNames": true,
"dependencies": true,
"dependentSchemas": true,
"dependentRequired": true,
// 其他不支持的字段
"default": true,
"const": true,
"examples": true,
"deprecated": true,
"readOnly": true,
"writeOnly": true,
"contentMediaType": true,
"contentEncoding": true,
// Claude 特有字段
"strict": true,
}
// cleanSchemaValue 递归清理 schema 值
......@@ -523,6 +615,31 @@ func cleanSchemaValue(value any) any {
continue
}
// 特殊处理 format 字段:只保留 Gemini 支持的 format 值
if k == "format" {
if formatStr, ok := val.(string); ok {
// Gemini 只支持 date-time, date, time
if formatStr == "date-time" || formatStr == "date" || formatStr == "time" {
result[k] = val
}
// 其他 format 值直接跳过
}
continue
}
// 特殊处理 additionalProperties:Claude API 只支持布尔值,不支持 schema 对象
if k == "additionalProperties" {
if boolVal, ok := val.(bool); ok {
result[k] = boolVal
log.Printf("[Debug] additionalProperties is bool: %v", boolVal)
} else {
// 如果是 schema 对象,转换为 false(更安全的默认值)
result[k] = false
log.Printf("[Debug] additionalProperties is not bool (type: %T), converting to false", val)
}
continue
}
// 递归清理所有值
result[k] = cleanSchemaValue(val)
}
......
package antigravity
import (
"encoding/json"
"testing"
)
// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理
func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) {
tests := []struct {
name string
content string
allowDummyThought bool
expectedParts int
description string
}{
{
name: "Claude model - skip thinking block without signature",
content: `[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
{"type": "text", "text": "World"}
]`,
allowDummyThought: false,
expectedParts: 2, // 只有两个text block
description: "Claude模型应该跳过无signature的thinking block",
},
{
name: "Claude model - keep thinking block with signature",
content: `[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": "valid_sig"},
{"type": "text", "text": "World"}
]`,
allowDummyThought: false,
expectedParts: 3, // 三个block都保留
description: "Claude模型应该保留有signature的thinking block",
},
{
name: "Gemini model - use dummy signature",
content: `[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
{"type": "text", "text": "World"}
]`,
allowDummyThought: true,
expectedParts: 3, // 三个block都保留,thinking使用dummy signature
description: "Gemini模型应该为无signature的thinking block使用dummy signature",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
toolIDToName := make(map[string]string)
parts, err := buildParts(json.RawMessage(tt.content), toolIDToName, tt.allowDummyThought)
if err != nil {
t.Fatalf("buildParts() error = %v", err)
}
if len(parts) != tt.expectedParts {
t.Errorf("%s: got %d parts, want %d parts", tt.description, len(parts), tt.expectedParts)
}
})
}
}
// TestBuildTools_CustomTypeTools 测试custom类型工具转换
func TestBuildTools_CustomTypeTools(t *testing.T) {
tests := []struct {
name string
tools []ClaudeTool
expectedLen int
description string
}{
{
name: "Standard tool format",
tools: []ClaudeTool{
{
Name: "get_weather",
Description: "Get weather information",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"location": map[string]any{"type": "string"},
},
},
},
},
expectedLen: 1,
description: "标准工具格式应该正常转换",
},
{
name: "Custom type tool (MCP format)",
tools: []ClaudeTool{
{
Type: "custom",
Name: "mcp_tool",
Custom: &CustomToolSpec{
Description: "MCP tool description",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"param": map[string]any{"type": "string"},
},
},
},
},
},
expectedLen: 1,
description: "Custom类型工具应该从Custom字段读取description和input_schema",
},
{
name: "Mixed standard and custom tools",
tools: []ClaudeTool{
{
Name: "standard_tool",
Description: "Standard tool",
InputSchema: map[string]any{"type": "object"},
},
{
Type: "custom",
Name: "custom_tool",
Custom: &CustomToolSpec{
Description: "Custom tool",
InputSchema: map[string]any{"type": "object"},
},
},
},
expectedLen: 1, // 返回一个GeminiToolDeclaration,包含2个function declarations
description: "混合标准和custom工具应该都能正确转换",
},
{
name: "Invalid custom tool - nil Custom field",
tools: []ClaudeTool{
{
Type: "custom",
Name: "invalid_custom",
// Custom 为 nil
},
},
expectedLen: 0, // 应该被跳过
description: "Custom字段为nil的custom工具应该被跳过",
},
{
name: "Invalid custom tool - nil InputSchema",
tools: []ClaudeTool{
{
Type: "custom",
Name: "invalid_custom",
Custom: &CustomToolSpec{
Description: "Invalid",
// InputSchema 为 nil
},
},
},
expectedLen: 0, // 应该被跳过
description: "InputSchema为nil的custom工具应该被跳过",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := buildTools(tt.tools)
if len(result) != tt.expectedLen {
t.Errorf("%s: got %d tool declarations, want %d", tt.description, len(result), tt.expectedLen)
}
// 验证function declarations存在
if len(result) > 0 && result[0].FunctionDeclarations != nil {
if len(result[0].FunctionDeclarations) != len(tt.tools) {
t.Errorf("%s: got %d function declarations, want %d",
tt.description, len(result[0].FunctionDeclarations), len(tt.tools))
}
}
})
}
}
......@@ -16,6 +16,12 @@ const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleav
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta)
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
// ApiKeyBetaHeader API-key 账号建议使用的 anthropic-beta header(不包含 oauth)
const ApiKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
// ApiKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header(不包含 oauth / claude-code)
const ApiKeyHaikuBetaHeader = BetaInterleavedThinking
// Claude Code 客户端默认请求头
var DefaultHeaders = map[string]string{
"User-Agent": "claude-cli/2.0.62 (external, cli)",
......
package geminicli
import (
"context"
"encoding/json"
"fmt"
"math/rand"
"net/http"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
)
// DriveStorageInfo represents Google Drive storage quota information
type DriveStorageInfo struct {
Limit int64 `json:"limit"` // Storage limit in bytes
Usage int64 `json:"usage"` // Current usage in bytes
}
// DriveClient interface for Google Drive API operations
type DriveClient interface {
GetStorageQuota(ctx context.Context, accessToken, proxyURL string) (*DriveStorageInfo, error)
}
type driveClient struct{}
// NewDriveClient creates a new Drive API client
func NewDriveClient() DriveClient {
return &driveClient{}
}
// GetStorageQuota fetches storage quota from Google Drive API
func (c *driveClient) GetStorageQuota(ctx context.Context, accessToken, proxyURL string) (*DriveStorageInfo, error) {
const driveAPIURL = "https://www.googleapis.com/drive/v3/about?fields=storageQuota"
req, err := http.NewRequestWithContext(ctx, "GET", driveAPIURL, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
// Get HTTP client with proxy support
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL,
Timeout: 10 * time.Second,
})
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client: %w", err)
}
sleepWithContext := func(d time.Duration) error {
timer := time.NewTimer(d)
defer timer.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
return nil
}
}
// Retry logic with exponential backoff (+ jitter) for rate limits and transient failures
var resp *http.Response
maxRetries := 3
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
for attempt := 0; attempt < maxRetries; attempt++ {
if ctx.Err() != nil {
return nil, fmt.Errorf("request cancelled: %w", ctx.Err())
}
resp, err = client.Do(req)
if err != nil {
// Network error retry
if attempt < maxRetries-1 {
backoff := time.Duration(1<<uint(attempt)) * time.Second
jitter := time.Duration(rng.Intn(1000)) * time.Millisecond
if err := sleepWithContext(backoff + jitter); err != nil {
return nil, fmt.Errorf("request cancelled: %w", err)
}
continue
}
return nil, fmt.Errorf("network error after %d attempts: %w", maxRetries, err)
}
// Success
if resp.StatusCode == http.StatusOK {
break
}
// Retry 429, 500, 502, 503 with exponential backoff + jitter
if (resp.StatusCode == http.StatusTooManyRequests ||
resp.StatusCode == http.StatusInternalServerError ||
resp.StatusCode == http.StatusBadGateway ||
resp.StatusCode == http.StatusServiceUnavailable) && attempt < maxRetries-1 {
if err := func() error {
defer func() { _ = resp.Body.Close() }()
backoff := time.Duration(1<<uint(attempt)) * time.Second
jitter := time.Duration(rng.Intn(1000)) * time.Millisecond
return sleepWithContext(backoff + jitter)
}(); err != nil {
return nil, fmt.Errorf("request cancelled: %w", err)
}
continue
}
break
}
if resp == nil {
return nil, fmt.Errorf("request failed: no response received")
}
if resp.StatusCode != http.StatusOK {
_ = resp.Body.Close()
statusText := http.StatusText(resp.StatusCode)
if statusText == "" {
statusText = resp.Status
}
fmt.Printf("[DriveClient] Drive API error: status=%d, msg=%s\n", resp.StatusCode, statusText)
// 只返回通用错误
return nil, fmt.Errorf("drive API error: status %d", resp.StatusCode)
}
defer func() { _ = resp.Body.Close() }()
// Parse response
var result struct {
StorageQuota struct {
Limit string `json:"limit"` // Can be string or number
Usage string `json:"usage"`
} `json:"storageQuota"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
// Parse limit and usage (handle both string and number formats)
var limit, usage int64
if result.StorageQuota.Limit != "" {
if val, err := strconv.ParseInt(result.StorageQuota.Limit, 10, 64); err == nil {
limit = val
}
}
if result.StorageQuota.Usage != "" {
if val, err := strconv.ParseInt(result.StorageQuota.Usage, 10, 64); err == nil {
usage = val
}
}
return &DriveStorageInfo{
Limit: limit,
Usage: usage,
}, nil
}
package geminicli
import "testing"
func TestDriveStorageInfo(t *testing.T) {
// 测试 DriveStorageInfo 结构体
info := &DriveStorageInfo{
Limit: 100 * 1024 * 1024 * 1024, // 100GB
Usage: 50 * 1024 * 1024 * 1024, // 50GB
}
if info.Limit != 100*1024*1024*1024 {
t.Errorf("Expected limit 100GB, got %d", info.Limit)
}
if info.Usage != 50*1024*1024*1024 {
t.Errorf("Expected usage 50GB, got %d", info.Usage)
}
}
......@@ -124,6 +124,90 @@ func (r *accountRepository) GetByID(ctx context.Context, id int64) (*service.Acc
return &accounts[0], nil
}
func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*service.Account, error) {
if len(ids) == 0 {
return []*service.Account{}, nil
}
// De-duplicate while preserving order of first occurrence.
uniqueIDs := make([]int64, 0, len(ids))
seen := make(map[int64]struct{}, len(ids))
for _, id := range ids {
if id <= 0 {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
uniqueIDs = append(uniqueIDs, id)
}
if len(uniqueIDs) == 0 {
return []*service.Account{}, nil
}
entAccounts, err := r.client.Account.
Query().
Where(dbaccount.IDIn(uniqueIDs...)).
WithProxy().
All(ctx)
if err != nil {
return nil, err
}
if len(entAccounts) == 0 {
return []*service.Account{}, nil
}
accountIDs := make([]int64, 0, len(entAccounts))
entByID := make(map[int64]*dbent.Account, len(entAccounts))
for _, acc := range entAccounts {
entByID[acc.ID] = acc
accountIDs = append(accountIDs, acc.ID)
}
groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs)
if err != nil {
return nil, err
}
outByID := make(map[int64]*service.Account, len(entAccounts))
for _, entAcc := range entAccounts {
out := accountEntityToService(entAcc)
if out == nil {
continue
}
// Prefer the preloaded proxy edge when available.
if entAcc.Edges.Proxy != nil {
out.Proxy = proxyEntityToService(entAcc.Edges.Proxy)
}
if groups, ok := groupsByAccount[entAcc.ID]; ok {
out.Groups = groups
}
if groupIDs, ok := groupIDsByAccount[entAcc.ID]; ok {
out.GroupIDs = groupIDs
}
if ags, ok := accountGroupsByAccount[entAcc.ID]; ok {
out.AccountGroups = ags
}
outByID[entAcc.ID] = out
}
// Preserve input order (first occurrence), and ignore missing IDs.
out := make([]*service.Account, 0, len(uniqueIDs))
for _, id := range uniqueIDs {
if _, ok := entByID[id]; !ok {
continue
}
if acc, ok := outByID[id]; ok && acc != nil {
out = append(out, acc)
}
}
return out, nil
}
// ExistsByID 检查指定 ID 的账号是否存在。
// 相比 GetByID,此方法性能更优,因为:
// - 使用 Exist() 方法生成 SELECT EXISTS 查询,只返回布尔值
......
......@@ -294,7 +294,6 @@ func userEntityToService(u *dbent.User) *service.User {
ID: u.ID,
Email: u.Email,
Username: u.Username,
Wechat: u.Wechat,
Notes: u.Notes,
PasswordHash: u.PasswordHash,
Role: u.Role,
......
......@@ -2,7 +2,9 @@ package repository
import (
"context"
"errors"
"fmt"
"strconv"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
......@@ -27,6 +29,8 @@ const (
userSlotKeyPrefix = "concurrency:user:"
// 等待队列计数器格式: concurrency:wait:{userID}
waitQueueKeyPrefix = "concurrency:wait:"
// 账号级等待队列计数器格式: wait:account:{accountID}
accountWaitKeyPrefix = "wait:account:"
// 默认槽位过期时间(分钟),可通过配置覆盖
defaultSlotTTLMinutes = 15
......@@ -112,33 +116,112 @@ var (
redis.call('EXPIRE', KEYS[1], ARGV[2])
end
return 1
`)
return 1
`)
// incrementAccountWaitScript - account-level wait queue count
incrementAccountWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current == false then
current = 0
else
current = tonumber(current)
end
if current >= tonumber(ARGV[1]) then
return 0
end
local newVal = redis.call('INCR', KEYS[1])
-- Only set TTL on first creation to avoid refreshing zombie data
if newVal == 1 then
redis.call('EXPIRE', KEYS[1], ARGV[2])
end
return 1
`)
// decrementWaitScript - same as before
decrementWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then
redis.call('DECR', KEYS[1])
end
return 1
`)
local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then
redis.call('DECR', KEYS[1])
end
return 1
`)
// getAccountsLoadBatchScript - batch load query (read-only)
// ARGV[1] = slot TTL (seconds, retained for compatibility)
// ARGV[2..n] = accountID1, maxConcurrency1, accountID2, maxConcurrency2, ...
getAccountsLoadBatchScript = redis.NewScript(`
local result = {}
local i = 2
while i <= #ARGV do
local accountID = ARGV[i]
local maxConcurrency = tonumber(ARGV[i + 1])
local slotKey = 'concurrency:account:' .. accountID
local currentConcurrency = redis.call('ZCARD', slotKey)
local waitKey = 'wait:account:' .. accountID
local waitingCount = redis.call('GET', waitKey)
if waitingCount == false then
waitingCount = 0
else
waitingCount = tonumber(waitingCount)
end
local loadRate = 0
if maxConcurrency > 0 then
loadRate = math.floor((currentConcurrency + waitingCount) * 100 / maxConcurrency)
end
table.insert(result, accountID)
table.insert(result, currentConcurrency)
table.insert(result, waitingCount)
table.insert(result, loadRate)
i = i + 2
end
return result
`)
// cleanupExpiredSlotsScript - remove expired slots
// KEYS[1] = concurrency:account:{accountID}
// ARGV[1] = TTL (seconds)
cleanupExpiredSlotsScript = redis.NewScript(`
local key = KEYS[1]
local ttl = tonumber(ARGV[1])
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - ttl
return redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
`)
)
type concurrencyCache struct {
rdb *redis.Client
slotTTLSeconds int // 槽位过期时间(秒)
rdb *redis.Client
slotTTLSeconds int // 槽位过期时间(秒)
waitQueueTTLSeconds int // 等待队列过期时间(秒)
}
// NewConcurrencyCache 创建并发控制缓存
// slotTTLMinutes: 槽位过期时间(分钟),0 或负数使用默认值 15 分钟
func NewConcurrencyCache(rdb *redis.Client, slotTTLMinutes int) service.ConcurrencyCache {
// waitQueueTTLSeconds: 等待队列过期时间(秒),0 或负数使用 slot TTL
func NewConcurrencyCache(rdb *redis.Client, slotTTLMinutes int, waitQueueTTLSeconds int) service.ConcurrencyCache {
if slotTTLMinutes <= 0 {
slotTTLMinutes = defaultSlotTTLMinutes
}
if waitQueueTTLSeconds <= 0 {
waitQueueTTLSeconds = slotTTLMinutes * 60
}
return &concurrencyCache{
rdb: rdb,
slotTTLSeconds: slotTTLMinutes * 60,
rdb: rdb,
slotTTLSeconds: slotTTLMinutes * 60,
waitQueueTTLSeconds: waitQueueTTLSeconds,
}
}
......@@ -155,6 +238,10 @@ func waitQueueKey(userID int64) string {
return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
}
func accountWaitKey(accountID int64) string {
return fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
}
// Account slot operations
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
......@@ -225,3 +312,75 @@ func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64)
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
return err
}
// Account wait queue operations
func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
key := accountWaitKey(accountID)
result, err := incrementAccountWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.waitQueueTTLSeconds).Int()
if err != nil {
return false, err
}
return result == 1, nil
}
func (c *concurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
key := accountWaitKey(accountID)
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
return err
}
func (c *concurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
key := accountWaitKey(accountID)
val, err := c.rdb.Get(ctx, key).Int()
if err != nil && !errors.Is(err, redis.Nil) {
return 0, err
}
if errors.Is(err, redis.Nil) {
return 0, nil
}
return val, nil
}
func (c *concurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
if len(accounts) == 0 {
return map[int64]*service.AccountLoadInfo{}, nil
}
args := []any{c.slotTTLSeconds}
for _, acc := range accounts {
args = append(args, acc.ID, acc.MaxConcurrency)
}
result, err := getAccountsLoadBatchScript.Run(ctx, c.rdb, []string{}, args...).Slice()
if err != nil {
return nil, err
}
loadMap := make(map[int64]*service.AccountLoadInfo)
for i := 0; i < len(result); i += 4 {
if i+3 >= len(result) {
break
}
accountID, _ := strconv.ParseInt(fmt.Sprintf("%v", result[i]), 10, 64)
currentConcurrency, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+1]))
waitingCount, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+2]))
loadRate, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+3]))
loadMap[accountID] = &service.AccountLoadInfo{
AccountID: accountID,
CurrentConcurrency: currentConcurrency,
WaitingCount: waitingCount,
LoadRate: loadRate,
}
}
return loadMap, nil
}
func (c *concurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
key := accountSlotKey(accountID)
_, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result()
return err
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment