Unverified Commit dd96ada3 authored by 程序猿MT's avatar 程序猿MT Committed by GitHub
Browse files

Merge branch 'Wei-Shaw:main' into main

parents 31fe0178 8f397548
...@@ -319,6 +319,10 @@ func (r *stubApiKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ( ...@@ -319,6 +319,10 @@ func (r *stubApiKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) (
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubApiKeyRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) {
return 0, errors.New("not implemented")
}
type stubUserSubscriptionRepo struct { type stubUserSubscriptionRepo struct {
getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error)
updateStatus func(ctx context.Context, subscriptionID int64, status string) error updateStatus func(ctx context.Context, subscriptionID int64, status string) error
......
...@@ -111,9 +111,14 @@ type CreateGroupInput struct { ...@@ -111,9 +111,14 @@ type CreateGroupInput struct {
ImagePrice4K *float64 ImagePrice4K *float64
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端 ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID FallbackGroupID *int64 // 降级分组 ID
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
FallbackGroupIDOnInvalidRequest *int64
// 模型路由配置(仅 anthropic 平台使用) // 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 ModelRouting map[string][]int64
ModelRoutingEnabled bool // 是否启用模型路由 ModelRoutingEnabled bool // 是否启用模型路由
MCPXMLInject *bool
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string
// 从指定分组复制账号(创建分组后在同一事务内绑定) // 从指定分组复制账号(创建分组后在同一事务内绑定)
CopyAccountsFromGroupIDs []int64 CopyAccountsFromGroupIDs []int64
} }
...@@ -135,9 +140,14 @@ type UpdateGroupInput struct { ...@@ -135,9 +140,14 @@ type UpdateGroupInput struct {
ImagePrice4K *float64 ImagePrice4K *float64
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端 ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID FallbackGroupID *int64 // 降级分组 ID
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
FallbackGroupIDOnInvalidRequest *int64
// 模型路由配置(仅 anthropic 平台使用) // 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 ModelRouting map[string][]int64
ModelRoutingEnabled *bool // 是否启用模型路由 ModelRoutingEnabled *bool // 是否启用模型路由
MCPXMLInject *bool
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64 CopyAccountsFromGroupIDs []int64
} }
...@@ -594,6 +604,22 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn ...@@ -594,6 +604,22 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
return nil, err return nil, err
} }
} }
fallbackOnInvalidRequest := input.FallbackGroupIDOnInvalidRequest
if fallbackOnInvalidRequest != nil && *fallbackOnInvalidRequest <= 0 {
fallbackOnInvalidRequest = nil
}
// 校验无效请求兜底分组
if fallbackOnInvalidRequest != nil {
if err := s.validateFallbackGroupOnInvalidRequest(ctx, 0, platform, subscriptionType, *fallbackOnInvalidRequest); err != nil {
return nil, err
}
}
// MCPXMLInject:默认为 true,仅当显式传入 false 时关闭
mcpXMLInject := true
if input.MCPXMLInject != nil {
mcpXMLInject = *input.MCPXMLInject
}
// 如果指定了复制账号的源分组,先获取账号 ID 列表 // 如果指定了复制账号的源分组,先获取账号 ID 列表
var accountIDsToCopy []int64 var accountIDsToCopy []int64
...@@ -628,22 +654,25 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn ...@@ -628,22 +654,25 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
} }
group := &Group{ group := &Group{
Name: input.Name, Name: input.Name,
Description: input.Description, Description: input.Description,
Platform: platform, Platform: platform,
RateMultiplier: input.RateMultiplier, RateMultiplier: input.RateMultiplier,
IsExclusive: input.IsExclusive, IsExclusive: input.IsExclusive,
Status: StatusActive, Status: StatusActive,
SubscriptionType: subscriptionType, SubscriptionType: subscriptionType,
DailyLimitUSD: dailyLimit, DailyLimitUSD: dailyLimit,
WeeklyLimitUSD: weeklyLimit, WeeklyLimitUSD: weeklyLimit,
MonthlyLimitUSD: monthlyLimit, MonthlyLimitUSD: monthlyLimit,
ImagePrice1K: imagePrice1K, ImagePrice1K: imagePrice1K,
ImagePrice2K: imagePrice2K, ImagePrice2K: imagePrice2K,
ImagePrice4K: imagePrice4K, ImagePrice4K: imagePrice4K,
ClaudeCodeOnly: input.ClaudeCodeOnly, ClaudeCodeOnly: input.ClaudeCodeOnly,
FallbackGroupID: input.FallbackGroupID, FallbackGroupID: input.FallbackGroupID,
ModelRouting: input.ModelRouting, FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest,
ModelRouting: input.ModelRouting,
MCPXMLInject: mcpXMLInject,
SupportedModelScopes: input.SupportedModelScopes,
} }
if err := s.groupRepo.Create(ctx, group); err != nil { if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err return nil, err
...@@ -714,6 +743,37 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro ...@@ -714,6 +743,37 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro
} }
} }
// validateFallbackGroupOnInvalidRequest 校验无效请求兜底分组的有效性
// currentGroupID: 当前分组 ID(新建时为 0)
// platform/subscriptionType: 当前分组的有效平台/订阅类型
// fallbackGroupID: 兜底分组 ID
func (s *adminServiceImpl) validateFallbackGroupOnInvalidRequest(ctx context.Context, currentGroupID int64, platform, subscriptionType string, fallbackGroupID int64) error {
if platform != PlatformAnthropic && platform != PlatformAntigravity {
return fmt.Errorf("invalid request fallback only supported for anthropic or antigravity groups")
}
if subscriptionType == SubscriptionTypeSubscription {
return fmt.Errorf("subscription groups cannot set invalid request fallback")
}
if currentGroupID > 0 && currentGroupID == fallbackGroupID {
return fmt.Errorf("cannot set self as invalid request fallback group")
}
fallbackGroup, err := s.groupRepo.GetByIDLite(ctx, fallbackGroupID)
if err != nil {
return fmt.Errorf("fallback group not found: %w", err)
}
if fallbackGroup.Platform != PlatformAnthropic {
return fmt.Errorf("fallback group must be anthropic platform")
}
if fallbackGroup.SubscriptionType == SubscriptionTypeSubscription {
return fmt.Errorf("fallback group cannot be subscription type")
}
if fallbackGroup.FallbackGroupIDOnInvalidRequest != nil {
return fmt.Errorf("fallback group cannot have invalid request fallback configured")
}
return nil
}
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) { func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
group, err := s.groupRepo.GetByID(ctx, id) group, err := s.groupRepo.GetByID(ctx, id)
if err != nil { if err != nil {
...@@ -780,6 +840,20 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd ...@@ -780,6 +840,20 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
group.FallbackGroupID = nil group.FallbackGroupID = nil
} }
} }
fallbackOnInvalidRequest := group.FallbackGroupIDOnInvalidRequest
if input.FallbackGroupIDOnInvalidRequest != nil {
if *input.FallbackGroupIDOnInvalidRequest > 0 {
fallbackOnInvalidRequest = input.FallbackGroupIDOnInvalidRequest
} else {
fallbackOnInvalidRequest = nil
}
}
if fallbackOnInvalidRequest != nil {
if err := s.validateFallbackGroupOnInvalidRequest(ctx, id, group.Platform, group.SubscriptionType, *fallbackOnInvalidRequest); err != nil {
return nil, err
}
}
group.FallbackGroupIDOnInvalidRequest = fallbackOnInvalidRequest
// 模型路由配置 // 模型路由配置
if input.ModelRouting != nil { if input.ModelRouting != nil {
...@@ -788,6 +862,14 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd ...@@ -788,6 +862,14 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.ModelRoutingEnabled != nil { if input.ModelRoutingEnabled != nil {
group.ModelRoutingEnabled = *input.ModelRoutingEnabled group.ModelRoutingEnabled = *input.ModelRoutingEnabled
} }
if input.MCPXMLInject != nil {
group.MCPXMLInject = *input.MCPXMLInject
}
// 支持的模型系列(仅 antigravity 平台使用)
if input.SupportedModelScopes != nil {
group.SupportedModelScopes = *input.SupportedModelScopes
}
if err := s.groupRepo.Update(ctx, group); err != nil { if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, err return nil, err
......
...@@ -394,3 +394,382 @@ func (s *groupRepoStubForFallbackCycle) BindAccountsToGroup(_ context.Context, _ ...@@ -394,3 +394,382 @@ func (s *groupRepoStubForFallbackCycle) BindAccountsToGroup(_ context.Context, _
func (s *groupRepoStubForFallbackCycle) GetAccountIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) { func (s *groupRepoStubForFallbackCycle) GetAccountIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) {
panic("unexpected GetAccountIDsByGroupIDs call") panic("unexpected GetAccountIDsByGroupIDs call")
} }
type groupRepoStubForInvalidRequestFallback struct {
groups map[int64]*Group
created *Group
updated *Group
}
func (s *groupRepoStubForInvalidRequestFallback) Create(_ context.Context, g *Group) error {
s.created = g
return nil
}
func (s *groupRepoStubForInvalidRequestFallback) Update(_ context.Context, g *Group) error {
s.updated = g
return nil
}
func (s *groupRepoStubForInvalidRequestFallback) GetByID(ctx context.Context, id int64) (*Group, error) {
return s.GetByIDLite(ctx, id)
}
func (s *groupRepoStubForInvalidRequestFallback) GetByIDLite(_ context.Context, id int64) (*Group, error) {
if g, ok := s.groups[id]; ok {
return g, nil
}
return nil, ErrGroupNotFound
}
func (s *groupRepoStubForInvalidRequestFallback) Delete(_ context.Context, _ int64) error {
panic("unexpected Delete call")
}
func (s *groupRepoStubForInvalidRequestFallback) DeleteCascade(_ context.Context, _ int64) ([]int64, error) {
panic("unexpected DeleteCascade call")
}
func (s *groupRepoStubForInvalidRequestFallback) List(_ context.Context, _ pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
func (s *groupRepoStubForInvalidRequestFallback) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}
func (s *groupRepoStubForInvalidRequestFallback) ListActive(_ context.Context) ([]Group, error) {
panic("unexpected ListActive call")
}
func (s *groupRepoStubForInvalidRequestFallback) ListActiveByPlatform(_ context.Context, _ string) ([]Group, error) {
panic("unexpected ListActiveByPlatform call")
}
func (s *groupRepoStubForInvalidRequestFallback) ExistsByName(_ context.Context, _ string) (bool, error) {
panic("unexpected ExistsByName call")
}
func (s *groupRepoStubForInvalidRequestFallback) GetAccountCount(_ context.Context, _ int64) (int64, error) {
panic("unexpected GetAccountCount call")
}
func (s *groupRepoStubForInvalidRequestFallback) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
panic("unexpected DeleteAccountGroupsByGroupID call")
}
func (s *groupRepoStubForInvalidRequestFallback) GetAccountIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) {
panic("unexpected GetAccountIDsByGroupIDs call")
}
func (s *groupRepoStubForInvalidRequestFallback) BindAccountsToGroup(_ context.Context, _ int64, _ []int64) error {
panic("unexpected BindAccountsToGroup call")
}
func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsUnsupportedPlatform(t *testing.T) {
fallbackID := int64(10)
repo := &groupRepoStubForInvalidRequestFallback{
groups: map[int64]*Group{
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
},
}
svc := &adminServiceImpl{groupRepo: repo}
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformOpenAI,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
require.Error(t, err)
require.Contains(t, err.Error(), "invalid request fallback only supported for anthropic or antigravity groups")
require.Nil(t, repo.created)
}
func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsSubscription(t *testing.T) {
fallbackID := int64(10)
repo := &groupRepoStubForInvalidRequestFallback{
groups: map[int64]*Group{
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
},
}
svc := &adminServiceImpl{groupRepo: repo}
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAnthropic,
SubscriptionType: SubscriptionTypeSubscription,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
require.Error(t, err)
require.Contains(t, err.Error(), "subscription groups cannot set invalid request fallback")
require.Nil(t, repo.created)
}
func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsFallbackGroup(t *testing.T) {
tests := []struct {
name string
fallback *Group
wantMessage string
}{
{
name: "openai_target",
fallback: &Group{ID: 10, Platform: PlatformOpenAI, SubscriptionType: SubscriptionTypeStandard},
wantMessage: "fallback group must be anthropic platform",
},
{
name: "antigravity_target",
fallback: &Group{ID: 10, Platform: PlatformAntigravity, SubscriptionType: SubscriptionTypeStandard},
wantMessage: "fallback group must be anthropic platform",
},
{
name: "subscription_group",
fallback: &Group{ID: 10, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeSubscription},
wantMessage: "fallback group cannot be subscription type",
},
{
name: "nested_fallback",
fallback: &Group{
ID: 10,
Platform: PlatformAnthropic,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: func() *int64 { v := int64(99); return &v }(),
},
wantMessage: "fallback group cannot have invalid request fallback configured",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
fallbackID := tc.fallback.ID
repo := &groupRepoStubForInvalidRequestFallback{
groups: map[int64]*Group{
fallbackID: tc.fallback,
},
}
svc := &adminServiceImpl{groupRepo: repo}
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAnthropic,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
require.Error(t, err)
require.Contains(t, err.Error(), tc.wantMessage)
require.Nil(t, repo.created)
})
}
}
func TestAdminService_CreateGroup_InvalidRequestFallbackNotFound(t *testing.T) {
fallbackID := int64(10)
repo := &groupRepoStubForInvalidRequestFallback{}
svc := &adminServiceImpl{groupRepo: repo}
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAnthropic,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
require.Error(t, err)
require.Contains(t, err.Error(), "fallback group not found")
require.Nil(t, repo.created)
}
func TestAdminService_CreateGroup_InvalidRequestFallbackAllowsAntigravity(t *testing.T) {
fallbackID := int64(10)
repo := &groupRepoStubForInvalidRequestFallback{
groups: map[int64]*Group{
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
},
}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAntigravity,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.created)
require.Equal(t, fallbackID, *repo.created.FallbackGroupIDOnInvalidRequest)
}
func TestAdminService_CreateGroup_InvalidRequestFallbackClearsOnZero(t *testing.T) {
zero := int64(0)
repo := &groupRepoStubForInvalidRequestFallback{}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAnthropic,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &zero,
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.created)
require.Nil(t, repo.created.FallbackGroupIDOnInvalidRequest)
}
func TestAdminService_UpdateGroup_InvalidRequestFallbackPlatformMismatch(t *testing.T) {
fallbackID := int64(10)
existing := &Group{
ID: 1,
Name: "g1",
Platform: PlatformAnthropic,
SubscriptionType: SubscriptionTypeStandard,
Status: StatusActive,
FallbackGroupIDOnInvalidRequest: &fallbackID,
}
repo := &groupRepoStubForInvalidRequestFallback{
groups: map[int64]*Group{
existing.ID: existing,
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
},
}
svc := &adminServiceImpl{groupRepo: repo}
_, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
Platform: PlatformOpenAI,
})
require.Error(t, err)
require.Contains(t, err.Error(), "invalid request fallback only supported for anthropic or antigravity groups")
require.Nil(t, repo.updated)
}
func TestAdminService_UpdateGroup_InvalidRequestFallbackSubscriptionMismatch(t *testing.T) {
fallbackID := int64(10)
existing := &Group{
ID: 1,
Name: "g1",
Platform: PlatformAnthropic,
SubscriptionType: SubscriptionTypeStandard,
Status: StatusActive,
FallbackGroupIDOnInvalidRequest: &fallbackID,
}
repo := &groupRepoStubForInvalidRequestFallback{
groups: map[int64]*Group{
existing.ID: existing,
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
},
}
svc := &adminServiceImpl{groupRepo: repo}
_, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
SubscriptionType: SubscriptionTypeSubscription,
})
require.Error(t, err)
require.Contains(t, err.Error(), "subscription groups cannot set invalid request fallback")
require.Nil(t, repo.updated)
}
func TestAdminService_UpdateGroup_InvalidRequestFallbackClearsOnZero(t *testing.T) {
fallbackID := int64(10)
existing := &Group{
ID: 1,
Name: "g1",
Platform: PlatformAnthropic,
SubscriptionType: SubscriptionTypeStandard,
Status: StatusActive,
FallbackGroupIDOnInvalidRequest: &fallbackID,
}
repo := &groupRepoStubForInvalidRequestFallback{
groups: map[int64]*Group{
existing.ID: existing,
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
},
}
svc := &adminServiceImpl{groupRepo: repo}
clear := int64(0)
group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
Platform: PlatformOpenAI,
FallbackGroupIDOnInvalidRequest: &clear,
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.updated)
require.Nil(t, repo.updated.FallbackGroupIDOnInvalidRequest)
}
func TestAdminService_UpdateGroup_InvalidRequestFallbackRejectsFallbackGroup(t *testing.T) {
fallbackID := int64(10)
existing := &Group{
ID: 1,
Name: "g1",
Platform: PlatformAnthropic,
SubscriptionType: SubscriptionTypeStandard,
Status: StatusActive,
}
repo := &groupRepoStubForInvalidRequestFallback{
groups: map[int64]*Group{
existing.ID: existing,
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeSubscription},
},
}
svc := &adminServiceImpl{groupRepo: repo}
_, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
require.Error(t, err)
require.Contains(t, err.Error(), "fallback group cannot be subscription type")
require.Nil(t, repo.updated)
}
func TestAdminService_UpdateGroup_InvalidRequestFallbackSetSuccess(t *testing.T) {
fallbackID := int64(10)
existing := &Group{
ID: 1,
Name: "g1",
Platform: PlatformAnthropic,
SubscriptionType: SubscriptionTypeStandard,
Status: StatusActive,
}
repo := &groupRepoStubForInvalidRequestFallback{
groups: map[int64]*Group{
existing.ID: existing,
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
},
}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.updated)
require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest)
}
func TestAdminService_UpdateGroup_InvalidRequestFallbackAllowsAntigravity(t *testing.T) {
fallbackID := int64(10)
existing := &Group{
ID: 1,
Name: "g1",
Platform: PlatformAntigravity,
SubscriptionType: SubscriptionTypeStandard,
Status: StatusActive,
}
repo := &groupRepoStubForInvalidRequestFallback{
groups: map[int64]*Group{
existing.ID: existing,
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
},
}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.updated)
require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest)
}
...@@ -13,23 +13,34 @@ import ( ...@@ -13,23 +13,34 @@ import (
"net" "net"
"net/http" "net/http"
"os" "os"
"strconv"
"strings" "strings"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
) )
const ( const (
antigravityStickySessionTTL = time.Hour antigravityStickySessionTTL = time.Hour
antigravityMaxRetries = 3 antigravityDefaultMaxRetries = 3
antigravityRetryBaseDelay = 1 * time.Second antigravityRetryBaseDelay = 1 * time.Second
antigravityRetryMaxDelay = 16 * time.Second antigravityRetryMaxDelay = 16 * time.Second
) )
const antigravityScopeRateLimitEnv = "GATEWAY_ANTIGRAVITY_429_SCOPE_LIMIT" const (
antigravityMaxRetriesEnv = "GATEWAY_ANTIGRAVITY_MAX_RETRIES"
antigravityMaxRetriesAfterSwitchEnv = "GATEWAY_ANTIGRAVITY_AFTER_SWITCHMAX_RETRIES"
antigravityMaxRetriesClaudeEnv = "GATEWAY_ANTIGRAVITY_MAX_RETRIES_CLAUDE"
antigravityMaxRetriesGeminiTextEnv = "GATEWAY_ANTIGRAVITY_MAX_RETRIES_GEMINI_TEXT"
antigravityMaxRetriesGeminiImageEnv = "GATEWAY_ANTIGRAVITY_MAX_RETRIES_GEMINI_IMAGE"
antigravityScopeRateLimitEnv = "GATEWAY_ANTIGRAVITY_429_SCOPE_LIMIT"
antigravityBillingModelEnv = "GATEWAY_ANTIGRAVITY_BILL_WITH_MAPPED_MODEL"
antigravityFallbackSecondsEnv = "GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS"
)
// antigravityRetryLoopParams 重试循环的参数 // antigravityRetryLoopParams 重试循环的参数
type antigravityRetryLoopParams struct { type antigravityRetryLoopParams struct {
...@@ -41,6 +52,7 @@ type antigravityRetryLoopParams struct { ...@@ -41,6 +52,7 @@ type antigravityRetryLoopParams struct {
action string action string
body []byte body []byte
quotaScope AntigravityQuotaScope quotaScope AntigravityQuotaScope
maxRetries int
c *gin.Context c *gin.Context
httpUpstream HTTPUpstream httpUpstream HTTPUpstream
settingService *SettingService settingService *SettingService
...@@ -52,11 +64,28 @@ type antigravityRetryLoopResult struct { ...@@ -52,11 +64,28 @@ type antigravityRetryLoopResult struct {
resp *http.Response resp *http.Response
} }
// PromptTooLongError 表示上游明确返回 prompt too long
type PromptTooLongError struct {
StatusCode int
RequestID string
Body []byte
}
func (e *PromptTooLongError) Error() string {
return fmt.Sprintf("prompt too long: status=%d", e.StatusCode)
}
// antigravityRetryLoop 执行带 URL fallback 的重试循环 // antigravityRetryLoop 执行带 URL fallback 的重试循环
func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) { func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) {
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs() baseURLs := antigravity.ForwardBaseURLs()
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLsWithBase(baseURLs)
if len(availableURLs) == 0 { if len(availableURLs) == 0 {
availableURLs = antigravity.BaseURLs availableURLs = baseURLs
}
maxRetries := p.maxRetries
if maxRetries <= 0 {
maxRetries = antigravityDefaultMaxRetries
} }
var resp *http.Response var resp *http.Response
...@@ -76,7 +105,7 @@ func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopRe ...@@ -76,7 +105,7 @@ func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopRe
urlFallbackLoop: urlFallbackLoop:
for urlIdx, baseURL := range availableURLs { for urlIdx, baseURL := range availableURLs {
usedBaseURL = baseURL usedBaseURL = baseURL
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { for attempt := 1; attempt <= maxRetries; attempt++ {
select { select {
case <-p.ctx.Done(): case <-p.ctx.Done():
log.Printf("%s status=context_canceled error=%v", p.prefix, p.ctx.Err()) log.Printf("%s status=context_canceled error=%v", p.prefix, p.ctx.Err())
...@@ -109,8 +138,8 @@ urlFallbackLoop: ...@@ -109,8 +138,8 @@ urlFallbackLoop:
log.Printf("%s URL fallback (connection error): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1]) log.Printf("%s URL fallback (connection error): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1])
continue urlFallbackLoop continue urlFallbackLoop
} }
if attempt < antigravityMaxRetries { if attempt < maxRetries {
log.Printf("%s status=request_failed retry=%d/%d error=%v", p.prefix, attempt, antigravityMaxRetries, err) log.Printf("%s status=request_failed retry=%d/%d error=%v", p.prefix, attempt, maxRetries, err)
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) { if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", p.prefix) log.Printf("%s status=context_canceled_during_backoff", p.prefix)
return nil, p.ctx.Err() return nil, p.ctx.Err()
...@@ -134,7 +163,7 @@ urlFallbackLoop: ...@@ -134,7 +163,7 @@ urlFallbackLoop:
} }
// 账户/模型配额限流,重试 3 次(指数退避) // 账户/模型配额限流,重试 3 次(指数退避)
if attempt < antigravityMaxRetries { if attempt < maxRetries {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{ appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{
...@@ -147,7 +176,7 @@ urlFallbackLoop: ...@@ -147,7 +176,7 @@ urlFallbackLoop:
Message: upstreamMsg, Message: upstreamMsg,
Detail: getUpstreamDetail(respBody), Detail: getUpstreamDetail(respBody),
}) })
log.Printf("%s status=429 retry=%d/%d body=%s", p.prefix, attempt, antigravityMaxRetries, truncateForLog(respBody, 200)) log.Printf("%s status=429 retry=%d/%d body=%s", p.prefix, attempt, maxRetries, truncateForLog(respBody, 200))
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) { if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", p.prefix) log.Printf("%s status=context_canceled_during_backoff", p.prefix)
return nil, p.ctx.Err() return nil, p.ctx.Err()
...@@ -171,7 +200,7 @@ urlFallbackLoop: ...@@ -171,7 +200,7 @@ urlFallbackLoop:
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close() _ = resp.Body.Close()
if attempt < antigravityMaxRetries { if attempt < maxRetries {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{ appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{
...@@ -184,7 +213,7 @@ urlFallbackLoop: ...@@ -184,7 +213,7 @@ urlFallbackLoop:
Message: upstreamMsg, Message: upstreamMsg,
Detail: getUpstreamDetail(respBody), Detail: getUpstreamDetail(respBody),
}) })
log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500)) log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, maxRetries, truncateForLog(respBody, 500))
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) { if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", p.prefix) log.Printf("%s status=context_canceled_during_backoff", p.prefix)
return nil, p.ctx.Err() return nil, p.ctx.Err()
...@@ -390,6 +419,11 @@ type TestConnectionResult struct { ...@@ -390,6 +419,11 @@ type TestConnectionResult struct {
// TestConnection 测试 Antigravity 账号连接(非流式,无重试、无计费) // TestConnection 测试 Antigravity 账号连接(非流式,无重试、无计费)
// 支持 Claude 和 Gemini 两种协议,根据 modelID 前缀自动选择 // 支持 Claude 和 Gemini 两种协议,根据 modelID 前缀自动选择
func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) { func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) {
// 上游透传账号使用专用测试方法
if account.Type == AccountTypeUpstream {
return s.testUpstreamConnection(ctx, account, modelID)
}
// 获取 token // 获取 token
if s.tokenProvider == nil { if s.tokenProvider == nil {
return nil, errors.New("antigravity token provider not configured") return nil, errors.New("antigravity token provider not configured")
...@@ -484,6 +518,87 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account ...@@ -484,6 +518,87 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
return nil, lastErr return nil, lastErr
} }
// testUpstreamConnection 测试上游透传账号连接
func (s *AntigravityGatewayService) testUpstreamConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) {
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
if baseURL == "" || apiKey == "" {
return nil, errors.New("upstream account missing base_url or api_key")
}
baseURL = strings.TrimSuffix(baseURL, "/")
// 使用 Claude 模型进行测试
if modelID == "" {
modelID = "claude-sonnet-4-20250514"
}
// 构建最小测试请求
testReq := map[string]any{
"model": modelID,
"max_tokens": 1,
"messages": []map[string]any{
{"role": "user", "content": "."},
},
}
requestBody, err := json.Marshal(testReq)
if err != nil {
return nil, fmt.Errorf("构建请求失败: %w", err)
}
// 构建 HTTP 请求
upstreamURL := baseURL + "/v1/messages"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(requestBody))
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
req.Header.Set("x-api-key", apiKey)
req.Header.Set("anthropic-version", "2023-06-01")
// 代理 URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
log.Printf("[antigravity-Test-Upstream] account=%s url=%s", account.Name, upstreamURL)
// 发送请求
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil {
return nil, fmt.Errorf("请求失败: %w", err)
}
defer func() { _ = resp.Body.Close() }()
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
if resp.StatusCode >= 400 {
return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody))
}
// 提取响应文本
var respData map[string]any
text := ""
if json.Unmarshal(respBody, &respData) == nil {
if content, ok := respData["content"].([]any); ok && len(content) > 0 {
if block, ok := content[0].(map[string]any); ok {
if t, ok := block["text"].(string); ok {
text = t
}
}
}
}
return &TestConnectionResult{
Text: text,
MappedModel: modelID,
}, nil
}
// buildGeminiTestRequest 构建 Gemini 格式测试请求 // buildGeminiTestRequest 构建 Gemini 格式测试请求
// 使用最小 token 消耗:输入 "." + maxOutputTokens: 1 // 使用最小 token 消耗:输入 "." + maxOutputTokens: 1
func (s *AntigravityGatewayService) buildGeminiTestRequest(projectID, model string) ([]byte, error) { func (s *AntigravityGatewayService) buildGeminiTestRequest(projectID, model string) ([]byte, error) {
...@@ -534,6 +649,10 @@ func (s *AntigravityGatewayService) getClaudeTransformOptions(ctx context.Contex ...@@ -534,6 +649,10 @@ func (s *AntigravityGatewayService) getClaudeTransformOptions(ctx context.Contex
} }
opts.EnableIdentityPatch = s.settingService.IsIdentityPatchEnabled(ctx) opts.EnableIdentityPatch = s.settingService.IsIdentityPatchEnabled(ctx)
opts.IdentityPatch = s.settingService.GetIdentityPatchPrompt(ctx) opts.IdentityPatch = s.settingService.GetIdentityPatchPrompt(ctx)
if group, ok := ctx.Value(ctxkey.Group).(*Group); ok && group != nil {
opts.EnableMCPXML = group.MCPXMLInject
}
return opts return opts
} }
...@@ -702,6 +821,11 @@ func isModelNotFoundError(statusCode int, body []byte) bool { ...@@ -702,6 +821,11 @@ func isModelNotFoundError(statusCode int, body []byte) bool {
// Forward 转发 Claude 协议请求(Claude → Gemini 转换) // Forward 转发 Claude 协议请求(Claude → Gemini 转换)
func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) { func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
// 上游透传账号直接转发,不走 OAuth token 刷新
if account.Type == AccountTypeUpstream {
return s.ForwardUpstream(ctx, c, account, body)
}
startTime := time.Now() startTime := time.Now()
sessionID := getSessionID(c) sessionID := getSessionID(c)
prefix := logPrefix(sessionID, account.Name) prefix := logPrefix(sessionID, account.Name)
...@@ -718,6 +842,12 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -718,6 +842,12 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
originalModel := claudeReq.Model originalModel := claudeReq.Model
mappedModel := s.getMappedModel(account, claudeReq.Model) mappedModel := s.getMappedModel(account, claudeReq.Model)
quotaScope, _ := resolveAntigravityQuotaScope(originalModel) quotaScope, _ := resolveAntigravityQuotaScope(originalModel)
billingModel := originalModel
if antigravityUseMappedModelForBilling() && strings.TrimSpace(mappedModel) != "" {
billingModel = mappedModel
}
afterSwitch := antigravityHasAccountSwitch(ctx)
maxRetries := antigravityMaxRetriesForModel(originalModel, afterSwitch)
// 获取 access_token // 获取 access_token
if s.tokenProvider == nil { if s.tokenProvider == nil {
...@@ -766,6 +896,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -766,6 +896,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
httpUpstream: s.httpUpstream, httpUpstream: s.httpUpstream,
settingService: s.settingService, settingService: s.settingService,
handleError: s.handleUpstreamError, handleError: s.handleUpstreamError,
maxRetries: maxRetries,
}) })
if err != nil { if err != nil {
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries") return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries")
...@@ -842,6 +973,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -842,6 +973,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
httpUpstream: s.httpUpstream, httpUpstream: s.httpUpstream,
settingService: s.settingService, settingService: s.settingService,
handleError: s.handleUpstreamError, handleError: s.handleUpstreamError,
maxRetries: maxRetries,
}) })
if retryErr != nil { if retryErr != nil {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
...@@ -917,6 +1049,39 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -917,6 +1049,39 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
// 处理错误响应(重试后仍失败或不触发重试) // 处理错误响应(重试后仍失败或不触发重试)
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
if resp.StatusCode == http.StatusBadRequest {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
log.Printf("%s status=400 prompt_too_long=%v upstream_message=%q request_id=%s body=%s", prefix, isPromptTooLongError(respBody), upstreamMsg, resp.Header.Get("x-request-id"), truncateForLog(respBody, 500))
}
if resp.StatusCode == http.StatusBadRequest && isPromptTooLongError(respBody) {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
maxBytes := 2048
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
}
upstreamDetail := ""
if logBody {
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "prompt_too_long",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &PromptTooLongError{
StatusCode: resp.StatusCode,
RequestID: resp.Header.Get("x-request-id"),
Body: respBody,
}
}
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope) s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
if s.shouldFailoverUpstreamError(resp.StatusCode) { if s.shouldFailoverUpstreamError(resp.StatusCode) {
...@@ -978,7 +1143,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -978,7 +1143,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return &ForwardResult{ return &ForwardResult{
RequestID: requestID, RequestID: requestID,
Usage: *usage, Usage: *usage,
Model: originalModel, // 使用原始模型用于计费和日志 Model: billingModel, // 计费模型(可按映射模型覆盖)
Stream: claudeReq.Stream, Stream: claudeReq.Stream,
Duration: time.Since(startTime), Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs, FirstTokenMs: firstTokenMs,
...@@ -1003,24 +1168,64 @@ func isSignatureRelatedError(respBody []byte) bool { ...@@ -1003,24 +1168,64 @@ func isSignatureRelatedError(respBody []byte) bool {
return true return true
} }
// Detect thinking block modification errors:
// "thinking or redacted_thinking blocks in the latest assistant message cannot be modified"
if strings.Contains(msg, "cannot be modified") && (strings.Contains(msg, "thinking") || strings.Contains(msg, "redacted_thinking")) {
return true
}
return false return false
} }
func isPromptTooLongError(respBody []byte) bool {
msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
if msg == "" {
msg = strings.ToLower(string(respBody))
}
return strings.Contains(msg, "prompt is too long")
}
func extractAntigravityErrorMessage(body []byte) string { func extractAntigravityErrorMessage(body []byte) string {
var payload map[string]any var payload map[string]any
if err := json.Unmarshal(body, &payload); err != nil { if err := json.Unmarshal(body, &payload); err != nil {
return "" return ""
} }
parseNestedMessage := func(msg string) string {
trimmed := strings.TrimSpace(msg)
if trimmed == "" || !strings.HasPrefix(trimmed, "{") {
return ""
}
var nested map[string]any
if err := json.Unmarshal([]byte(trimmed), &nested); err != nil {
return ""
}
if errObj, ok := nested["error"].(map[string]any); ok {
if innerMsg, ok := errObj["message"].(string); ok && strings.TrimSpace(innerMsg) != "" {
return innerMsg
}
}
if innerMsg, ok := nested["message"].(string); ok && strings.TrimSpace(innerMsg) != "" {
return innerMsg
}
return ""
}
// Google-style: {"error": {"message": "..."}} // Google-style: {"error": {"message": "..."}}
if errObj, ok := payload["error"].(map[string]any); ok { if errObj, ok := payload["error"].(map[string]any); ok {
if msg, ok := errObj["message"].(string); ok && strings.TrimSpace(msg) != "" { if msg, ok := errObj["message"].(string); ok && strings.TrimSpace(msg) != "" {
if innerMsg := parseNestedMessage(msg); innerMsg != "" {
return innerMsg
}
return msg return msg
} }
} }
// Fallback: top-level message // Fallback: top-level message
if msg, ok := payload["message"].(string); ok && strings.TrimSpace(msg) != "" { if msg, ok := payload["message"].(string); ok && strings.TrimSpace(msg) != "" {
if innerMsg := parseNestedMessage(msg); innerMsg != "" {
return innerMsg
}
return msg return msg
} }
...@@ -1248,6 +1453,208 @@ func stripSignatureSensitiveBlocksFromClaudeRequest(req *antigravity.ClaudeReque ...@@ -1248,6 +1453,208 @@ func stripSignatureSensitiveBlocksFromClaudeRequest(req *antigravity.ClaudeReque
return changed, nil return changed, nil
} }
// ForwardUpstream 透传请求到上游 Antigravity 服务
// 用于 upstream 类型账号,直接使用 base_url + api_key 转发,不走 OAuth token
func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
startTime := time.Now()
sessionID := getSessionID(c)
prefix := logPrefix(sessionID, account.Name)
// 获取上游配置
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
if baseURL == "" || apiKey == "" {
return nil, fmt.Errorf("upstream account missing base_url or api_key")
}
baseURL = strings.TrimSuffix(baseURL, "/")
// 解析请求获取模型信息
var claudeReq antigravity.ClaudeRequest
if err := json.Unmarshal(body, &claudeReq); err != nil {
return nil, fmt.Errorf("parse claude request: %w", err)
}
if strings.TrimSpace(claudeReq.Model) == "" {
return nil, fmt.Errorf("missing model")
}
originalModel := claudeReq.Model
billingModel := originalModel
// 构建上游请求 URL
upstreamURL := baseURL + "/v1/messages"
// 创建请求
req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("create upstream request: %w", err)
}
// 设置请求头
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
req.Header.Set("x-api-key", apiKey) // Claude API 兼容
// 透传 Claude 相关 headers
if v := c.GetHeader("anthropic-version"); v != "" {
req.Header.Set("anthropic-version", v)
}
if v := c.GetHeader("anthropic-beta"); v != "" {
req.Header.Set("anthropic-beta", v)
}
// 代理 URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 发送请求
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil {
log.Printf("%s upstream request failed: %v", prefix, err)
return nil, fmt.Errorf("upstream request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
// 处理错误响应
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
// 429 错误时标记账号限流
if resp.StatusCode == http.StatusTooManyRequests {
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, AntigravityQuotaScopeClaude)
}
// 透传上游错误
c.Header("Content-Type", resp.Header.Get("Content-Type"))
c.Status(resp.StatusCode)
_, _ = c.Writer.Write(respBody)
return &ForwardResult{
Model: billingModel,
}, nil
}
// 处理成功响应(流式/非流式)
var usage *ClaudeUsage
var firstTokenMs *int
if claudeReq.Stream {
// 流式响应:透传
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
c.Status(http.StatusOK)
usage, firstTokenMs = s.streamUpstreamResponse(c, resp, startTime)
} else {
// 非流式响应:直接透传
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read upstream response: %w", err)
}
// 提取 usage
usage = s.extractClaudeUsage(respBody)
c.Header("Content-Type", resp.Header.Get("Content-Type"))
c.Status(http.StatusOK)
_, _ = c.Writer.Write(respBody)
}
// 构建计费结果
duration := time.Since(startTime)
log.Printf("%s status=success duration_ms=%d", prefix, duration.Milliseconds())
return &ForwardResult{
Model: billingModel,
Stream: claudeReq.Stream,
Duration: duration,
FirstTokenMs: firstTokenMs,
Usage: ClaudeUsage{
InputTokens: usage.InputTokens,
OutputTokens: usage.OutputTokens,
CacheReadInputTokens: usage.CacheReadInputTokens,
CacheCreationInputTokens: usage.CacheCreationInputTokens,
},
}, nil
}
// streamUpstreamResponse 透传上游流式响应并提取 usage
func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp *http.Response, startTime time.Time) (*ClaudeUsage, *int) {
usage := &ClaudeUsage{}
var firstTokenMs *int
var firstTokenRecorded bool
scanner := bufio.NewScanner(resp.Body)
buf := make([]byte, 0, 64*1024)
scanner.Buffer(buf, 1024*1024)
for scanner.Scan() {
line := scanner.Bytes()
// 记录首 token 时间
if !firstTokenRecorded && len(line) > 0 {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
firstTokenRecorded = true
}
// 尝试从 message_delta 或 message_stop 事件提取 usage
if bytes.HasPrefix(line, []byte("data: ")) {
dataStr := bytes.TrimPrefix(line, []byte("data: "))
var event map[string]any
if json.Unmarshal(dataStr, &event) == nil {
if u, ok := event["usage"].(map[string]any); ok {
if v, ok := u["input_tokens"].(float64); ok && int(v) > 0 {
usage.InputTokens = int(v)
}
if v, ok := u["output_tokens"].(float64); ok && int(v) > 0 {
usage.OutputTokens = int(v)
}
if v, ok := u["cache_read_input_tokens"].(float64); ok && int(v) > 0 {
usage.CacheReadInputTokens = int(v)
}
if v, ok := u["cache_creation_input_tokens"].(float64); ok && int(v) > 0 {
usage.CacheCreationInputTokens = int(v)
}
}
}
}
// 透传行
_, _ = c.Writer.Write(line)
_, _ = c.Writer.Write([]byte("\n"))
c.Writer.Flush()
}
return usage, firstTokenMs
}
// extractClaudeUsage 从非流式 Claude 响应提取 usage
func (s *AntigravityGatewayService) extractClaudeUsage(body []byte) *ClaudeUsage {
usage := &ClaudeUsage{}
var resp map[string]any
if json.Unmarshal(body, &resp) != nil {
return usage
}
if u, ok := resp["usage"].(map[string]any); ok {
if v, ok := u["input_tokens"].(float64); ok {
usage.InputTokens = int(v)
}
if v, ok := u["output_tokens"].(float64); ok {
usage.OutputTokens = int(v)
}
if v, ok := u["cache_read_input_tokens"].(float64); ok {
usage.CacheReadInputTokens = int(v)
}
if v, ok := u["cache_creation_input_tokens"].(float64); ok {
usage.CacheCreationInputTokens = int(v)
}
}
return usage
}
// ForwardGemini 转发 Gemini 协议请求 // ForwardGemini 转发 Gemini 协议请求
func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) { func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) {
startTime := time.Now() startTime := time.Now()
...@@ -1287,6 +1694,12 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co ...@@ -1287,6 +1694,12 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
} }
mappedModel := s.getMappedModel(account, originalModel) mappedModel := s.getMappedModel(account, originalModel)
billingModel := originalModel
if antigravityUseMappedModelForBilling() && strings.TrimSpace(mappedModel) != "" {
billingModel = mappedModel
}
afterSwitch := antigravityHasAccountSwitch(ctx)
maxRetries := antigravityMaxRetriesForModel(originalModel, afterSwitch)
// 获取 access_token // 获取 access_token
if s.tokenProvider == nil { if s.tokenProvider == nil {
...@@ -1306,8 +1719,15 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co ...@@ -1306,8 +1719,15 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
proxyURL = account.Proxy.URL() proxyURL = account.Proxy.URL()
} }
// 过滤掉 parts 为空的消息(Gemini API 不接受空 parts)
filteredBody, err := filterEmptyPartsFromGeminiRequest(body)
if err != nil {
log.Printf("[Antigravity] Failed to filter empty parts: %v", err)
filteredBody = body
}
// Antigravity 上游要求必须包含身份提示词,注入到请求中 // Antigravity 上游要求必须包含身份提示词,注入到请求中
injectedBody, err := injectIdentityPatchToGeminiRequest(body) injectedBody, err := injectIdentityPatchToGeminiRequest(filteredBody)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -1344,6 +1764,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co ...@@ -1344,6 +1764,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
httpUpstream: s.httpUpstream, httpUpstream: s.httpUpstream,
settingService: s.settingService, settingService: s.settingService,
handleError: s.handleUpstreamError, handleError: s.handleUpstreamError,
maxRetries: maxRetries,
}) })
if err != nil { if err != nil {
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries") return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
...@@ -1493,7 +1914,7 @@ handleSuccess: ...@@ -1493,7 +1914,7 @@ handleSuccess:
return &ForwardResult{ return &ForwardResult{
RequestID: requestID, RequestID: requestID,
Usage: *usage, Usage: *usage,
Model: originalModel, Model: billingModel,
Stream: stream, Stream: stream,
Duration: time.Since(startTime), Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs, FirstTokenMs: firstTokenMs,
...@@ -1544,6 +1965,81 @@ func antigravityUseScopeRateLimit() bool { ...@@ -1544,6 +1965,81 @@ func antigravityUseScopeRateLimit() bool {
return true return true
} }
func antigravityHasAccountSwitch(ctx context.Context) bool {
if ctx == nil {
return false
}
if v, ok := ctx.Value(ctxkey.AccountSwitchCount).(int); ok {
return v > 0
}
return false
}
func antigravityMaxRetries() int {
raw := strings.TrimSpace(os.Getenv(antigravityMaxRetriesEnv))
if raw == "" {
return antigravityDefaultMaxRetries
}
value, err := strconv.Atoi(raw)
if err != nil || value <= 0 {
return antigravityDefaultMaxRetries
}
return value
}
func antigravityMaxRetriesAfterSwitch() int {
raw := strings.TrimSpace(os.Getenv(antigravityMaxRetriesAfterSwitchEnv))
if raw == "" {
return antigravityMaxRetries()
}
value, err := strconv.Atoi(raw)
if err != nil || value <= 0 {
return antigravityMaxRetries()
}
return value
}
// antigravityMaxRetriesForModel 根据模型类型获取重试次数
// 优先使用模型细分配置,未设置则回退到平台级配置
func antigravityMaxRetriesForModel(model string, afterSwitch bool) int {
var envKey string
if strings.HasPrefix(model, "claude-") {
envKey = antigravityMaxRetriesClaudeEnv
} else if isImageGenerationModel(model) {
envKey = antigravityMaxRetriesGeminiImageEnv
} else if strings.HasPrefix(model, "gemini-") {
envKey = antigravityMaxRetriesGeminiTextEnv
}
if envKey != "" {
if raw := strings.TrimSpace(os.Getenv(envKey)); raw != "" {
if value, err := strconv.Atoi(raw); err == nil && value > 0 {
return value
}
}
}
if afterSwitch {
return antigravityMaxRetriesAfterSwitch()
}
return antigravityMaxRetries()
}
func antigravityUseMappedModelForBilling() bool {
v := strings.ToLower(strings.TrimSpace(os.Getenv(antigravityBillingModelEnv)))
return v == "1" || v == "true" || v == "yes" || v == "on"
}
func antigravityFallbackCooldownSeconds() (time.Duration, bool) {
raw := strings.TrimSpace(os.Getenv(antigravityFallbackSecondsEnv))
if raw == "" {
return 0, false
}
seconds, err := strconv.Atoi(raw)
if err != nil || seconds <= 0 {
return 0, false
}
return time.Duration(seconds) * time.Second, true
}
func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) { func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) {
// 429 使用 Gemini 格式解析(从 body 解析重置时间) // 429 使用 Gemini 格式解析(从 body 解析重置时间)
if statusCode == 429 { if statusCode == 429 {
...@@ -1556,6 +2052,9 @@ func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, pre ...@@ -1556,6 +2052,9 @@ func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, pre
fallbackMinutes = s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes fallbackMinutes = s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes
} }
defaultDur := time.Duration(fallbackMinutes) * time.Minute defaultDur := time.Duration(fallbackMinutes) * time.Minute
if fallbackDur, ok := antigravityFallbackCooldownSeconds(); ok {
defaultDur = fallbackDur
}
ra := time.Now().Add(defaultDur) ra := time.Now().Add(defaultDur)
if useScopeLimit { if useScopeLimit {
log.Printf("%s status=429 rate_limited scope=%s reset_in=%v (fallback)", prefix, quotaScope, defaultDur) log.Printf("%s status=429 rate_limited scope=%s reset_in=%v (fallback)", prefix, quotaScope, defaultDur)
...@@ -2193,6 +2692,10 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, accou ...@@ -2193,6 +2692,10 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, accou
return fmt.Errorf("upstream error: %d message=%s", upstreamStatus, upstreamMsg) return fmt.Errorf("upstream error: %d message=%s", upstreamStatus, upstreamMsg)
} }
func (s *AntigravityGatewayService) WriteMappedClaudeError(c *gin.Context, account *Account, upstreamStatus int, upstreamRequestID string, body []byte) error {
return s.writeMappedClaudeError(c, account, upstreamStatus, upstreamRequestID, body)
}
func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int, message string) error { func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int, message string) error {
statusStr := "UNKNOWN" statusStr := "UNKNOWN"
switch status { switch status {
...@@ -2618,3 +3121,55 @@ func cleanGeminiRequest(body []byte) ([]byte, error) { ...@@ -2618,3 +3121,55 @@ func cleanGeminiRequest(body []byte) ([]byte, error) {
return json.Marshal(payload) return json.Marshal(payload)
} }
// filterEmptyPartsFromGeminiRequest 过滤 Gemini 请求中 parts 为空的消息
// Gemini API 不接受 parts 为空数组的消息,会返回 400 错误
func filterEmptyPartsFromGeminiRequest(body []byte) ([]byte, error) {
var payload map[string]any
if err := json.Unmarshal(body, &payload); err != nil {
return nil, err
}
contents, ok := payload["contents"].([]any)
if !ok || len(contents) == 0 {
return body, nil
}
filtered := make([]any, 0, len(contents))
modified := false
for _, c := range contents {
contentMap, ok := c.(map[string]any)
if !ok {
filtered = append(filtered, c)
continue
}
parts, hasParts := contentMap["parts"]
if !hasParts {
filtered = append(filtered, c)
continue
}
partsSlice, ok := parts.([]any)
if !ok {
filtered = append(filtered, c)
continue
}
// 跳过 parts 为空数组的消息
if len(partsSlice) == 0 {
modified = true
continue
}
filtered = append(filtered, c)
}
if !modified {
return body, nil
}
payload["contents"] = filtered
return json.Marshal(payload)
}
package service package service
import ( import (
"bytes"
"context"
"encoding/json" "encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
...@@ -81,3 +87,106 @@ func TestStripThinkingFromClaudeRequest_DoesNotDowngradeTools(t *testing.T) { ...@@ -81,3 +87,106 @@ func TestStripThinkingFromClaudeRequest_DoesNotDowngradeTools(t *testing.T) {
require.Equal(t, "secret plan", blocks[0]["text"]) require.Equal(t, "secret plan", blocks[0]["text"])
require.Equal(t, "tool_use", blocks[1]["type"]) require.Equal(t, "tool_use", blocks[1]["type"])
} }
func TestIsPromptTooLongError(t *testing.T) {
require.True(t, isPromptTooLongError([]byte(`{"error":{"message":"Prompt is too long"}}`)))
require.True(t, isPromptTooLongError([]byte(`{"message":"Prompt is too long"}`)))
require.False(t, isPromptTooLongError([]byte(`{"error":{"message":"other"}}`)))
}
type httpUpstreamStub struct {
resp *http.Response
err error
}
func (s *httpUpstreamStub) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
return s.resp, s.err
}
func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int, _ bool) (*http.Response, error) {
return s.resp, s.err
}
func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
body, err := json.Marshal(map[string]any{
"model": "claude-opus-4-5",
"messages": []map[string]any{
{"role": "user", "content": "hi"},
},
"max_tokens": 1,
"stream": false,
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
c.Request = req
respBody := []byte(`{"error":{"message":"Prompt is too long"}}`)
resp := &http.Response{
StatusCode: http.StatusBadRequest,
Header: http.Header{"X-Request-Id": []string{"req-1"}},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
svc := &AntigravityGatewayService{
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: &httpUpstreamStub{resp: resp},
}
account := &Account{
ID: 1,
Name: "acc-1",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "token",
},
}
result, err := svc.Forward(context.Background(), c, account, body)
require.Nil(t, result)
var promptErr *PromptTooLongError
require.ErrorAs(t, err, &promptErr)
require.Equal(t, http.StatusBadRequest, promptErr.StatusCode)
require.Equal(t, "req-1", promptErr.RequestID)
require.NotEmpty(t, promptErr.Body)
raw, ok := c.Get(OpsUpstreamErrorsKey)
require.True(t, ok)
events, ok := raw.([]*OpsUpstreamErrorEvent)
require.True(t, ok)
require.Len(t, events, 1)
require.Equal(t, "prompt_too_long", events[0].Kind)
}
func TestAntigravityMaxRetriesForModel_AfterSwitch(t *testing.T) {
t.Setenv(antigravityMaxRetriesEnv, "4")
t.Setenv(antigravityMaxRetriesAfterSwitchEnv, "7")
t.Setenv(antigravityMaxRetriesClaudeEnv, "")
t.Setenv(antigravityMaxRetriesGeminiTextEnv, "")
t.Setenv(antigravityMaxRetriesGeminiImageEnv, "")
got := antigravityMaxRetriesForModel("claude-sonnet-4-5", false)
require.Equal(t, 4, got)
got = antigravityMaxRetriesForModel("claude-sonnet-4-5", true)
require.Equal(t, 7, got)
}
func TestAntigravityMaxRetriesForModel_AfterSwitchFallback(t *testing.T) {
t.Setenv(antigravityMaxRetriesEnv, "5")
t.Setenv(antigravityMaxRetriesAfterSwitchEnv, "")
t.Setenv(antigravityMaxRetriesClaudeEnv, "")
t.Setenv(antigravityMaxRetriesGeminiTextEnv, "")
t.Setenv(antigravityMaxRetriesGeminiImageEnv, "")
got := antigravityMaxRetriesForModel("gemini-2.5-flash", true)
require.Equal(t, 5, got)
}
package service package service
import ( import (
"slices"
"strings" "strings"
"time" "time"
) )
...@@ -16,6 +17,21 @@ const ( ...@@ -16,6 +17,21 @@ const (
AntigravityQuotaScopeGeminiImage AntigravityQuotaScope = "gemini_image" AntigravityQuotaScopeGeminiImage AntigravityQuotaScope = "gemini_image"
) )
// IsScopeSupported 检查给定的 scope 是否在分组支持的 scope 列表中
func IsScopeSupported(supportedScopes []string, scope AntigravityQuotaScope) bool {
if len(supportedScopes) == 0 {
// 未配置时默认全部支持
return true
}
supported := slices.Contains(supportedScopes, string(scope))
return supported
}
// ResolveAntigravityQuotaScope 根据模型名称解析配额域(导出版本)
func ResolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) {
return resolveAntigravityQuotaScope(requestedModel)
}
// resolveAntigravityQuotaScope 根据模型名称解析配额域 // resolveAntigravityQuotaScope 根据模型名称解析配额域
func resolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) { func resolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) {
model := normalizeAntigravityModelName(requestedModel) model := normalizeAntigravityModelName(requestedModel)
......
...@@ -2,6 +2,14 @@ package service ...@@ -2,6 +2,14 @@ package service
import "time" import "time"
// API Key status constants
const (
StatusAPIKeyActive = "active"
StatusAPIKeyDisabled = "disabled"
StatusAPIKeyQuotaExhausted = "quota_exhausted"
StatusAPIKeyExpired = "expired"
)
type APIKey struct { type APIKey struct {
ID int64 ID int64
UserID int64 UserID int64
...@@ -15,8 +23,53 @@ type APIKey struct { ...@@ -15,8 +23,53 @@ type APIKey struct {
UpdatedAt time.Time UpdatedAt time.Time
User *User User *User
Group *Group Group *Group
// Quota fields
Quota float64 // Quota limit in USD (0 = unlimited)
QuotaUsed float64 // Used quota amount
ExpiresAt *time.Time // Expiration time (nil = never expires)
} }
func (k *APIKey) IsActive() bool { func (k *APIKey) IsActive() bool {
return k.Status == StatusActive return k.Status == StatusActive
} }
// IsExpired checks if the API key has expired
func (k *APIKey) IsExpired() bool {
if k.ExpiresAt == nil {
return false
}
return time.Now().After(*k.ExpiresAt)
}
// IsQuotaExhausted checks if the API key quota is exhausted
func (k *APIKey) IsQuotaExhausted() bool {
if k.Quota <= 0 {
return false // unlimited
}
return k.QuotaUsed >= k.Quota
}
// GetQuotaRemaining returns remaining quota (-1 for unlimited)
func (k *APIKey) GetQuotaRemaining() float64 {
if k.Quota <= 0 {
return -1 // unlimited
}
remaining := k.Quota - k.QuotaUsed
if remaining < 0 {
return 0
}
return remaining
}
// GetDaysUntilExpiry returns days until expiry (-1 for never expires)
func (k *APIKey) GetDaysUntilExpiry() int {
if k.ExpiresAt == nil {
return -1 // never expires
}
duration := time.Until(*k.ExpiresAt)
if duration < 0 {
return 0
}
return int(duration.Hours() / 24)
}
package service package service
import "time"
// APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段) // APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段)
type APIKeyAuthSnapshot struct { type APIKeyAuthSnapshot struct {
APIKeyID int64 `json:"api_key_id"` APIKeyID int64 `json:"api_key_id"`
...@@ -10,6 +12,13 @@ type APIKeyAuthSnapshot struct { ...@@ -10,6 +12,13 @@ type APIKeyAuthSnapshot struct {
IPBlacklist []string `json:"ip_blacklist,omitempty"` IPBlacklist []string `json:"ip_blacklist,omitempty"`
User APIKeyAuthUserSnapshot `json:"user"` User APIKeyAuthUserSnapshot `json:"user"`
Group *APIKeyAuthGroupSnapshot `json:"group,omitempty"` Group *APIKeyAuthGroupSnapshot `json:"group,omitempty"`
// Quota fields for API Key independent quota feature
Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited)
QuotaUsed float64 `json:"quota_used"` // Used quota amount
// Expiration field for API Key expiration feature
ExpiresAt *time.Time `json:"expires_at,omitempty"` // Expiration time (nil = never expires)
} }
// APIKeyAuthUserSnapshot 用户快照 // APIKeyAuthUserSnapshot 用户快照
...@@ -23,25 +32,30 @@ type APIKeyAuthUserSnapshot struct { ...@@ -23,25 +32,30 @@ type APIKeyAuthUserSnapshot struct {
// APIKeyAuthGroupSnapshot 分组快照 // APIKeyAuthGroupSnapshot 分组快照
type APIKeyAuthGroupSnapshot struct { type APIKeyAuthGroupSnapshot struct {
ID int64 `json:"id"` ID int64 `json:"id"`
Name string `json:"name"` Name string `json:"name"`
Platform string `json:"platform"` Platform string `json:"platform"`
Status string `json:"status"` Status string `json:"status"`
SubscriptionType string `json:"subscription_type"` SubscriptionType string `json:"subscription_type"`
RateMultiplier float64 `json:"rate_multiplier"` RateMultiplier float64 `json:"rate_multiplier"`
DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"` DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"`
WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"` WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"` MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"`
ImagePrice1K *float64 `json:"image_price_1k,omitempty"` ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
ImagePrice2K *float64 `json:"image_price_2k,omitempty"` ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
ImagePrice4K *float64 `json:"image_price_4k,omitempty"` ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
ClaudeCodeOnly bool `json:"claude_code_only"` ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"`
// Model routing is used by gateway account selection, so it must be part of auth cache snapshot. // Model routing is used by gateway account selection, so it must be part of auth cache snapshot.
// Only anthropic groups use these fields; others may leave them empty. // Only anthropic groups use these fields; others may leave them empty.
ModelRouting map[string][]int64 `json:"model_routing,omitempty"` ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
ModelRoutingEnabled bool `json:"model_routing_enabled"` ModelRoutingEnabled bool `json:"model_routing_enabled"`
MCPXMLInject bool `json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
} }
// APIKeyAuthCacheEntry 缓存条目,支持负缓存 // APIKeyAuthCacheEntry 缓存条目,支持负缓存
......
...@@ -213,6 +213,9 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { ...@@ -213,6 +213,9 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
Status: apiKey.Status, Status: apiKey.Status,
IPWhitelist: apiKey.IPWhitelist, IPWhitelist: apiKey.IPWhitelist,
IPBlacklist: apiKey.IPBlacklist, IPBlacklist: apiKey.IPBlacklist,
Quota: apiKey.Quota,
QuotaUsed: apiKey.QuotaUsed,
ExpiresAt: apiKey.ExpiresAt,
User: APIKeyAuthUserSnapshot{ User: APIKeyAuthUserSnapshot{
ID: apiKey.User.ID, ID: apiKey.User.ID,
Status: apiKey.User.Status, Status: apiKey.User.Status,
...@@ -223,22 +226,25 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { ...@@ -223,22 +226,25 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
} }
if apiKey.Group != nil { if apiKey.Group != nil {
snapshot.Group = &APIKeyAuthGroupSnapshot{ snapshot.Group = &APIKeyAuthGroupSnapshot{
ID: apiKey.Group.ID, ID: apiKey.Group.ID,
Name: apiKey.Group.Name, Name: apiKey.Group.Name,
Platform: apiKey.Group.Platform, Platform: apiKey.Group.Platform,
Status: apiKey.Group.Status, Status: apiKey.Group.Status,
SubscriptionType: apiKey.Group.SubscriptionType, SubscriptionType: apiKey.Group.SubscriptionType,
RateMultiplier: apiKey.Group.RateMultiplier, RateMultiplier: apiKey.Group.RateMultiplier,
DailyLimitUSD: apiKey.Group.DailyLimitUSD, DailyLimitUSD: apiKey.Group.DailyLimitUSD,
WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD, WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD,
MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD, MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD,
ImagePrice1K: apiKey.Group.ImagePrice1K, ImagePrice1K: apiKey.Group.ImagePrice1K,
ImagePrice2K: apiKey.Group.ImagePrice2K, ImagePrice2K: apiKey.Group.ImagePrice2K,
ImagePrice4K: apiKey.Group.ImagePrice4K, ImagePrice4K: apiKey.Group.ImagePrice4K,
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly, ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
FallbackGroupID: apiKey.Group.FallbackGroupID, FallbackGroupID: apiKey.Group.FallbackGroupID,
ModelRouting: apiKey.Group.ModelRouting, FallbackGroupIDOnInvalidRequest: apiKey.Group.FallbackGroupIDOnInvalidRequest,
ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled, ModelRouting: apiKey.Group.ModelRouting,
ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled,
MCPXMLInject: apiKey.Group.MCPXMLInject,
SupportedModelScopes: apiKey.Group.SupportedModelScopes,
} }
} }
return snapshot return snapshot
...@@ -256,6 +262,9 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho ...@@ -256,6 +262,9 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
Status: snapshot.Status, Status: snapshot.Status,
IPWhitelist: snapshot.IPWhitelist, IPWhitelist: snapshot.IPWhitelist,
IPBlacklist: snapshot.IPBlacklist, IPBlacklist: snapshot.IPBlacklist,
Quota: snapshot.Quota,
QuotaUsed: snapshot.QuotaUsed,
ExpiresAt: snapshot.ExpiresAt,
User: &User{ User: &User{
ID: snapshot.User.ID, ID: snapshot.User.ID,
Status: snapshot.User.Status, Status: snapshot.User.Status,
...@@ -266,23 +275,26 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho ...@@ -266,23 +275,26 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
} }
if snapshot.Group != nil { if snapshot.Group != nil {
apiKey.Group = &Group{ apiKey.Group = &Group{
ID: snapshot.Group.ID, ID: snapshot.Group.ID,
Name: snapshot.Group.Name, Name: snapshot.Group.Name,
Platform: snapshot.Group.Platform, Platform: snapshot.Group.Platform,
Status: snapshot.Group.Status, Status: snapshot.Group.Status,
Hydrated: true, Hydrated: true,
SubscriptionType: snapshot.Group.SubscriptionType, SubscriptionType: snapshot.Group.SubscriptionType,
RateMultiplier: snapshot.Group.RateMultiplier, RateMultiplier: snapshot.Group.RateMultiplier,
DailyLimitUSD: snapshot.Group.DailyLimitUSD, DailyLimitUSD: snapshot.Group.DailyLimitUSD,
WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD, WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD,
MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD, MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD,
ImagePrice1K: snapshot.Group.ImagePrice1K, ImagePrice1K: snapshot.Group.ImagePrice1K,
ImagePrice2K: snapshot.Group.ImagePrice2K, ImagePrice2K: snapshot.Group.ImagePrice2K,
ImagePrice4K: snapshot.Group.ImagePrice4K, ImagePrice4K: snapshot.Group.ImagePrice4K,
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly, ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
FallbackGroupID: snapshot.Group.FallbackGroupID, FallbackGroupID: snapshot.Group.FallbackGroupID,
ModelRouting: snapshot.Group.ModelRouting, FallbackGroupIDOnInvalidRequest: snapshot.Group.FallbackGroupIDOnInvalidRequest,
ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled, ModelRouting: snapshot.Group.ModelRouting,
ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled,
MCPXMLInject: snapshot.Group.MCPXMLInject,
SupportedModelScopes: snapshot.Group.SupportedModelScopes,
} }
} }
return apiKey return apiKey
......
...@@ -24,6 +24,10 @@ var ( ...@@ -24,6 +24,10 @@ var (
ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens") ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later") ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
ErrInvalidIPPattern = infraerrors.BadRequest("INVALID_IP_PATTERN", "invalid IP or CIDR pattern") ErrInvalidIPPattern = infraerrors.BadRequest("INVALID_IP_PATTERN", "invalid IP or CIDR pattern")
// ErrAPIKeyExpired = infraerrors.Forbidden("API_KEY_EXPIRED", "api key has expired")
ErrAPIKeyExpired = infraerrors.Forbidden("API_KEY_EXPIRED", "api key 已过期")
// ErrAPIKeyQuotaExhausted = infraerrors.TooManyRequests("API_KEY_QUOTA_EXHAUSTED", "api key quota exhausted")
ErrAPIKeyQuotaExhausted = infraerrors.TooManyRequests("API_KEY_QUOTA_EXHAUSTED", "api key 额度已用完")
) )
const ( const (
...@@ -51,6 +55,9 @@ type APIKeyRepository interface { ...@@ -51,6 +55,9 @@ type APIKeyRepository interface {
CountByGroupID(ctx context.Context, groupID int64) (int64, error) CountByGroupID(ctx context.Context, groupID int64) (int64, error)
ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error)
ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error)
// Quota methods
IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error)
} }
// APIKeyCache defines cache operations for API key service // APIKeyCache defines cache operations for API key service
...@@ -85,6 +92,10 @@ type CreateAPIKeyRequest struct { ...@@ -85,6 +92,10 @@ type CreateAPIKeyRequest struct {
CustomKey *string `json:"custom_key"` // 可选的自定义key CustomKey *string `json:"custom_key"` // 可选的自定义key
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单 IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单 IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
// Quota fields
Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited)
ExpiresInDays *int `json:"expires_in_days"` // Days until expiry (nil = never expires)
} }
// UpdateAPIKeyRequest 更新API Key请求 // UpdateAPIKeyRequest 更新API Key请求
...@@ -94,6 +105,12 @@ type UpdateAPIKeyRequest struct { ...@@ -94,6 +105,12 @@ type UpdateAPIKeyRequest struct {
Status *string `json:"status"` Status *string `json:"status"`
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单(空数组清空) IPWhitelist []string `json:"ip_whitelist"` // IP 白名单(空数组清空)
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单(空数组清空) IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单(空数组清空)
// Quota fields
Quota *float64 `json:"quota"` // Quota limit in USD (nil = no change, 0 = unlimited)
ExpiresAt *time.Time `json:"expires_at"` // Expiration time (nil = no change)
ClearExpiration bool `json:"-"` // Clear expiration (internal use)
ResetQuota *bool `json:"reset_quota"` // Reset quota_used to 0
} }
// APIKeyService API Key服务 // APIKeyService API Key服务
...@@ -289,6 +306,14 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK ...@@ -289,6 +306,14 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
Status: StatusActive, Status: StatusActive,
IPWhitelist: req.IPWhitelist, IPWhitelist: req.IPWhitelist,
IPBlacklist: req.IPBlacklist, IPBlacklist: req.IPBlacklist,
Quota: req.Quota,
QuotaUsed: 0,
}
// Set expiration time if specified
if req.ExpiresInDays != nil && *req.ExpiresInDays > 0 {
expiresAt := time.Now().AddDate(0, 0, *req.ExpiresInDays)
apiKey.ExpiresAt = &expiresAt
} }
if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil { if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil {
...@@ -436,6 +461,35 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req ...@@ -436,6 +461,35 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
} }
} }
// Update quota fields
if req.Quota != nil {
apiKey.Quota = *req.Quota
// If quota is increased and status was quota_exhausted, reactivate
if apiKey.Status == StatusAPIKeyQuotaExhausted && *req.Quota > apiKey.QuotaUsed {
apiKey.Status = StatusActive
}
}
if req.ResetQuota != nil && *req.ResetQuota {
apiKey.QuotaUsed = 0
// If resetting quota and status was quota_exhausted, reactivate
if apiKey.Status == StatusAPIKeyQuotaExhausted {
apiKey.Status = StatusActive
}
}
if req.ClearExpiration {
apiKey.ExpiresAt = nil
// If clearing expiry and status was expired, reactivate
if apiKey.Status == StatusAPIKeyExpired {
apiKey.Status = StatusActive
}
} else if req.ExpiresAt != nil {
apiKey.ExpiresAt = req.ExpiresAt
// If extending expiry and status was expired, reactivate
if apiKey.Status == StatusAPIKeyExpired && time.Now().Before(*req.ExpiresAt) {
apiKey.Status = StatusActive
}
}
// 更新 IP 限制(空数组会清空设置) // 更新 IP 限制(空数组会清空设置)
apiKey.IPWhitelist = req.IPWhitelist apiKey.IPWhitelist = req.IPWhitelist
apiKey.IPBlacklist = req.IPBlacklist apiKey.IPBlacklist = req.IPBlacklist
...@@ -572,3 +626,51 @@ func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword ...@@ -572,3 +626,51 @@ func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword
} }
return keys, nil return keys, nil
} }
// CheckAPIKeyQuotaAndExpiry checks if the API key is valid for use (not expired, quota not exhausted)
// Returns nil if valid, error if invalid
func (s *APIKeyService) CheckAPIKeyQuotaAndExpiry(apiKey *APIKey) error {
// Check expiration
if apiKey.IsExpired() {
return ErrAPIKeyExpired
}
// Check quota
if apiKey.IsQuotaExhausted() {
return ErrAPIKeyQuotaExhausted
}
return nil
}
// UpdateQuotaUsed updates the quota_used field after a request
// Also checks if quota is exhausted and updates status accordingly
func (s *APIKeyService) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error {
if cost <= 0 {
return nil
}
// Use repository to atomically increment quota_used
newQuotaUsed, err := s.apiKeyRepo.IncrementQuotaUsed(ctx, apiKeyID, cost)
if err != nil {
return fmt.Errorf("increment quota used: %w", err)
}
// Check if quota is now exhausted and update status if needed
apiKey, err := s.apiKeyRepo.GetByID(ctx, apiKeyID)
if err != nil {
return nil // Don't fail the request, just log
}
// If quota is set and now exhausted, update status
if apiKey.Quota > 0 && newQuotaUsed >= apiKey.Quota {
apiKey.Status = StatusAPIKeyQuotaExhausted
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
return nil // Don't fail the request
}
// Invalidate cache so next request sees the new status
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
}
return nil
}
...@@ -99,6 +99,10 @@ func (s *authRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([] ...@@ -99,6 +99,10 @@ func (s *authRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]
return s.listKeysByGroupID(ctx, groupID) return s.listKeysByGroupID(ctx, groupID)
} }
func (s *authRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) {
panic("unexpected IncrementQuotaUsed call")
}
type authCacheStub struct { type authCacheStub struct {
getAuthCache func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) getAuthCache func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)
setAuthKeys []string setAuthKeys []string
......
...@@ -118,6 +118,10 @@ func (s *apiKeyRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ( ...@@ -118,6 +118,10 @@ func (s *apiKeyRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) (
panic("unexpected ListKeysByGroupID call") panic("unexpected ListKeysByGroupID call")
} }
func (s *apiKeyRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) {
panic("unexpected IncrementQuotaUsed call")
}
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。 // apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
// 用于验证删除操作时缓存清理逻辑是否被正确调用。 // 用于验证删除操作时缓存清理逻辑是否被正确调用。
// //
......
...@@ -185,7 +185,6 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw ...@@ -185,7 +185,6 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
log.Printf("[Auth] Failed to mark invitation code as used for user %d: %v", user.ID, err) log.Printf("[Auth] Failed to mark invitation code as used for user %d: %v", user.ID, err)
} }
} }
// 应用优惠码(如果提供且功能已启用) // 应用优惠码(如果提供且功能已启用)
if promoCode != "" && s.promoService != nil && s.settingService != nil && s.settingService.IsPromoCodeEnabled(ctx) { if promoCode != "" && s.promoService != nil && s.settingService != nil && s.settingService.IsPromoCodeEnabled(ctx) {
if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil { if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil {
......
...@@ -31,6 +31,7 @@ const ( ...@@ -31,6 +31,7 @@ const (
AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference) AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference)
AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope) AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope)
AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号 AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游)
) )
// Redeem type constants // Redeem type constants
......
...@@ -257,6 +257,9 @@ var ( ...@@ -257,6 +257,9 @@ var (
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问 // ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients") var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients")
// ErrModelScopeNotSupported 表示请求的模型系列不在分组支持的范围内
var ErrModelScopeNotSupported = errors.New("model scope not supported by this group")
// allowedHeaders 白名单headers(参考CRS项目) // allowedHeaders 白名单headers(参考CRS项目)
var allowedHeaders = map[string]bool{ var allowedHeaders = map[string]bool{
"accept": true, "accept": true,
...@@ -585,12 +588,18 @@ func (s *GatewayService) hashContent(content string) string { ...@@ -585,12 +588,18 @@ func (s *GatewayService) hashContent(content string) string {
} }
// replaceModelInBody 替换请求体中的model字段 // replaceModelInBody 替换请求体中的model字段
// 使用 json.RawMessage 保留其他字段的原始字节,避免 thinking 块等内容被修改
func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte { func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte {
var req map[string]any var req map[string]json.RawMessage
if err := json.Unmarshal(body, &req); err != nil { if err := json.Unmarshal(body, &req); err != nil {
return body return body
} }
req["model"] = newModel // 只序列化 model 字段
modelBytes, err := json.Marshal(newModel)
if err != nil {
return body
}
req["model"] = modelBytes
newBody, err := json.Marshal(req) newBody, err := json.Marshal(req)
if err != nil { if err != nil {
return body return body
...@@ -787,12 +796,21 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu ...@@ -787,12 +796,21 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
if len(body) == 0 { if len(body) == 0 {
return body, modelID, nil return body, modelID, nil
} }
// 使用 json.RawMessage 保留 messages 的原始字节,避免 thinking 块被修改
var reqRaw map[string]json.RawMessage
if err := json.Unmarshal(body, &reqRaw); err != nil {
return body, modelID, nil
}
// 同时解析为 map[string]any 用于修改非 messages 字段
var req map[string]any var req map[string]any
if err := json.Unmarshal(body, &req); err != nil { if err := json.Unmarshal(body, &req); err != nil {
return body, modelID, nil return body, modelID, nil
} }
toolNameMap := make(map[string]string) toolNameMap := make(map[string]string)
modified := false
if system, ok := req["system"]; ok { if system, ok := req["system"]; ok {
switch v := system.(type) { switch v := system.(type) {
...@@ -800,6 +818,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu ...@@ -800,6 +818,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
sanitized := sanitizeSystemText(v) sanitized := sanitizeSystemText(v)
if sanitized != v { if sanitized != v {
req["system"] = sanitized req["system"] = sanitized
modified = true
} }
case []any: case []any:
for _, item := range v { for _, item := range v {
...@@ -817,6 +836,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu ...@@ -817,6 +836,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
sanitized := sanitizeSystemText(text) sanitized := sanitizeSystemText(text)
if sanitized != text { if sanitized != text {
block["text"] = sanitized block["text"] = sanitized
modified = true
} }
} }
} }
...@@ -827,6 +847,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu ...@@ -827,6 +847,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
if normalized != rawModel { if normalized != rawModel {
req["model"] = normalized req["model"] = normalized
modelID = normalized modelID = normalized
modified = true
} }
} }
...@@ -842,16 +863,19 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu ...@@ -842,16 +863,19 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
normalized := normalizeToolNameForClaude(name, toolNameMap) normalized := normalizeToolNameForClaude(name, toolNameMap)
if normalized != "" && normalized != name { if normalized != "" && normalized != name {
toolMap["name"] = normalized toolMap["name"] = normalized
modified = true
} }
} }
if desc, ok := toolMap["description"].(string); ok { if desc, ok := toolMap["description"].(string); ok {
sanitized := sanitizeToolDescription(desc) sanitized := sanitizeToolDescription(desc)
if sanitized != desc { if sanitized != desc {
toolMap["description"] = sanitized toolMap["description"] = sanitized
modified = true
} }
} }
if schema, ok := toolMap["input_schema"]; ok { if schema, ok := toolMap["input_schema"]; ok {
normalizeToolInputSchema(schema, toolNameMap) normalizeToolInputSchema(schema, toolNameMap)
modified = true
} }
tools[idx] = toolMap tools[idx] = toolMap
} }
...@@ -880,11 +904,15 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu ...@@ -880,11 +904,15 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
normalizedTools[normalized] = value normalizedTools[normalized] = value
} }
req["tools"] = normalizedTools req["tools"] = normalizedTools
modified = true
} }
} else { } else {
req["tools"] = []any{} req["tools"] = []any{}
modified = true
} }
// 处理 messages 中的 tool_use 块,但保留包含 thinking 块的消息的原始字节
messagesModified := false
if messages, ok := req["messages"].([]any); ok { if messages, ok := req["messages"].([]any); ok {
for _, msg := range messages { for _, msg := range messages {
msgMap, ok := msg.(map[string]any) msgMap, ok := msg.(map[string]any)
...@@ -895,6 +923,24 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu ...@@ -895,6 +923,24 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
if !ok { if !ok {
continue continue
} }
// 检查此消息是否包含 thinking 块
hasThinking := false
for _, block := range content {
blockMap, ok := block.(map[string]any)
if !ok {
continue
}
blockType, _ := blockMap["type"].(string)
if blockType == "thinking" || blockType == "redacted_thinking" {
hasThinking = true
break
}
}
// 如果包含 thinking 块,跳过此消息的修改
if hasThinking {
continue
}
// 只修改不包含 thinking 块的消息中的 tool_use
for _, block := range content { for _, block := range content {
blockMap, ok := block.(map[string]any) blockMap, ok := block.(map[string]any)
if !ok { if !ok {
...@@ -907,6 +953,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu ...@@ -907,6 +953,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
normalized := normalizeToolNameForClaude(name, toolNameMap) normalized := normalizeToolNameForClaude(name, toolNameMap)
if normalized != "" && normalized != name { if normalized != "" && normalized != name {
blockMap["name"] = normalized blockMap["name"] = normalized
messagesModified = true
} }
} }
} }
...@@ -916,6 +963,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu ...@@ -916,6 +963,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
if opts.stripSystemCacheControl { if opts.stripSystemCacheControl {
if system, ok := req["system"]; ok { if system, ok := req["system"]; ok {
_ = stripCacheControlFromSystemBlocks(system) _ = stripCacheControlFromSystemBlocks(system)
modified = true
} }
} }
...@@ -927,12 +975,46 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu ...@@ -927,12 +975,46 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
} }
if existing, ok := metadata["user_id"].(string); !ok || existing == "" { if existing, ok := metadata["user_id"].(string); !ok || existing == "" {
metadata["user_id"] = opts.metadataUserID metadata["user_id"] = opts.metadataUserID
modified = true
} }
} }
delete(req, "temperature") if _, hasTemp := req["temperature"]; hasTemp {
delete(req, "tool_choice") delete(req, "temperature")
modified = true
}
if _, hasChoice := req["tool_choice"]; hasChoice {
delete(req, "tool_choice")
modified = true
}
if !modified && !messagesModified {
return body, modelID, toolNameMap
}
// 如果 messages 没有被修改,保留原始 messages 字节
if !messagesModified {
// 序列化非 messages 字段
newBody, err := json.Marshal(req)
if err != nil {
return body, modelID, toolNameMap
}
// 替换回原始的 messages
var newReq map[string]json.RawMessage
if err := json.Unmarshal(newBody, &newReq); err != nil {
return newBody, modelID, toolNameMap
}
if origMessages, ok := reqRaw["messages"]; ok {
newReq["messages"] = origMessages
}
finalBody, err := json.Marshal(newReq)
if err != nil {
return newBody, modelID, toolNameMap
}
return finalBody, modelID, toolNameMap
}
// messages 被修改了,需要完整序列化
newBody, err := json.Marshal(req) newBody, err := json.Marshal(req)
if err != nil { if err != nil {
return body, modelID, toolNameMap return body, modelID, toolNameMap
...@@ -1135,6 +1217,13 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1135,6 +1217,13 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
log.Printf("[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), platform) log.Printf("[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), platform)
} }
// Antigravity 模型系列检查(在账号选择前检查,确保所有代码路径都经过此检查)
if platform == PlatformAntigravity && groupID != nil && requestedModel != "" {
if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil {
return nil, err
}
}
accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -1632,6 +1721,10 @@ func (s *GatewayService) resolveGroupByID(ctx context.Context, groupID int64) (* ...@@ -1632,6 +1721,10 @@ func (s *GatewayService) resolveGroupByID(ctx context.Context, groupID int64) (*
return group, nil return group, nil
} }
func (s *GatewayService) ResolveGroupByID(ctx context.Context, groupID int64) (*Group, error) {
return s.resolveGroupByID(ctx, groupID)
}
func (s *GatewayService) routingAccountIDsForRequest(ctx context.Context, groupID *int64, requestedModel string, platform string) []int64 { func (s *GatewayService) routingAccountIDsForRequest(ctx context.Context, groupID *int64, requestedModel string, platform string) []int64 {
if groupID == nil || requestedModel == "" || platform != PlatformAnthropic { if groupID == nil || requestedModel == "" || platform != PlatformAnthropic {
return nil return nil
...@@ -1697,7 +1790,7 @@ func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID ...@@ -1697,7 +1790,7 @@ func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID
} }
// 强制平台模式不检查 Claude Code 限制 // 强制平台模式不检查 Claude Code 限制
if _, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform { if forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform && forcePlatform != "" {
return nil, groupID, nil return nil, groupID, nil
} }
...@@ -2026,6 +2119,13 @@ func shuffleWithinPriority(accounts []*Account) { ...@@ -2026,6 +2119,13 @@ func shuffleWithinPriority(accounts []*Account) {
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离) // selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) { func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
// 对 Antigravity 平台,检查请求的模型系列是否在分组支持范围内
if platform == PlatformAntigravity && groupID != nil && requestedModel != "" {
if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil {
return nil, err
}
}
preferOAuth := platform == PlatformGemini preferOAuth := platform == PlatformGemini
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform) routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform)
...@@ -2461,6 +2561,10 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo ...@@ -2461,6 +2561,10 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
// Antigravity 平台使用专门的模型支持检查 // Antigravity 平台使用专门的模型支持检查
return IsAntigravityModelSupported(requestedModel) return IsAntigravityModelSupported(requestedModel)
} }
// OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID)
if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
requestedModel = claude.NormalizeModelID(requestedModel)
}
// Gemini API Key 账户直接透传,由上游判断模型是否支持 // Gemini API Key 账户直接透传,由上游判断模型是否支持
if account.Platform == PlatformGemini && account.Type == AccountTypeAPIKey { if account.Platform == PlatformGemini && account.Type == AccountTypeAPIKey {
return true return true
...@@ -2910,16 +3014,30 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -2910,16 +3014,30 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 强制执行 cache_control 块数量限制(最多 4 个) // 强制执行 cache_control 块数量限制(最多 4 个)
body = enforceCacheControlLimit(body) body = enforceCacheControlLimit(body)
// 应用模型映射(仅对apikey类型账号) // 应用模型映射:
// - APIKey 账号:使用账号级别的显式映射(如果配置),否则透传原始模型名
// - OAuth/SetupToken 账号:使用 Anthropic 标准映射(短ID → 长ID)
mappedModel := reqModel
mappingSource := ""
if account.Type == AccountTypeAPIKey { if account.Type == AccountTypeAPIKey {
mappedModel := account.GetMappedModel(reqModel) mappedModel = account.GetMappedModel(reqModel)
if mappedModel != reqModel { if mappedModel != reqModel {
// 替换请求体中的模型名 mappingSource = "account"
body = s.replaceModelInBody(body, mappedModel)
reqModel = mappedModel
log.Printf("Model mapping applied: %s -> %s (account: %s)", originalModel, mappedModel, account.Name)
} }
} }
if mappingSource == "" && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
normalized := claude.NormalizeModelID(reqModel)
if normalized != reqModel {
mappedModel = normalized
mappingSource = "prefix"
}
}
if mappedModel != reqModel {
// 替换请求体中的模型名
body = s.replaceModelInBody(body, mappedModel)
reqModel = mappedModel
log.Printf("Model mapping applied: %s -> %s (account: %s, source=%s)", originalModel, mappedModel, account.Name, mappingSource)
}
// 获取凭证 // 获取凭证
token, tokenType, err := s.GetAccessToken(ctx, account) token, tokenType, err := s.GetAccessToken(ctx, account)
...@@ -3621,6 +3739,13 @@ func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool { ...@@ -3621,6 +3739,13 @@ func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool {
return true return true
} }
// 检测 thinking block 被修改的错误
// 例如: "thinking or redacted_thinking blocks in the latest assistant message cannot be modified"
if strings.Contains(msg, "cannot be modified") && (strings.Contains(msg, "thinking") || strings.Contains(msg, "redacted_thinking")) {
log.Printf("[SignatureCheck] Detected thinking block modification error")
return true
}
// 检测空消息内容错误(可能是过滤 thinking blocks 后导致的) // 检测空消息内容错误(可能是过滤 thinking blocks 后导致的)
// 例如: "all messages must have non-empty content" // 例如: "all messages must have non-empty content"
if strings.Contains(msg, "non-empty content") || strings.Contains(msg, "empty content") { if strings.Contains(msg, "non-empty content") || strings.Contains(msg, "empty content") {
...@@ -4489,13 +4614,19 @@ func (s *GatewayService) replaceToolNamesInResponseBody(body []byte, toolNameMap ...@@ -4489,13 +4614,19 @@ func (s *GatewayService) replaceToolNamesInResponseBody(body []byte, toolNameMap
// RecordUsageInput 记录使用量的输入参数 // RecordUsageInput 记录使用量的输入参数
type RecordUsageInput struct { type RecordUsageInput struct {
Result *ForwardResult Result *ForwardResult
APIKey *APIKey APIKey *APIKey
User *User User *User
Account *Account Account *Account
Subscription *UserSubscription // 可选:订阅信息 Subscription *UserSubscription // 可选:订阅信息
UserAgent string // 请求的 User-Agent UserAgent string // 请求的 User-Agent
IPAddress string // 请求的客户端 IP 地址 IPAddress string // 请求的客户端 IP 地址
APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额
}
// APIKeyQuotaUpdater defines the interface for updating API Key quota
type APIKeyQuotaUpdater interface {
UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error
} }
// RecordUsage 记录使用量并扣费(或更新订阅用量) // RecordUsage 记录使用量并扣费(或更新订阅用量)
...@@ -4635,6 +4766,13 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -4635,6 +4766,13 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
} }
} }
// 更新 API Key 配额(如果设置了配额限制)
if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil {
if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil {
log.Printf("Update API key quota failed: %v", err)
}
}
// Schedule batch update for account last_used_at // Schedule batch update for account last_used_at
s.deferredService.ScheduleLastUsedUpdate(account.ID) s.deferredService.ScheduleLastUsedUpdate(account.ID)
...@@ -4652,6 +4790,7 @@ type RecordUsageLongContextInput struct { ...@@ -4652,6 +4790,7 @@ type RecordUsageLongContextInput struct {
IPAddress string // 请求的客户端 IP 地址 IPAddress string // 请求的客户端 IP 地址
LongContextThreshold int // 长上下文阈值(如 200000) LongContextThreshold int // 长上下文阈值(如 200000)
LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0) LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0)
APIKeyService *APIKeyService // API Key 配额服务(可选)
} }
// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini) // RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini)
...@@ -4788,6 +4927,12 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * ...@@ -4788,6 +4927,12 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
} }
// 异步更新余额缓存 // 异步更新余额缓存
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost) s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
// API Key 独立配额扣费
if input.APIKeyService != nil && apiKey.Quota > 0 {
if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil {
log.Printf("Add API key quota used failed: %v", err)
}
}
} }
} }
...@@ -4822,16 +4967,30 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, ...@@ -4822,16 +4967,30 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
return nil return nil
} }
// 应用模型映射(仅对 apikey 类型账号) // 应用模型映射:
if account.Type == AccountTypeAPIKey { // - APIKey 账号:使用账号级别的显式映射(如果配置),否则透传原始模型名
if reqModel != "" { // - OAuth/SetupToken 账号:使用 Anthropic 标准映射(短ID → 长ID)
mappedModel := account.GetMappedModel(reqModel) if reqModel != "" {
mappedModel := reqModel
mappingSource := ""
if account.Type == AccountTypeAPIKey {
mappedModel = account.GetMappedModel(reqModel)
if mappedModel != reqModel { if mappedModel != reqModel {
body = s.replaceModelInBody(body, mappedModel) mappingSource = "account"
reqModel = mappedModel
log.Printf("CountTokens model mapping applied: %s -> %s (account: %s)", parsed.Model, mappedModel, account.Name)
} }
} }
if mappingSource == "" && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
normalized := claude.NormalizeModelID(reqModel)
if normalized != reqModel {
mappedModel = normalized
mappingSource = "prefix"
}
}
if mappedModel != reqModel {
body = s.replaceModelInBody(body, mappedModel)
reqModel = mappedModel
log.Printf("CountTokens model mapping applied: %s -> %s (account: %s, source=%s)", parsed.Model, mappedModel, account.Name, mappingSource)
}
} }
// 获取凭证 // 获取凭证
...@@ -5083,6 +5242,27 @@ func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) { ...@@ -5083,6 +5242,27 @@ func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
return normalized, nil return normalized, nil
} }
// checkAntigravityModelScope 检查 Antigravity 平台的模型系列是否在分组支持范围内
func (s *GatewayService) checkAntigravityModelScope(ctx context.Context, groupID int64, requestedModel string) error {
scope, ok := ResolveAntigravityQuotaScope(requestedModel)
if !ok {
return nil // 无法解析 scope,跳过检查
}
group, err := s.resolveGroupByID(ctx, groupID)
if err != nil {
return nil // 查询失败时放行
}
if group == nil {
return nil // 分组不存在时放行
}
if !IsScopeSupported(group.SupportedModelScopes, scope) {
return ErrModelScopeNotSupported
}
return nil
}
// GetAvailableModels returns the list of models available for a group // GetAvailableModels returns the list of models available for a group
// It aggregates model_mapping keys from all schedulable accounts in the group // It aggregates model_mapping keys from all schedulable accounts in the group
func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string { func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string {
......
...@@ -977,6 +977,11 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. ...@@ -977,6 +977,11 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty") return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty")
} }
// 过滤掉 parts 为空的消息(Gemini API 不接受空 parts)
if filteredBody, err := filterEmptyPartsFromGeminiRequest(body); err == nil {
body = filteredBody
}
switch action { switch action {
case "generateContent", "streamGenerateContent", "countTokens": case "generateContent", "streamGenerateContent", "countTokens":
// ok // ok
......
...@@ -2,20 +2,22 @@ package service ...@@ -2,20 +2,22 @@ package service
import ( import (
"encoding/json" "encoding/json"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
) )
// CleanGeminiNativeThoughtSignatures 从 Gemini 原生 API 请求中移除 thoughtSignature 字段, // CleanGeminiNativeThoughtSignatures 从 Gemini 原生 API 请求中替换 thoughtSignature 字段为 dummy 签名
// 以避免跨账号签名验证错误。 // 以避免跨账号签名验证错误。
// //
// 当粘性会话切换账号时(例如原账号异常、不可调度等),旧账号返回的 thoughtSignature // 当粘性会话切换账号时(例如原账号异常、不可调度等),旧账号返回的 thoughtSignature
// 会导致新账号的签名验证失败。通过移除这些签名,让新账号重新生成有效的签名 // 会导致新账号的签名验证失败。通过替换为 dummy 签名,跳过签名验证
// //
// CleanGeminiNativeThoughtSignatures removes thoughtSignature fields from Gemini native API requests // CleanGeminiNativeThoughtSignatures replaces thoughtSignature fields with dummy signature
// to avoid cross-account signature validation errors. // in Gemini native API requests to avoid cross-account signature validation errors.
// //
// When sticky session switches accounts (e.g., original account becomes unavailable), // When sticky session switches accounts (e.g., original account becomes unavailable),
// thoughtSignatures from the old account will cause validation failures on the new account. // thoughtSignatures from the old account will cause validation failures on the new account.
// By removing these signatures, we allow the new account to generate valid signatures. // By replacing with dummy signature, we skip signature validation.
func CleanGeminiNativeThoughtSignatures(body []byte) []byte { func CleanGeminiNativeThoughtSignatures(body []byte) []byte {
if len(body) == 0 { if len(body) == 0 {
return body return body
...@@ -28,11 +30,11 @@ func CleanGeminiNativeThoughtSignatures(body []byte) []byte { ...@@ -28,11 +30,11 @@ func CleanGeminiNativeThoughtSignatures(body []byte) []byte {
return body return body
} }
// 递归清理 thoughtSignature // 递归替换 thoughtSignature 为 dummy 签名
cleaned := cleanThoughtSignaturesRecursive(data) replaced := replaceThoughtSignaturesRecursive(data)
// 重新序列化 // 重新序列化
result, err := json.Marshal(cleaned) result, err := json.Marshal(replaced)
if err != nil { if err != nil {
// 如果序列化失败,返回原始 body // 如果序列化失败,返回原始 body
return body return body
...@@ -41,19 +43,20 @@ func CleanGeminiNativeThoughtSignatures(body []byte) []byte { ...@@ -41,19 +43,20 @@ func CleanGeminiNativeThoughtSignatures(body []byte) []byte {
return result return result
} }
// cleanThoughtSignaturesRecursive 递归遍历数据结构,移除所有 thoughtSignature 字段 // replaceThoughtSignaturesRecursive 递归遍历数据结构,所有 thoughtSignature 字段替换为 dummy 签名
func cleanThoughtSignaturesRecursive(data any) any { func replaceThoughtSignaturesRecursive(data any) any {
switch v := data.(type) { switch v := data.(type) {
case map[string]any: case map[string]any:
// 创建新的 map,移除 thoughtSignature // 创建新的 map,替换 thoughtSignature 为 dummy 签名
result := make(map[string]any, len(v)) result := make(map[string]any, len(v))
for key, value := range v { for key, value := range v {
// 跳过 thoughtSignature 字段 // 替换 thoughtSignature 字段为 dummy 签名
if key == "thoughtSignature" { if key == "thoughtSignature" {
result[key] = antigravity.DummyThoughtSignature
continue continue
} }
// 递归处理嵌套结构 // 递归处理嵌套结构
result[key] = cleanThoughtSignaturesRecursive(value) result[key] = replaceThoughtSignaturesRecursive(value)
} }
return result return result
...@@ -61,7 +64,7 @@ func cleanThoughtSignaturesRecursive(data any) any { ...@@ -61,7 +64,7 @@ func cleanThoughtSignaturesRecursive(data any) any {
// 递归处理数组中的每个元素 // 递归处理数组中的每个元素
result := make([]any, len(v)) result := make([]any, len(v))
for i, item := range v { for i, item := range v {
result[i] = cleanThoughtSignaturesRecursive(item) result[i] = replaceThoughtSignaturesRecursive(item)
} }
return result return result
......
...@@ -29,6 +29,8 @@ type Group struct { ...@@ -29,6 +29,8 @@ type Group struct {
// Claude Code 客户端限制 // Claude Code 客户端限制
ClaudeCodeOnly bool ClaudeCodeOnly bool
FallbackGroupID *int64 FallbackGroupID *int64
// 无效请求兜底分组(仅 anthropic 平台使用)
FallbackGroupIDOnInvalidRequest *int64
// 模型路由配置 // 模型路由配置
// key: 模型匹配模式(支持 * 通配符,如 "claude-opus-*") // key: 模型匹配模式(支持 * 通配符,如 "claude-opus-*")
...@@ -36,6 +38,13 @@ type Group struct { ...@@ -36,6 +38,13 @@ type Group struct {
ModelRouting map[string][]int64 ModelRouting map[string][]int64
ModelRoutingEnabled bool ModelRoutingEnabled bool
// MCP XML 协议注入开关(仅 antigravity 平台使用)
MCPXMLInject bool
// 支持的模型系列(仅 antigravity 平台使用)
// 可选值: claude, gemini_text, gemini_image
SupportedModelScopes []string
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
......
...@@ -169,22 +169,31 @@ func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) { ...@@ -169,22 +169,31 @@ func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) {
// RewriteUserID 重写body中的metadata.user_id // RewriteUserID 重写body中的metadata.user_id
// 输入格式:user_{clientId}_account__session_{sessionUUID} // 输入格式:user_{clientId}_account__session_{sessionUUID}
// 输出格式:user_{cachedClientID}_account_{accountUUID}_session_{newHash} // 输出格式:user_{cachedClientID}_account_{accountUUID}_session_{newHash}
//
// 重要:此函数使用 json.RawMessage 保留其他字段的原始字节,
// 避免重新序列化导致 thinking 块等内容被修改。
func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUID, cachedClientID string) ([]byte, error) { func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUID, cachedClientID string) ([]byte, error) {
if len(body) == 0 || accountUUID == "" || cachedClientID == "" { if len(body) == 0 || accountUUID == "" || cachedClientID == "" {
return body, nil return body, nil
} }
// 解析JSON // 使用 RawMessage 保留其他字段的原始字节
var reqMap map[string]any var reqMap map[string]json.RawMessage
if err := json.Unmarshal(body, &reqMap); err != nil { if err := json.Unmarshal(body, &reqMap); err != nil {
return body, nil return body, nil
} }
metadata, ok := reqMap["metadata"].(map[string]any) // 解析 metadata 字段
metadataRaw, ok := reqMap["metadata"]
if !ok { if !ok {
return body, nil return body, nil
} }
var metadata map[string]any
if err := json.Unmarshal(metadataRaw, &metadata); err != nil {
return body, nil
}
userID, ok := metadata["user_id"].(string) userID, ok := metadata["user_id"].(string)
if !ok || userID == "" { if !ok || userID == "" {
return body, nil return body, nil
...@@ -207,7 +216,13 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI ...@@ -207,7 +216,13 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
newUserID := fmt.Sprintf("user_%s_account_%s_session_%s", cachedClientID, accountUUID, newSessionHash) newUserID := fmt.Sprintf("user_%s_account_%s_session_%s", cachedClientID, accountUUID, newSessionHash)
metadata["user_id"] = newUserID metadata["user_id"] = newUserID
reqMap["metadata"] = metadata
// 只重新序列化 metadata 字段
newMetadataRaw, err := json.Marshal(metadata)
if err != nil {
return body, nil
}
reqMap["metadata"] = newMetadataRaw
return json.Marshal(reqMap) return json.Marshal(reqMap)
} }
...@@ -215,6 +230,9 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI ...@@ -215,6 +230,9 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
// RewriteUserIDWithMasking 重写body中的metadata.user_id,支持会话ID伪装 // RewriteUserIDWithMasking 重写body中的metadata.user_id,支持会话ID伪装
// 如果账号启用了会话ID伪装(session_id_masking_enabled), // 如果账号启用了会话ID伪装(session_id_masking_enabled),
// 则在完成常规重写后,将 session 部分替换为固定的伪装ID(15分钟内保持不变) // 则在完成常规重写后,将 session 部分替换为固定的伪装ID(15分钟内保持不变)
//
// 重要:此函数使用 json.RawMessage 保留其他字段的原始字节,
// 避免重新序列化导致 thinking 块等内容被修改。
func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []byte, account *Account, accountUUID, cachedClientID string) ([]byte, error) { func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []byte, account *Account, accountUUID, cachedClientID string) ([]byte, error) {
// 先执行常规的 RewriteUserID 逻辑 // 先执行常规的 RewriteUserID 逻辑
newBody, err := s.RewriteUserID(body, account.ID, accountUUID, cachedClientID) newBody, err := s.RewriteUserID(body, account.ID, accountUUID, cachedClientID)
...@@ -227,17 +245,23 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b ...@@ -227,17 +245,23 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
return newBody, nil return newBody, nil
} }
// 解析重写后的 body,提取 user_id // 使用 RawMessage 保留其他字段的原始字节
var reqMap map[string]any var reqMap map[string]json.RawMessage
if err := json.Unmarshal(newBody, &reqMap); err != nil { if err := json.Unmarshal(newBody, &reqMap); err != nil {
return newBody, nil return newBody, nil
} }
metadata, ok := reqMap["metadata"].(map[string]any) // 解析 metadata 字段
metadataRaw, ok := reqMap["metadata"]
if !ok { if !ok {
return newBody, nil return newBody, nil
} }
var metadata map[string]any
if err := json.Unmarshal(metadataRaw, &metadata); err != nil {
return newBody, nil
}
userID, ok := metadata["user_id"].(string) userID, ok := metadata["user_id"].(string)
if !ok || userID == "" { if !ok || userID == "" {
return newBody, nil return newBody, nil
...@@ -278,7 +302,13 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b ...@@ -278,7 +302,13 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
) )
metadata["user_id"] = newUserID metadata["user_id"] = newUserID
reqMap["metadata"] = metadata
// 只重新序列化 metadata 字段
newMetadataRaw, marshalErr := json.Marshal(metadata)
if marshalErr != nil {
return newBody, nil
}
reqMap["metadata"] = newMetadataRaw
return json.Marshal(reqMap) return json.Marshal(reqMap)
} }
......
...@@ -72,7 +72,7 @@ type opencodeCacheMetadata struct { ...@@ -72,7 +72,7 @@ type opencodeCacheMetadata struct {
LastChecked int64 `json:"lastChecked"` LastChecked int64 `json:"lastChecked"`
} }
func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult { func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool) codexTransformResult {
result := codexTransformResult{} result := codexTransformResult{}
// 工具续链需求会影响存储策略与 input 过滤逻辑。 // 工具续链需求会影响存储策略与 input 过滤逻辑。
needsToolContinuation := NeedsToolContinuation(reqBody) needsToolContinuation := NeedsToolContinuation(reqBody)
...@@ -118,22 +118,9 @@ func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult { ...@@ -118,22 +118,9 @@ func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult {
result.PromptCacheKey = strings.TrimSpace(v) result.PromptCacheKey = strings.TrimSpace(v)
} }
instructions := strings.TrimSpace(getOpenCodeCodexHeader()) // instructions 处理逻辑:根据是否是 Codex CLI 分别调用不同方法
existingInstructions, _ := reqBody["instructions"].(string) if applyInstructions(reqBody, isCodexCLI) {
existingInstructions = strings.TrimSpace(existingInstructions) result.Modified = true
if instructions != "" {
if existingInstructions != instructions {
reqBody["instructions"] = instructions
result.Modified = true
}
} else if existingInstructions == "" {
// 未获取到 opencode 指令时,回退使用 Codex CLI 指令。
codexInstructions := strings.TrimSpace(getCodexCLIInstructions())
if codexInstructions != "" {
reqBody["instructions"] = codexInstructions
result.Modified = true
}
} }
// 续链场景保留 item_reference 与 id,避免 call_id 上下文丢失。 // 续链场景保留 item_reference 与 id,避免 call_id 上下文丢失。
...@@ -276,6 +263,72 @@ func GetCodexCLIInstructions() string { ...@@ -276,6 +263,72 @@ func GetCodexCLIInstructions() string {
return getCodexCLIInstructions() return getCodexCLIInstructions()
} }
// applyInstructions 处理 instructions 字段
// isCodexCLI=true: 仅补充缺失的 instructions(使用 opencode 指令)
// isCodexCLI=false: 优先使用 opencode 指令覆盖
func applyInstructions(reqBody map[string]any, isCodexCLI bool) bool {
if isCodexCLI {
return applyCodexCLIInstructions(reqBody)
}
return applyOpenCodeInstructions(reqBody)
}
// applyCodexCLIInstructions 为 Codex CLI 请求补充缺失的 instructions
// 仅在 instructions 为空时添加 opencode 指令
func applyCodexCLIInstructions(reqBody map[string]any) bool {
if !isInstructionsEmpty(reqBody) {
return false // 已有有效 instructions,不修改
}
instructions := strings.TrimSpace(getOpenCodeCodexHeader())
if instructions != "" {
reqBody["instructions"] = instructions
return true
}
return false
}
// applyOpenCodeInstructions 为非 Codex CLI 请求应用 opencode 指令
// 优先使用 opencode 指令覆盖
func applyOpenCodeInstructions(reqBody map[string]any) bool {
instructions := strings.TrimSpace(getOpenCodeCodexHeader())
existingInstructions, _ := reqBody["instructions"].(string)
existingInstructions = strings.TrimSpace(existingInstructions)
if instructions != "" {
if existingInstructions != instructions {
reqBody["instructions"] = instructions
return true
}
} else if existingInstructions == "" {
codexInstructions := strings.TrimSpace(getCodexCLIInstructions())
if codexInstructions != "" {
reqBody["instructions"] = codexInstructions
return true
}
}
return false
}
// isInstructionsEmpty 检查 instructions 字段是否为空
// 处理以下情况:字段不存在、nil、空字符串、纯空白字符串
func isInstructionsEmpty(reqBody map[string]any) bool {
val, exists := reqBody["instructions"]
if !exists {
return true
}
if val == nil {
return true
}
str, ok := val.(string)
if !ok {
return true
}
return strings.TrimSpace(str) == ""
}
// ReplaceWithCodexInstructions 将请求 instructions 替换为内置 Codex 指令(必要时)。 // ReplaceWithCodexInstructions 将请求 instructions 替换为内置 Codex 指令(必要时)。
func ReplaceWithCodexInstructions(reqBody map[string]any) bool { func ReplaceWithCodexInstructions(reqBody map[string]any) bool {
codexInstructions := strings.TrimSpace(getCodexCLIInstructions()) codexInstructions := strings.TrimSpace(getCodexCLIInstructions())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment