Unverified Commit ff8b1b4a authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #467 from slovx2/main

Antigravity 相关BUG修复及调度优化
parents c0c9c984 4cce21b1
......@@ -14,6 +14,9 @@ const (
// RetryCount 表示当前请求在网关层的重试次数(用于 Ops 记录与排障)。
RetryCount Key = "ctx_retry_count"
// AccountSwitchCount 表示请求过程中发生的账号切换次数
AccountSwitchCount Key = "ctx_account_switch_count"
// IsClaudeCodeClient 标识当前请求是否来自 Claude Code 客户端
IsClaudeCodeClient Key = "ctx_is_claude_code_client"
// Group 认证后的分组信息,由 API Key 认证中间件设置
......
......@@ -142,8 +142,11 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
group.FieldImagePrice4k,
group.FieldClaudeCodeOnly,
group.FieldFallbackGroupID,
group.FieldFallbackGroupIDOnInvalidRequest,
group.FieldModelRoutingEnabled,
group.FieldModelRouting,
group.FieldMcpXMLInject,
group.FieldSupportedModelScopes,
)
}).
Only(ctx)
......@@ -459,28 +462,31 @@ func groupEntityToService(g *dbent.Group) *service.Group {
return nil
}
return &service.Group{
ID: g.ID,
Name: g.Name,
Description: derefString(g.Description),
Platform: g.Platform,
RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive,
Status: g.Status,
Hydrated: true,
SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUsd,
WeeklyLimitUSD: g.WeeklyLimitUsd,
MonthlyLimitUSD: g.MonthlyLimitUsd,
ImagePrice1K: g.ImagePrice1k,
ImagePrice2K: g.ImagePrice2k,
ImagePrice4K: g.ImagePrice4k,
DefaultValidityDays: g.DefaultValidityDays,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
ID: g.ID,
Name: g.Name,
Description: derefString(g.Description),
Platform: g.Platform,
RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive,
Status: g.Status,
Hydrated: true,
SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUsd,
WeeklyLimitUSD: g.WeeklyLimitUsd,
MonthlyLimitUSD: g.MonthlyLimitUsd,
ImagePrice1K: g.ImagePrice1k,
ImagePrice2K: g.ImagePrice2k,
ImagePrice4K: g.ImagePrice4k,
DefaultValidityDays: g.DefaultValidityDays,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest,
ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled,
MCPXMLInject: g.McpXMLInject,
SupportedModelScopes: g.SupportedModelScopes,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
}
}
......
......@@ -50,13 +50,18 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetNillableFallbackGroupID(groupIn.FallbackGroupID).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled)
SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
SetMcpXMLInject(groupIn.MCPXMLInject)
// 设置模型路由配置
if groupIn.ModelRouting != nil {
builder = builder.SetModelRouting(groupIn.ModelRouting)
}
// 设置支持的模型系列(始终设置,空数组表示不限制)
builder = builder.SetSupportedModelScopes(groupIn.SupportedModelScopes)
created, err := builder.Save(ctx)
if err == nil {
groupIn.ID = created.ID
......@@ -87,7 +92,6 @@ func (r *groupRepository) GetByIDLite(ctx context.Context, id int64) (*service.G
if err != nil {
return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil)
}
return groupEntityToService(m), nil
}
......@@ -108,7 +112,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled)
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
SetMcpXMLInject(groupIn.MCPXMLInject)
// 处理 FallbackGroupID:nil 时清除,否则设置
if groupIn.FallbackGroupID != nil {
......@@ -116,6 +121,12 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
} else {
builder = builder.ClearFallbackGroupID()
}
// 处理 FallbackGroupIDOnInvalidRequest:nil 时清除,否则设置
if groupIn.FallbackGroupIDOnInvalidRequest != nil {
builder = builder.SetFallbackGroupIDOnInvalidRequest(*groupIn.FallbackGroupIDOnInvalidRequest)
} else {
builder = builder.ClearFallbackGroupIDOnInvalidRequest()
}
// 处理 ModelRouting:nil 时清除,否则设置
if groupIn.ModelRouting != nil {
......@@ -124,6 +135,9 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
builder = builder.ClearModelRouting()
}
// 处理 SupportedModelScopes(始终设置,空数组表示不限制)
builder = builder.SetSupportedModelScopes(groupIn.SupportedModelScopes)
updated, err := builder.Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists)
......
......@@ -43,6 +43,7 @@ INSERT INTO ops_system_metrics (
upstream_529_count,
token_consumed,
account_switch_count,
qps,
tps,
......@@ -81,14 +82,14 @@ INSERT INTO ops_system_metrics (
$1,$2,$3,$4,
$5,$6,$7,$8,
$9,$10,$11,
$12,$13,$14,
$15,$16,$17,$18,$19,$20,
$21,$22,$23,$24,$25,$26,
$27,$28,$29,$30,
$31,$32,
$33,$34,
$35,$36,$37,
$38,$39
$12,$13,$14,$15,
$16,$17,$18,$19,$20,$21,
$22,$23,$24,$25,$26,$27,
$28,$29,$30,$31,
$32,$33,
$34,$35,
$36,$37,$38,
$39,$40
)`
_, err := r.db.ExecContext(
......@@ -109,6 +110,7 @@ INSERT INTO ops_system_metrics (
input.Upstream529Count,
input.TokenConsumed,
input.AccountSwitchCount,
opsNullFloat64(input.QPS),
opsNullFloat64(input.TPS),
......@@ -177,7 +179,8 @@ SELECT
db_conn_waiting,
goroutine_count,
concurrency_queue_depth
concurrency_queue_depth,
account_switch_count
FROM ops_system_metrics
WHERE window_minutes = $1
AND platform IS NULL
......@@ -199,6 +202,7 @@ LIMIT 1`
var dbWaiting sql.NullInt64
var goroutines sql.NullInt64
var queueDepth sql.NullInt64
var accountSwitchCount sql.NullInt64
if err := r.db.QueryRowContext(ctx, q, windowMinutes).Scan(
&out.ID,
......@@ -217,6 +221,7 @@ LIMIT 1`
&dbWaiting,
&goroutines,
&queueDepth,
&accountSwitchCount,
); err != nil {
return nil, err
}
......@@ -273,6 +278,10 @@ LIMIT 1`
v := int(queueDepth.Int64)
out.ConcurrencyQueueDepth = &v
}
if accountSwitchCount.Valid {
v := accountSwitchCount.Int64
out.AccountSwitchCount = &v
}
return &out, nil
}
......
......@@ -56,18 +56,44 @@ error_buckets AS (
AND COALESCE(status_code, 0) >= 400
GROUP BY 1
),
switch_buckets AS (
SELECT ` + errorBucketExpr + ` AS bucket,
COALESCE(SUM(CASE
WHEN split_part(ev->>'kind', ':', 1) IN ('failover', 'retry_exhausted_failover', 'failover_on_400') THEN 1
ELSE 0
END), 0) AS switch_count
FROM ops_error_logs
CROSS JOIN LATERAL jsonb_array_elements(
COALESCE(NULLIF(upstream_errors, 'null'::jsonb), '[]'::jsonb)
) AS ev
` + errorWhere + `
AND upstream_errors IS NOT NULL
GROUP BY 1
),
combined AS (
SELECT COALESCE(u.bucket, e.bucket) AS bucket,
COALESCE(u.success_count, 0) AS success_count,
COALESCE(e.error_count, 0) AS error_count,
COALESCE(u.token_consumed, 0) AS token_consumed
FROM usage_buckets u
FULL OUTER JOIN error_buckets e ON u.bucket = e.bucket
SELECT
bucket,
SUM(success_count) AS success_count,
SUM(error_count) AS error_count,
SUM(token_consumed) AS token_consumed,
SUM(switch_count) AS switch_count
FROM (
SELECT bucket, success_count, 0 AS error_count, token_consumed, 0 AS switch_count
FROM usage_buckets
UNION ALL
SELECT bucket, 0, error_count, 0, 0
FROM error_buckets
UNION ALL
SELECT bucket, 0, 0, 0, switch_count
FROM switch_buckets
) t
GROUP BY bucket
)
SELECT
bucket,
(success_count + error_count) AS request_count,
token_consumed
token_consumed,
switch_count
FROM combined
ORDER BY bucket ASC`
......@@ -84,13 +110,18 @@ ORDER BY bucket ASC`
var bucket time.Time
var requests int64
var tokens sql.NullInt64
if err := rows.Scan(&bucket, &requests, &tokens); err != nil {
var switches sql.NullInt64
if err := rows.Scan(&bucket, &requests, &tokens, &switches); err != nil {
return nil, err
}
tokenConsumed := int64(0)
if tokens.Valid {
tokenConsumed = tokens.Int64
}
switchCount := int64(0)
if switches.Valid {
switchCount = switches.Int64
}
denom := float64(bucketSeconds)
if denom <= 0 {
......@@ -103,6 +134,7 @@ ORDER BY bucket ASC`
BucketStart: bucket.UTC(),
RequestCount: requests,
TokenConsumed: tokenConsumed,
SwitchCount: switchCount,
QPS: qps,
TPS: tps,
})
......@@ -385,6 +417,7 @@ func fillOpsThroughputBuckets(start, end time.Time, bucketSeconds int, points []
BucketStart: cursor,
RequestCount: 0,
TokenConsumed: 0,
SwitchCount: 0,
QPS: 0,
TPS: 0,
})
......
......@@ -186,6 +186,7 @@ func TestAPIContracts(t *testing.T) {
"image_price_4k": null,
"claude_code_only": false,
"fallback_group_id": null,
"fallback_group_id_on_invalid_request": null,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
}
......@@ -607,7 +608,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil)
......
......@@ -111,9 +111,14 @@ type CreateGroupInput struct {
ImagePrice4K *float64
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
FallbackGroupIDOnInvalidRequest *int64
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64
ModelRoutingEnabled bool // 是否启用模型路由
MCPXMLInject *bool
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string
// 从指定分组复制账号(创建分组后在同一事务内绑定)
CopyAccountsFromGroupIDs []int64
}
......@@ -135,9 +140,14 @@ type UpdateGroupInput struct {
ImagePrice4K *float64
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
FallbackGroupIDOnInvalidRequest *int64
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64
ModelRoutingEnabled *bool // 是否启用模型路由
MCPXMLInject *bool
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64
}
......@@ -594,6 +604,22 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
return nil, err
}
}
fallbackOnInvalidRequest := input.FallbackGroupIDOnInvalidRequest
if fallbackOnInvalidRequest != nil && *fallbackOnInvalidRequest <= 0 {
fallbackOnInvalidRequest = nil
}
// 校验无效请求兜底分组
if fallbackOnInvalidRequest != nil {
if err := s.validateFallbackGroupOnInvalidRequest(ctx, 0, platform, subscriptionType, *fallbackOnInvalidRequest); err != nil {
return nil, err
}
}
// MCPXMLInject:默认为 true,仅当显式传入 false 时关闭
mcpXMLInject := true
if input.MCPXMLInject != nil {
mcpXMLInject = *input.MCPXMLInject
}
// 如果指定了复制账号的源分组,先获取账号 ID 列表
var accountIDsToCopy []int64
......@@ -628,22 +654,25 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
}
group := &Group{
Name: input.Name,
Description: input.Description,
Platform: platform,
RateMultiplier: input.RateMultiplier,
IsExclusive: input.IsExclusive,
Status: StatusActive,
SubscriptionType: subscriptionType,
DailyLimitUSD: dailyLimit,
WeeklyLimitUSD: weeklyLimit,
MonthlyLimitUSD: monthlyLimit,
ImagePrice1K: imagePrice1K,
ImagePrice2K: imagePrice2K,
ImagePrice4K: imagePrice4K,
ClaudeCodeOnly: input.ClaudeCodeOnly,
FallbackGroupID: input.FallbackGroupID,
ModelRouting: input.ModelRouting,
Name: input.Name,
Description: input.Description,
Platform: platform,
RateMultiplier: input.RateMultiplier,
IsExclusive: input.IsExclusive,
Status: StatusActive,
SubscriptionType: subscriptionType,
DailyLimitUSD: dailyLimit,
WeeklyLimitUSD: weeklyLimit,
MonthlyLimitUSD: monthlyLimit,
ImagePrice1K: imagePrice1K,
ImagePrice2K: imagePrice2K,
ImagePrice4K: imagePrice4K,
ClaudeCodeOnly: input.ClaudeCodeOnly,
FallbackGroupID: input.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest,
ModelRouting: input.ModelRouting,
MCPXMLInject: mcpXMLInject,
SupportedModelScopes: input.SupportedModelScopes,
}
if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err
......@@ -714,6 +743,37 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro
}
}
// validateFallbackGroupOnInvalidRequest 校验无效请求兜底分组的有效性
// currentGroupID: 当前分组 ID(新建时为 0)
// platform/subscriptionType: 当前分组的有效平台/订阅类型
// fallbackGroupID: 兜底分组 ID
func (s *adminServiceImpl) validateFallbackGroupOnInvalidRequest(ctx context.Context, currentGroupID int64, platform, subscriptionType string, fallbackGroupID int64) error {
if platform != PlatformAnthropic && platform != PlatformAntigravity {
return fmt.Errorf("invalid request fallback only supported for anthropic or antigravity groups")
}
if subscriptionType == SubscriptionTypeSubscription {
return fmt.Errorf("subscription groups cannot set invalid request fallback")
}
if currentGroupID > 0 && currentGroupID == fallbackGroupID {
return fmt.Errorf("cannot set self as invalid request fallback group")
}
fallbackGroup, err := s.groupRepo.GetByIDLite(ctx, fallbackGroupID)
if err != nil {
return fmt.Errorf("fallback group not found: %w", err)
}
if fallbackGroup.Platform != PlatformAnthropic {
return fmt.Errorf("fallback group must be anthropic platform")
}
if fallbackGroup.SubscriptionType == SubscriptionTypeSubscription {
return fmt.Errorf("fallback group cannot be subscription type")
}
if fallbackGroup.FallbackGroupIDOnInvalidRequest != nil {
return fmt.Errorf("fallback group cannot have invalid request fallback configured")
}
return nil
}
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
group, err := s.groupRepo.GetByID(ctx, id)
if err != nil {
......@@ -780,6 +840,20 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
group.FallbackGroupID = nil
}
}
fallbackOnInvalidRequest := group.FallbackGroupIDOnInvalidRequest
if input.FallbackGroupIDOnInvalidRequest != nil {
if *input.FallbackGroupIDOnInvalidRequest > 0 {
fallbackOnInvalidRequest = input.FallbackGroupIDOnInvalidRequest
} else {
fallbackOnInvalidRequest = nil
}
}
if fallbackOnInvalidRequest != nil {
if err := s.validateFallbackGroupOnInvalidRequest(ctx, id, group.Platform, group.SubscriptionType, *fallbackOnInvalidRequest); err != nil {
return nil, err
}
}
group.FallbackGroupIDOnInvalidRequest = fallbackOnInvalidRequest
// 模型路由配置
if input.ModelRouting != nil {
......@@ -788,6 +862,14 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.ModelRoutingEnabled != nil {
group.ModelRoutingEnabled = *input.ModelRoutingEnabled
}
if input.MCPXMLInject != nil {
group.MCPXMLInject = *input.MCPXMLInject
}
// 支持的模型系列(仅 antigravity 平台使用)
if input.SupportedModelScopes != nil {
group.SupportedModelScopes = *input.SupportedModelScopes
}
if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, err
......
......@@ -394,3 +394,382 @@ func (s *groupRepoStubForFallbackCycle) BindAccountsToGroup(_ context.Context, _
func (s *groupRepoStubForFallbackCycle) GetAccountIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) {
panic("unexpected GetAccountIDsByGroupIDs call")
}
type groupRepoStubForInvalidRequestFallback struct {
groups map[int64]*Group
created *Group
updated *Group
}
func (s *groupRepoStubForInvalidRequestFallback) Create(_ context.Context, g *Group) error {
s.created = g
return nil
}
func (s *groupRepoStubForInvalidRequestFallback) Update(_ context.Context, g *Group) error {
s.updated = g
return nil
}
func (s *groupRepoStubForInvalidRequestFallback) GetByID(ctx context.Context, id int64) (*Group, error) {
return s.GetByIDLite(ctx, id)
}
func (s *groupRepoStubForInvalidRequestFallback) GetByIDLite(_ context.Context, id int64) (*Group, error) {
if g, ok := s.groups[id]; ok {
return g, nil
}
return nil, ErrGroupNotFound
}
func (s *groupRepoStubForInvalidRequestFallback) Delete(_ context.Context, _ int64) error {
panic("unexpected Delete call")
}
func (s *groupRepoStubForInvalidRequestFallback) DeleteCascade(_ context.Context, _ int64) ([]int64, error) {
panic("unexpected DeleteCascade call")
}
func (s *groupRepoStubForInvalidRequestFallback) List(_ context.Context, _ pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
func (s *groupRepoStubForInvalidRequestFallback) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}
func (s *groupRepoStubForInvalidRequestFallback) ListActive(_ context.Context) ([]Group, error) {
panic("unexpected ListActive call")
}
func (s *groupRepoStubForInvalidRequestFallback) ListActiveByPlatform(_ context.Context, _ string) ([]Group, error) {
panic("unexpected ListActiveByPlatform call")
}
func (s *groupRepoStubForInvalidRequestFallback) ExistsByName(_ context.Context, _ string) (bool, error) {
panic("unexpected ExistsByName call")
}
func (s *groupRepoStubForInvalidRequestFallback) GetAccountCount(_ context.Context, _ int64) (int64, error) {
panic("unexpected GetAccountCount call")
}
func (s *groupRepoStubForInvalidRequestFallback) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
panic("unexpected DeleteAccountGroupsByGroupID call")
}
func (s *groupRepoStubForInvalidRequestFallback) GetAccountIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) {
panic("unexpected GetAccountIDsByGroupIDs call")
}
func (s *groupRepoStubForInvalidRequestFallback) BindAccountsToGroup(_ context.Context, _ int64, _ []int64) error {
panic("unexpected BindAccountsToGroup call")
}
func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsUnsupportedPlatform(t *testing.T) {
fallbackID := int64(10)
repo := &groupRepoStubForInvalidRequestFallback{
groups: map[int64]*Group{
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
},
}
svc := &adminServiceImpl{groupRepo: repo}
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformOpenAI,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
require.Error(t, err)
require.Contains(t, err.Error(), "invalid request fallback only supported for anthropic or antigravity groups")
require.Nil(t, repo.created)
}
func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsSubscription(t *testing.T) {
fallbackID := int64(10)
repo := &groupRepoStubForInvalidRequestFallback{
groups: map[int64]*Group{
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
},
}
svc := &adminServiceImpl{groupRepo: repo}
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAnthropic,
SubscriptionType: SubscriptionTypeSubscription,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
require.Error(t, err)
require.Contains(t, err.Error(), "subscription groups cannot set invalid request fallback")
require.Nil(t, repo.created)
}
func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsFallbackGroup(t *testing.T) {
tests := []struct {
name string
fallback *Group
wantMessage string
}{
{
name: "openai_target",
fallback: &Group{ID: 10, Platform: PlatformOpenAI, SubscriptionType: SubscriptionTypeStandard},
wantMessage: "fallback group must be anthropic platform",
},
{
name: "antigravity_target",
fallback: &Group{ID: 10, Platform: PlatformAntigravity, SubscriptionType: SubscriptionTypeStandard},
wantMessage: "fallback group must be anthropic platform",
},
{
name: "subscription_group",
fallback: &Group{ID: 10, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeSubscription},
wantMessage: "fallback group cannot be subscription type",
},
{
name: "nested_fallback",
fallback: &Group{
ID: 10,
Platform: PlatformAnthropic,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: func() *int64 { v := int64(99); return &v }(),
},
wantMessage: "fallback group cannot have invalid request fallback configured",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
fallbackID := tc.fallback.ID
repo := &groupRepoStubForInvalidRequestFallback{
groups: map[int64]*Group{
fallbackID: tc.fallback,
},
}
svc := &adminServiceImpl{groupRepo: repo}
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAnthropic,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
require.Error(t, err)
require.Contains(t, err.Error(), tc.wantMessage)
require.Nil(t, repo.created)
})
}
}
func TestAdminService_CreateGroup_InvalidRequestFallbackNotFound(t *testing.T) {
fallbackID := int64(10)
repo := &groupRepoStubForInvalidRequestFallback{}
svc := &adminServiceImpl{groupRepo: repo}
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAnthropic,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
require.Error(t, err)
require.Contains(t, err.Error(), "fallback group not found")
require.Nil(t, repo.created)
}
func TestAdminService_CreateGroup_InvalidRequestFallbackAllowsAntigravity(t *testing.T) {
fallbackID := int64(10)
repo := &groupRepoStubForInvalidRequestFallback{
groups: map[int64]*Group{
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
},
}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAntigravity,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.created)
require.Equal(t, fallbackID, *repo.created.FallbackGroupIDOnInvalidRequest)
}
func TestAdminService_CreateGroup_InvalidRequestFallbackClearsOnZero(t *testing.T) {
zero := int64(0)
repo := &groupRepoStubForInvalidRequestFallback{}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAnthropic,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &zero,
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.created)
require.Nil(t, repo.created.FallbackGroupIDOnInvalidRequest)
}
func TestAdminService_UpdateGroup_InvalidRequestFallbackPlatformMismatch(t *testing.T) {
fallbackID := int64(10)
existing := &Group{
ID: 1,
Name: "g1",
Platform: PlatformAnthropic,
SubscriptionType: SubscriptionTypeStandard,
Status: StatusActive,
FallbackGroupIDOnInvalidRequest: &fallbackID,
}
repo := &groupRepoStubForInvalidRequestFallback{
groups: map[int64]*Group{
existing.ID: existing,
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
},
}
svc := &adminServiceImpl{groupRepo: repo}
_, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
Platform: PlatformOpenAI,
})
require.Error(t, err)
require.Contains(t, err.Error(), "invalid request fallback only supported for anthropic or antigravity groups")
require.Nil(t, repo.updated)
}
func TestAdminService_UpdateGroup_InvalidRequestFallbackSubscriptionMismatch(t *testing.T) {
fallbackID := int64(10)
existing := &Group{
ID: 1,
Name: "g1",
Platform: PlatformAnthropic,
SubscriptionType: SubscriptionTypeStandard,
Status: StatusActive,
FallbackGroupIDOnInvalidRequest: &fallbackID,
}
repo := &groupRepoStubForInvalidRequestFallback{
groups: map[int64]*Group{
existing.ID: existing,
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
},
}
svc := &adminServiceImpl{groupRepo: repo}
_, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
SubscriptionType: SubscriptionTypeSubscription,
})
require.Error(t, err)
require.Contains(t, err.Error(), "subscription groups cannot set invalid request fallback")
require.Nil(t, repo.updated)
}
func TestAdminService_UpdateGroup_InvalidRequestFallbackClearsOnZero(t *testing.T) {
fallbackID := int64(10)
existing := &Group{
ID: 1,
Name: "g1",
Platform: PlatformAnthropic,
SubscriptionType: SubscriptionTypeStandard,
Status: StatusActive,
FallbackGroupIDOnInvalidRequest: &fallbackID,
}
repo := &groupRepoStubForInvalidRequestFallback{
groups: map[int64]*Group{
existing.ID: existing,
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
},
}
svc := &adminServiceImpl{groupRepo: repo}
clear := int64(0)
group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
Platform: PlatformOpenAI,
FallbackGroupIDOnInvalidRequest: &clear,
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.updated)
require.Nil(t, repo.updated.FallbackGroupIDOnInvalidRequest)
}
func TestAdminService_UpdateGroup_InvalidRequestFallbackRejectsFallbackGroup(t *testing.T) {
fallbackID := int64(10)
existing := &Group{
ID: 1,
Name: "g1",
Platform: PlatformAnthropic,
SubscriptionType: SubscriptionTypeStandard,
Status: StatusActive,
}
repo := &groupRepoStubForInvalidRequestFallback{
groups: map[int64]*Group{
existing.ID: existing,
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeSubscription},
},
}
svc := &adminServiceImpl{groupRepo: repo}
_, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
require.Error(t, err)
require.Contains(t, err.Error(), "fallback group cannot be subscription type")
require.Nil(t, repo.updated)
}
func TestAdminService_UpdateGroup_InvalidRequestFallbackSetSuccess(t *testing.T) {
fallbackID := int64(10)
existing := &Group{
ID: 1,
Name: "g1",
Platform: PlatformAnthropic,
SubscriptionType: SubscriptionTypeStandard,
Status: StatusActive,
}
repo := &groupRepoStubForInvalidRequestFallback{
groups: map[int64]*Group{
existing.ID: existing,
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
},
}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.updated)
require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest)
}
func TestAdminService_UpdateGroup_InvalidRequestFallbackAllowsAntigravity(t *testing.T) {
fallbackID := int64(10)
existing := &Group{
ID: 1,
Name: "g1",
Platform: PlatformAntigravity,
SubscriptionType: SubscriptionTypeStandard,
Status: StatusActive,
}
repo := &groupRepoStubForInvalidRequestFallback{
groups: map[int64]*Group{
existing.ID: existing,
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
},
}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.updated)
require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest)
}
......@@ -13,23 +13,34 @@ import (
"net"
"net/http"
"os"
"strconv"
"strings"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
const (
antigravityStickySessionTTL = time.Hour
antigravityMaxRetries = 3
antigravityRetryBaseDelay = 1 * time.Second
antigravityRetryMaxDelay = 16 * time.Second
antigravityStickySessionTTL = time.Hour
antigravityDefaultMaxRetries = 3
antigravityRetryBaseDelay = 1 * time.Second
antigravityRetryMaxDelay = 16 * time.Second
)
const antigravityScopeRateLimitEnv = "GATEWAY_ANTIGRAVITY_429_SCOPE_LIMIT"
const (
antigravityMaxRetriesEnv = "GATEWAY_ANTIGRAVITY_MAX_RETRIES"
antigravityMaxRetriesAfterSwitchEnv = "GATEWAY_ANTIGRAVITY_AFTER_SWITCHMAX_RETRIES"
antigravityMaxRetriesClaudeEnv = "GATEWAY_ANTIGRAVITY_MAX_RETRIES_CLAUDE"
antigravityMaxRetriesGeminiTextEnv = "GATEWAY_ANTIGRAVITY_MAX_RETRIES_GEMINI_TEXT"
antigravityMaxRetriesGeminiImageEnv = "GATEWAY_ANTIGRAVITY_MAX_RETRIES_GEMINI_IMAGE"
antigravityScopeRateLimitEnv = "GATEWAY_ANTIGRAVITY_429_SCOPE_LIMIT"
antigravityBillingModelEnv = "GATEWAY_ANTIGRAVITY_BILL_WITH_MAPPED_MODEL"
antigravityFallbackSecondsEnv = "GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS"
)
// antigravityRetryLoopParams 重试循环的参数
type antigravityRetryLoopParams struct {
......@@ -41,6 +52,7 @@ type antigravityRetryLoopParams struct {
action string
body []byte
quotaScope AntigravityQuotaScope
maxRetries int
c *gin.Context
httpUpstream HTTPUpstream
settingService *SettingService
......@@ -52,11 +64,28 @@ type antigravityRetryLoopResult struct {
resp *http.Response
}
// PromptTooLongError 表示上游明确返回 prompt too long
type PromptTooLongError struct {
StatusCode int
RequestID string
Body []byte
}
func (e *PromptTooLongError) Error() string {
return fmt.Sprintf("prompt too long: status=%d", e.StatusCode)
}
// antigravityRetryLoop 执行带 URL fallback 的重试循环
func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) {
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
baseURLs := antigravity.ForwardBaseURLs()
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLsWithBase(baseURLs)
if len(availableURLs) == 0 {
availableURLs = antigravity.BaseURLs
availableURLs = baseURLs
}
maxRetries := p.maxRetries
if maxRetries <= 0 {
maxRetries = antigravityDefaultMaxRetries
}
var resp *http.Response
......@@ -76,7 +105,7 @@ func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopRe
urlFallbackLoop:
for urlIdx, baseURL := range availableURLs {
usedBaseURL = baseURL
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
for attempt := 1; attempt <= maxRetries; attempt++ {
select {
case <-p.ctx.Done():
log.Printf("%s status=context_canceled error=%v", p.prefix, p.ctx.Err())
......@@ -109,8 +138,8 @@ urlFallbackLoop:
log.Printf("%s URL fallback (connection error): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1])
continue urlFallbackLoop
}
if attempt < antigravityMaxRetries {
log.Printf("%s status=request_failed retry=%d/%d error=%v", p.prefix, attempt, antigravityMaxRetries, err)
if attempt < maxRetries {
log.Printf("%s status=request_failed retry=%d/%d error=%v", p.prefix, attempt, maxRetries, err)
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", p.prefix)
return nil, p.ctx.Err()
......@@ -134,7 +163,7 @@ urlFallbackLoop:
}
// 账户/模型配额限流,重试 3 次(指数退避)
if attempt < antigravityMaxRetries {
if attempt < maxRetries {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{
......@@ -147,7 +176,7 @@ urlFallbackLoop:
Message: upstreamMsg,
Detail: getUpstreamDetail(respBody),
})
log.Printf("%s status=429 retry=%d/%d body=%s", p.prefix, attempt, antigravityMaxRetries, truncateForLog(respBody, 200))
log.Printf("%s status=429 retry=%d/%d body=%s", p.prefix, attempt, maxRetries, truncateForLog(respBody, 200))
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", p.prefix)
return nil, p.ctx.Err()
......@@ -171,7 +200,7 @@ urlFallbackLoop:
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
if attempt < antigravityMaxRetries {
if attempt < maxRetries {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{
......@@ -184,7 +213,7 @@ urlFallbackLoop:
Message: upstreamMsg,
Detail: getUpstreamDetail(respBody),
})
log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500))
log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, maxRetries, truncateForLog(respBody, 500))
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", p.prefix)
return nil, p.ctx.Err()
......@@ -390,6 +419,11 @@ type TestConnectionResult struct {
// TestConnection 测试 Antigravity 账号连接(非流式,无重试、无计费)
// 支持 Claude 和 Gemini 两种协议,根据 modelID 前缀自动选择
func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) {
// 上游透传账号使用专用测试方法
if account.Type == AccountTypeUpstream {
return s.testUpstreamConnection(ctx, account, modelID)
}
// 获取 token
if s.tokenProvider == nil {
return nil, errors.New("antigravity token provider not configured")
......@@ -484,6 +518,87 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
return nil, lastErr
}
// testUpstreamConnection 测试上游透传账号连接
func (s *AntigravityGatewayService) testUpstreamConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) {
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
if baseURL == "" || apiKey == "" {
return nil, errors.New("upstream account missing base_url or api_key")
}
baseURL = strings.TrimSuffix(baseURL, "/")
// 使用 Claude 模型进行测试
if modelID == "" {
modelID = "claude-sonnet-4-20250514"
}
// 构建最小测试请求
testReq := map[string]any{
"model": modelID,
"max_tokens": 1,
"messages": []map[string]any{
{"role": "user", "content": "."},
},
}
requestBody, err := json.Marshal(testReq)
if err != nil {
return nil, fmt.Errorf("构建请求失败: %w", err)
}
// 构建 HTTP 请求
upstreamURL := baseURL + "/v1/messages"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(requestBody))
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
req.Header.Set("x-api-key", apiKey)
req.Header.Set("anthropic-version", "2023-06-01")
// 代理 URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
log.Printf("[antigravity-Test-Upstream] account=%s url=%s", account.Name, upstreamURL)
// 发送请求
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil {
return nil, fmt.Errorf("请求失败: %w", err)
}
defer func() { _ = resp.Body.Close() }()
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
if resp.StatusCode >= 400 {
return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody))
}
// 提取响应文本
var respData map[string]any
text := ""
if json.Unmarshal(respBody, &respData) == nil {
if content, ok := respData["content"].([]any); ok && len(content) > 0 {
if block, ok := content[0].(map[string]any); ok {
if t, ok := block["text"].(string); ok {
text = t
}
}
}
}
return &TestConnectionResult{
Text: text,
MappedModel: modelID,
}, nil
}
// buildGeminiTestRequest 构建 Gemini 格式测试请求
// 使用最小 token 消耗:输入 "." + maxOutputTokens: 1
func (s *AntigravityGatewayService) buildGeminiTestRequest(projectID, model string) ([]byte, error) {
......@@ -534,6 +649,10 @@ func (s *AntigravityGatewayService) getClaudeTransformOptions(ctx context.Contex
}
opts.EnableIdentityPatch = s.settingService.IsIdentityPatchEnabled(ctx)
opts.IdentityPatch = s.settingService.GetIdentityPatchPrompt(ctx)
if group, ok := ctx.Value(ctxkey.Group).(*Group); ok && group != nil {
opts.EnableMCPXML = group.MCPXMLInject
}
return opts
}
......@@ -702,6 +821,11 @@ func isModelNotFoundError(statusCode int, body []byte) bool {
// Forward 转发 Claude 协议请求(Claude → Gemini 转换)
func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
// 上游透传账号直接转发,不走 OAuth token 刷新
if account.Type == AccountTypeUpstream {
return s.ForwardUpstream(ctx, c, account, body)
}
startTime := time.Now()
sessionID := getSessionID(c)
prefix := logPrefix(sessionID, account.Name)
......@@ -718,6 +842,12 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
originalModel := claudeReq.Model
mappedModel := s.getMappedModel(account, claudeReq.Model)
quotaScope, _ := resolveAntigravityQuotaScope(originalModel)
billingModel := originalModel
if antigravityUseMappedModelForBilling() && strings.TrimSpace(mappedModel) != "" {
billingModel = mappedModel
}
afterSwitch := antigravityHasAccountSwitch(ctx)
maxRetries := antigravityMaxRetriesForModel(originalModel, afterSwitch)
// 获取 access_token
if s.tokenProvider == nil {
......@@ -766,6 +896,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
httpUpstream: s.httpUpstream,
settingService: s.settingService,
handleError: s.handleUpstreamError,
maxRetries: maxRetries,
})
if err != nil {
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries")
......@@ -842,6 +973,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
httpUpstream: s.httpUpstream,
settingService: s.settingService,
handleError: s.handleUpstreamError,
maxRetries: maxRetries,
})
if retryErr != nil {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
......@@ -917,6 +1049,39 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
// 处理错误响应(重试后仍失败或不触发重试)
if resp.StatusCode >= 400 {
if resp.StatusCode == http.StatusBadRequest {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
log.Printf("%s status=400 prompt_too_long=%v upstream_message=%q request_id=%s body=%s", prefix, isPromptTooLongError(respBody), upstreamMsg, resp.Header.Get("x-request-id"), truncateForLog(respBody, 500))
}
if resp.StatusCode == http.StatusBadRequest && isPromptTooLongError(respBody) {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
maxBytes := 2048
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
}
upstreamDetail := ""
if logBody {
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "prompt_too_long",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &PromptTooLongError{
StatusCode: resp.StatusCode,
RequestID: resp.Header.Get("x-request-id"),
Body: respBody,
}
}
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
if s.shouldFailoverUpstreamError(resp.StatusCode) {
......@@ -978,7 +1143,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return &ForwardResult{
RequestID: requestID,
Usage: *usage,
Model: originalModel, // 使用原始模型用于计费和日志
Model: billingModel, // 计费模型(可按映射模型覆盖)
Stream: claudeReq.Stream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
......@@ -1006,21 +1171,55 @@ func isSignatureRelatedError(respBody []byte) bool {
return false
}
func isPromptTooLongError(respBody []byte) bool {
msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
if msg == "" {
msg = strings.ToLower(string(respBody))
}
return strings.Contains(msg, "prompt is too long")
}
func extractAntigravityErrorMessage(body []byte) string {
var payload map[string]any
if err := json.Unmarshal(body, &payload); err != nil {
return ""
}
parseNestedMessage := func(msg string) string {
trimmed := strings.TrimSpace(msg)
if trimmed == "" || !strings.HasPrefix(trimmed, "{") {
return ""
}
var nested map[string]any
if err := json.Unmarshal([]byte(trimmed), &nested); err != nil {
return ""
}
if errObj, ok := nested["error"].(map[string]any); ok {
if innerMsg, ok := errObj["message"].(string); ok && strings.TrimSpace(innerMsg) != "" {
return innerMsg
}
}
if innerMsg, ok := nested["message"].(string); ok && strings.TrimSpace(innerMsg) != "" {
return innerMsg
}
return ""
}
// Google-style: {"error": {"message": "..."}}
if errObj, ok := payload["error"].(map[string]any); ok {
if msg, ok := errObj["message"].(string); ok && strings.TrimSpace(msg) != "" {
if innerMsg := parseNestedMessage(msg); innerMsg != "" {
return innerMsg
}
return msg
}
}
// Fallback: top-level message
if msg, ok := payload["message"].(string); ok && strings.TrimSpace(msg) != "" {
if innerMsg := parseNestedMessage(msg); innerMsg != "" {
return innerMsg
}
return msg
}
......@@ -1248,6 +1447,208 @@ func stripSignatureSensitiveBlocksFromClaudeRequest(req *antigravity.ClaudeReque
return changed, nil
}
// ForwardUpstream 透传请求到上游 Antigravity 服务
// 用于 upstream 类型账号,直接使用 base_url + api_key 转发,不走 OAuth token
func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
startTime := time.Now()
sessionID := getSessionID(c)
prefix := logPrefix(sessionID, account.Name)
// 获取上游配置
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
if baseURL == "" || apiKey == "" {
return nil, fmt.Errorf("upstream account missing base_url or api_key")
}
baseURL = strings.TrimSuffix(baseURL, "/")
// 解析请求获取模型信息
var claudeReq antigravity.ClaudeRequest
if err := json.Unmarshal(body, &claudeReq); err != nil {
return nil, fmt.Errorf("parse claude request: %w", err)
}
if strings.TrimSpace(claudeReq.Model) == "" {
return nil, fmt.Errorf("missing model")
}
originalModel := claudeReq.Model
billingModel := originalModel
// 构建上游请求 URL
upstreamURL := baseURL + "/v1/messages"
// 创建请求
req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("create upstream request: %w", err)
}
// 设置请求头
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
req.Header.Set("x-api-key", apiKey) // Claude API 兼容
// 透传 Claude 相关 headers
if v := c.GetHeader("anthropic-version"); v != "" {
req.Header.Set("anthropic-version", v)
}
if v := c.GetHeader("anthropic-beta"); v != "" {
req.Header.Set("anthropic-beta", v)
}
// 代理 URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 发送请求
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil {
log.Printf("%s upstream request failed: %v", prefix, err)
return nil, fmt.Errorf("upstream request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
// 处理错误响应
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
// 429 错误时标记账号限流
if resp.StatusCode == http.StatusTooManyRequests {
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, AntigravityQuotaScopeClaude)
}
// 透传上游错误
c.Header("Content-Type", resp.Header.Get("Content-Type"))
c.Status(resp.StatusCode)
_, _ = c.Writer.Write(respBody)
return &ForwardResult{
Model: billingModel,
}, nil
}
// 处理成功响应(流式/非流式)
var usage *ClaudeUsage
var firstTokenMs *int
if claudeReq.Stream {
// 流式响应:透传
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
c.Status(http.StatusOK)
usage, firstTokenMs = s.streamUpstreamResponse(c, resp, startTime)
} else {
// 非流式响应:直接透传
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read upstream response: %w", err)
}
// 提取 usage
usage = s.extractClaudeUsage(respBody)
c.Header("Content-Type", resp.Header.Get("Content-Type"))
c.Status(http.StatusOK)
_, _ = c.Writer.Write(respBody)
}
// 构建计费结果
duration := time.Since(startTime)
log.Printf("%s status=success duration_ms=%d", prefix, duration.Milliseconds())
return &ForwardResult{
Model: billingModel,
Stream: claudeReq.Stream,
Duration: duration,
FirstTokenMs: firstTokenMs,
Usage: ClaudeUsage{
InputTokens: usage.InputTokens,
OutputTokens: usage.OutputTokens,
CacheReadInputTokens: usage.CacheReadInputTokens,
CacheCreationInputTokens: usage.CacheCreationInputTokens,
},
}, nil
}
// streamUpstreamResponse 透传上游流式响应并提取 usage
func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp *http.Response, startTime time.Time) (*ClaudeUsage, *int) {
usage := &ClaudeUsage{}
var firstTokenMs *int
var firstTokenRecorded bool
scanner := bufio.NewScanner(resp.Body)
buf := make([]byte, 0, 64*1024)
scanner.Buffer(buf, 1024*1024)
for scanner.Scan() {
line := scanner.Bytes()
// 记录首 token 时间
if !firstTokenRecorded && len(line) > 0 {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
firstTokenRecorded = true
}
// 尝试从 message_delta 或 message_stop 事件提取 usage
if bytes.HasPrefix(line, []byte("data: ")) {
dataStr := bytes.TrimPrefix(line, []byte("data: "))
var event map[string]any
if json.Unmarshal(dataStr, &event) == nil {
if u, ok := event["usage"].(map[string]any); ok {
if v, ok := u["input_tokens"].(float64); ok && int(v) > 0 {
usage.InputTokens = int(v)
}
if v, ok := u["output_tokens"].(float64); ok && int(v) > 0 {
usage.OutputTokens = int(v)
}
if v, ok := u["cache_read_input_tokens"].(float64); ok && int(v) > 0 {
usage.CacheReadInputTokens = int(v)
}
if v, ok := u["cache_creation_input_tokens"].(float64); ok && int(v) > 0 {
usage.CacheCreationInputTokens = int(v)
}
}
}
}
// 透传行
_, _ = c.Writer.Write(line)
_, _ = c.Writer.Write([]byte("\n"))
c.Writer.Flush()
}
return usage, firstTokenMs
}
// extractClaudeUsage 从非流式 Claude 响应提取 usage
func (s *AntigravityGatewayService) extractClaudeUsage(body []byte) *ClaudeUsage {
usage := &ClaudeUsage{}
var resp map[string]any
if json.Unmarshal(body, &resp) != nil {
return usage
}
if u, ok := resp["usage"].(map[string]any); ok {
if v, ok := u["input_tokens"].(float64); ok {
usage.InputTokens = int(v)
}
if v, ok := u["output_tokens"].(float64); ok {
usage.OutputTokens = int(v)
}
if v, ok := u["cache_read_input_tokens"].(float64); ok {
usage.CacheReadInputTokens = int(v)
}
if v, ok := u["cache_creation_input_tokens"].(float64); ok {
usage.CacheCreationInputTokens = int(v)
}
}
return usage
}
// ForwardGemini 转发 Gemini 协议请求
func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) {
startTime := time.Now()
......@@ -1287,6 +1688,12 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
}
mappedModel := s.getMappedModel(account, originalModel)
billingModel := originalModel
if antigravityUseMappedModelForBilling() && strings.TrimSpace(mappedModel) != "" {
billingModel = mappedModel
}
afterSwitch := antigravityHasAccountSwitch(ctx)
maxRetries := antigravityMaxRetriesForModel(originalModel, afterSwitch)
// 获取 access_token
if s.tokenProvider == nil {
......@@ -1306,8 +1713,15 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
proxyURL = account.Proxy.URL()
}
// 过滤掉 parts 为空的消息(Gemini API 不接受空 parts)
filteredBody, err := filterEmptyPartsFromGeminiRequest(body)
if err != nil {
log.Printf("[Antigravity] Failed to filter empty parts: %v", err)
filteredBody = body
}
// Antigravity 上游要求必须包含身份提示词,注入到请求中
injectedBody, err := injectIdentityPatchToGeminiRequest(body)
injectedBody, err := injectIdentityPatchToGeminiRequest(filteredBody)
if err != nil {
return nil, err
}
......@@ -1344,6 +1758,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
httpUpstream: s.httpUpstream,
settingService: s.settingService,
handleError: s.handleUpstreamError,
maxRetries: maxRetries,
})
if err != nil {
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
......@@ -1493,7 +1908,7 @@ handleSuccess:
return &ForwardResult{
RequestID: requestID,
Usage: *usage,
Model: originalModel,
Model: billingModel,
Stream: stream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
......@@ -1544,6 +1959,81 @@ func antigravityUseScopeRateLimit() bool {
return true
}
func antigravityHasAccountSwitch(ctx context.Context) bool {
if ctx == nil {
return false
}
if v, ok := ctx.Value(ctxkey.AccountSwitchCount).(int); ok {
return v > 0
}
return false
}
func antigravityMaxRetries() int {
raw := strings.TrimSpace(os.Getenv(antigravityMaxRetriesEnv))
if raw == "" {
return antigravityDefaultMaxRetries
}
value, err := strconv.Atoi(raw)
if err != nil || value <= 0 {
return antigravityDefaultMaxRetries
}
return value
}
func antigravityMaxRetriesAfterSwitch() int {
raw := strings.TrimSpace(os.Getenv(antigravityMaxRetriesAfterSwitchEnv))
if raw == "" {
return antigravityMaxRetries()
}
value, err := strconv.Atoi(raw)
if err != nil || value <= 0 {
return antigravityMaxRetries()
}
return value
}
// antigravityMaxRetriesForModel 根据模型类型获取重试次数
// 优先使用模型细分配置,未设置则回退到平台级配置
func antigravityMaxRetriesForModel(model string, afterSwitch bool) int {
var envKey string
if strings.HasPrefix(model, "claude-") {
envKey = antigravityMaxRetriesClaudeEnv
} else if isImageGenerationModel(model) {
envKey = antigravityMaxRetriesGeminiImageEnv
} else if strings.HasPrefix(model, "gemini-") {
envKey = antigravityMaxRetriesGeminiTextEnv
}
if envKey != "" {
if raw := strings.TrimSpace(os.Getenv(envKey)); raw != "" {
if value, err := strconv.Atoi(raw); err == nil && value > 0 {
return value
}
}
}
if afterSwitch {
return antigravityMaxRetriesAfterSwitch()
}
return antigravityMaxRetries()
}
func antigravityUseMappedModelForBilling() bool {
v := strings.ToLower(strings.TrimSpace(os.Getenv(antigravityBillingModelEnv)))
return v == "1" || v == "true" || v == "yes" || v == "on"
}
func antigravityFallbackCooldownSeconds() (time.Duration, bool) {
raw := strings.TrimSpace(os.Getenv(antigravityFallbackSecondsEnv))
if raw == "" {
return 0, false
}
seconds, err := strconv.Atoi(raw)
if err != nil || seconds <= 0 {
return 0, false
}
return time.Duration(seconds) * time.Second, true
}
func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) {
// 429 使用 Gemini 格式解析(从 body 解析重置时间)
if statusCode == 429 {
......@@ -1556,6 +2046,9 @@ func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, pre
fallbackMinutes = s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes
}
defaultDur := time.Duration(fallbackMinutes) * time.Minute
if fallbackDur, ok := antigravityFallbackCooldownSeconds(); ok {
defaultDur = fallbackDur
}
ra := time.Now().Add(defaultDur)
if useScopeLimit {
log.Printf("%s status=429 rate_limited scope=%s reset_in=%v (fallback)", prefix, quotaScope, defaultDur)
......@@ -2193,6 +2686,10 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, accou
return fmt.Errorf("upstream error: %d message=%s", upstreamStatus, upstreamMsg)
}
func (s *AntigravityGatewayService) WriteMappedClaudeError(c *gin.Context, account *Account, upstreamStatus int, upstreamRequestID string, body []byte) error {
return s.writeMappedClaudeError(c, account, upstreamStatus, upstreamRequestID, body)
}
func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int, message string) error {
statusStr := "UNKNOWN"
switch status {
......@@ -2618,3 +3115,55 @@ func cleanGeminiRequest(body []byte) ([]byte, error) {
return json.Marshal(payload)
}
// filterEmptyPartsFromGeminiRequest 过滤 Gemini 请求中 parts 为空的消息
// Gemini API 不接受 parts 为空数组的消息,会返回 400 错误
func filterEmptyPartsFromGeminiRequest(body []byte) ([]byte, error) {
var payload map[string]any
if err := json.Unmarshal(body, &payload); err != nil {
return nil, err
}
contents, ok := payload["contents"].([]any)
if !ok || len(contents) == 0 {
return body, nil
}
filtered := make([]any, 0, len(contents))
modified := false
for _, c := range contents {
contentMap, ok := c.(map[string]any)
if !ok {
filtered = append(filtered, c)
continue
}
parts, hasParts := contentMap["parts"]
if !hasParts {
filtered = append(filtered, c)
continue
}
partsSlice, ok := parts.([]any)
if !ok {
filtered = append(filtered, c)
continue
}
// 跳过 parts 为空数组的消息
if len(partsSlice) == 0 {
modified = true
continue
}
filtered = append(filtered, c)
}
if !modified {
return body, nil
}
payload["contents"] = filtered
return json.Marshal(payload)
}
package service
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
......@@ -81,3 +87,106 @@ func TestStripThinkingFromClaudeRequest_DoesNotDowngradeTools(t *testing.T) {
require.Equal(t, "secret plan", blocks[0]["text"])
require.Equal(t, "tool_use", blocks[1]["type"])
}
func TestIsPromptTooLongError(t *testing.T) {
require.True(t, isPromptTooLongError([]byte(`{"error":{"message":"Prompt is too long"}}`)))
require.True(t, isPromptTooLongError([]byte(`{"message":"Prompt is too long"}`)))
require.False(t, isPromptTooLongError([]byte(`{"error":{"message":"other"}}`)))
}
type httpUpstreamStub struct {
resp *http.Response
err error
}
func (s *httpUpstreamStub) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
return s.resp, s.err
}
func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int, _ bool) (*http.Response, error) {
return s.resp, s.err
}
func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
body, err := json.Marshal(map[string]any{
"model": "claude-opus-4-5",
"messages": []map[string]any{
{"role": "user", "content": "hi"},
},
"max_tokens": 1,
"stream": false,
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
c.Request = req
respBody := []byte(`{"error":{"message":"Prompt is too long"}}`)
resp := &http.Response{
StatusCode: http.StatusBadRequest,
Header: http.Header{"X-Request-Id": []string{"req-1"}},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
svc := &AntigravityGatewayService{
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: &httpUpstreamStub{resp: resp},
}
account := &Account{
ID: 1,
Name: "acc-1",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "token",
},
}
result, err := svc.Forward(context.Background(), c, account, body)
require.Nil(t, result)
var promptErr *PromptTooLongError
require.ErrorAs(t, err, &promptErr)
require.Equal(t, http.StatusBadRequest, promptErr.StatusCode)
require.Equal(t, "req-1", promptErr.RequestID)
require.NotEmpty(t, promptErr.Body)
raw, ok := c.Get(OpsUpstreamErrorsKey)
require.True(t, ok)
events, ok := raw.([]*OpsUpstreamErrorEvent)
require.True(t, ok)
require.Len(t, events, 1)
require.Equal(t, "prompt_too_long", events[0].Kind)
}
func TestAntigravityMaxRetriesForModel_AfterSwitch(t *testing.T) {
t.Setenv(antigravityMaxRetriesEnv, "4")
t.Setenv(antigravityMaxRetriesAfterSwitchEnv, "7")
t.Setenv(antigravityMaxRetriesClaudeEnv, "")
t.Setenv(antigravityMaxRetriesGeminiTextEnv, "")
t.Setenv(antigravityMaxRetriesGeminiImageEnv, "")
got := antigravityMaxRetriesForModel("claude-sonnet-4-5", false)
require.Equal(t, 4, got)
got = antigravityMaxRetriesForModel("claude-sonnet-4-5", true)
require.Equal(t, 7, got)
}
func TestAntigravityMaxRetriesForModel_AfterSwitchFallback(t *testing.T) {
t.Setenv(antigravityMaxRetriesEnv, "5")
t.Setenv(antigravityMaxRetriesAfterSwitchEnv, "")
t.Setenv(antigravityMaxRetriesClaudeEnv, "")
t.Setenv(antigravityMaxRetriesGeminiTextEnv, "")
t.Setenv(antigravityMaxRetriesGeminiImageEnv, "")
got := antigravityMaxRetriesForModel("gemini-2.5-flash", true)
require.Equal(t, 5, got)
}
package service
import (
"slices"
"strings"
"time"
)
......@@ -16,6 +17,21 @@ const (
AntigravityQuotaScopeGeminiImage AntigravityQuotaScope = "gemini_image"
)
// IsScopeSupported 检查给定的 scope 是否在分组支持的 scope 列表中
func IsScopeSupported(supportedScopes []string, scope AntigravityQuotaScope) bool {
if len(supportedScopes) == 0 {
// 未配置时默认全部支持
return true
}
supported := slices.Contains(supportedScopes, string(scope))
return supported
}
// ResolveAntigravityQuotaScope 根据模型名称解析配额域(导出版本)
func ResolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) {
return resolveAntigravityQuotaScope(requestedModel)
}
// resolveAntigravityQuotaScope 根据模型名称解析配额域
func resolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) {
model := normalizeAntigravityModelName(requestedModel)
......
......@@ -32,25 +32,30 @@ type APIKeyAuthUserSnapshot struct {
// APIKeyAuthGroupSnapshot 分组快照
type APIKeyAuthGroupSnapshot struct {
ID int64 `json:"id"`
Name string `json:"name"`
Platform string `json:"platform"`
Status string `json:"status"`
SubscriptionType string `json:"subscription_type"`
RateMultiplier float64 `json:"rate_multiplier"`
DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"`
WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"`
ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
ID int64 `json:"id"`
Name string `json:"name"`
Platform string `json:"platform"`
Status string `json:"status"`
SubscriptionType string `json:"subscription_type"`
RateMultiplier float64 `json:"rate_multiplier"`
DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"`
WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"`
ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"`
// Model routing is used by gateway account selection, so it must be part of auth cache snapshot.
// Only anthropic groups use these fields; others may leave them empty.
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
ModelRoutingEnabled bool `json:"model_routing_enabled"`
MCPXMLInject bool `json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
}
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
......
......@@ -226,22 +226,25 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
}
if apiKey.Group != nil {
snapshot.Group = &APIKeyAuthGroupSnapshot{
ID: apiKey.Group.ID,
Name: apiKey.Group.Name,
Platform: apiKey.Group.Platform,
Status: apiKey.Group.Status,
SubscriptionType: apiKey.Group.SubscriptionType,
RateMultiplier: apiKey.Group.RateMultiplier,
DailyLimitUSD: apiKey.Group.DailyLimitUSD,
WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD,
MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD,
ImagePrice1K: apiKey.Group.ImagePrice1K,
ImagePrice2K: apiKey.Group.ImagePrice2K,
ImagePrice4K: apiKey.Group.ImagePrice4K,
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
FallbackGroupID: apiKey.Group.FallbackGroupID,
ModelRouting: apiKey.Group.ModelRouting,
ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled,
ID: apiKey.Group.ID,
Name: apiKey.Group.Name,
Platform: apiKey.Group.Platform,
Status: apiKey.Group.Status,
SubscriptionType: apiKey.Group.SubscriptionType,
RateMultiplier: apiKey.Group.RateMultiplier,
DailyLimitUSD: apiKey.Group.DailyLimitUSD,
WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD,
MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD,
ImagePrice1K: apiKey.Group.ImagePrice1K,
ImagePrice2K: apiKey.Group.ImagePrice2K,
ImagePrice4K: apiKey.Group.ImagePrice4K,
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
FallbackGroupID: apiKey.Group.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: apiKey.Group.FallbackGroupIDOnInvalidRequest,
ModelRouting: apiKey.Group.ModelRouting,
ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled,
MCPXMLInject: apiKey.Group.MCPXMLInject,
SupportedModelScopes: apiKey.Group.SupportedModelScopes,
}
}
return snapshot
......@@ -272,23 +275,26 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
}
if snapshot.Group != nil {
apiKey.Group = &Group{
ID: snapshot.Group.ID,
Name: snapshot.Group.Name,
Platform: snapshot.Group.Platform,
Status: snapshot.Group.Status,
Hydrated: true,
SubscriptionType: snapshot.Group.SubscriptionType,
RateMultiplier: snapshot.Group.RateMultiplier,
DailyLimitUSD: snapshot.Group.DailyLimitUSD,
WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD,
MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD,
ImagePrice1K: snapshot.Group.ImagePrice1K,
ImagePrice2K: snapshot.Group.ImagePrice2K,
ImagePrice4K: snapshot.Group.ImagePrice4K,
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
FallbackGroupID: snapshot.Group.FallbackGroupID,
ModelRouting: snapshot.Group.ModelRouting,
ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled,
ID: snapshot.Group.ID,
Name: snapshot.Group.Name,
Platform: snapshot.Group.Platform,
Status: snapshot.Group.Status,
Hydrated: true,
SubscriptionType: snapshot.Group.SubscriptionType,
RateMultiplier: snapshot.Group.RateMultiplier,
DailyLimitUSD: snapshot.Group.DailyLimitUSD,
WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD,
MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD,
ImagePrice1K: snapshot.Group.ImagePrice1K,
ImagePrice2K: snapshot.Group.ImagePrice2K,
ImagePrice4K: snapshot.Group.ImagePrice4K,
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
FallbackGroupID: snapshot.Group.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: snapshot.Group.FallbackGroupIDOnInvalidRequest,
ModelRouting: snapshot.Group.ModelRouting,
ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled,
MCPXMLInject: snapshot.Group.MCPXMLInject,
SupportedModelScopes: snapshot.Group.SupportedModelScopes,
}
}
return apiKey
......
......@@ -185,7 +185,6 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
log.Printf("[Auth] Failed to mark invitation code as used for user %d: %v", user.ID, err)
}
}
// 应用优惠码(如果提供且功能已启用)
if promoCode != "" && s.promoService != nil && s.settingService != nil && s.settingService.IsPromoCodeEnabled(ctx) {
if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil {
......
......@@ -31,6 +31,7 @@ const (
AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference)
AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope)
AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游)
)
// Redeem type constants
......
......@@ -70,6 +70,15 @@ func shortSessionHash(sessionHash string) string {
return sessionHash[:8]
}
func normalizeClaudeModelForAnthropic(requestedModel string) string {
for _, prefix := range anthropicPrefixMappings {
if strings.HasPrefix(requestedModel, prefix) {
return prefix
}
}
return requestedModel
}
func redactAuthHeaderValue(v string) string {
v = strings.TrimSpace(v)
if v == "" {
......@@ -252,11 +261,20 @@ var (
"You are a file search specialist for Claude Code", // Explore Agent 版
"You are a helpful AI assistant tasked with summarizing conversations", // Compact 版
}
anthropicPrefixMappings = []string{
"claude-opus-4-5",
"claude-haiku-4-5",
"claude-sonnet-4-5",
}
)
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients")
// ErrModelScopeNotSupported 表示请求的模型系列不在分组支持的范围内
var ErrModelScopeNotSupported = errors.New("model scope not supported by this group")
// allowedHeaders 白名单headers(参考CRS项目)
var allowedHeaders = map[string]bool{
"accept": true,
......@@ -1135,6 +1153,13 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
log.Printf("[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), platform)
}
// Antigravity 模型系列检查(在账号选择前检查,确保所有代码路径都经过此检查)
if platform == PlatformAntigravity && groupID != nil && requestedModel != "" {
if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil {
return nil, err
}
}
accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
if err != nil {
return nil, err
......@@ -1632,6 +1657,10 @@ func (s *GatewayService) resolveGroupByID(ctx context.Context, groupID int64) (*
return group, nil
}
func (s *GatewayService) ResolveGroupByID(ctx context.Context, groupID int64) (*Group, error) {
return s.resolveGroupByID(ctx, groupID)
}
func (s *GatewayService) routingAccountIDsForRequest(ctx context.Context, groupID *int64, requestedModel string, platform string) []int64 {
if groupID == nil || requestedModel == "" || platform != PlatformAnthropic {
return nil
......@@ -1697,7 +1726,7 @@ func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID
}
// 强制平台模式不检查 Claude Code 限制
if _, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform {
if forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform && forcePlatform != "" {
return nil, groupID, nil
}
......@@ -2026,6 +2055,13 @@ func shuffleWithinPriority(accounts []*Account) {
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
// 对 Antigravity 平台,检查请求的模型系列是否在分组支持范围内
if platform == PlatformAntigravity && groupID != nil && requestedModel != "" {
if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil {
return nil, err
}
}
preferOAuth := platform == PlatformGemini
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform)
......@@ -2461,6 +2497,9 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
// Antigravity 平台使用专门的模型支持检查
return IsAntigravityModelSupported(requestedModel)
}
if account.Platform == PlatformAnthropic {
requestedModel = normalizeClaudeModelForAnthropic(requestedModel)
}
// Gemini API Key 账户直接透传,由上游判断模型是否支持
if account.Platform == PlatformGemini && account.Type == AccountTypeAPIKey {
return true
......@@ -2910,16 +2949,28 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 强制执行 cache_control 块数量限制(最多 4 个)
body = enforceCacheControlLimit(body)
// 应用模型映射(仅对apikey类型账号)
// 应用模型映射(APIKey 明确映射优先,其次使用 Anthropic 前缀映射)
mappedModel := reqModel
mappingSource := ""
if account.Type == AccountTypeAPIKey {
mappedModel := account.GetMappedModel(reqModel)
mappedModel = account.GetMappedModel(reqModel)
if mappedModel != reqModel {
// 替换请求体中的模型名
body = s.replaceModelInBody(body, mappedModel)
reqModel = mappedModel
log.Printf("Model mapping applied: %s -> %s (account: %s)", originalModel, mappedModel, account.Name)
mappingSource = "account"
}
}
if mappingSource == "" && account.Platform == PlatformAnthropic {
normalized := normalizeClaudeModelForAnthropic(reqModel)
if normalized != reqModel {
mappedModel = normalized
mappingSource = "prefix"
}
}
if mappedModel != reqModel {
// 替换请求体中的模型名
body = s.replaceModelInBody(body, mappedModel)
reqModel = mappedModel
log.Printf("Model mapping applied: %s -> %s (account: %s, source=%s)", originalModel, mappedModel, account.Name, mappingSource)
}
// 获取凭证
token, tokenType, err := s.GetAccessToken(ctx, account)
......@@ -4842,16 +4893,28 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
return nil
}
// 应用模型映射(仅对 apikey 类型账号)
if account.Type == AccountTypeAPIKey {
if reqModel != "" {
mappedModel := account.GetMappedModel(reqModel)
// 应用模型映射(APIKey 明确映射优先,其次使用 Anthropic 前缀映射)
if reqModel != "" {
mappedModel := reqModel
mappingSource := ""
if account.Type == AccountTypeAPIKey {
mappedModel = account.GetMappedModel(reqModel)
if mappedModel != reqModel {
body = s.replaceModelInBody(body, mappedModel)
reqModel = mappedModel
log.Printf("CountTokens model mapping applied: %s -> %s (account: %s)", parsed.Model, mappedModel, account.Name)
mappingSource = "account"
}
}
if mappingSource == "" && account.Platform == PlatformAnthropic {
normalized := normalizeClaudeModelForAnthropic(reqModel)
if normalized != reqModel {
mappedModel = normalized
mappingSource = "prefix"
}
}
if mappedModel != reqModel {
body = s.replaceModelInBody(body, mappedModel)
reqModel = mappedModel
log.Printf("CountTokens model mapping applied: %s -> %s (account: %s, source=%s)", parsed.Model, mappedModel, account.Name, mappingSource)
}
}
// 获取凭证
......@@ -5103,6 +5166,27 @@ func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
return normalized, nil
}
// checkAntigravityModelScope 检查 Antigravity 平台的模型系列是否在分组支持范围内
func (s *GatewayService) checkAntigravityModelScope(ctx context.Context, groupID int64, requestedModel string) error {
scope, ok := ResolveAntigravityQuotaScope(requestedModel)
if !ok {
return nil // 无法解析 scope,跳过检查
}
group, err := s.resolveGroupByID(ctx, groupID)
if err != nil {
return nil // 查询失败时放行
}
if group == nil {
return nil // 分组不存在时放行
}
if !IsScopeSupported(group.SupportedModelScopes, scope) {
return ErrModelScopeNotSupported
}
return nil
}
// GetAvailableModels returns the list of models available for a group
// It aggregates model_mapping keys from all schedulable accounts in the group
func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string {
......
......@@ -977,6 +977,11 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty")
}
// 过滤掉 parts 为空的消息(Gemini API 不接受空 parts)
if filteredBody, err := filterEmptyPartsFromGeminiRequest(body); err == nil {
body = filteredBody
}
switch action {
case "generateContent", "streamGenerateContent", "countTokens":
// ok
......
......@@ -29,6 +29,8 @@ type Group struct {
// Claude Code 客户端限制
ClaudeCodeOnly bool
FallbackGroupID *int64
// 无效请求兜底分组(仅 anthropic 平台使用)
FallbackGroupIDOnInvalidRequest *int64
// 模型路由配置
// key: 模型匹配模式(支持 * 通配符,如 "claude-opus-*")
......@@ -36,6 +38,13 @@ type Group struct {
ModelRouting map[string][]int64
ModelRoutingEnabled bool
// MCP XML 协议注入开关(仅 antigravity 平台使用)
MCPXMLInject bool
// 支持的模型系列(仅 antigravity 平台使用)
// 可选值: claude, gemini_text, gemini_image
SupportedModelScopes []string
CreatedAt time.Time
UpdatedAt time.Time
......
......@@ -285,6 +285,11 @@ func (c *OpsMetricsCollector) collectAndPersist(ctx context.Context) error {
return fmt.Errorf("query error counts: %w", err)
}
accountSwitchCount, err := c.queryAccountSwitchCount(ctx, windowStart, windowEnd)
if err != nil {
return fmt.Errorf("query account switch counts: %w", err)
}
windowSeconds := windowEnd.Sub(windowStart).Seconds()
if windowSeconds <= 0 {
windowSeconds = 60
......@@ -309,9 +314,10 @@ func (c *OpsMetricsCollector) collectAndPersist(ctx context.Context) error {
Upstream429Count: upstream429,
Upstream529Count: upstream529,
TokenConsumed: tokenConsumed,
QPS: float64Ptr(roundTo1DP(qps)),
TPS: float64Ptr(roundTo1DP(tps)),
TokenConsumed: tokenConsumed,
AccountSwitchCount: accountSwitchCount,
QPS: float64Ptr(roundTo1DP(qps)),
TPS: float64Ptr(roundTo1DP(tps)),
DurationP50Ms: duration.p50,
DurationP90Ms: duration.p90,
......@@ -551,6 +557,27 @@ WHERE created_at >= $1 AND created_at < $2`
return errorTotal, businessLimited, errorSLA, upstreamExcl429529, upstream429, upstream529, nil
}
func (c *OpsMetricsCollector) queryAccountSwitchCount(ctx context.Context, start, end time.Time) (int64, error) {
q := `
SELECT
COALESCE(SUM(CASE
WHEN split_part(ev->>'kind', ':', 1) IN ('failover', 'retry_exhausted_failover', 'failover_on_400') THEN 1
ELSE 0
END), 0) AS switch_count
FROM ops_error_logs o
CROSS JOIN LATERAL jsonb_array_elements(
COALESCE(NULLIF(o.upstream_errors, 'null'::jsonb), '[]'::jsonb)
) AS ev
WHERE o.created_at >= $1 AND o.created_at < $2
AND o.is_count_tokens = FALSE`
var count int64
if err := c.db.QueryRowContext(ctx, q, start, end).Scan(&count); err != nil {
return 0, err
}
return count, nil
}
type opsCollectedSystemStats struct {
cpuUsagePercent *float64
memoryUsedMB *int64
......
......@@ -161,7 +161,8 @@ type OpsInsertSystemMetricsInput struct {
Upstream429Count int64
Upstream529Count int64
TokenConsumed int64
TokenConsumed int64
AccountSwitchCount int64
QPS *float64
TPS *float64
......@@ -223,8 +224,9 @@ type OpsSystemMetricsSnapshot struct {
DBConnIdle *int `json:"db_conn_idle"`
DBConnWaiting *int `json:"db_conn_waiting"`
GoroutineCount *int `json:"goroutine_count"`
ConcurrencyQueueDepth *int `json:"concurrency_queue_depth"`
GoroutineCount *int `json:"goroutine_count"`
ConcurrencyQueueDepth *int `json:"concurrency_queue_depth"`
AccountSwitchCount *int64 `json:"account_switch_count"`
}
type OpsUpsertJobHeartbeatInput struct {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment