Commit e51a3288 authored by yangjianbo's avatar yangjianbo
Browse files

merge: 合并 test 分支到 test-dev,解决冲突

解决的冲突文件:
- wire_gen.go: 合并 ConcurrencyService/CRSSyncService 参数和 userAttributeHandler
- gateway_handler.go: 合并 pkg/errors 和 antigravity 导入
- gateway_service.go: 合并 validateUpstreamBaseURL 和 GetAvailableModels
- config.example.yaml: 合并 billing/turnstile 配置和额外 gateway 选项

🤖 Generated with [Claude Code](https://claude.com/claude-code

)
Co-Authored-By: default avatarClaude Opus 4.5 <noreply@anthropic.com>
parents 25e16326 8a50ca59
......@@ -46,8 +46,8 @@ func (h *GeminiOAuthHandler) GenerateAuthURL(c *gin.Context) {
if oauthType == "" {
oauthType = "code_assist"
}
if oauthType != "code_assist" && oauthType != "ai_studio" {
response.BadRequest(c, "Invalid oauth_type: must be 'code_assist' or 'ai_studio'")
if oauthType != "code_assist" && oauthType != "google_one" && oauthType != "ai_studio" {
response.BadRequest(c, "Invalid oauth_type: must be 'code_assist', 'google_one', or 'ai_studio'")
return
}
......@@ -92,8 +92,8 @@ func (h *GeminiOAuthHandler) ExchangeCode(c *gin.Context) {
if oauthType == "" {
oauthType = "code_assist"
}
if oauthType != "code_assist" && oauthType != "ai_studio" {
response.BadRequest(c, "Invalid oauth_type: must be 'code_assist' or 'ai_studio'")
if oauthType != "code_assist" && oauthType != "google_one" && oauthType != "ai_studio" {
response.BadRequest(c, "Invalid oauth_type: must be 'code_assist', 'google_one', or 'ai_studio'")
return
}
......
......@@ -16,13 +16,15 @@ import (
type SettingHandler struct {
settingService *service.SettingService
emailService *service.EmailService
turnstileService *service.TurnstileService
}
// NewSettingHandler 创建系统设置处理器
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService) *SettingHandler {
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService) *SettingHandler {
return &SettingHandler{
settingService: settingService,
emailService: emailService,
turnstileService: turnstileService,
}
}
......@@ -118,6 +120,36 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
req.SmtpPort = 587
}
// Turnstile 参数验证
if req.TurnstileEnabled {
// 检查必填字段
if req.TurnstileSiteKey == "" {
response.BadRequest(c, "Turnstile Site Key is required when enabled")
return
}
if req.TurnstileSecretKey == "" {
response.BadRequest(c, "Turnstile Secret Key is required when enabled")
return
}
// 获取当前设置,检查参数是否有变化
currentSettings, err := h.settingService.GetAllSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
// 当 site_key 或 secret_key 任一变化时验证(避免配置错误导致无法登录)
siteKeyChanged := currentSettings.TurnstileSiteKey != req.TurnstileSiteKey
secretKeyChanged := currentSettings.TurnstileSecretKey != req.TurnstileSecretKey
if siteKeyChanged || secretKeyChanged {
if err := h.turnstileService.ValidateSecretKey(c.Request.Context(), req.TurnstileSecretKey); err != nil {
response.ErrorFrom(c, err)
return
}
}
}
settings := &service.SystemSettings{
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,
......
......@@ -246,7 +246,7 @@ func (h *UsageHandler) SearchUsers(c *gin.Context) {
}
// Limit to 30 results
users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 30, "", "", keyword)
users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 30, service.UserListFilters{Search: keyword})
if err != nil {
response.ErrorFrom(c, err)
return
......
package admin
import (
"strconv"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// UserAttributeHandler handles user attribute management
type UserAttributeHandler struct {
attrService *service.UserAttributeService
}
// NewUserAttributeHandler creates a new handler
func NewUserAttributeHandler(attrService *service.UserAttributeService) *UserAttributeHandler {
return &UserAttributeHandler{attrService: attrService}
}
// --- Request/Response DTOs ---
// CreateAttributeDefinitionRequest represents create attribute definition request
type CreateAttributeDefinitionRequest struct {
Key string `json:"key" binding:"required,min=1,max=100"`
Name string `json:"name" binding:"required,min=1,max=255"`
Description string `json:"description"`
Type string `json:"type" binding:"required"`
Options []service.UserAttributeOption `json:"options"`
Required bool `json:"required"`
Validation service.UserAttributeValidation `json:"validation"`
Placeholder string `json:"placeholder"`
Enabled bool `json:"enabled"`
}
// UpdateAttributeDefinitionRequest represents update attribute definition request
type UpdateAttributeDefinitionRequest struct {
Name *string `json:"name"`
Description *string `json:"description"`
Type *string `json:"type"`
Options *[]service.UserAttributeOption `json:"options"`
Required *bool `json:"required"`
Validation *service.UserAttributeValidation `json:"validation"`
Placeholder *string `json:"placeholder"`
Enabled *bool `json:"enabled"`
}
// ReorderRequest represents reorder attribute definitions request
type ReorderRequest struct {
IDs []int64 `json:"ids" binding:"required"`
}
// UpdateUserAttributesRequest represents update user attributes request
type UpdateUserAttributesRequest struct {
Values map[int64]string `json:"values" binding:"required"`
}
// BatchGetUserAttributesRequest represents batch get user attributes request
type BatchGetUserAttributesRequest struct {
UserIDs []int64 `json:"user_ids" binding:"required"`
}
// BatchUserAttributesResponse represents batch user attributes response
type BatchUserAttributesResponse struct {
// Map of userID -> map of attributeID -> value
Attributes map[int64]map[int64]string `json:"attributes"`
}
// AttributeDefinitionResponse represents attribute definition response
type AttributeDefinitionResponse struct {
ID int64 `json:"id"`
Key string `json:"key"`
Name string `json:"name"`
Description string `json:"description"`
Type string `json:"type"`
Options []service.UserAttributeOption `json:"options"`
Required bool `json:"required"`
Validation service.UserAttributeValidation `json:"validation"`
Placeholder string `json:"placeholder"`
DisplayOrder int `json:"display_order"`
Enabled bool `json:"enabled"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}
// AttributeValueResponse represents attribute value response
type AttributeValueResponse struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
AttributeID int64 `json:"attribute_id"`
Value string `json:"value"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}
// --- Helpers ---
func defToResponse(def *service.UserAttributeDefinition) *AttributeDefinitionResponse {
return &AttributeDefinitionResponse{
ID: def.ID,
Key: def.Key,
Name: def.Name,
Description: def.Description,
Type: string(def.Type),
Options: def.Options,
Required: def.Required,
Validation: def.Validation,
Placeholder: def.Placeholder,
DisplayOrder: def.DisplayOrder,
Enabled: def.Enabled,
CreatedAt: def.CreatedAt.Format("2006-01-02T15:04:05Z07:00"),
UpdatedAt: def.UpdatedAt.Format("2006-01-02T15:04:05Z07:00"),
}
}
func valueToResponse(val *service.UserAttributeValue) *AttributeValueResponse {
return &AttributeValueResponse{
ID: val.ID,
UserID: val.UserID,
AttributeID: val.AttributeID,
Value: val.Value,
CreatedAt: val.CreatedAt.Format("2006-01-02T15:04:05Z07:00"),
UpdatedAt: val.UpdatedAt.Format("2006-01-02T15:04:05Z07:00"),
}
}
// --- Handlers ---
// ListDefinitions lists all attribute definitions
// GET /admin/user-attributes
func (h *UserAttributeHandler) ListDefinitions(c *gin.Context) {
enabledOnly := c.Query("enabled") == "true"
defs, err := h.attrService.ListDefinitions(c.Request.Context(), enabledOnly)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]*AttributeDefinitionResponse, 0, len(defs))
for i := range defs {
out = append(out, defToResponse(&defs[i]))
}
response.Success(c, out)
}
// CreateDefinition creates a new attribute definition
// POST /admin/user-attributes
func (h *UserAttributeHandler) CreateDefinition(c *gin.Context) {
var req CreateAttributeDefinitionRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
def, err := h.attrService.CreateDefinition(c.Request.Context(), service.CreateAttributeDefinitionInput{
Key: req.Key,
Name: req.Name,
Description: req.Description,
Type: service.UserAttributeType(req.Type),
Options: req.Options,
Required: req.Required,
Validation: req.Validation,
Placeholder: req.Placeholder,
Enabled: req.Enabled,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, defToResponse(def))
}
// UpdateDefinition updates an attribute definition
// PUT /admin/user-attributes/:id
func (h *UserAttributeHandler) UpdateDefinition(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid attribute ID")
return
}
var req UpdateAttributeDefinitionRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
input := service.UpdateAttributeDefinitionInput{
Name: req.Name,
Description: req.Description,
Options: req.Options,
Required: req.Required,
Validation: req.Validation,
Placeholder: req.Placeholder,
Enabled: req.Enabled,
}
if req.Type != nil {
t := service.UserAttributeType(*req.Type)
input.Type = &t
}
def, err := h.attrService.UpdateDefinition(c.Request.Context(), id, input)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, defToResponse(def))
}
// DeleteDefinition deletes an attribute definition
// DELETE /admin/user-attributes/:id
func (h *UserAttributeHandler) DeleteDefinition(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid attribute ID")
return
}
if err := h.attrService.DeleteDefinition(c.Request.Context(), id); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Attribute definition deleted successfully"})
}
// ReorderDefinitions reorders attribute definitions
// PUT /admin/user-attributes/reorder
func (h *UserAttributeHandler) ReorderDefinitions(c *gin.Context) {
var req ReorderRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Convert IDs array to orders map (position in array = display_order)
orders := make(map[int64]int, len(req.IDs))
for i, id := range req.IDs {
orders[id] = i
}
if err := h.attrService.ReorderDefinitions(c.Request.Context(), orders); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Reorder successful"})
}
// GetUserAttributes gets a user's attribute values
// GET /admin/users/:id/attributes
func (h *UserAttributeHandler) GetUserAttributes(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
values, err := h.attrService.GetUserAttributes(c.Request.Context(), userID)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]*AttributeValueResponse, 0, len(values))
for i := range values {
out = append(out, valueToResponse(&values[i]))
}
response.Success(c, out)
}
// UpdateUserAttributes updates a user's attribute values
// PUT /admin/users/:id/attributes
func (h *UserAttributeHandler) UpdateUserAttributes(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
var req UpdateUserAttributesRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
inputs := make([]service.UpdateUserAttributeInput, 0, len(req.Values))
for attrID, value := range req.Values {
inputs = append(inputs, service.UpdateUserAttributeInput{
AttributeID: attrID,
Value: value,
})
}
if err := h.attrService.UpdateUserAttributes(c.Request.Context(), userID, inputs); err != nil {
response.ErrorFrom(c, err)
return
}
// Return updated values
values, err := h.attrService.GetUserAttributes(c.Request.Context(), userID)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]*AttributeValueResponse, 0, len(values))
for i := range values {
out = append(out, valueToResponse(&values[i]))
}
response.Success(c, out)
}
// GetBatchUserAttributes gets attribute values for multiple users
// POST /admin/user-attributes/batch
func (h *UserAttributeHandler) GetBatchUserAttributes(c *gin.Context) {
var req BatchGetUserAttributesRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if len(req.UserIDs) == 0 {
response.Success(c, BatchUserAttributesResponse{Attributes: map[int64]map[int64]string{}})
return
}
attrs, err := h.attrService.GetBatchUserAttributes(c.Request.Context(), req.UserIDs)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, BatchUserAttributesResponse{Attributes: attrs})
}
......@@ -27,7 +27,6 @@ type CreateUserRequest struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=6"`
Username string `json:"username"`
Wechat string `json:"wechat"`
Notes string `json:"notes"`
Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"`
......@@ -40,7 +39,6 @@ type UpdateUserRequest struct {
Email string `json:"email" binding:"omitempty,email"`
Password string `json:"password" binding:"omitempty,min=6"`
Username *string `json:"username"`
Wechat *string `json:"wechat"`
Notes *string `json:"notes"`
Balance *float64 `json:"balance"`
Concurrency *int `json:"concurrency"`
......@@ -57,13 +55,22 @@ type UpdateBalanceRequest struct {
// List handles listing all users with pagination
// GET /api/v1/admin/users
// Query params:
// - status: filter by user status
// - role: filter by user role
// - search: search in email, username
// - attr[{id}]: filter by custom attribute value, e.g. attr[1]=company
func (h *UserHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
status := c.Query("status")
role := c.Query("role")
search := c.Query("search")
users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, status, role, search)
filters := service.UserListFilters{
Status: c.Query("status"),
Role: c.Query("role"),
Search: c.Query("search"),
Attributes: parseAttributeFilters(c),
}
users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, filters)
if err != nil {
response.ErrorFrom(c, err)
return
......@@ -76,6 +83,29 @@ func (h *UserHandler) List(c *gin.Context) {
response.Paginated(c, out, total, page, pageSize)
}
// parseAttributeFilters extracts attribute filters from query params
// Format: attr[{attributeID}]=value, e.g. attr[1]=company&attr[2]=developer
func parseAttributeFilters(c *gin.Context) map[int64]string {
result := make(map[int64]string)
// Get all query params and look for attr[*] pattern
for key, values := range c.Request.URL.Query() {
if len(values) == 0 || values[0] == "" {
continue
}
// Check if key matches pattern attr[{id}]
if len(key) > 5 && key[:5] == "attr[" && key[len(key)-1] == ']' {
idStr := key[5 : len(key)-1]
id, err := strconv.ParseInt(idStr, 10, 64)
if err == nil && id > 0 {
result[id] = values[0]
}
}
}
return result
}
// GetByID handles getting a user by ID
// GET /api/v1/admin/users/:id
func (h *UserHandler) GetByID(c *gin.Context) {
......@@ -107,7 +137,6 @@ func (h *UserHandler) Create(c *gin.Context) {
Email: req.Email,
Password: req.Password,
Username: req.Username,
Wechat: req.Wechat,
Notes: req.Notes,
Balance: req.Balance,
Concurrency: req.Concurrency,
......@@ -141,7 +170,6 @@ func (h *UserHandler) Update(c *gin.Context) {
Email: req.Email,
Password: req.Password,
Username: req.Username,
Wechat: req.Wechat,
Notes: req.Notes,
Balance: req.Balance,
Concurrency: req.Concurrency,
......
......@@ -10,7 +10,6 @@ func UserFromServiceShallow(u *service.User) *User {
ID: u.ID,
Email: u.Email,
Username: u.Username,
Wechat: u.Wechat,
Notes: u.Notes,
Role: u.Role,
Balance: u.Balance,
......
......@@ -6,7 +6,6 @@ type User struct {
ID int64 `json:"id"`
Email string `json:"email"`
Username string `json:"username"`
Wechat string `json:"wechat"`
Notes string `json:"notes"`
Role string `json:"role"`
Balance float64 `json:"balance"`
......
......@@ -11,7 +11,8 @@ import (
"strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
......@@ -143,6 +144,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} else if apiKey.Group != nil {
platform = apiKey.Group.Platform
}
sessionKey := sessionHash
if platform == service.PlatformGemini && sessionHash != "" {
sessionKey = "gemini:" + sessionHash
}
if platform == service.PlatformGemini {
const maxAccountSwitches = 3
......@@ -151,7 +156,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
lastFailoverStatus := 0
for {
account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs)
if err != nil {
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
......@@ -160,9 +165,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
account := selection.Account
// 检查预热请求拦截(在账号选择后、转发前检查)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
if reqStream {
sendMockWarmupStream(c, reqModel)
} else {
......@@ -172,12 +181,47 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 3. 获取账号并发槽位
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
accountReleaseFunc := selection.ReleaseFunc
var accountWaitRelease func()
if !selection.Acquired {
if selection.WaitPlan == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
} else {
// Only set release function if increment succeeded
accountWaitRelease = func() {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
if accountWaitRelease != nil {
accountWaitRelease()
}
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 转发请求 - 根据账号平台分流
var result *service.ForwardResult
......@@ -189,6 +233,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if accountWaitRelease != nil {
accountWaitRelease()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
......@@ -233,7 +280,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
for {
// 选择支持该模型的账号
account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs)
if err != nil {
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
......@@ -242,9 +289,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
account := selection.Account
// 检查预热请求拦截(在账号选择后、转发前检查)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
if reqStream {
sendMockWarmupStream(c, reqModel)
} else {
......@@ -254,12 +305,47 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 3. 获取账号并发槽位
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
accountReleaseFunc := selection.ReleaseFunc
var accountWaitRelease func()
if !selection.Acquired {
if selection.WaitPlan == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
} else {
// Only set release function if increment succeeded
accountWaitRelease = func() {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
if accountWaitRelease != nil {
accountWaitRelease()
}
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 转发请求 - 根据账号平台分流
var result *service.ForwardResult
......@@ -271,6 +357,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if accountWaitRelease != nil {
accountWaitRelease()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
......@@ -310,12 +399,42 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// Models handles listing available models
// GET /v1/models
// Returns different model lists based on the API key's group platform
// Returns models based on account configurations (model_mapping whitelist)
// Falls back to default models if no whitelist is configured
func (h *GatewayHandler) Models(c *gin.Context) {
apiKey, _ := middleware2.GetApiKeyFromContext(c)
// Return OpenAI models for OpenAI platform groups
if apiKey != nil && apiKey.Group != nil && apiKey.Group.Platform == "openai" {
var groupID *int64
var platform string
if apiKey != nil && apiKey.Group != nil {
groupID = &apiKey.Group.ID
platform = apiKey.Group.Platform
}
// Get available models from account configurations (without platform filter)
availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, "")
if len(availableModels) > 0 {
// Build model list from whitelist
models := make([]claude.Model, 0, len(availableModels))
for _, modelID := range availableModels {
models = append(models, claude.Model{
ID: modelID,
Type: "model",
DisplayName: modelID,
CreatedAt: "2024-01-01T00:00:00Z",
})
}
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": models,
})
return
}
// Fallback to default models
if platform == "openai" {
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": openai.DefaultModels,
......@@ -323,13 +442,21 @@ func (h *GatewayHandler) Models(c *gin.Context) {
return
}
// Default: Claude models
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": claude.DefaultModels,
})
}
// AntigravityModels 返回 Antigravity 支持的全部模型
// GET /antigravity/models
func (h *GatewayHandler) AntigravityModels(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": antigravity.DefaultModels(),
})
}
// Usage handles getting account balance for CC Switch integration
// GET /v1/usage
func (h *GatewayHandler) Usage(c *gin.Context) {
......@@ -648,13 +775,13 @@ func sendMockWarmupResponse(c *gin.Context, model string) {
func billingErrorDetails(err error) (status int, code, message string) {
if errors.Is(err, service.ErrBillingServiceUnavailable) {
msg := infraerrors.Message(err)
msg := pkgerrors.Message(err)
if msg == "" {
msg = "Billing service temporarily unavailable. Please retry later."
}
return http.StatusServiceUnavailable, "billing_service_error", msg
}
msg := infraerrors.Message(err)
msg := pkgerrors.Message(err)
if msg == "" {
msg = err.Error()
}
......
......@@ -83,6 +83,16 @@ func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64
h.concurrencyService.DecrementWaitCount(ctx, userID)
}
// IncrementAccountWaitCount increments the wait count for an account
func (h *ConcurrencyHelper) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
return h.concurrencyService.IncrementAccountWaitCount(ctx, accountID, maxWait)
}
// DecrementAccountWaitCount decrements the wait count for an account
func (h *ConcurrencyHelper) DecrementAccountWaitCount(ctx context.Context, accountID int64) {
h.concurrencyService.DecrementAccountWaitCount(ctx, accountID)
}
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
// For streaming requests, sends ping events during the wait.
// streamStarted is updated if streaming response has begun.
......@@ -126,7 +136,12 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.
// streamStarted pointer is updated when streaming begins (for proper error handling by caller).
func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
ctx, cancel := context.WithTimeout(c.Request.Context(), maxConcurrencyWait)
return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted)
}
// waitForSlotWithPingTimeout waits for a concurrency slot with a custom timeout.
func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) {
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel()
// Determine if ping is needed (streaming + ping format defined)
......@@ -200,6 +215,11 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string,
}
}
// AcquireAccountSlotWithWaitTimeout acquires an account slot with a custom timeout (keeps SSE ping).
func (h *ConcurrencyHelper) AcquireAccountSlotWithWaitTimeout(c *gin.Context, accountID int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) {
return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted)
}
// nextBackoff 计算下一次退避时间
// 性能优化:使用指数退避 + 随机抖动,避免惊群效应
// current: 当前退避时间
......
......@@ -9,6 +9,7 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
......@@ -32,9 +33,9 @@ func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
return
}
// 强制 antigravity 模式:直接返回静态模型列表
// 强制 antigravity 模式:返回 antigravity 支持的模型列表
if forcePlatform == service.PlatformAntigravity {
c.JSON(http.StatusOK, gemini.FallbackModelsList())
c.JSON(http.StatusOK, antigravity.FallbackGeminiModelsList())
return
}
......@@ -84,9 +85,9 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
return
}
// 强制 antigravity 模式:直接返回静态模型信息
// 强制 antigravity 模式:返回 antigravity 模型信息
if forcePlatform == service.PlatformAntigravity {
c.JSON(http.StatusOK, gemini.FallbackModel(modelName))
c.JSON(http.StatusOK, antigravity.FallbackGeminiModel(modelName))
return
}
......@@ -198,13 +199,17 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 3) select account (sticky session based on request body)
parsedReq, _ := service.ParseGatewayRequest(body)
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
sessionKey := sessionHash
if sessionHash != "" {
sessionKey = "gemini:" + sessionHash
}
const maxAccountSwitches = 3
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
for {
account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, modelName, failedAccountIDs)
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs)
if err != nil {
if len(failedAccountIDs) == 0 {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
......@@ -213,13 +218,49 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
handleGeminiFailoverExhausted(c, lastFailoverStatus)
return
}
account := selection.Account
// 4) account concurrency slot
accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, stream, &streamStarted)
accountReleaseFunc := selection.ReleaseFunc
var accountWaitRelease func()
if !selection.Acquired {
if selection.WaitPlan == nil {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts")
return
}
canWait, err := geminiConcurrency.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
return
} else {
// Only set release function if increment succeeded
accountWaitRelease = func() {
geminiConcurrency.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}
accountReleaseFunc, err = geminiConcurrency.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
stream,
&streamStarted,
)
if err != nil {
if accountWaitRelease != nil {
accountWaitRelease()
}
googleError(c, http.StatusTooManyRequests, err.Error())
return
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 5) forward (根据平台分流)
var result *service.ForwardResult
......@@ -231,6 +272,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if accountWaitRelease != nil {
accountWaitRelease()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
......
......@@ -20,6 +20,7 @@ type AdminHandlers struct {
System *admin.SystemHandler
Subscription *admin.SubscriptionHandler
Usage *admin.UsageHandler
UserAttribute *admin.UserAttributeHandler
}
// Handlers contains all HTTP handlers
......
......@@ -147,7 +147,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
for {
// Select account supporting the requested model
log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel)
account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
if err != nil {
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
if len(failedAccountIDs) == 0 {
......@@ -157,21 +157,60 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
account := selection.Account
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
// 3. Acquire account concurrency slot
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
accountReleaseFunc := selection.ReleaseFunc
var accountWaitRelease func()
if !selection.Acquired {
if selection.WaitPlan == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
} else {
// Only set release function if increment succeeded
accountWaitRelease = func() {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
if accountWaitRelease != nil {
accountWaitRelease()
}
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionHash, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
}
// Forward request
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if accountWaitRelease != nil {
accountWaitRelease()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
......
......@@ -30,7 +30,6 @@ type ChangePasswordRequest struct {
// UpdateProfileRequest represents the update profile request payload
type UpdateProfileRequest struct {
Username *string `json:"username"`
Wechat *string `json:"wechat"`
}
// GetProfile handles getting user profile
......@@ -99,7 +98,6 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
svcReq := service.UpdateProfileRequest{
Username: req.Username,
Wechat: req.Wechat,
}
updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), subject.UserID, svcReq)
if err != nil {
......
......@@ -23,6 +23,7 @@ func ProvideAdminHandlers(
systemHandler *admin.SystemHandler,
subscriptionHandler *admin.SubscriptionHandler,
usageHandler *admin.UsageHandler,
userAttributeHandler *admin.UserAttributeHandler,
) *AdminHandlers {
return &AdminHandlers{
Dashboard: dashboardHandler,
......@@ -39,6 +40,7 @@ func ProvideAdminHandlers(
System: systemHandler,
Subscription: subscriptionHandler,
Usage: usageHandler,
UserAttribute: userAttributeHandler,
}
}
......@@ -107,6 +109,7 @@ var ProviderSet = wire.NewSet(
ProvideSystemHandler,
admin.NewSubscriptionHandler,
admin.NewUsageHandler,
admin.NewUserAttributeHandler,
// AdminHandlers and Handlers constructors
ProvideAdminHandlers,
......
package infrastructure
import (
"database/sql"
"errors"
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/google/wire"
"github.com/redis/go-redis/v9"
entsql "entgo.io/ent/dialect/sql"
)
// ProviderSet 是基础设施层的 Wire 依赖提供者集合。
//
// Wire 是 Google 开发的编译时依赖注入工具。ProviderSet 将相关的依赖提供函数
// 组织在一起,便于在应用程序启动时自动组装依赖关系。
//
// 包含的提供者:
// - ProvideEnt: 提供 Ent ORM 客户端
// - ProvideSQLDB: 提供底层 SQL 数据库连接
// - ProvideRedis: 提供 Redis 客户端
var ProviderSet = wire.NewSet(
ProvideEnt,
ProvideSQLDB,
ProvideRedis,
)
// ProvideEnt 为依赖注入提供 Ent 客户端。
//
// 该函数是 InitEnt 的包装器,符合 Wire 的依赖提供函数签名要求。
// Wire 会在编译时分析依赖关系,自动生成初始化代码。
//
// 依赖:config.Config
// 提供:*ent.Client
func ProvideEnt(cfg *config.Config) (*ent.Client, error) {
client, _, err := InitEnt(cfg)
return client, err
}
// ProvideSQLDB 从 Ent 客户端提取底层的 *sql.DB 连接。
//
// 某些 Repository 需要直接执行原生 SQL(如复杂的批量更新、聚合查询),
// 此时需要访问底层的 sql.DB 而不是通过 Ent ORM。
//
// 设计说明:
// - Ent 底层使用 sql.DB,通过 Driver 接口可以访问
// - 这种设计允许在同一事务中混用 Ent 和原生 SQL
//
// 依赖:*ent.Client
// 提供:*sql.DB
func ProvideSQLDB(client *ent.Client) (*sql.DB, error) {
if client == nil {
return nil, errors.New("nil ent client")
}
// 从 Ent 客户端获取底层驱动
drv, ok := client.Driver().(*entsql.Driver)
if !ok {
return nil, errors.New("ent driver does not expose *sql.DB")
}
// 返回驱动持有的 sql.DB 实例
return drv.DB(), nil
}
// ProvideRedis 为依赖注入提供 Redis 客户端。
//
// Redis 用于:
// - 分布式锁(如并发控制)
// - 缓存(如用户会话、API 响应缓存)
// - 速率限制
// - 实时统计数据
//
// 依赖:config.Config
// 提供:*redis.Client
func ProvideRedis(cfg *config.Config) *redis.Client {
return InitRedis(cfg)
}
......@@ -57,6 +57,7 @@ var geminiModels = []string{
"gemini-2.5-flash-lite",
"gemini-3-flash",
"gemini-3-pro-low",
"gemini-3-pro-high",
}
func TestMain(m *testing.M) {
......@@ -641,6 +642,37 @@ func testClaudeThinkingWithToolHistory(t *testing.T, model string) {
t.Logf("✅ thinking 模式工具调用测试通过, id=%v", result["id"])
}
// TestClaudeMessagesWithGeminiModel 测试在 Claude 端点使用 Gemini 模型
// 验证:通过 /v1/messages 端点传入 gemini 模型名的场景(含前缀映射)
// 仅在 Antigravity 模式下运行(ENDPOINT_PREFIX="/antigravity")
func TestClaudeMessagesWithGeminiModel(t *testing.T) {
if endpointPrefix != "/antigravity" {
t.Skip("仅在 Antigravity 模式下运行")
}
// 测试通过 Claude 端点调用 Gemini 模型
geminiViaClaude := []string{
"gemini-3-flash", // 直接支持
"gemini-3-pro-low", // 直接支持
"gemini-3-pro-high", // 直接支持
"gemini-3-pro", // 前缀映射 -> gemini-3-pro-high
"gemini-3-pro-preview", // 前缀映射 -> gemini-3-pro-high
}
for i, model := range geminiViaClaude {
if i > 0 {
time.Sleep(testInterval)
}
t.Run(model+"_通过Claude端点", func(t *testing.T) {
testClaudeMessage(t, model, false)
})
time.Sleep(testInterval)
t.Run(model+"_通过Claude端点_流式", func(t *testing.T) {
testClaudeMessage(t, model, true)
})
}
}
// TestClaudeMessagesWithNoSignature 测试历史 thinking block 不带 signature 的场景
// 验证:Gemini 模型接受没有 signature 的 thinking block
func TestClaudeMessagesWithNoSignature(t *testing.T) {
......@@ -738,3 +770,30 @@ func testClaudeWithNoSignature(t *testing.T, model string) {
}
t.Logf("✅ 无 signature thinking 处理测试通过, id=%v", result["id"])
}
// TestGeminiEndpointWithClaudeModel 测试通过 Gemini 端点调用 Claude 模型
// 仅在 Antigravity 模式下运行(ENDPOINT_PREFIX="/antigravity")
func TestGeminiEndpointWithClaudeModel(t *testing.T) {
if endpointPrefix != "/antigravity" {
t.Skip("仅在 Antigravity 模式下运行")
}
// 测试通过 Gemini 端点调用 Claude 模型
claudeViaGemini := []string{
"claude-sonnet-4-5",
"claude-opus-4-5-thinking",
}
for i, model := range claudeViaGemini {
if i > 0 {
time.Sleep(testInterval)
}
t.Run(model+"_通过Gemini端点", func(t *testing.T) {
testGeminiGenerate(t, model, false)
})
time.Sleep(testInterval)
t.Run(model+"_通过Gemini端点_流式", func(t *testing.T) {
testGeminiGenerate(t, model, true)
})
}
}
......@@ -37,12 +37,26 @@ type ClaudeMetadata struct {
}
// ClaudeTool Claude 工具定义
// 支持两种格式:
// 1. 标准格式: { "name": "...", "description": "...", "input_schema": {...} }
// 2. Custom 格式 (MCP): { "type": "custom", "name": "...", "custom": { "description": "...", "input_schema": {...} } }
type ClaudeTool struct {
Type string `json:"type,omitempty"` // "custom" 或空(标准格式)
Name string `json:"name"`
Description string `json:"description,omitempty"` // 标准格式使用
InputSchema map[string]any `json:"input_schema,omitempty"` // 标准格式使用
Custom *CustomToolSpec `json:"custom,omitempty"` // custom 格式使用
}
// CustomToolSpec MCP custom 工具规格
type CustomToolSpec struct {
Description string `json:"description,omitempty"`
InputSchema map[string]any `json:"input_schema"`
}
// ClaudeCustomToolSpec 兼容旧命名(MCP custom 工具规格)
type ClaudeCustomToolSpec = CustomToolSpec
// SystemBlock system prompt 数组形式的元素
type SystemBlock struct {
Type string `json:"type"`
......@@ -124,3 +138,91 @@ type ErrorDetail struct {
Type string `json:"type"`
Message string `json:"message"`
}
// modelDef Antigravity 模型定义(内部使用)
type modelDef struct {
ID string
DisplayName string
CreatedAt string // 仅 Claude API 格式使用
}
// Antigravity 支持的 Claude 模型
var claudeModels = []modelDef{
{ID: "claude-opus-4-5-thinking", DisplayName: "Claude Opus 4.5 Thinking", CreatedAt: "2025-11-01T00:00:00Z"},
{ID: "claude-sonnet-4-5", DisplayName: "Claude Sonnet 4.5", CreatedAt: "2025-09-29T00:00:00Z"},
{ID: "claude-sonnet-4-5-thinking", DisplayName: "Claude Sonnet 4.5 Thinking", CreatedAt: "2025-09-29T00:00:00Z"},
}
// Antigravity 支持的 Gemini 模型
var geminiModels = []modelDef{
{ID: "gemini-2.5-flash", DisplayName: "Gemini 2.5 Flash", CreatedAt: "2025-01-01T00:00:00Z"},
{ID: "gemini-2.5-flash-lite", DisplayName: "Gemini 2.5 Flash Lite", CreatedAt: "2025-01-01T00:00:00Z"},
{ID: "gemini-2.5-flash-thinking", DisplayName: "Gemini 2.5 Flash Thinking", CreatedAt: "2025-01-01T00:00:00Z"},
{ID: "gemini-3-flash", DisplayName: "Gemini 3 Flash", CreatedAt: "2025-06-01T00:00:00Z"},
{ID: "gemini-3-pro-low", DisplayName: "Gemini 3 Pro Low", CreatedAt: "2025-06-01T00:00:00Z"},
{ID: "gemini-3-pro-high", DisplayName: "Gemini 3 Pro High", CreatedAt: "2025-06-01T00:00:00Z"},
{ID: "gemini-3-pro-preview", DisplayName: "Gemini 3 Pro Preview", CreatedAt: "2025-06-01T00:00:00Z"},
{ID: "gemini-3-pro-image", DisplayName: "Gemini 3 Pro Image", CreatedAt: "2025-06-01T00:00:00Z"},
}
// ========== Claude API 格式 (/v1/models) ==========
// ClaudeModel Claude API 模型格式
type ClaudeModel struct {
ID string `json:"id"`
Type string `json:"type"`
DisplayName string `json:"display_name"`
CreatedAt string `json:"created_at"`
}
// DefaultModels 返回 Claude API 格式的模型列表(Claude + Gemini)
func DefaultModels() []ClaudeModel {
all := append(claudeModels, geminiModels...)
result := make([]ClaudeModel, len(all))
for i, m := range all {
result[i] = ClaudeModel{ID: m.ID, Type: "model", DisplayName: m.DisplayName, CreatedAt: m.CreatedAt}
}
return result
}
// ========== Gemini v1beta 格式 (/v1beta/models) ==========
// GeminiModel Gemini v1beta 模型格式
type GeminiModel struct {
Name string `json:"name"`
DisplayName string `json:"displayName,omitempty"`
SupportedGenerationMethods []string `json:"supportedGenerationMethods,omitempty"`
}
// GeminiModelsListResponse Gemini v1beta 模型列表响应
type GeminiModelsListResponse struct {
Models []GeminiModel `json:"models"`
}
var defaultGeminiMethods = []string{"generateContent", "streamGenerateContent"}
// DefaultGeminiModels 返回 Gemini v1beta 格式的模型列表(仅 Gemini 模型)
func DefaultGeminiModels() []GeminiModel {
result := make([]GeminiModel, len(geminiModels))
for i, m := range geminiModels {
result[i] = GeminiModel{Name: "models/" + m.ID, DisplayName: m.DisplayName, SupportedGenerationMethods: defaultGeminiMethods}
}
return result
}
// FallbackGeminiModelsList 返回 Gemini v1beta 格式的模型列表响应
func FallbackGeminiModelsList() GeminiModelsListResponse {
return GeminiModelsListResponse{Models: DefaultGeminiModels()}
}
// FallbackGeminiModel 返回单个模型信息(v1beta 格式)
func FallbackGeminiModel(model string) GeminiModel {
if model == "" {
return GeminiModel{Name: "models/unknown", SupportedGenerationMethods: defaultGeminiMethods}
}
name := model
if len(model) < 7 || model[:7] != "models/" {
name = "models/" + model
}
return GeminiModel{Name: name, SupportedGenerationMethods: defaultGeminiMethods}
}
......@@ -3,6 +3,7 @@ package antigravity
import (
"encoding/json"
"fmt"
"log"
"strings"
"github.com/google/uuid"
......@@ -13,13 +14,16 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st
// 用于存储 tool_use id -> name 映射
toolIDToName := make(map[string]string)
// 检测是否启用 thinking
isThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
// 只有 Gemini 模型支持 dummy thought workaround
// Claude 模型通过 Vertex/Google API 需要有效的 thought signatures
allowDummyThought := strings.HasPrefix(mappedModel, "gemini-")
// 检测是否启用 thinking
requestedThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
// 为避免 Claude 模型的 thought signature/消息块约束导致 400(上游要求 thinking 块开头等),
// 非 Gemini 模型默认不启用 thinking(除非未来支持完整签名链路)。
isThinkingEnabled := requestedThinkingEnabled && allowDummyThought
// 1. 构建 contents
contents, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought)
if err != nil {
......@@ -30,7 +34,15 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st
systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model)
// 3. 构建 generationConfig
generationConfig := buildGenerationConfig(claudeReq)
reqForGen := claudeReq
if requestedThinkingEnabled && !allowDummyThought {
log.Printf("[Warning] Disabling thinking for non-Gemini model in antigravity transform: model=%s", mappedModel)
// shallow copy to avoid mutating caller's request
clone := *claudeReq
clone.Thinking = nil
reqForGen = &clone
}
generationConfig := buildGenerationConfig(reqForGen)
// 4. 构建 tools
tools := buildTools(claudeReq.Tools)
......@@ -149,6 +161,7 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
parts = append([]GeminiPart{{
Text: "Thinking...",
Thought: true,
ThoughtSignature: dummyThoughtSignature,
}}, parts...)
}
}
......@@ -170,6 +183,34 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
// 参考: https://ai.google.dev/gemini-api/docs/thought-signatures
const dummyThoughtSignature = "skip_thought_signature_validator"
// isValidThoughtSignature 验证 thought signature 是否有效
// Claude API 要求 signature 必须是 base64 编码的字符串,长度至少 32 字节
func isValidThoughtSignature(signature string) bool {
// 空字符串无效
if signature == "" {
return false
}
// signature 应该是 base64 编码,长度至少 40 个字符(约 30 字节)
// 参考 Claude API 文档和实际观察到的有效 signature
if len(signature) < 40 {
log.Printf("[Debug] Signature too short: len=%d", len(signature))
return false
}
// 检查是否是有效的 base64 字符
// base64 字符集: A-Z, a-z, 0-9, +, /, =
for i, c := range signature {
if (c < 'A' || c > 'Z') && (c < 'a' || c > 'z') &&
(c < '0' || c > '9') && c != '+' && c != '/' && c != '=' {
log.Printf("[Debug] Invalid base64 character at position %d: %c (code=%d)", i, c, c)
return false
}
}
return true
}
// buildParts 构建消息的 parts
// allowDummyThought: 只有 Gemini 模型支持 dummy thought signature
func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDummyThought bool) ([]GeminiPart, error) {
......@@ -198,15 +239,30 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
}
case "thinking":
part := GeminiPart{
if allowDummyThought {
// Gemini 模型可以使用 dummy signature
parts = append(parts, GeminiPart{
Text: block.Thinking,
Thought: true,
ThoughtSignature: dummyThoughtSignature,
})
continue
}
// 保留原有 signature(Claude 模型需要有效的 signature)
if block.Signature != "" {
part.ThoughtSignature = block.Signature
// Claude 模型:仅在提供有效 signature 时保留 thinking block;否则跳过以避免上游校验失败。
signature := strings.TrimSpace(block.Signature)
if signature == "" || signature == dummyThoughtSignature {
log.Printf("[Warning] Skipping thinking block for Claude model (missing or dummy signature)")
continue
}
parts = append(parts, part)
if !isValidThoughtSignature(signature) {
log.Printf("[Debug] Thinking signature may be invalid (passing through anyway): len=%d", len(signature))
}
parts = append(parts, GeminiPart{
Text: block.Thinking,
Thought: true,
ThoughtSignature: signature,
})
case "image":
if block.Source != nil && block.Source.Type == "base64" {
......@@ -231,10 +287,9 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
ID: block.ID,
},
}
// 保留原有 signature,或对 Gemini 模型使用 dummy signature
if block.Signature != "" {
part.ThoughtSignature = block.Signature
} else if allowDummyThought {
// 只有 Gemini 模型使用 dummy signature
// Claude 模型不设置 signature(避免验证问题)
if allowDummyThought {
part.ThoughtSignature = dummyThoughtSignature
}
parts = append(parts, part)
......@@ -378,13 +433,53 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
// 普通工具
var funcDecls []GeminiFunctionDecl
for _, tool := range tools {
for i, tool := range tools {
// 跳过无效工具名称
if strings.TrimSpace(tool.Name) == "" {
log.Printf("Warning: skipping tool with empty name")
continue
}
var description string
var inputSchema map[string]any
// 检查是否为 custom 类型工具 (MCP)
if tool.Type == "custom" {
if tool.Custom == nil || tool.Custom.InputSchema == nil {
log.Printf("[Warning] Skipping invalid custom tool '%s': missing custom spec or input_schema", tool.Name)
continue
}
description = tool.Custom.Description
inputSchema = tool.Custom.InputSchema
// 调试日志:记录 custom 工具的 schema
if schemaJSON, err := json.Marshal(inputSchema); err == nil {
log.Printf("[Debug] Tool[%d] '%s' (custom) original schema: %s", i, tool.Name, string(schemaJSON))
}
} else {
// 标准格式: 从顶层字段获取
description = tool.Description
inputSchema = tool.InputSchema
}
// 清理 JSON Schema
params := cleanJSONSchema(tool.InputSchema)
params := cleanJSONSchema(inputSchema)
// 为 nil schema 提供默认值
if params == nil {
params = map[string]any{
"type": "OBJECT",
"properties": map[string]any{},
}
}
// 调试日志:记录清理后的 schema
if paramsJSON, err := json.Marshal(params); err == nil {
log.Printf("[Debug] Tool[%d] '%s' cleaned schema: %s", i, tool.Name, string(paramsJSON))
}
funcDecls = append(funcDecls, GeminiFunctionDecl{
Name: tool.Name,
Description: tool.Description,
Description: description,
Parameters: params,
})
}
......@@ -443,24 +538,54 @@ func cleanJSONSchema(schema map[string]any) map[string]any {
}
// excludedSchemaKeys 不支持的 schema 字段
// 基于 Claude API (Vertex AI) 的实际支持情况
// 支持: type, description, enum, properties, required, additionalProperties, items
// 不支持: minItems, maxItems, minLength, maxLength, pattern, minimum, maximum 等验证字段
var excludedSchemaKeys = map[string]bool{
// 元 schema 字段
"$schema": true,
"$id": true,
"$ref": true,
"additionalProperties": true,
// 字符串验证(Gemini 不支持)
"minLength": true,
"maxLength": true,
"minItems": true,
"maxItems": true,
"uniqueItems": true,
"pattern": true,
// 数字验证(Claude API 通过 Vertex AI 不支持这些字段)
"minimum": true,
"maximum": true,
"exclusiveMinimum": true,
"exclusiveMaximum": true,
"pattern": true,
"format": true,
"multipleOf": true,
// 数组验证(Claude API 通过 Vertex AI 不支持这些字段)
"uniqueItems": true,
"minItems": true,
"maxItems": true,
// 组合 schema(Gemini 不支持)
"oneOf": true,
"anyOf": true,
"allOf": true,
"not": true,
"if": true,
"then": true,
"else": true,
"$defs": true,
"definitions": true,
// 对象验证(仅保留 properties/required/additionalProperties)
"minProperties": true,
"maxProperties": true,
"patternProperties": true,
"propertyNames": true,
"dependencies": true,
"dependentSchemas": true,
"dependentRequired": true,
// 其他不支持的字段
"default": true,
"strict": true,
"const": true,
"examples": true,
"deprecated": true,
......@@ -468,6 +593,9 @@ var excludedSchemaKeys = map[string]bool{
"writeOnly": true,
"contentMediaType": true,
"contentEncoding": true,
// Claude 特有字段
"strict": true,
}
// cleanSchemaValue 递归清理 schema 值
......@@ -487,6 +615,31 @@ func cleanSchemaValue(value any) any {
continue
}
// 特殊处理 format 字段:只保留 Gemini 支持的 format 值
if k == "format" {
if formatStr, ok := val.(string); ok {
// Gemini 只支持 date-time, date, time
if formatStr == "date-time" || formatStr == "date" || formatStr == "time" {
result[k] = val
}
// 其他 format 值直接跳过
}
continue
}
// 特殊处理 additionalProperties:Claude API 只支持布尔值,不支持 schema 对象
if k == "additionalProperties" {
if boolVal, ok := val.(bool); ok {
result[k] = boolVal
log.Printf("[Debug] additionalProperties is bool: %v", boolVal)
} else {
// 如果是 schema 对象,转换为 false(更安全的默认值)
result[k] = false
log.Printf("[Debug] additionalProperties is not bool (type: %T), converting to false", val)
}
continue
}
// 递归清理所有值
result[k] = cleanSchemaValue(val)
}
......
package antigravity
import (
"encoding/json"
"testing"
)
// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理
func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) {
tests := []struct {
name string
content string
allowDummyThought bool
expectedParts int
description string
}{
{
name: "Claude model - skip thinking block without signature",
content: `[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
{"type": "text", "text": "World"}
]`,
allowDummyThought: false,
expectedParts: 2, // 只有两个text block
description: "Claude模型应该跳过无signature的thinking block",
},
{
name: "Claude model - keep thinking block with signature",
content: `[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": "valid_sig"},
{"type": "text", "text": "World"}
]`,
allowDummyThought: false,
expectedParts: 3, // 三个block都保留
description: "Claude模型应该保留有signature的thinking block",
},
{
name: "Gemini model - use dummy signature",
content: `[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
{"type": "text", "text": "World"}
]`,
allowDummyThought: true,
expectedParts: 3, // 三个block都保留,thinking使用dummy signature
description: "Gemini模型应该为无signature的thinking block使用dummy signature",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
toolIDToName := make(map[string]string)
parts, err := buildParts(json.RawMessage(tt.content), toolIDToName, tt.allowDummyThought)
if err != nil {
t.Fatalf("buildParts() error = %v", err)
}
if len(parts) != tt.expectedParts {
t.Errorf("%s: got %d parts, want %d parts", tt.description, len(parts), tt.expectedParts)
}
})
}
}
// TestBuildTools_CustomTypeTools 测试custom类型工具转换
func TestBuildTools_CustomTypeTools(t *testing.T) {
tests := []struct {
name string
tools []ClaudeTool
expectedLen int
description string
}{
{
name: "Standard tool format",
tools: []ClaudeTool{
{
Name: "get_weather",
Description: "Get weather information",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"location": map[string]any{"type": "string"},
},
},
},
},
expectedLen: 1,
description: "标准工具格式应该正常转换",
},
{
name: "Custom type tool (MCP format)",
tools: []ClaudeTool{
{
Type: "custom",
Name: "mcp_tool",
Custom: &CustomToolSpec{
Description: "MCP tool description",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"param": map[string]any{"type": "string"},
},
},
},
},
},
expectedLen: 1,
description: "Custom类型工具应该从Custom字段读取description和input_schema",
},
{
name: "Mixed standard and custom tools",
tools: []ClaudeTool{
{
Name: "standard_tool",
Description: "Standard tool",
InputSchema: map[string]any{"type": "object"},
},
{
Type: "custom",
Name: "custom_tool",
Custom: &CustomToolSpec{
Description: "Custom tool",
InputSchema: map[string]any{"type": "object"},
},
},
},
expectedLen: 1, // 返回一个GeminiToolDeclaration,包含2个function declarations
description: "混合标准和custom工具应该都能正确转换",
},
{
name: "Invalid custom tool - nil Custom field",
tools: []ClaudeTool{
{
Type: "custom",
Name: "invalid_custom",
// Custom 为 nil
},
},
expectedLen: 0, // 应该被跳过
description: "Custom字段为nil的custom工具应该被跳过",
},
{
name: "Invalid custom tool - nil InputSchema",
tools: []ClaudeTool{
{
Type: "custom",
Name: "invalid_custom",
Custom: &CustomToolSpec{
Description: "Invalid",
// InputSchema 为 nil
},
},
},
expectedLen: 0, // 应该被跳过
description: "InputSchema为nil的custom工具应该被跳过",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := buildTools(tt.tools)
if len(result) != tt.expectedLen {
t.Errorf("%s: got %d tool declarations, want %d", tt.description, len(result), tt.expectedLen)
}
// 验证function declarations存在
if len(result) > 0 && result[0].FunctionDeclarations != nil {
if len(result[0].FunctionDeclarations) != len(tt.tools) {
t.Errorf("%s: got %d function declarations, want %d",
tt.description, len(result[0].FunctionDeclarations), len(tt.tools))
}
}
})
}
}
......@@ -16,6 +16,12 @@ const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleav
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta)
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
// ApiKeyBetaHeader API-key 账号建议使用的 anthropic-beta header(不包含 oauth)
const ApiKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
// ApiKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header(不包含 oauth / claude-code)
const ApiKeyHaikuBetaHeader = BetaInterleavedThinking
// Claude Code 客户端默认请求头
var DefaultHeaders = map[string]string{
"User-Agent": "claude-cli/2.0.62 (external, cli)",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment