Commit 7dddd065 authored by yangjianbo's avatar yangjianbo
Browse files
parents 25a0d49a e78c8646
...@@ -139,6 +139,14 @@ func (s *accountRepoStub) SetOverloaded(ctx context.Context, id int64, until tim ...@@ -139,6 +139,14 @@ func (s *accountRepoStub) SetOverloaded(ctx context.Context, id int64, until tim
panic("unexpected SetOverloaded call") panic("unexpected SetOverloaded call")
} }
func (s *accountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
panic("unexpected SetTempUnschedulable call")
}
func (s *accountRepoStub) ClearTempUnschedulable(ctx context.Context, id int64) error {
panic("unexpected ClearTempUnschedulable call")
}
func (s *accountRepoStub) ClearRateLimit(ctx context.Context, id int64) error { func (s *accountRepoStub) ClearRateLimit(ctx context.Context, id int64) error {
panic("unexpected ClearRateLimit call") panic("unexpected ClearRateLimit call")
} }
......
...@@ -398,7 +398,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account ...@@ -398,7 +398,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
} }
// For API Key accounts with model mapping, map the model // For API Key accounts with model mapping, map the model
if account.Type == AccountTypeApiKey { if account.Type == AccountTypeAPIKey {
mapping := account.GetModelMapping() mapping := account.GetModelMapping()
if len(mapping) > 0 { if len(mapping) > 0 {
if mappedModel, exists := mapping[testModelID]; exists { if mappedModel, exists := mapping[testModelID]; exists {
...@@ -422,7 +422,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account ...@@ -422,7 +422,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
var err error var err error
switch account.Type { switch account.Type {
case AccountTypeApiKey: case AccountTypeAPIKey:
req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload) req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload)
case AccountTypeOAuth: case AccountTypeOAuth:
req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload) req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload)
......
...@@ -12,16 +12,18 @@ import ( ...@@ -12,16 +12,18 @@ import (
) )
type UsageLogRepository interface { type UsageLogRepository interface {
Create(ctx context.Context, log *UsageLog) error // Create creates a usage log and returns whether it was actually inserted.
// inserted is false when the insert was skipped due to conflict (idempotent retries).
Create(ctx context.Context, log *UsageLog) (inserted bool, err error)
GetByID(ctx context.Context, id int64) (*UsageLog, error) GetByID(ctx context.Context, id int64) (*UsageLog, error)
Delete(ctx context.Context, id int64) error Delete(ctx context.Context, id int64) error
ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
...@@ -32,10 +34,10 @@ type UsageLogRepository interface { ...@@ -32,10 +34,10 @@ type UsageLogRepository interface {
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error)
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error)
GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error)
// User dashboard stats // User dashboard stats
GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error) GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error)
...@@ -51,7 +53,7 @@ type UsageLogRepository interface { ...@@ -51,7 +53,7 @@ type UsageLogRepository interface {
// Aggregated stats (optimized) // Aggregated stats (optimized)
GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error)
...@@ -105,6 +107,8 @@ type UsageProgress struct { ...@@ -105,6 +107,8 @@ type UsageProgress struct {
ResetsAt *time.Time `json:"resets_at"` // 重置时间 ResetsAt *time.Time `json:"resets_at"` // 重置时间
RemainingSeconds int `json:"remaining_seconds"` // 距重置剩余秒数 RemainingSeconds int `json:"remaining_seconds"` // 距重置剩余秒数
WindowStats *WindowStats `json:"window_stats,omitempty"` // 窗口期统计(从窗口开始到当前的使用量) WindowStats *WindowStats `json:"window_stats,omitempty"` // 窗口期统计(从窗口开始到当前的使用量)
UsedRequests int64 `json:"used_requests,omitempty"`
LimitRequests int64 `json:"limit_requests,omitempty"`
} }
// AntigravityModelQuota Antigravity 单个模型的配额信息 // AntigravityModelQuota Antigravity 单个模型的配额信息
...@@ -115,12 +119,16 @@ type AntigravityModelQuota struct { ...@@ -115,12 +119,16 @@ type AntigravityModelQuota struct {
// UsageInfo 账号使用量信息 // UsageInfo 账号使用量信息
type UsageInfo struct { type UsageInfo struct {
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间 UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口 FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口
SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口 SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口
SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口 SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口
GeminiProDaily *UsageProgress `json:"gemini_pro_daily,omitempty"` // Gemini Pro 日配额 GeminiSharedDaily *UsageProgress `json:"gemini_shared_daily,omitempty"` // Gemini shared pool RPD (Google One / Code Assist)
GeminiFlashDaily *UsageProgress `json:"gemini_flash_daily,omitempty"` // Gemini Flash 日配额 GeminiProDaily *UsageProgress `json:"gemini_pro_daily,omitempty"` // Gemini Pro 日配额
GeminiFlashDaily *UsageProgress `json:"gemini_flash_daily,omitempty"` // Gemini Flash 日配额
GeminiSharedMinute *UsageProgress `json:"gemini_shared_minute,omitempty"` // Gemini shared pool RPM (Google One / Code Assist)
GeminiProMinute *UsageProgress `json:"gemini_pro_minute,omitempty"` // Gemini Pro RPM
GeminiFlashMinute *UsageProgress `json:"gemini_flash_minute,omitempty"` // Gemini Flash RPM
// Antigravity 多模型配额 // Antigravity 多模型配额
AntigravityQuota map[string]*AntigravityModelQuota `json:"antigravity_quota,omitempty"` AntigravityQuota map[string]*AntigravityModelQuota `json:"antigravity_quota,omitempty"`
...@@ -256,17 +264,44 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou ...@@ -256,17 +264,44 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
return usage, nil return usage, nil
} }
start := geminiDailyWindowStart(now) dayStart := geminiDailyWindowStart(now)
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID) stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID)
if err != nil { if err != nil {
return nil, fmt.Errorf("get gemini usage stats failed: %w", err) return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
} }
totals := geminiAggregateUsage(stats) dayTotals := geminiAggregateUsage(stats)
resetAt := geminiDailyResetTime(now) dailyResetAt := geminiDailyResetTime(now)
usage.GeminiProDaily = buildGeminiUsageProgress(totals.ProRequests, quota.ProRPD, resetAt, totals.ProTokens, totals.ProCost, now) // Daily window (RPD)
usage.GeminiFlashDaily = buildGeminiUsageProgress(totals.FlashRequests, quota.FlashRPD, resetAt, totals.FlashTokens, totals.FlashCost, now) if quota.SharedRPD > 0 {
totalReq := dayTotals.ProRequests + dayTotals.FlashRequests
totalTokens := dayTotals.ProTokens + dayTotals.FlashTokens
totalCost := dayTotals.ProCost + dayTotals.FlashCost
usage.GeminiSharedDaily = buildGeminiUsageProgress(totalReq, quota.SharedRPD, dailyResetAt, totalTokens, totalCost, now)
} else {
usage.GeminiProDaily = buildGeminiUsageProgress(dayTotals.ProRequests, quota.ProRPD, dailyResetAt, dayTotals.ProTokens, dayTotals.ProCost, now)
usage.GeminiFlashDaily = buildGeminiUsageProgress(dayTotals.FlashRequests, quota.FlashRPD, dailyResetAt, dayTotals.FlashTokens, dayTotals.FlashCost, now)
}
// Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m)
minuteStart := now.Truncate(time.Minute)
minuteResetAt := minuteStart.Add(time.Minute)
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID)
if err != nil {
return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err)
}
minuteTotals := geminiAggregateUsage(minuteStats)
if quota.SharedRPM > 0 {
totalReq := minuteTotals.ProRequests + minuteTotals.FlashRequests
totalTokens := minuteTotals.ProTokens + minuteTotals.FlashTokens
totalCost := minuteTotals.ProCost + minuteTotals.FlashCost
usage.GeminiSharedMinute = buildGeminiUsageProgress(totalReq, quota.SharedRPM, minuteResetAt, totalTokens, totalCost, now)
} else {
usage.GeminiProMinute = buildGeminiUsageProgress(minuteTotals.ProRequests, quota.ProRPM, minuteResetAt, minuteTotals.ProTokens, minuteTotals.ProCost, now)
usage.GeminiFlashMinute = buildGeminiUsageProgress(minuteTotals.FlashRequests, quota.FlashRPM, minuteResetAt, minuteTotals.FlashTokens, minuteTotals.FlashCost, now)
}
return usage, nil return usage, nil
} }
...@@ -506,6 +541,7 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn ...@@ -506,6 +541,7 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn
} }
func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64, cost float64, now time.Time) *UsageProgress { func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64, cost float64, now time.Time) *UsageProgress {
// limit <= 0 means "no local quota window" (unknown or unlimited).
if limit <= 0 { if limit <= 0 {
return nil return nil
} }
...@@ -519,6 +555,8 @@ func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64 ...@@ -519,6 +555,8 @@ func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64
Utilization: utilization, Utilization: utilization,
ResetsAt: &resetCopy, ResetsAt: &resetCopy,
RemainingSeconds: remainingSeconds, RemainingSeconds: remainingSeconds,
UsedRequests: used,
LimitRequests: limit,
WindowStats: &WindowStats{ WindowStats: &WindowStats{
Requests: used, Requests: used,
Tokens: tokens, Tokens: tokens,
......
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"log" "log"
"strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
...@@ -19,7 +20,7 @@ type AdminService interface { ...@@ -19,7 +20,7 @@ type AdminService interface {
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error)
DeleteUser(ctx context.Context, id int64) error DeleteUser(ctx context.Context, id int64) error
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]ApiKey, int64, error) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error)
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
// Group management // Group management
...@@ -30,7 +31,7 @@ type AdminService interface { ...@@ -30,7 +31,7 @@ type AdminService interface {
CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error)
UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error)
DeleteGroup(ctx context.Context, id int64) error DeleteGroup(ctx context.Context, id int64) error
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]ApiKey, int64, error) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error)
// Account management // Account management
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error)
...@@ -65,7 +66,7 @@ type AdminService interface { ...@@ -65,7 +66,7 @@ type AdminService interface {
ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error)
} }
// Input types for admin operations // CreateUserInput represents input for creating a new user via admin operations.
type CreateUserInput struct { type CreateUserInput struct {
Email string Email string
Password string Password string
...@@ -122,18 +123,22 @@ type CreateAccountInput struct { ...@@ -122,18 +123,22 @@ type CreateAccountInput struct {
Concurrency int Concurrency int
Priority int Priority int
GroupIDs []int64 GroupIDs []int64
// SkipMixedChannelCheck skips the mixed channel risk check when binding groups.
// This should only be set when the caller has explicitly confirmed the risk.
SkipMixedChannelCheck bool
} }
type UpdateAccountInput struct { type UpdateAccountInput struct {
Name string Name string
Type string // Account type: oauth, setup-token, apikey Type string // Account type: oauth, setup-token, apikey
Credentials map[string]any Credentials map[string]any
Extra map[string]any Extra map[string]any
ProxyID *int64 ProxyID *int64
Concurrency *int // 使用指针区分"未提供"和"设置为0" Concurrency *int // 使用指针区分"未提供"和"设置为0"
Priority *int // 使用指针区分"未提供"和"设置为0" Priority *int // 使用指针区分"未提供"和"设置为0"
Status string Status string
GroupIDs *[]int64 GroupIDs *[]int64
SkipMixedChannelCheck bool // 跳过混合渠道检查(用户已确认风险)
} }
// BulkUpdateAccountsInput describes the payload for bulk updating accounts. // BulkUpdateAccountsInput describes the payload for bulk updating accounts.
...@@ -147,6 +152,9 @@ type BulkUpdateAccountsInput struct { ...@@ -147,6 +152,9 @@ type BulkUpdateAccountsInput struct {
GroupIDs *[]int64 GroupIDs *[]int64
Credentials map[string]any Credentials map[string]any
Extra map[string]any Extra map[string]any
// SkipMixedChannelCheck skips the mixed channel risk check when binding groups.
// This should only be set when the caller has explicitly confirmed the risk.
SkipMixedChannelCheck bool
} }
// BulkUpdateAccountResult captures the result for a single account update. // BulkUpdateAccountResult captures the result for a single account update.
...@@ -220,7 +228,7 @@ type adminServiceImpl struct { ...@@ -220,7 +228,7 @@ type adminServiceImpl struct {
groupRepo GroupRepository groupRepo GroupRepository
accountRepo AccountRepository accountRepo AccountRepository
proxyRepo ProxyRepository proxyRepo ProxyRepository
apiKeyRepo ApiKeyRepository apiKeyRepo APIKeyRepository
redeemCodeRepo RedeemCodeRepository redeemCodeRepo RedeemCodeRepository
billingCacheService *BillingCacheService billingCacheService *BillingCacheService
proxyProber ProxyExitInfoProber proxyProber ProxyExitInfoProber
...@@ -232,7 +240,7 @@ func NewAdminService( ...@@ -232,7 +240,7 @@ func NewAdminService(
groupRepo GroupRepository, groupRepo GroupRepository,
accountRepo AccountRepository, accountRepo AccountRepository,
proxyRepo ProxyRepository, proxyRepo ProxyRepository,
apiKeyRepo ApiKeyRepository, apiKeyRepo APIKeyRepository,
redeemCodeRepo RedeemCodeRepository, redeemCodeRepo RedeemCodeRepository,
billingCacheService *BillingCacheService, billingCacheService *BillingCacheService,
proxyProber ProxyExitInfoProber, proxyProber ProxyExitInfoProber,
...@@ -430,7 +438,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, ...@@ -430,7 +438,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
return user, nil return user, nil
} }
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]ApiKey, int64, error) { func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params) keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil { if err != nil {
...@@ -583,7 +591,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error { ...@@ -583,7 +591,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
return nil return nil
} }
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]ApiKey, int64, error) { func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params) keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params)
if err != nil { if err != nil {
...@@ -620,6 +628,29 @@ func (s *adminServiceImpl) GetAccountsByIDs(ctx context.Context, ids []int64) ([ ...@@ -620,6 +628,29 @@ func (s *adminServiceImpl) GetAccountsByIDs(ctx context.Context, ids []int64) ([
} }
func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) { func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) {
// 绑定分组
groupIDs := input.GroupIDs
// 如果没有指定分组,自动绑定对应平台的默认分组
if len(groupIDs) == 0 {
defaultGroupName := input.Platform + "-default"
groups, err := s.groupRepo.ListActiveByPlatform(ctx, input.Platform)
if err == nil {
for _, g := range groups {
if g.Name == defaultGroupName {
groupIDs = []int64{g.ID}
break
}
}
}
}
// 检查混合渠道风险(除非用户已确认)
if len(groupIDs) > 0 && !input.SkipMixedChannelCheck {
if err := s.checkMixedChannelRisk(ctx, 0, input.Platform, groupIDs); err != nil {
return nil, err
}
}
account := &Account{ account := &Account{
Name: input.Name, Name: input.Name,
Platform: input.Platform, Platform: input.Platform,
...@@ -637,22 +668,6 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou ...@@ -637,22 +668,6 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
} }
// 绑定分组 // 绑定分组
groupIDs := input.GroupIDs
// 如果没有指定分组,自动绑定对应平台的默认分组
if len(groupIDs) == 0 {
defaultGroupName := input.Platform + "-default"
groups, err := s.groupRepo.ListActiveByPlatform(ctx, input.Platform)
if err == nil {
for _, g := range groups {
if g.Name == defaultGroupName {
groupIDs = []int64{g.ID}
log.Printf("[CreateAccount] Auto-binding account %d to default group %s (ID: %d)", account.ID, defaultGroupName, g.ID)
break
}
}
}
}
if len(groupIDs) > 0 { if len(groupIDs) > 0 {
if err := s.accountRepo.BindGroups(ctx, account.ID, groupIDs); err != nil { if err := s.accountRepo.BindGroups(ctx, account.ID, groupIDs); err != nil {
return nil, err return nil, err
...@@ -703,6 +718,13 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U ...@@ -703,6 +718,13 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
return nil, fmt.Errorf("get group: %w", err) return nil, fmt.Errorf("get group: %w", err)
} }
} }
// 检查混合渠道风险(除非用户已确认)
if !input.SkipMixedChannelCheck {
if err := s.checkMixedChannelRisk(ctx, account.ID, account.Platform, *input.GroupIDs); err != nil {
return nil, err
}
}
} }
if err := s.accountRepo.Update(ctx, account); err != nil { if err := s.accountRepo.Update(ctx, account); err != nil {
...@@ -731,6 +753,20 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp ...@@ -731,6 +753,20 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
return result, nil return result, nil
} }
// Preload account platforms for mixed channel risk checks if group bindings are requested.
platformByID := map[int64]string{}
if input.GroupIDs != nil && !input.SkipMixedChannelCheck {
accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs)
if err != nil {
return nil, err
}
for _, account := range accounts {
if account != nil {
platformByID[account.ID] = account.Platform
}
}
}
// Prepare bulk updates for columns and JSONB fields. // Prepare bulk updates for columns and JSONB fields.
repoUpdates := AccountBulkUpdate{ repoUpdates := AccountBulkUpdate{
Credentials: input.Credentials, Credentials: input.Credentials,
...@@ -762,6 +798,29 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp ...@@ -762,6 +798,29 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
entry := BulkUpdateAccountResult{AccountID: accountID} entry := BulkUpdateAccountResult{AccountID: accountID}
if input.GroupIDs != nil { if input.GroupIDs != nil {
// 检查混合渠道风险(除非用户已确认)
if !input.SkipMixedChannelCheck {
platform := platformByID[accountID]
if platform == "" {
account, err := s.accountRepo.GetByID(ctx, accountID)
if err != nil {
entry.Success = false
entry.Error = err.Error()
result.Failed++
result.Results = append(result.Results, entry)
continue
}
platform = account.Platform
}
if err := s.checkMixedChannelRisk(ctx, accountID, platform, *input.GroupIDs); err != nil {
entry.Success = false
entry.Error = err.Error()
result.Failed++
result.Results = append(result.Results, entry)
continue
}
}
if err := s.accountRepo.BindGroups(ctx, accountID, *input.GroupIDs); err != nil { if err := s.accountRepo.BindGroups(ctx, accountID, *input.GroupIDs); err != nil {
entry.Success = false entry.Success = false
entry.Error = err.Error() entry.Error = err.Error()
...@@ -1006,3 +1065,77 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR ...@@ -1006,3 +1065,77 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR
Country: exitInfo.Country, Country: exitInfo.Country,
}, nil }, nil
} }
// checkMixedChannelRisk 检查分组中是否存在混合渠道(Antigravity + Anthropic)
// 如果存在混合,返回错误提示用户确认
func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
// 判断当前账号的渠道类型(基于 platform 字段,而不是 type 字段)
currentPlatform := getAccountPlatform(currentAccountPlatform)
if currentPlatform == "" {
// 不是 Antigravity 或 Anthropic,无需检查
return nil
}
// 检查每个分组中的其他账号
for _, groupID := range groupIDs {
accounts, err := s.accountRepo.ListByGroup(ctx, groupID)
if err != nil {
return fmt.Errorf("get accounts in group %d: %w", groupID, err)
}
// 检查是否存在不同渠道的账号
for _, account := range accounts {
if currentAccountID > 0 && account.ID == currentAccountID {
continue // 跳过当前账号
}
otherPlatform := getAccountPlatform(account.Platform)
if otherPlatform == "" {
continue // 不是 Antigravity 或 Anthropic,跳过
}
// 检测混合渠道
if currentPlatform != otherPlatform {
group, _ := s.groupRepo.GetByID(ctx, groupID)
groupName := fmt.Sprintf("Group %d", groupID)
if group != nil {
groupName = group.Name
}
return &MixedChannelError{
GroupID: groupID,
GroupName: groupName,
CurrentPlatform: currentPlatform,
OtherPlatform: otherPlatform,
}
}
}
}
return nil
}
// getAccountPlatform 根据账号 platform 判断混合渠道检查用的平台标识
func getAccountPlatform(accountPlatform string) string {
switch strings.ToLower(strings.TrimSpace(accountPlatform)) {
case PlatformAntigravity:
return "Antigravity"
case PlatformAnthropic, "claude":
return "Anthropic"
default:
return ""
}
}
// MixedChannelError 混合渠道错误
type MixedChannelError struct {
GroupID int64
GroupName string
CurrentPlatform string
OtherPlatform string
}
func (e *MixedChannelError) Error() string {
return fmt.Sprintf("mixed_channel_warning: Group '%s' contains both %s and %s accounts. Using mixed channels in the same context may cause thinking block signature validation issues, which will fallback to non-thinking mode for historical messages.",
e.GroupName, e.CurrentPlatform, e.OtherPlatform)
}
...@@ -14,7 +14,6 @@ import ( ...@@ -14,7 +14,6 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
...@@ -83,7 +82,7 @@ type AntigravityGatewayService struct { ...@@ -83,7 +82,7 @@ type AntigravityGatewayService struct {
tokenProvider *AntigravityTokenProvider tokenProvider *AntigravityTokenProvider
rateLimitService *RateLimitService rateLimitService *RateLimitService
httpUpstream HTTPUpstream httpUpstream HTTPUpstream
cfg *config.Config settingService *SettingService
} }
func NewAntigravityGatewayService( func NewAntigravityGatewayService(
...@@ -92,14 +91,14 @@ func NewAntigravityGatewayService( ...@@ -92,14 +91,14 @@ func NewAntigravityGatewayService(
tokenProvider *AntigravityTokenProvider, tokenProvider *AntigravityTokenProvider,
rateLimitService *RateLimitService, rateLimitService *RateLimitService,
httpUpstream HTTPUpstream, httpUpstream HTTPUpstream,
cfg *config.Config, settingService *SettingService,
) *AntigravityGatewayService { ) *AntigravityGatewayService {
return &AntigravityGatewayService{ return &AntigravityGatewayService{
accountRepo: accountRepo, accountRepo: accountRepo,
tokenProvider: tokenProvider, tokenProvider: tokenProvider,
rateLimitService: rateLimitService, rateLimitService: rateLimitService,
httpUpstream: httpUpstream, httpUpstream: httpUpstream,
cfg: cfg, settingService: settingService,
} }
} }
...@@ -329,6 +328,22 @@ func (s *AntigravityGatewayService) unwrapV1InternalResponse(body []byte) ([]byt ...@@ -329,6 +328,22 @@ func (s *AntigravityGatewayService) unwrapV1InternalResponse(body []byte) ([]byt
return body, nil return body, nil
} }
// isModelNotFoundError 检测是否为模型不存在的 404 错误
func isModelNotFoundError(statusCode int, body []byte) bool {
if statusCode != 404 {
return false
}
bodyStr := strings.ToLower(string(body))
keywords := []string{"model not found", "unknown model", "not found"}
for _, keyword := range keywords {
if strings.Contains(bodyStr, keyword) {
return true
}
}
return true // 404 without specific message also treated as model not found
}
// Forward 转发 Claude 协议请求(Claude → Gemini 转换) // Forward 转发 Claude 协议请求(Claude → Gemini 转换)
func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) { func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
startTime := time.Now() startTime := time.Now()
...@@ -422,16 +437,56 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -422,16 +437,56 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
} }
defer func() { _ = resp.Body.Close() }() defer func() { _ = resp.Body.Close() }()
// 处理错误响应
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
if s.shouldFailoverUpstreamError(resp.StatusCode) { // 优先检测 thinking block 的 signature 相关错误(400)并重试一次:
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} // Antigravity /v1internal 链路在部分场景会对 thought/thinking signature 做严格校验,
// 当历史消息携带的 signature 不合法时会直接 400;去除 thinking 后可继续完成请求。
if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) {
retryClaudeReq := claudeReq
retryClaudeReq.Messages = append([]antigravity.ClaudeMessage(nil), claudeReq.Messages...)
stripped, stripErr := stripThinkingFromClaudeRequest(&retryClaudeReq)
if stripErr == nil && stripped {
log.Printf("Antigravity account %d: detected signature-related 400, retrying once without thinking blocks", account.ID)
retryGeminiBody, txErr := antigravity.TransformClaudeToGemini(&retryClaudeReq, projectID, mappedModel)
if txErr == nil {
retryReq, buildErr := antigravity.NewAPIRequest(ctx, action, accessToken, retryGeminiBody)
if buildErr == nil {
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
if retryErr == nil {
// Retry success: continue normal success flow with the new response.
if retryResp.StatusCode < 400 {
_ = resp.Body.Close()
resp = retryResp
respBody = nil
} else {
// Retry still errored: replace error context with retry response.
retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
_ = retryResp.Body.Close()
respBody = retryBody
resp = retryResp
}
} else {
log.Printf("Antigravity account %d: signature retry request failed: %v", account.ID, retryErr)
}
}
}
}
} }
return nil, s.writeMappedClaudeError(c, resp.StatusCode, respBody) // 处理错误响应(重试后仍失败或不触发重试)
if resp.StatusCode >= 400 {
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
if s.shouldFailoverUpstreamError(resp.StatusCode) {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
return nil, s.writeMappedClaudeError(c, resp.StatusCode, respBody)
}
} }
requestID := resp.Header.Get("x-request-id") requestID := resp.Header.Get("x-request-id")
...@@ -466,6 +521,122 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -466,6 +521,122 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
}, nil }, nil
} }
func isSignatureRelatedError(respBody []byte) bool {
msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
if msg == "" {
// Fallback: best-effort scan of the raw payload.
msg = strings.ToLower(string(respBody))
}
// Keep this intentionally broad: different upstreams may use "signature" or "thought_signature".
return strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature")
}
func extractAntigravityErrorMessage(body []byte) string {
var payload map[string]any
if err := json.Unmarshal(body, &payload); err != nil {
return ""
}
// Google-style: {"error": {"message": "..."}}
if errObj, ok := payload["error"].(map[string]any); ok {
if msg, ok := errObj["message"].(string); ok && strings.TrimSpace(msg) != "" {
return msg
}
}
// Fallback: top-level message
if msg, ok := payload["message"].(string); ok && strings.TrimSpace(msg) != "" {
return msg
}
return ""
}
// stripThinkingFromClaudeRequest converts thinking blocks to text blocks in a Claude Messages request.
// This preserves the thinking content while avoiding signature validation errors.
// Note: redacted_thinking blocks are removed because they cannot be converted to text.
// It also disables top-level `thinking` to prevent dummy-thought injection during retry.
func stripThinkingFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error) {
if req == nil {
return false, nil
}
changed := false
if req.Thinking != nil {
req.Thinking = nil
changed = true
}
for i := range req.Messages {
raw := req.Messages[i].Content
if len(raw) == 0 {
continue
}
// If content is a string, nothing to strip.
var str string
if json.Unmarshal(raw, &str) == nil {
continue
}
// Otherwise treat as an array of blocks and convert thinking blocks to text.
var blocks []map[string]any
if err := json.Unmarshal(raw, &blocks); err != nil {
continue
}
filtered := make([]map[string]any, 0, len(blocks))
modifiedAny := false
for _, block := range blocks {
t, _ := block["type"].(string)
switch t {
case "thinking":
// Convert thinking to text, skip if empty
thinkingText, _ := block["thinking"].(string)
if thinkingText != "" {
filtered = append(filtered, map[string]any{
"type": "text",
"text": thinkingText,
})
}
modifiedAny = true
case "redacted_thinking":
// Remove redacted_thinking (cannot convert encrypted content)
modifiedAny = true
case "":
// Handle untyped block with "thinking" field
if thinkingText, hasThinking := block["thinking"].(string); hasThinking {
if thinkingText != "" {
filtered = append(filtered, map[string]any{
"type": "text",
"text": thinkingText,
})
}
modifiedAny = true
} else {
filtered = append(filtered, block)
}
default:
filtered = append(filtered, block)
}
}
if !modifiedAny {
continue
}
newRaw, err := json.Marshal(filtered)
if err != nil {
return changed, err
}
req.Messages[i].Content = newRaw
changed = true
}
return changed, nil
}
// ForwardGemini 转发 Gemini 协议请求 // ForwardGemini 转发 Gemini 协议请求
func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) { func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) {
startTime := time.Now() startTime := time.Now()
...@@ -579,14 +750,40 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co ...@@ -579,14 +750,40 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
} }
defer func() { _ = resp.Body.Close() }() defer func() { _ = resp.Body.Close() }()
requestID := resp.Header.Get("x-request-id")
if requestID != "" {
c.Header("x-request-id", requestID)
}
// 处理错误响应 // 处理错误响应
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
// 模型兜底:模型不存在且开启 fallback 时,自动用 fallback 模型重试一次
if s.settingService != nil && s.settingService.IsModelFallbackEnabled(ctx) &&
isModelNotFoundError(resp.StatusCode, respBody) {
fallbackModel := s.settingService.GetFallbackModel(ctx, PlatformAntigravity)
if fallbackModel != "" && fallbackModel != mappedModel {
log.Printf("[Antigravity] Model not found (%s), retrying with fallback model %s (account: %s)", mappedModel, fallbackModel, account.Name)
// 关闭原始响应,释放连接(respBody 已读取到内存)
_ = resp.Body.Close()
fallbackWrapped, err := s.wrapV1InternalRequest(projectID, fallbackModel, body)
if err == nil {
fallbackReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, fallbackWrapped)
if err == nil {
fallbackResp, err := s.httpUpstream.Do(fallbackReq, proxyURL, account.ID, account.Concurrency)
if err == nil && fallbackResp.StatusCode < 400 {
resp = fallbackResp
} else if fallbackResp != nil {
_ = fallbackResp.Body.Close()
}
}
}
}
}
// fallback 成功:继续按正常响应处理
if resp.StatusCode < 400 {
goto handleSuccess
}
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody) s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
if s.shouldFailoverUpstreamError(resp.StatusCode) { if s.shouldFailoverUpstreamError(resp.StatusCode) {
...@@ -594,6 +791,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co ...@@ -594,6 +791,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
} }
// 解包并返回错误 // 解包并返回错误
requestID := resp.Header.Get("x-request-id")
if requestID != "" {
c.Header("x-request-id", requestID)
}
unwrapped, _ := s.unwrapV1InternalResponse(respBody) unwrapped, _ := s.unwrapV1InternalResponse(respBody)
contentType := resp.Header.Get("Content-Type") contentType := resp.Header.Get("Content-Type")
if contentType == "" { if contentType == "" {
...@@ -603,6 +804,12 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co ...@@ -603,6 +804,12 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode) return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode)
} }
handleSuccess:
requestID := resp.Header.Get("x-request-id")
if requestID != "" {
c.Header("x-request-id", requestID)
}
var usage *ClaudeUsage var usage *ClaudeUsage
var firstTokenMs *int var firstTokenMs *int
...@@ -713,8 +920,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context ...@@ -713,8 +920,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
// 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM // 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
} }
scanner.Buffer(make([]byte, 64*1024), maxLineSize) scanner.Buffer(make([]byte, 64*1024), maxLineSize)
usage := &ClaudeUsage{} usage := &ClaudeUsage{}
...@@ -753,8 +960,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context ...@@ -753,8 +960,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
// 上游数据间隔超时保护(防止上游挂起长期占用连接) // 上游数据间隔超时保护(防止上游挂起长期占用连接)
streamInterval := time.Duration(0) streamInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
} }
var intervalTicker *time.Ticker var intervalTicker *time.Ticker
if streamInterval > 0 { if streamInterval > 0 {
...@@ -990,8 +1197,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context ...@@ -990,8 +1197,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
// 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM // 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
} }
scanner.Buffer(make([]byte, 64*1024), maxLineSize) scanner.Buffer(make([]byte, 64*1024), maxLineSize)
...@@ -1040,8 +1247,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context ...@@ -1040,8 +1247,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
defer close(done) defer close(done)
streamInterval := time.Duration(0) streamInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
} }
var intervalTicker *time.Ticker var intervalTicker *time.Ticker
if streamInterval > 0 { if streamInterval > 0 {
......
...@@ -2,7 +2,7 @@ package service ...@@ -2,7 +2,7 @@ package service
import "time" import "time"
type ApiKey struct { type APIKey struct {
ID int64 ID int64
UserID int64 UserID int64
Key string Key string
...@@ -15,6 +15,6 @@ type ApiKey struct { ...@@ -15,6 +15,6 @@ type ApiKey struct {
Group *Group Group *Group
} }
func (k *ApiKey) IsActive() bool { func (k *APIKey) IsActive() bool {
return k.Status == StatusActive return k.Status == StatusActive
} }
...@@ -14,39 +14,39 @@ import ( ...@@ -14,39 +14,39 @@ import (
) )
var ( var (
ErrApiKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found") ErrAPIKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found")
ErrGroupNotAllowed = infraerrors.Forbidden("GROUP_NOT_ALLOWED", "user is not allowed to bind this group") ErrGroupNotAllowed = infraerrors.Forbidden("GROUP_NOT_ALLOWED", "user is not allowed to bind this group")
ErrApiKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists") ErrAPIKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists")
ErrApiKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters") ErrAPIKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
ErrApiKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens") ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
ErrApiKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later") ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
) )
const ( const (
apiKeyMaxErrorsPerHour = 20 apiKeyMaxErrorsPerHour = 20
) )
type ApiKeyRepository interface { type APIKeyRepository interface {
Create(ctx context.Context, key *ApiKey) error Create(ctx context.Context, key *APIKey) error
GetByID(ctx context.Context, id int64) (*ApiKey, error) GetByID(ctx context.Context, id int64) (*APIKey, error)
// GetOwnerID 仅获取 API Key 的所有者 ID,用于删除前的轻量级权限验证 // GetOwnerID 仅获取 API Key 的所有者 ID,用于删除前的轻量级权限验证
GetOwnerID(ctx context.Context, id int64) (int64, error) GetOwnerID(ctx context.Context, id int64) (int64, error)
GetByKey(ctx context.Context, key string) (*ApiKey, error) GetByKey(ctx context.Context, key string) (*APIKey, error)
Update(ctx context.Context, key *ApiKey) error Update(ctx context.Context, key *APIKey) error
Delete(ctx context.Context, id int64) error Delete(ctx context.Context, id int64) error
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error)
VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error)
CountByUserID(ctx context.Context, userID int64) (int64, error) CountByUserID(ctx context.Context, userID int64) (int64, error)
ExistsByKey(ctx context.Context, key string) (bool, error) ExistsByKey(ctx context.Context, key string) (bool, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error)
SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error)
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
CountByGroupID(ctx context.Context, groupID int64) (int64, error) CountByGroupID(ctx context.Context, groupID int64) (int64, error)
} }
// ApiKeyCache defines cache operations for API key service // APIKeyCache defines cache operations for API key service
type ApiKeyCache interface { type APIKeyCache interface {
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
IncrementCreateAttemptCount(ctx context.Context, userID int64) error IncrementCreateAttemptCount(ctx context.Context, userID int64) error
DeleteCreateAttemptCount(ctx context.Context, userID int64) error DeleteCreateAttemptCount(ctx context.Context, userID int64) error
...@@ -55,40 +55,40 @@ type ApiKeyCache interface { ...@@ -55,40 +55,40 @@ type ApiKeyCache interface {
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
} }
// CreateApiKeyRequest 创建API Key请求 // CreateAPIKeyRequest 创建API Key请求
type CreateApiKeyRequest struct { type CreateAPIKeyRequest struct {
Name string `json:"name"` Name string `json:"name"`
GroupID *int64 `json:"group_id"` GroupID *int64 `json:"group_id"`
CustomKey *string `json:"custom_key"` // 可选的自定义key CustomKey *string `json:"custom_key"` // 可选的自定义key
} }
// UpdateApiKeyRequest 更新API Key请求 // UpdateAPIKeyRequest 更新API Key请求
type UpdateApiKeyRequest struct { type UpdateAPIKeyRequest struct {
Name *string `json:"name"` Name *string `json:"name"`
GroupID *int64 `json:"group_id"` GroupID *int64 `json:"group_id"`
Status *string `json:"status"` Status *string `json:"status"`
} }
// ApiKeyService API Key服务 // APIKeyService API Key服务
type ApiKeyService struct { type APIKeyService struct {
apiKeyRepo ApiKeyRepository apiKeyRepo APIKeyRepository
userRepo UserRepository userRepo UserRepository
groupRepo GroupRepository groupRepo GroupRepository
userSubRepo UserSubscriptionRepository userSubRepo UserSubscriptionRepository
cache ApiKeyCache cache APIKeyCache
cfg *config.Config cfg *config.Config
} }
// NewApiKeyService 创建API Key服务实例 // NewAPIKeyService 创建API Key服务实例
func NewApiKeyService( func NewAPIKeyService(
apiKeyRepo ApiKeyRepository, apiKeyRepo APIKeyRepository,
userRepo UserRepository, userRepo UserRepository,
groupRepo GroupRepository, groupRepo GroupRepository,
userSubRepo UserSubscriptionRepository, userSubRepo UserSubscriptionRepository,
cache ApiKeyCache, cache APIKeyCache,
cfg *config.Config, cfg *config.Config,
) *ApiKeyService { ) *APIKeyService {
return &ApiKeyService{ return &APIKeyService{
apiKeyRepo: apiKeyRepo, apiKeyRepo: apiKeyRepo,
userRepo: userRepo, userRepo: userRepo,
groupRepo: groupRepo, groupRepo: groupRepo,
...@@ -99,7 +99,7 @@ func NewApiKeyService( ...@@ -99,7 +99,7 @@ func NewApiKeyService(
} }
// GenerateKey 生成随机API Key // GenerateKey 生成随机API Key
func (s *ApiKeyService) GenerateKey() (string, error) { func (s *APIKeyService) GenerateKey() (string, error) {
// 生成32字节随机数据 // 生成32字节随机数据
bytes := make([]byte, 32) bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil { if _, err := rand.Read(bytes); err != nil {
...@@ -107,7 +107,7 @@ func (s *ApiKeyService) GenerateKey() (string, error) { ...@@ -107,7 +107,7 @@ func (s *ApiKeyService) GenerateKey() (string, error) {
} }
// 转换为十六进制字符串并添加前缀 // 转换为十六进制字符串并添加前缀
prefix := s.cfg.Default.ApiKeyPrefix prefix := s.cfg.Default.APIKeyPrefix
if prefix == "" { if prefix == "" {
prefix = "sk-" prefix = "sk-"
} }
...@@ -117,10 +117,10 @@ func (s *ApiKeyService) GenerateKey() (string, error) { ...@@ -117,10 +117,10 @@ func (s *ApiKeyService) GenerateKey() (string, error) {
} }
// ValidateCustomKey 验证自定义API Key格式 // ValidateCustomKey 验证自定义API Key格式
func (s *ApiKeyService) ValidateCustomKey(key string) error { func (s *APIKeyService) ValidateCustomKey(key string) error {
// 检查长度 // 检查长度
if len(key) < 16 { if len(key) < 16 {
return ErrApiKeyTooShort return ErrAPIKeyTooShort
} }
// 检查字符:只允许字母、数字、下划线、连字符 // 检查字符:只允许字母、数字、下划线、连字符
...@@ -131,14 +131,14 @@ func (s *ApiKeyService) ValidateCustomKey(key string) error { ...@@ -131,14 +131,14 @@ func (s *ApiKeyService) ValidateCustomKey(key string) error {
c == '_' || c == '-' { c == '_' || c == '-' {
continue continue
} }
return ErrApiKeyInvalidChars return ErrAPIKeyInvalidChars
} }
return nil return nil
} }
// checkApiKeyRateLimit 检查用户创建自定义Key的错误次数是否超限 // checkAPIKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) error { func (s *APIKeyService) checkAPIKeyRateLimit(ctx context.Context, userID int64) error {
if s.cache == nil { if s.cache == nil {
return nil return nil
} }
...@@ -150,14 +150,14 @@ func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) ...@@ -150,14 +150,14 @@ func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64)
} }
if count >= apiKeyMaxErrorsPerHour { if count >= apiKeyMaxErrorsPerHour {
return ErrApiKeyRateLimited return ErrAPIKeyRateLimited
} }
return nil return nil
} }
// incrementApiKeyErrorCount 增加用户创建自定义Key的错误计数 // incrementAPIKeyErrorCount 增加用户创建自定义Key的错误计数
func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID int64) { func (s *APIKeyService) incrementAPIKeyErrorCount(ctx context.Context, userID int64) {
if s.cache == nil { if s.cache == nil {
return return
} }
...@@ -168,7 +168,7 @@ func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID in ...@@ -168,7 +168,7 @@ func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID in
// canUserBindGroup 检查用户是否可以绑定指定分组 // canUserBindGroup 检查用户是否可以绑定指定分组
// 对于订阅类型分组:检查用户是否有有效订阅 // 对于订阅类型分组:检查用户是否有有效订阅
// 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑 // 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑
func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool { func (s *APIKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool {
// 订阅类型分组:需要有效订阅 // 订阅类型分组:需要有效订阅
if group.IsSubscriptionType() { if group.IsSubscriptionType() {
_, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, group.ID) _, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, group.ID)
...@@ -179,7 +179,7 @@ func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *User, group ...@@ -179,7 +179,7 @@ func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *User, group
} }
// Create 创建API Key // Create 创建API Key
func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiKeyRequest) (*ApiKey, error) { func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIKeyRequest) (*APIKey, error) {
// 验证用户存在 // 验证用户存在
user, err := s.userRepo.GetByID(ctx, userID) user, err := s.userRepo.GetByID(ctx, userID)
if err != nil { if err != nil {
...@@ -204,7 +204,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK ...@@ -204,7 +204,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
// 判断是否使用自定义Key // 判断是否使用自定义Key
if req.CustomKey != nil && *req.CustomKey != "" { if req.CustomKey != nil && *req.CustomKey != "" {
// 检查限流(仅对自定义key进行限流) // 检查限流(仅对自定义key进行限流)
if err := s.checkApiKeyRateLimit(ctx, userID); err != nil { if err := s.checkAPIKeyRateLimit(ctx, userID); err != nil {
return nil, err return nil, err
} }
...@@ -220,8 +220,8 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK ...@@ -220,8 +220,8 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
} }
if exists { if exists {
// Key已存在,增加错误计数 // Key已存在,增加错误计数
s.incrementApiKeyErrorCount(ctx, userID) s.incrementAPIKeyErrorCount(ctx, userID)
return nil, ErrApiKeyExists return nil, ErrAPIKeyExists
} }
key = *req.CustomKey key = *req.CustomKey
...@@ -235,7 +235,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK ...@@ -235,7 +235,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
} }
// 创建API Key记录 // 创建API Key记录
apiKey := &ApiKey{ apiKey := &APIKey{
UserID: userID, UserID: userID,
Key: key, Key: key,
Name: req.Name, Name: req.Name,
...@@ -251,7 +251,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK ...@@ -251,7 +251,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
} }
// List 获取用户的API Key列表 // List 获取用户的API Key列表
func (s *ApiKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) { func (s *APIKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params) keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("list api keys: %w", err) return nil, nil, fmt.Errorf("list api keys: %w", err)
...@@ -259,7 +259,7 @@ func (s *ApiKeyService) List(ctx context.Context, userID int64, params paginatio ...@@ -259,7 +259,7 @@ func (s *ApiKeyService) List(ctx context.Context, userID int64, params paginatio
return keys, pagination, nil return keys, pagination, nil
} }
func (s *ApiKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) { func (s *APIKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
if len(apiKeyIDs) == 0 { if len(apiKeyIDs) == 0 {
return []int64{}, nil return []int64{}, nil
} }
...@@ -272,7 +272,7 @@ func (s *ApiKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKe ...@@ -272,7 +272,7 @@ func (s *ApiKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKe
} }
// GetByID 根据ID获取API Key // GetByID 根据ID获取API Key
func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error) { func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id) apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, fmt.Errorf("get api key: %w", err) return nil, fmt.Errorf("get api key: %w", err)
...@@ -281,7 +281,7 @@ func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error) ...@@ -281,7 +281,7 @@ func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error)
} }
// GetByKey 根据Key字符串获取API Key(用于认证) // GetByKey 根据Key字符串获取API Key(用于认证)
func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*ApiKey, error) { func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, error) {
// 尝试从Redis缓存获取 // 尝试从Redis缓存获取
cacheKey := fmt.Sprintf("apikey:%s", key) cacheKey := fmt.Sprintf("apikey:%s", key)
...@@ -301,7 +301,7 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*ApiKey, erro ...@@ -301,7 +301,7 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*ApiKey, erro
} }
// Update 更新API Key // Update 更新API Key
func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*ApiKey, error) { func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateAPIKeyRequest) (*APIKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id) apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, fmt.Errorf("get api key: %w", err) return nil, fmt.Errorf("get api key: %w", err)
...@@ -353,8 +353,8 @@ func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req ...@@ -353,8 +353,8 @@ func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req
// Delete 删除API Key // Delete 删除API Key
// 优化:使用 GetOwnerID 替代 GetByID 进行权限验证, // 优化:使用 GetOwnerID 替代 GetByID 进行权限验证,
// 避免加载完整 ApiKey 对象及其关联数据(User、Group),提升删除操作的性能 // 避免加载完整 APIKey 对象及其关联数据(User、Group),提升删除操作的性能
func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) error { func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) error {
// 仅获取所有者 ID 用于权限验证,而非加载完整对象 // 仅获取所有者 ID 用于权限验证,而非加载完整对象
ownerID, err := s.apiKeyRepo.GetOwnerID(ctx, id) ownerID, err := s.apiKeyRepo.GetOwnerID(ctx, id)
if err != nil { if err != nil {
...@@ -379,7 +379,7 @@ func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) erro ...@@ -379,7 +379,7 @@ func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) erro
} }
// ValidateKey 验证API Key是否有效(用于认证中间件) // ValidateKey 验证API Key是否有效(用于认证中间件)
func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*ApiKey, *User, error) { func (s *APIKeyService) ValidateKey(ctx context.Context, key string) (*APIKey, *User, error) {
// 获取API Key // 获取API Key
apiKey, err := s.GetByKey(ctx, key) apiKey, err := s.GetByKey(ctx, key)
if err != nil { if err != nil {
...@@ -406,7 +406,7 @@ func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*ApiKey, * ...@@ -406,7 +406,7 @@ func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*ApiKey, *
} }
// IncrementUsage 增加API Key使用次数(可选:用于统计) // IncrementUsage 增加API Key使用次数(可选:用于统计)
func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error { func (s *APIKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
// 使用Redis计数器 // 使用Redis计数器
if s.cache != nil { if s.cache != nil {
cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02")) cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02"))
...@@ -423,7 +423,7 @@ func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error { ...@@ -423,7 +423,7 @@ func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
// 返回用户可以选择的分组: // 返回用户可以选择的分组:
// - 标准类型分组:公开的(非专属)或用户被明确允许的 // - 标准类型分组:公开的(非专属)或用户被明确允许的
// - 订阅类型分组:用户有有效订阅的 // - 订阅类型分组:用户有有效订阅的
func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) { func (s *APIKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) {
// 获取用户信息 // 获取用户信息
user, err := s.userRepo.GetByID(ctx, userID) user, err := s.userRepo.GetByID(ctx, userID)
if err != nil { if err != nil {
...@@ -460,7 +460,7 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([ ...@@ -460,7 +460,7 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([
} }
// canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据) // canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据)
func (s *ApiKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool { func (s *APIKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool {
// 订阅类型分组:需要有效订阅 // 订阅类型分组:需要有效订阅
if group.IsSubscriptionType() { if group.IsSubscriptionType() {
return subscribedGroupIDs[group.ID] return subscribedGroupIDs[group.ID]
...@@ -469,8 +469,8 @@ func (s *ApiKeyService) canUserBindGroupInternal(user *User, group *Group, subsc ...@@ -469,8 +469,8 @@ func (s *ApiKeyService) canUserBindGroupInternal(user *User, group *Group, subsc
return user.CanBindGroup(group.ID, group.IsExclusive) return user.CanBindGroup(group.ID, group.IsExclusive)
} }
func (s *ApiKeyService) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) { func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
keys, err := s.apiKeyRepo.SearchApiKeys(ctx, userID, keyword, limit) keys, err := s.apiKeyRepo.SearchAPIKeys(ctx, userID, keyword, limit)
if err != nil { if err != nil {
return nil, fmt.Errorf("search api keys: %w", err) return nil, fmt.Errorf("search api keys: %w", err)
} }
......
//go:build unit //go:build unit
// API Key 服务删除方法的单元测试 // API Key 服务删除方法的单元测试
// 测试 ApiKeyService.Delete 方法在各种场景下的行为, // 测试 APIKeyService.Delete 方法在各种场景下的行为,
// 包括权限验证、缓存清理和错误处理 // 包括权限验证、缓存清理和错误处理
package service package service
...@@ -16,12 +16,12 @@ import ( ...@@ -16,12 +16,12 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
// apiKeyRepoStub 是 ApiKeyRepository 接口的测试桩实现。 // apiKeyRepoStub 是 APIKeyRepository 接口的测试桩实现。
// 用于隔离测试 ApiKeyService.Delete 方法,避免依赖真实数据库。 // 用于隔离测试 APIKeyService.Delete 方法,避免依赖真实数据库。
// //
// 设计说明: // 设计说明:
// - ownerID: 模拟 GetOwnerID 返回的所有者 ID // - ownerID: 模拟 GetOwnerID 返回的所有者 ID
// - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrApiKeyNotFound) // - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrAPIKeyNotFound)
// - deleteErr: 模拟 Delete 返回的错误 // - deleteErr: 模拟 Delete 返回的错误
// - deletedIDs: 记录被调用删除的 API Key ID,用于断言验证 // - deletedIDs: 记录被调用删除的 API Key ID,用于断言验证
type apiKeyRepoStub struct { type apiKeyRepoStub struct {
...@@ -33,11 +33,11 @@ type apiKeyRepoStub struct { ...@@ -33,11 +33,11 @@ type apiKeyRepoStub struct {
// 以下方法在本测试中不应被调用,使用 panic 确保测试失败时能快速定位问题 // 以下方法在本测试中不应被调用,使用 panic 确保测试失败时能快速定位问题
func (s *apiKeyRepoStub) Create(ctx context.Context, key *ApiKey) error { func (s *apiKeyRepoStub) Create(ctx context.Context, key *APIKey) error {
panic("unexpected Create call") panic("unexpected Create call")
} }
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*ApiKey, error) { func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
panic("unexpected GetByID call") panic("unexpected GetByID call")
} }
...@@ -47,11 +47,11 @@ func (s *apiKeyRepoStub) GetOwnerID(ctx context.Context, id int64) (int64, error ...@@ -47,11 +47,11 @@ func (s *apiKeyRepoStub) GetOwnerID(ctx context.Context, id int64) (int64, error
return s.ownerID, s.ownerErr return s.ownerID, s.ownerErr
} }
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*ApiKey, error) { func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
panic("unexpected GetByKey call") panic("unexpected GetByKey call")
} }
func (s *apiKeyRepoStub) Update(ctx context.Context, key *ApiKey) error { func (s *apiKeyRepoStub) Update(ctx context.Context, key *APIKey) error {
panic("unexpected Update call") panic("unexpected Update call")
} }
...@@ -64,7 +64,7 @@ func (s *apiKeyRepoStub) Delete(ctx context.Context, id int64) error { ...@@ -64,7 +64,7 @@ func (s *apiKeyRepoStub) Delete(ctx context.Context, id int64) error {
// 以下是接口要求实现但本测试不关心的方法 // 以下是接口要求实现但本测试不关心的方法
func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) { func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByUserID call") panic("unexpected ListByUserID call")
} }
...@@ -80,12 +80,12 @@ func (s *apiKeyRepoStub) ExistsByKey(ctx context.Context, key string) (bool, err ...@@ -80,12 +80,12 @@ func (s *apiKeyRepoStub) ExistsByKey(ctx context.Context, key string) (bool, err
panic("unexpected ExistsByKey call") panic("unexpected ExistsByKey call")
} }
func (s *apiKeyRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) { func (s *apiKeyRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByGroupID call") panic("unexpected ListByGroupID call")
} }
func (s *apiKeyRepoStub) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) { func (s *apiKeyRepoStub) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
panic("unexpected SearchApiKeys call") panic("unexpected SearchAPIKeys call")
} }
func (s *apiKeyRepoStub) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) { func (s *apiKeyRepoStub) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
...@@ -96,7 +96,7 @@ func (s *apiKeyRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int ...@@ -96,7 +96,7 @@ func (s *apiKeyRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int
panic("unexpected CountByGroupID call") panic("unexpected CountByGroupID call")
} }
// apiKeyCacheStub 是 ApiKeyCache 接口的测试桩实现。 // apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
// 用于验证删除操作时缓存清理逻辑是否被正确调用。 // 用于验证删除操作时缓存清理逻辑是否被正确调用。
// //
// 设计说明: // 设计说明:
...@@ -142,7 +142,7 @@ func (s *apiKeyCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string ...@@ -142,7 +142,7 @@ func (s *apiKeyCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string
func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) { func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
repo := &apiKeyRepoStub{ownerID: 1} repo := &apiKeyRepoStub{ownerID: 1}
cache := &apiKeyCacheStub{} cache := &apiKeyCacheStub{}
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache} svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 10, 2) // API Key ID=10, 调用者 userID=2 err := svc.Delete(context.Background(), 10, 2) // API Key ID=10, 调用者 userID=2
require.ErrorIs(t, err, ErrInsufficientPerms) require.ErrorIs(t, err, ErrInsufficientPerms)
...@@ -160,7 +160,7 @@ func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) { ...@@ -160,7 +160,7 @@ func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
func TestApiKeyService_Delete_Success(t *testing.T) { func TestApiKeyService_Delete_Success(t *testing.T) {
repo := &apiKeyRepoStub{ownerID: 7} repo := &apiKeyRepoStub{ownerID: 7}
cache := &apiKeyCacheStub{} cache := &apiKeyCacheStub{}
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache} svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 42, 7) // API Key ID=42, 调用者 userID=7 err := svc.Delete(context.Background(), 42, 7) // API Key ID=42, 调用者 userID=7
require.NoError(t, err) require.NoError(t, err)
...@@ -170,17 +170,17 @@ func TestApiKeyService_Delete_Success(t *testing.T) { ...@@ -170,17 +170,17 @@ func TestApiKeyService_Delete_Success(t *testing.T) {
// TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。 // TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。
// 预期行为: // 预期行为:
// - GetOwnerID 返回 ErrApiKeyNotFound 错误 // - GetOwnerID 返回 ErrAPIKeyNotFound 错误
// - 返回 ErrApiKeyNotFound 错误(被 fmt.Errorf 包装) // - 返回 ErrAPIKeyNotFound 错误(被 fmt.Errorf 包装)
// - Delete 方法不被调用 // - Delete 方法不被调用
// - 缓存不被清除 // - 缓存不被清除
func TestApiKeyService_Delete_NotFound(t *testing.T) { func TestApiKeyService_Delete_NotFound(t *testing.T) {
repo := &apiKeyRepoStub{ownerErr: ErrApiKeyNotFound} repo := &apiKeyRepoStub{ownerErr: ErrAPIKeyNotFound}
cache := &apiKeyCacheStub{} cache := &apiKeyCacheStub{}
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache} svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 99, 1) err := svc.Delete(context.Background(), 99, 1)
require.ErrorIs(t, err, ErrApiKeyNotFound) require.ErrorIs(t, err, ErrAPIKeyNotFound)
require.Empty(t, repo.deletedIDs) require.Empty(t, repo.deletedIDs)
require.Empty(t, cache.invalidated) require.Empty(t, cache.invalidated)
} }
...@@ -198,7 +198,7 @@ func TestApiKeyService_Delete_DeleteFails(t *testing.T) { ...@@ -198,7 +198,7 @@ func TestApiKeyService_Delete_DeleteFails(t *testing.T) {
deleteErr: errors.New("delete failed"), deleteErr: errors.New("delete failed"),
} }
cache := &apiKeyCacheStub{} cache := &apiKeyCacheStub{}
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache} svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 3, 3) // API Key ID=3, 调用者 userID=3 err := svc.Delete(context.Background(), 3, 3) // API Key ID=3, 调用者 userID=3
require.Error(t, err) require.Error(t, err)
......
...@@ -448,7 +448,7 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID ...@@ -448,7 +448,7 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID
// CheckBillingEligibility 检查用户是否有资格发起请求 // CheckBillingEligibility 检查用户是否有资格发起请求
// 余额模式:检查缓存余额 > 0 // 余额模式:检查缓存余额 > 0
// 订阅模式:检查缓存用量未超过限额(Group限额从参数传入) // 订阅模式:检查缓存用量未超过限额(Group限额从参数传入)
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *ApiKey, group *Group, subscription *UserSubscription) error { func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *APIKey, group *Group, subscription *UserSubscription) error {
// 简易模式:跳过所有计费检查 // 简易模式:跳过所有计费检查
if s.cfg.RunMode == config.RunModeSimple { if s.cfg.RunMode == config.RunModeSimple {
return nil return nil
......
...@@ -439,7 +439,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -439,7 +439,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
account := &Account{ account := &Account{
Name: defaultName(src.Name, src.ID), Name: defaultName(src.Name, src.ID),
Platform: PlatformAnthropic, Platform: PlatformAnthropic,
Type: AccountTypeApiKey, Type: AccountTypeAPIKey,
Credentials: credentials, Credentials: credentials,
Extra: extra, Extra: extra,
ProxyID: proxyID, ProxyID: proxyID,
...@@ -464,7 +464,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -464,7 +464,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
existing.Extra = mergeMap(existing.Extra, extra) existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID) existing.Name = defaultName(src.Name, src.ID)
existing.Platform = PlatformAnthropic existing.Platform = PlatformAnthropic
existing.Type = AccountTypeApiKey existing.Type = AccountTypeAPIKey
existing.Credentials = mergeMap(existing.Credentials, credentials) existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil { if proxyID != nil {
existing.ProxyID = proxyID existing.ProxyID = proxyID
...@@ -683,7 +683,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -683,7 +683,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
account := &Account{ account := &Account{
Name: defaultName(src.Name, src.ID), Name: defaultName(src.Name, src.ID),
Platform: PlatformOpenAI, Platform: PlatformOpenAI,
Type: AccountTypeApiKey, Type: AccountTypeAPIKey,
Credentials: credentials, Credentials: credentials,
Extra: extra, Extra: extra,
ProxyID: proxyID, ProxyID: proxyID,
...@@ -708,7 +708,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -708,7 +708,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
existing.Extra = mergeMap(existing.Extra, extra) existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID) existing.Name = defaultName(src.Name, src.ID)
existing.Platform = PlatformOpenAI existing.Platform = PlatformOpenAI
existing.Type = AccountTypeApiKey existing.Type = AccountTypeAPIKey
existing.Credentials = mergeMap(existing.Credentials, credentials) existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil { if proxyID != nil {
existing.ProxyID = proxyID existing.ProxyID = proxyID
...@@ -902,7 +902,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -902,7 +902,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
account := &Account{ account := &Account{
Name: defaultName(src.Name, src.ID), Name: defaultName(src.Name, src.ID),
Platform: PlatformGemini, Platform: PlatformGemini,
Type: AccountTypeApiKey, Type: AccountTypeAPIKey,
Credentials: credentials, Credentials: credentials,
Extra: extra, Extra: extra,
ProxyID: proxyID, ProxyID: proxyID,
...@@ -927,7 +927,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -927,7 +927,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
existing.Extra = mergeMap(existing.Extra, extra) existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID) existing.Name = defaultName(src.Name, src.ID)
existing.Platform = PlatformGemini existing.Platform = PlatformGemini
existing.Type = AccountTypeApiKey existing.Type = AccountTypeAPIKey
existing.Credentials = mergeMap(existing.Credentials, credentials) existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil { if proxyID != nil {
existing.ProxyID = proxyID existing.ProxyID = proxyID
......
...@@ -43,8 +43,8 @@ func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTi ...@@ -43,8 +43,8 @@ func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTi
return stats, nil return stats, nil
} }
func (s *DashboardService) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error) { func (s *DashboardService) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
trend, err := s.usageRepo.GetApiKeyUsageTrend(ctx, startTime, endTime, granularity, limit) trend, err := s.usageRepo.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
if err != nil { if err != nil {
return nil, fmt.Errorf("get api key usage trend: %w", err) return nil, fmt.Errorf("get api key usage trend: %w", err)
} }
...@@ -67,8 +67,8 @@ func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs [ ...@@ -67,8 +67,8 @@ func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs [
return stats, nil return stats, nil
} }
func (s *DashboardService) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) { func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchApiKeyUsageStats(ctx, apiKeyIDs) stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs)
if err != nil { if err != nil {
return nil, fmt.Errorf("get batch api key usage stats: %w", err) return nil, fmt.Errorf("get batch api key usage stats: %w", err)
} }
......
...@@ -28,7 +28,7 @@ const ( ...@@ -28,7 +28,7 @@ const (
const ( const (
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference) AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope) AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
AccountTypeApiKey = "apikey" // API Key类型账号 AccountTypeAPIKey = "apikey" // API Key类型账号
) )
// Redeem type constants // Redeem type constants
...@@ -64,13 +64,13 @@ const ( ...@@ -64,13 +64,13 @@ const (
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证 SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
// 邮件服务设置 // 邮件服务设置
SettingKeySmtpHost = "smtp_host" // SMTP服务器地址 SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
SettingKeySmtpPort = "smtp_port" // SMTP端口 SettingKeySMTPPort = "smtp_port" // SMTP端口
SettingKeySmtpUsername = "smtp_username" // SMTP用户名 SettingKeySMTPUsername = "smtp_username" // SMTP用户名
SettingKeySmtpPassword = "smtp_password" // SMTP密码(加密存储) SettingKeySMTPPassword = "smtp_password" // SMTP密码(加密存储)
SettingKeySmtpFrom = "smtp_from" // 发件人地址 SettingKeySMTPFrom = "smtp_from" // 发件人地址
SettingKeySmtpFromName = "smtp_from_name" // 发件人名称 SettingKeySMTPFromName = "smtp_from_name" // 发件人名称
SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS SettingKeySMTPUseTLS = "smtp_use_tls" // 是否使用TLS
// Cloudflare Turnstile 设置 // Cloudflare Turnstile 设置
SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证 SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证
...@@ -81,20 +81,27 @@ const ( ...@@ -81,20 +81,27 @@ const (
SettingKeySiteName = "site_name" // 网站名称 SettingKeySiteName = "site_name" // 网站名称
SettingKeySiteLogo = "site_logo" // 网站Logo (base64) SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题 SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
SettingKeyApiBaseUrl = "api_base_url" // API端点地址(用于客户端配置和导入) SettingKeyAPIBaseURL = "api_base_url" // API端点地址(用于客户端配置和导入)
SettingKeyContactInfo = "contact_info" // 客服联系方式 SettingKeyContactInfo = "contact_info" // 客服联系方式
SettingKeyDocUrl = "doc_url" // 文档链接 SettingKeyDocURL = "doc_url" // 文档链接
// 默认配置 // 默认配置
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量 SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额 SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
// 管理员 API Key // 管理员 API Key
SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成) SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成)
// Gemini 配额策略(JSON) // Gemini 配额策略(JSON)
SettingKeyGeminiQuotaPolicy = "gemini_quota_policy" SettingKeyGeminiQuotaPolicy = "gemini_quota_policy"
// Model fallback settings
SettingKeyEnableModelFallback = "enable_model_fallback"
SettingKeyFallbackModelAnthropic = "fallback_model_anthropic"
SettingKeyFallbackModelOpenAI = "fallback_model_openai"
SettingKeyFallbackModelGemini = "fallback_model_gemini"
SettingKeyFallbackModelAntigravity = "fallback_model_antigravity"
) )
// Admin API Key prefix (distinct from user "sk-" keys) // AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
const AdminApiKeyPrefix = "admin-" const AdminAPIKeyPrefix = "admin-"
...@@ -40,8 +40,8 @@ const ( ...@@ -40,8 +40,8 @@ const (
maxVerifyCodeAttempts = 5 maxVerifyCodeAttempts = 5
) )
// SmtpConfig SMTP配置 // SMTPConfig SMTP配置
type SmtpConfig struct { type SMTPConfig struct {
Host string Host string
Port int Port int
Username string Username string
...@@ -65,16 +65,16 @@ func NewEmailService(settingRepo SettingRepository, cache EmailCache) *EmailServ ...@@ -65,16 +65,16 @@ func NewEmailService(settingRepo SettingRepository, cache EmailCache) *EmailServ
} }
} }
// GetSmtpConfig 从数据库获取SMTP配置 // GetSMTPConfig 从数据库获取SMTP配置
func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) { func (s *EmailService) GetSMTPConfig(ctx context.Context) (*SMTPConfig, error) {
keys := []string{ keys := []string{
SettingKeySmtpHost, SettingKeySMTPHost,
SettingKeySmtpPort, SettingKeySMTPPort,
SettingKeySmtpUsername, SettingKeySMTPUsername,
SettingKeySmtpPassword, SettingKeySMTPPassword,
SettingKeySmtpFrom, SettingKeySMTPFrom,
SettingKeySmtpFromName, SettingKeySMTPFromName,
SettingKeySmtpUseTLS, SettingKeySMTPUseTLS,
} }
settings, err := s.settingRepo.GetMultiple(ctx, keys) settings, err := s.settingRepo.GetMultiple(ctx, keys)
...@@ -82,34 +82,34 @@ func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) { ...@@ -82,34 +82,34 @@ func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) {
return nil, fmt.Errorf("get smtp settings: %w", err) return nil, fmt.Errorf("get smtp settings: %w", err)
} }
host := settings[SettingKeySmtpHost] host := settings[SettingKeySMTPHost]
if host == "" { if host == "" {
return nil, ErrEmailNotConfigured return nil, ErrEmailNotConfigured
} }
port := 587 // 默认端口 port := 587 // 默认端口
if portStr := settings[SettingKeySmtpPort]; portStr != "" { if portStr := settings[SettingKeySMTPPort]; portStr != "" {
if p, err := strconv.Atoi(portStr); err == nil { if p, err := strconv.Atoi(portStr); err == nil {
port = p port = p
} }
} }
useTLS := settings[SettingKeySmtpUseTLS] == "true" useTLS := settings[SettingKeySMTPUseTLS] == "true"
return &SmtpConfig{ return &SMTPConfig{
Host: host, Host: host,
Port: port, Port: port,
Username: settings[SettingKeySmtpUsername], Username: settings[SettingKeySMTPUsername],
Password: settings[SettingKeySmtpPassword], Password: settings[SettingKeySMTPPassword],
From: settings[SettingKeySmtpFrom], From: settings[SettingKeySMTPFrom],
FromName: settings[SettingKeySmtpFromName], FromName: settings[SettingKeySMTPFromName],
UseTLS: useTLS, UseTLS: useTLS,
}, nil }, nil
} }
// SendEmail 发送邮件(使用数据库中保存的配置) // SendEmail 发送邮件(使用数据库中保存的配置)
func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string) error { func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string) error {
config, err := s.GetSmtpConfig(ctx) config, err := s.GetSMTPConfig(ctx)
if err != nil { if err != nil {
return err return err
} }
...@@ -117,7 +117,7 @@ func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string) ...@@ -117,7 +117,7 @@ func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string)
} }
// SendEmailWithConfig 使用指定配置发送邮件 // SendEmailWithConfig 使用指定配置发送邮件
func (s *EmailService) SendEmailWithConfig(config *SmtpConfig, to, subject, body string) error { func (s *EmailService) SendEmailWithConfig(config *SMTPConfig, to, subject, body string) error {
from := config.From from := config.From
if config.FromName != "" { if config.FromName != "" {
from = fmt.Sprintf("%s <%s>", config.FromName, config.From) from = fmt.Sprintf("%s <%s>", config.FromName, config.From)
...@@ -306,8 +306,8 @@ func (s *EmailService) buildVerifyCodeEmailBody(code, siteName string) string { ...@@ -306,8 +306,8 @@ func (s *EmailService) buildVerifyCodeEmailBody(code, siteName string) string {
`, siteName, code) `, siteName, code)
} }
// TestSmtpConnectionWithConfig 使用指定配置测试SMTP连接 // TestSMTPConnectionWithConfig 使用指定配置测试SMTP连接
func (s *EmailService) TestSmtpConnectionWithConfig(config *SmtpConfig) error { func (s *EmailService) TestSMTPConnectionWithConfig(config *SMTPConfig) error {
addr := fmt.Sprintf("%s:%d", config.Host, config.Port) addr := fmt.Sprintf("%s:%d", config.Host, config.Port)
if config.UseTLS { if config.UseTLS {
......
...@@ -136,6 +136,12 @@ func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int6 ...@@ -136,6 +136,12 @@ func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int6
func (m *mockAccountRepoForPlatform) SetOverloaded(ctx context.Context, id int64, until time.Time) error { func (m *mockAccountRepoForPlatform) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return nil return nil
} }
func (m *mockAccountRepoForPlatform) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
return nil
}
func (m *mockAccountRepoForPlatform) ClearTempUnschedulable(ctx context.Context, id int64) error {
return nil
}
func (m *mockAccountRepoForPlatform) ClearRateLimit(ctx context.Context, id int64) error { func (m *mockAccountRepoForPlatform) ClearRateLimit(ctx context.Context, id int64) error {
return nil return nil
} }
...@@ -276,7 +282,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_GeminiOAuthPreference( ...@@ -276,7 +282,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_GeminiOAuthPreference(
repo := &mockAccountRepoForPlatform{ repo := &mockAccountRepoForPlatform{
accounts: []Account{ accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey}, {ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth}, {ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
}, },
accountsByID: map[int64]*Account{}, accountsByID: map[int64]*Account{},
...@@ -617,7 +623,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) { ...@@ -617,7 +623,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
t.Run("混合调度-Gemini优先选择OAuth账户", func(t *testing.T) { t.Run("混合调度-Gemini优先选择OAuth账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{ repo := &mockAccountRepoForPlatform{
accounts: []Account{ accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey}, {ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth}, {ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
}, },
accountsByID: map[int64]*Account{}, accountsByID: map[int64]*Account{},
......
package service package service
import ( import (
"bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
) )
...@@ -70,3 +71,224 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) { ...@@ -70,3 +71,224 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
return parsed, nil return parsed, nil
} }
// FilterThinkingBlocks removes thinking blocks from request body
// Returns filtered body or original body if filtering fails (fail-safe)
// This prevents 400 errors from invalid thinking block signatures
//
// Strategy:
// - When thinking.type != "enabled": Remove all thinking blocks
// - When thinking.type == "enabled": Only remove thinking blocks without valid signatures
// (blocks with missing/empty/dummy signatures that would cause 400 errors)
func FilterThinkingBlocks(body []byte) []byte {
return filterThinkingBlocksInternal(body, false)
}
// FilterThinkingBlocksForRetry removes thinking blocks from HISTORICAL messages for retry scenarios.
// This is used when upstream returns signature-related 400 errors.
//
// Key insight:
// - User's thinking.type = "enabled" should be PRESERVED (user's intent)
// - Only HISTORICAL assistant messages have thinking blocks with signatures
// - These signatures may be invalid when switching accounts/platforms
// - New responses will generate fresh thinking blocks without signature issues
//
// Strategy:
// - Keep thinking.type = "enabled" (preserve user intent)
// - Remove thinking/redacted_thinking blocks from historical assistant messages
// - Ensure no message has empty content after filtering
func FilterThinkingBlocksForRetry(body []byte) []byte {
// Fast path: check for presence of thinking-related keys in messages
if !bytes.Contains(body, []byte(`"type":"thinking"`)) &&
!bytes.Contains(body, []byte(`"type": "thinking"`)) &&
!bytes.Contains(body, []byte(`"type":"redacted_thinking"`)) &&
!bytes.Contains(body, []byte(`"type": "redacted_thinking"`)) {
return body
}
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return body
}
// DO NOT modify thinking.type - preserve user's intent to use thinking mode
// The issue is with historical message signatures, not the thinking mode itself
messages, ok := req["messages"].([]any)
if !ok {
return body
}
modified := false
newMessages := make([]any, 0, len(messages))
for _, msg := range messages {
msgMap, ok := msg.(map[string]any)
if !ok {
newMessages = append(newMessages, msg)
continue
}
role, _ := msgMap["role"].(string)
content, ok := msgMap["content"].([]any)
if !ok {
// String content or other format - keep as is
newMessages = append(newMessages, msg)
continue
}
newContent := make([]any, 0, len(content))
modifiedThisMsg := false
for _, block := range content {
blockMap, ok := block.(map[string]any)
if !ok {
newContent = append(newContent, block)
continue
}
blockType, _ := blockMap["type"].(string)
// Remove thinking/redacted_thinking blocks from historical messages
// These have signatures that may be invalid across different accounts
if blockType == "thinking" || blockType == "redacted_thinking" {
modifiedThisMsg = true
continue
}
newContent = append(newContent, block)
}
if modifiedThisMsg {
modified = true
// Handle empty content after filtering
if len(newContent) == 0 {
// For assistant messages, skip entirely (remove from conversation)
// For user messages, add placeholder to avoid empty content error
if role == "user" {
newContent = append(newContent, map[string]any{
"type": "text",
"text": "(content removed)",
})
msgMap["content"] = newContent
newMessages = append(newMessages, msgMap)
}
// Skip assistant messages with empty content (don't append)
continue
}
msgMap["content"] = newContent
}
newMessages = append(newMessages, msgMap)
}
if modified {
req["messages"] = newMessages
}
newBody, err := json.Marshal(req)
if err != nil {
return body
}
return newBody
}
// filterThinkingBlocksInternal removes invalid thinking blocks from request
// Strategy:
// - When thinking.type != "enabled": Remove all thinking blocks
// - When thinking.type == "enabled": Only remove thinking blocks without valid signatures
func filterThinkingBlocksInternal(body []byte, _ bool) []byte {
// Fast path: if body doesn't contain "thinking", skip parsing
if !bytes.Contains(body, []byte(`"type":"thinking"`)) &&
!bytes.Contains(body, []byte(`"type": "thinking"`)) &&
!bytes.Contains(body, []byte(`"type":"redacted_thinking"`)) &&
!bytes.Contains(body, []byte(`"type": "redacted_thinking"`)) &&
!bytes.Contains(body, []byte(`"thinking":`)) &&
!bytes.Contains(body, []byte(`"thinking" :`)) {
return body
}
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return body
}
// Check if thinking is enabled
thinkingEnabled := false
if thinking, ok := req["thinking"].(map[string]any); ok {
if thinkType, ok := thinking["type"].(string); ok && thinkType == "enabled" {
thinkingEnabled = true
}
}
messages, ok := req["messages"].([]any)
if !ok {
return body
}
filtered := false
for _, msg := range messages {
msgMap, ok := msg.(map[string]any)
if !ok {
continue
}
role, _ := msgMap["role"].(string)
content, ok := msgMap["content"].([]any)
if !ok {
continue
}
newContent := make([]any, 0, len(content))
filteredThisMessage := false
for _, block := range content {
blockMap, ok := block.(map[string]any)
if !ok {
newContent = append(newContent, block)
continue
}
blockType, _ := blockMap["type"].(string)
if blockType == "thinking" || blockType == "redacted_thinking" {
// When thinking is enabled and this is an assistant message,
// only keep thinking blocks with valid signatures
if thinkingEnabled && role == "assistant" {
signature, _ := blockMap["signature"].(string)
if signature != "" && signature != "skip_thought_signature_validator" {
newContent = append(newContent, block)
continue
}
}
filtered = true
filteredThisMessage = true
continue
}
// Handle blocks without type discriminator but with "thinking" key
if blockType == "" {
if _, hasThinking := blockMap["thinking"]; hasThinking {
filtered = true
filteredThisMessage = true
continue
}
}
newContent = append(newContent, block)
}
if filteredThisMessage {
msgMap["content"] = newContent
}
}
if !filtered {
return body
}
newBody, err := json.Marshal(req)
if err != nil {
return body
}
return newBody
}
package service package service
import ( import (
"encoding/json"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
...@@ -38,3 +39,115 @@ func TestParseGatewayRequest_InvalidStreamType(t *testing.T) { ...@@ -38,3 +39,115 @@ func TestParseGatewayRequest_InvalidStreamType(t *testing.T) {
_, err := ParseGatewayRequest(body) _, err := ParseGatewayRequest(body)
require.Error(t, err) require.Error(t, err)
} }
func TestFilterThinkingBlocks(t *testing.T) {
containsThinkingBlock := func(body []byte) bool {
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return false
}
messages, ok := req["messages"].([]any)
if !ok {
return false
}
for _, msg := range messages {
msgMap, ok := msg.(map[string]any)
if !ok {
continue
}
content, ok := msgMap["content"].([]any)
if !ok {
continue
}
for _, block := range content {
blockMap, ok := block.(map[string]any)
if !ok {
continue
}
blockType, _ := blockMap["type"].(string)
if blockType == "thinking" {
return true
}
if blockType == "" {
if _, hasThinking := blockMap["thinking"]; hasThinking {
return true
}
}
}
}
return false
}
tests := []struct {
name string
input string
shouldFilter bool
expectError bool
}{
{
name: "filters thinking blocks",
input: `{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":[{"type":"text","text":"Hello"},{"type":"thinking","thinking":"internal","signature":"invalid"},{"type":"text","text":"World"}]}]}`,
shouldFilter: true,
},
{
name: "handles no thinking blocks",
input: `{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":[{"type":"text","text":"Hello"}]}]}`,
shouldFilter: false,
},
{
name: "handles invalid JSON gracefully",
input: `{invalid json`,
shouldFilter: false,
expectError: true,
},
{
name: "handles multiple messages with thinking blocks",
input: `{"messages":[{"role":"user","content":[{"type":"text","text":"A"}]},{"role":"assistant","content":[{"type":"thinking","thinking":"think"},{"type":"text","text":"B"}]}]}`,
shouldFilter: true,
},
{
name: "filters thinking blocks without type discriminator",
input: `{"messages":[{"role":"assistant","content":[{"thinking":{"text":"internal"}},{"type":"text","text":"B"}]}]}`,
shouldFilter: true,
},
{
name: "does not filter tool_use input fields named thinking",
input: `{"messages":[{"role":"user","content":[{"type":"tool_use","id":"t1","name":"foo","input":{"thinking":"keepme","x":1}},{"type":"text","text":"Hello"}]}]}`,
shouldFilter: false,
},
{
name: "handles empty messages array",
input: `{"messages":[]}`,
shouldFilter: false,
},
{
name: "handles missing messages field",
input: `{"model":"claude-3"}`,
shouldFilter: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := FilterThinkingBlocks([]byte(tt.input))
if tt.expectError {
// For invalid JSON, should return original
require.Equal(t, tt.input, string(result))
return
}
if tt.shouldFilter {
require.False(t, containsThinkingBlock(result))
} else {
// Ensure we don't rewrite JSON when no filtering is needed.
require.Equal(t, tt.input, string(result))
}
// Verify valid JSON returned (unless input was invalid)
var parsed map[string]any
err := json.Unmarshal(result, &parsed)
require.NoError(t, err)
})
}
}
...@@ -547,7 +547,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -547,7 +547,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
for _, item := range available { for _, item := range available {
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
if sessionHash != "" { if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, sessionHash, item.account.ID, stickySessionTTL) _ = s.cache.SetSessionAccountID(ctx, sessionHash, item.account.ID, stickySessionTTL)
} }
return &AccountSelectionResult{ return &AccountSelectionResult{
...@@ -583,7 +583,7 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates ...@@ -583,7 +583,7 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates
for _, acc := range ordered { for _, acc := range ordered {
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
if sessionHash != "" { if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, sessionHash, acc.ID, stickySessionTTL) _ = s.cache.SetSessionAccountID(ctx, sessionHash, acc.ID, stickySessionTTL)
} }
return &AccountSelectionResult{ return &AccountSelectionResult{
...@@ -714,7 +714,7 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) { ...@@ -714,7 +714,7 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) { func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
preferOAuth := platform == PlatformGemini preferOAuth := platform == PlatformGemini
// 1. 查询粘性会话 // 1. 查询粘性会话
if sessionHash != "" { if sessionHash != "" && s.cache != nil {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
if err == nil && accountID > 0 { if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded { if _, excluded := excludedIDs[accountID]; !excluded {
...@@ -787,7 +787,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, ...@@ -787,7 +787,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
} }
// 4. 建立粘性绑定 // 4. 建立粘性绑定
if sessionHash != "" { if sessionHash != "" && s.cache != nil {
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil { if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil {
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
} }
...@@ -803,7 +803,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g ...@@ -803,7 +803,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
preferOAuth := nativePlatform == PlatformGemini preferOAuth := nativePlatform == PlatformGemini
// 1. 查询粘性会话 // 1. 查询粘性会话
if sessionHash != "" { if sessionHash != "" && s.cache != nil {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
if err == nil && accountID > 0 { if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded { if _, excluded := excludedIDs[accountID]; !excluded {
...@@ -879,7 +879,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g ...@@ -879,7 +879,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
} }
// 4. 建立粘性绑定 // 4. 建立粘性绑定
if sessionHash != "" { if sessionHash != "" && s.cache != nil {
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil { if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil {
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
} }
...@@ -911,7 +911,7 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) ( ...@@ -911,7 +911,7 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
case AccountTypeOAuth, AccountTypeSetupToken: case AccountTypeOAuth, AccountTypeSetupToken:
// Both oauth and setup-token use OAuth token flow // Both oauth and setup-token use OAuth token flow
return s.getOAuthToken(ctx, account) return s.getOAuthToken(ctx, account)
case AccountTypeApiKey: case AccountTypeAPIKey:
apiKey := account.GetCredential("api_key") apiKey := account.GetCredential("api_key")
if apiKey == "" { if apiKey == "" {
return "", "", errors.New("api_key not found in credentials") return "", "", errors.New("api_key not found in credentials")
...@@ -1049,7 +1049,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -1049,7 +1049,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 应用模型映射(仅对apikey类型账号) // 应用模型映射(仅对apikey类型账号)
originalModel := reqModel originalModel := reqModel
if account.Type == AccountTypeApiKey { if account.Type == AccountTypeAPIKey {
mappedModel := account.GetMappedModel(reqModel) mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel { if mappedModel != reqModel {
// 替换请求体中的模型名 // 替换请求体中的模型名
...@@ -1086,8 +1086,45 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -1086,8 +1086,45 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return nil, fmt.Errorf("upstream request failed: %w", err) return nil, fmt.Errorf("upstream request failed: %w", err)
} }
// 检查是否需要重试 // 优先检测thinking block签名错误(400)并重试一次
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) { if resp.StatusCode == 400 {
respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
if readErr == nil {
_ = resp.Body.Close()
if s.isThinkingBlockSignatureError(respBody) {
log.Printf("Account %d: detected thinking block signature error, retrying with filtered thinking blocks", account.ID)
// 过滤thinking blocks并重试(使用更激进的过滤)
filteredBody := FilterThinkingBlocksForRetry(body)
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
if buildErr == nil {
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
if retryErr == nil {
// 使用重试后的响应,继续后续处理
if retryResp.StatusCode < 400 {
log.Printf("Account %d: signature error retry succeeded", account.ID)
} else {
log.Printf("Account %d: signature error retry returned status %d", account.ID, retryResp.StatusCode)
}
resp = retryResp
break
}
log.Printf("Account %d: signature error retry failed: %v", account.ID, retryErr)
} else {
log.Printf("Account %d: signature error retry build request failed: %v", account.ID, buildErr)
}
// 重试失败,恢复原始响应体继续处理
resp.Body = io.NopCloser(bytes.NewReader(respBody))
break
}
// 不是thinking签名错误,恢复响应体
resp.Body = io.NopCloser(bytes.NewReader(respBody))
}
}
// 检查是否需要通用重试(排除400,因为400已经在上面特殊处理过了)
if resp.StatusCode >= 400 && resp.StatusCode != 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
if attempt < maxRetries { if attempt < maxRetries {
log.Printf("Account %d: upstream error %d, retry %d/%d after %v", log.Printf("Account %d: upstream error %d, retry %d/%d after %v",
account.ID, resp.StatusCode, attempt, maxRetries, retryDelay) account.ID, resp.StatusCode, attempt, maxRetries, retryDelay)
...@@ -1100,6 +1137,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -1100,6 +1137,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
} }
// 不需要重试(成功或不可重试的错误),跳出循环 // 不需要重试(成功或不可重试的错误),跳出循环
// DEBUG: 输出响应 headers(用于检测 rate limit 信息)
if account.Platform == PlatformGemini && resp.StatusCode < 400 {
log.Printf("[DEBUG] Gemini API Response Headers for account %d:", account.ID)
for k, v := range resp.Header {
log.Printf("[DEBUG] %s: %v", k, v)
}
}
break break
} }
defer func() { _ = resp.Body.Close() }() defer func() { _ = resp.Body.Close() }()
...@@ -1123,7 +1167,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -1123,7 +1167,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
// 可选:对部分 400 触发 failover(默认关闭以保持语义) // 可选:对部分 400 触发 failover(默认关闭以保持语义)
if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 { if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 {
respBody, readErr := io.ReadAll(resp.Body) respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
if readErr != nil { if readErr != nil {
// ReadAll failed, fall back to normal error handling without consuming the stream // ReadAll failed, fall back to normal error handling without consuming the stream
return s.handleErrorResponse(ctx, resp, c, account) return s.handleErrorResponse(ctx, resp, c, account)
...@@ -1183,7 +1227,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -1183,7 +1227,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) { func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
// 确定目标URL // 确定目标URL
targetURL := claudeAPIURL targetURL := claudeAPIURL
if account.Type == AccountTypeApiKey { if account.Type == AccountTypeAPIKey {
baseURL := account.GetBaseURL() baseURL := account.GetBaseURL()
if baseURL != "" { if baseURL != "" {
validatedURL, err := s.validateUpstreamBaseURL(baseURL) validatedURL, err := s.validateUpstreamBaseURL(baseURL)
...@@ -1253,10 +1297,10 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex ...@@ -1253,10 +1297,10 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// 处理anthropic-beta header(OAuth账号需要特殊处理) // 处理anthropic-beta header(OAuth账号需要特殊处理)
if tokenType == "oauth" { if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta"))) req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" { } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭) // API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
if requestNeedsBetaFeatures(body) { if requestNeedsBetaFeatures(body) {
if beta := defaultApiKeyBetaHeader(body); beta != "" { if beta := defaultAPIKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta) req.Header.Set("anthropic-beta", beta)
} }
} }
...@@ -1323,12 +1367,12 @@ func requestNeedsBetaFeatures(body []byte) bool { ...@@ -1323,12 +1367,12 @@ func requestNeedsBetaFeatures(body []byte) bool {
return false return false
} }
func defaultApiKeyBetaHeader(body []byte) string { func defaultAPIKeyBetaHeader(body []byte) string {
modelID := gjson.GetBytes(body, "model").String() modelID := gjson.GetBytes(body, "model").String()
if strings.Contains(strings.ToLower(modelID), "haiku") { if strings.Contains(strings.ToLower(modelID), "haiku") {
return claude.ApiKeyHaikuBetaHeader return claude.APIKeyHaikuBetaHeader
} }
return claude.ApiKeyBetaHeader return claude.APIKeyBetaHeader
} }
func truncateForLog(b []byte, maxBytes int) string { func truncateForLog(b []byte, maxBytes int) string {
...@@ -1345,6 +1389,41 @@ func truncateForLog(b []byte, maxBytes int) string { ...@@ -1345,6 +1389,41 @@ func truncateForLog(b []byte, maxBytes int) string {
return s return s
} }
// isThinkingBlockSignatureError 检测是否是thinking block相关错误
// 这类错误可以通过过滤thinking blocks并重试来解决
func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool {
msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
if msg == "" {
return false
}
// Log for debugging
log.Printf("[SignatureCheck] Checking error message: %s", msg)
// 检测signature相关的错误(更宽松的匹配)
// 例如: "Invalid `signature` in `thinking` block", "***.signature" 等
if strings.Contains(msg, "signature") {
log.Printf("[SignatureCheck] Detected signature error")
return true
}
// 检测 thinking block 顺序/类型错误
// 例如: "Expected `thinking` or `redacted_thinking`, but found `text`"
if strings.Contains(msg, "expected") && (strings.Contains(msg, "thinking") || strings.Contains(msg, "redacted_thinking")) {
log.Printf("[SignatureCheck] Detected thinking block type error")
return true
}
// 检测空消息内容错误(可能是过滤 thinking blocks 后导致的)
// 例如: "all messages must have non-empty content"
if strings.Contains(msg, "non-empty content") || strings.Contains(msg, "empty content") {
log.Printf("[SignatureCheck] Detected empty content error")
return true
}
return false
}
func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool { func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
// 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。 // 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。
// 默认保守:无法识别则不切换。 // 默认保守:无法识别则不切换。
...@@ -1393,7 +1472,13 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res ...@@ -1393,7 +1472,13 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
// 处理上游错误,标记账号状态 // 处理上游错误,标记账号状态
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) shouldDisable := false
if s.rateLimitService != nil {
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
if shouldDisable {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
// 根据状态码返回适当的自定义错误响应(不透传上游详细信息) // 根据状态码返回适当的自定义错误响应(不透传上游详细信息)
var errType, errMsg string var errType, errMsg string
...@@ -1783,7 +1868,7 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo ...@@ -1783,7 +1868,7 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
// RecordUsageInput 记录使用量的输入参数 // RecordUsageInput 记录使用量的输入参数
type RecordUsageInput struct { type RecordUsageInput struct {
Result *ForwardResult Result *ForwardResult
ApiKey *ApiKey APIKey *APIKey
User *User User *User
Account *Account Account *Account
Subscription *UserSubscription // 可选:订阅信息 Subscription *UserSubscription // 可选:订阅信息
...@@ -1792,7 +1877,7 @@ type RecordUsageInput struct { ...@@ -1792,7 +1877,7 @@ type RecordUsageInput struct {
// RecordUsage 记录使用量并扣费(或更新订阅用量) // RecordUsage 记录使用量并扣费(或更新订阅用量)
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error { func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
result := input.Result result := input.Result
apiKey := input.ApiKey apiKey := input.APIKey
user := input.User user := input.User
account := input.Account account := input.Account
subscription := input.Subscription subscription := input.Subscription
...@@ -1829,7 +1914,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -1829,7 +1914,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
durationMs := int(result.Duration.Milliseconds()) durationMs := int(result.Duration.Milliseconds())
usageLog := &UsageLog{ usageLog := &UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, APIKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
RequestID: result.RequestID, RequestID: result.RequestID,
Model: result.Model, Model: result.Model,
...@@ -1859,7 +1944,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -1859,7 +1944,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
usageLog.SubscriptionID = &subscription.ID usageLog.SubscriptionID = &subscription.ID
} }
if err := s.usageLogRepo.Create(ctx, usageLog); err != nil { inserted, err := s.usageLogRepo.Create(ctx, usageLog)
if err != nil {
log.Printf("Create usage log failed: %v", err) log.Printf("Create usage log failed: %v", err)
} }
...@@ -1869,10 +1955,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -1869,10 +1955,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
return nil return nil
} }
shouldBill := inserted || err != nil
// 根据计费类型执行扣费 // 根据计费类型执行扣费
if isSubscriptionBilling { if isSubscriptionBilling {
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率) // 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
if cost.TotalCost > 0 { if shouldBill && cost.TotalCost > 0 {
if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil { if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
log.Printf("Increment subscription usage failed: %v", err) log.Printf("Increment subscription usage failed: %v", err)
} }
...@@ -1881,7 +1969,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -1881,7 +1969,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
} }
} else { } else {
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用) // 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
if cost.ActualCost > 0 { if shouldBill && cost.ActualCost > 0 {
if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil { if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
log.Printf("Deduct balance failed: %v", err) log.Printf("Deduct balance failed: %v", err)
} }
...@@ -1914,7 +2002,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, ...@@ -1914,7 +2002,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
} }
// 应用模型映射(仅对 apikey 类型账号) // 应用模型映射(仅对 apikey 类型账号)
if account.Type == AccountTypeApiKey { if account.Type == AccountTypeAPIKey {
if reqModel != "" { if reqModel != "" {
mappedModel := account.GetMappedModel(reqModel) mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel { if mappedModel != reqModel {
...@@ -1951,17 +2039,35 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, ...@@ -1951,17 +2039,35 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed") s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
return fmt.Errorf("upstream request failed: %w", err) return fmt.Errorf("upstream request failed: %w", err)
} }
defer func() {
_ = resp.Body.Close()
}()
// 读取响应体 // 读取响应体
respBody, err := io.ReadAll(resp.Body) respBody, err := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if err != nil { if err != nil {
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
return err return err
} }
// 检测 thinking block 签名错误(400)并重试一次(过滤 thinking blocks)
if resp.StatusCode == 400 && s.isThinkingBlockSignatureError(respBody) {
log.Printf("Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID)
filteredBody := FilterThinkingBlocks(body)
retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
if buildErr == nil {
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
if retryErr == nil {
resp = retryResp
respBody, err = io.ReadAll(resp.Body)
_ = resp.Body.Close()
if err != nil {
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
return err
}
}
}
}
// 处理错误响应 // 处理错误响应
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
// 标记账号状态(429/529等) // 标记账号状态(429/529等)
...@@ -2000,7 +2106,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, ...@@ -2000,7 +2106,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) { func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
// 确定目标 URL // 确定目标 URL
targetURL := claudeAPICountTokensURL targetURL := claudeAPICountTokensURL
if account.Type == AccountTypeApiKey { if account.Type == AccountTypeAPIKey {
baseURL := account.GetBaseURL() baseURL := account.GetBaseURL()
if baseURL != "" { if baseURL != "" {
validatedURL, err := s.validateUpstreamBaseURL(baseURL) validatedURL, err := s.validateUpstreamBaseURL(baseURL)
...@@ -2065,10 +2171,10 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con ...@@ -2065,10 +2171,10 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
// OAuth 账号:处理 anthropic-beta header // OAuth 账号:处理 anthropic-beta header
if tokenType == "oauth" { if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta"))) req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" { } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key:与 messages 同步的按需 beta 注入(默认关闭) // API-key:与 messages 同步的按需 beta 注入(默认关闭)
if requestNeedsBetaFeatures(body) { if requestNeedsBetaFeatures(body) {
if beta := defaultApiKeyBetaHeader(body); beta != "" { if beta := defaultAPIKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta) req.Header.Set("anthropic-beta", beta)
} }
} }
......
...@@ -291,7 +291,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont ...@@ -291,7 +291,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
return 999 return 999
} }
switch a.Type { switch a.Type {
case AccountTypeApiKey: case AccountTypeAPIKey:
if strings.TrimSpace(a.GetCredential("api_key")) != "" { if strings.TrimSpace(a.GetCredential("api_key")) != "" {
return 0 return 0
} }
...@@ -369,7 +369,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex ...@@ -369,7 +369,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
originalModel := req.Model originalModel := req.Model
mappedModel := req.Model mappedModel := req.Model
if account.Type == AccountTypeApiKey { if account.Type == AccountTypeAPIKey {
mappedModel = account.GetMappedModel(req.Model) mappedModel = account.GetMappedModel(req.Model)
} }
...@@ -392,7 +392,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex ...@@ -392,7 +392,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
} }
switch account.Type { switch account.Type {
case AccountTypeApiKey: case AccountTypeAPIKey:
buildReq = func(ctx context.Context) (*http.Request, string, error) { buildReq = func(ctx context.Context) (*http.Request, string, error) {
apiKey := account.GetCredential("api_key") apiKey := account.GetCredential("api_key")
if strings.TrimSpace(apiKey) == "" { if strings.TrimSpace(apiKey) == "" {
...@@ -569,7 +569,14 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex ...@@ -569,7 +569,14 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
tempMatched := false
if s.rateLimitService != nil {
tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody)
}
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
if tempMatched {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) { if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
} }
...@@ -644,7 +651,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. ...@@ -644,7 +651,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
} }
mappedModel := originalModel mappedModel := originalModel
if account.Type == AccountTypeApiKey { if account.Type == AccountTypeAPIKey {
mappedModel = account.GetMappedModel(originalModel) mappedModel = account.GetMappedModel(originalModel)
} }
...@@ -666,7 +673,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. ...@@ -666,7 +673,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
var buildReq func(ctx context.Context) (*http.Request, string, error) var buildReq func(ctx context.Context) (*http.Request, string, error)
switch account.Type { switch account.Type {
case AccountTypeApiKey: case AccountTypeAPIKey:
buildReq = func(ctx context.Context) (*http.Request, string, error) { buildReq = func(ctx context.Context) (*http.Request, string, error) {
apiKey := account.GetCredential("api_key") apiKey := account.GetCredential("api_key")
if strings.TrimSpace(apiKey) == "" { if strings.TrimSpace(apiKey) == "" {
...@@ -867,6 +874,10 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. ...@@ -867,6 +874,10 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
tempMatched := false
if s.rateLimitService != nil {
tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody)
}
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
// Best-effort fallback for OAuth tokens missing AI Studio scopes when calling countTokens. // Best-effort fallback for OAuth tokens missing AI Studio scopes when calling countTokens.
...@@ -884,6 +895,9 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. ...@@ -884,6 +895,9 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}, nil }, nil
} }
if tempMatched {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) { if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
} }
...@@ -1656,6 +1670,15 @@ type UpstreamHTTPResult struct { ...@@ -1656,6 +1670,15 @@ type UpstreamHTTPResult struct {
} }
func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Context, resp *http.Response, isOAuth bool) (*ClaudeUsage, error) { func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Context, resp *http.Response, isOAuth bool) (*ClaudeUsage, error) {
// Log response headers for debugging
log.Printf("[GeminiAPI] ========== Response Headers ==========")
for key, values := range resp.Header {
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
log.Printf("[GeminiAPI] %s: %v", key, values)
}
}
log.Printf("[GeminiAPI] ========================================")
respBody, err := io.ReadAll(resp.Body) respBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -1688,6 +1711,15 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co ...@@ -1688,6 +1711,15 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
} }
func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, isOAuth bool) (*geminiNativeStreamResult, error) { func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, isOAuth bool) (*geminiNativeStreamResult, error) {
// Log response headers for debugging
log.Printf("[GeminiAPI] ========== Streaming Response Headers ==========")
for key, values := range resp.Header {
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
log.Printf("[GeminiAPI] %s: %v", key, values)
}
}
log.Printf("[GeminiAPI] ====================================================")
c.Status(resp.StatusCode) c.Status(resp.StatusCode)
c.Header("Cache-Control", "no-cache") c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive") c.Header("Connection", "keep-alive")
...@@ -1806,7 +1838,7 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac ...@@ -1806,7 +1838,7 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
} }
switch account.Type { switch account.Type {
case AccountTypeApiKey: case AccountTypeAPIKey:
apiKey := strings.TrimSpace(account.GetCredential("api_key")) apiKey := strings.TrimSpace(account.GetCredential("api_key"))
if apiKey == "" { if apiKey == "" {
return nil, errors.New("gemini api_key not configured") return nil, errors.New("gemini api_key not configured")
...@@ -2230,10 +2262,12 @@ func convertClaudeMessagesToGeminiContents(messages any, toolUseIDToName map[str ...@@ -2230,10 +2262,12 @@ func convertClaudeMessagesToGeminiContents(messages any, toolUseIDToName map[str
parts := make([]any, 0) parts := make([]any, 0)
switch content := mm["content"].(type) { switch content := mm["content"].(type) {
case string: case string:
if strings.TrimSpace(content) != "" { // 字符串形式的 content,保留所有内容(包括空白)
parts = append(parts, map[string]any{"text": content}) parts = append(parts, map[string]any{"text": content})
}
case []any: case []any:
// 如果只有一个 block,不过滤空白(让上游 API 报错)
singleBlock := len(content) == 1
for _, block := range content { for _, block := range content {
bm, ok := block.(map[string]any) bm, ok := block.(map[string]any)
if !ok { if !ok {
...@@ -2242,8 +2276,12 @@ func convertClaudeMessagesToGeminiContents(messages any, toolUseIDToName map[str ...@@ -2242,8 +2276,12 @@ func convertClaudeMessagesToGeminiContents(messages any, toolUseIDToName map[str
bt, _ := bm["type"].(string) bt, _ := bm["type"].(string)
switch bt { switch bt {
case "text": case "text":
if text, ok := bm["text"].(string); ok && strings.TrimSpace(text) != "" { if text, ok := bm["text"].(string); ok {
parts = append(parts, map[string]any{"text": text}) // 单个 block 时保留所有内容(包括空白)
// 多个 blocks 时过滤掉空白
if singleBlock || strings.TrimSpace(text) != "" {
parts = append(parts, map[string]any{"text": text})
}
} }
case "tool_use": case "tool_use":
id, _ := bm["id"].(string) id, _ := bm["id"].(string)
......
...@@ -121,6 +121,12 @@ func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, ...@@ -121,6 +121,12 @@ func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64,
func (m *mockAccountRepoForGemini) SetOverloaded(ctx context.Context, id int64, until time.Time) error { func (m *mockAccountRepoForGemini) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return nil return nil
} }
func (m *mockAccountRepoForGemini) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
return nil
}
func (m *mockAccountRepoForGemini) ClearTempUnschedulable(ctx context.Context, id int64) error {
return nil
}
func (m *mockAccountRepoForGemini) ClearRateLimit(ctx context.Context, id int64) error { return nil } func (m *mockAccountRepoForGemini) ClearRateLimit(ctx context.Context, id int64) error { return nil }
func (m *mockAccountRepoForGemini) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error { func (m *mockAccountRepoForGemini) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
return nil return nil
...@@ -275,7 +281,7 @@ func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OAuthPr ...@@ -275,7 +281,7 @@ func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OAuthPr
repo := &mockAccountRepoForGemini{ repo := &mockAccountRepoForGemini{
accounts: []Account{ accounts: []Account{
{ID: 1, Platform: PlatformGemini, Type: AccountTypeApiKey, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil}, {ID: 1, Platform: PlatformGemini, Type: AccountTypeAPIKey, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
{ID: 2, Platform: PlatformGemini, Type: AccountTypeOAuth, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil}, {ID: 2, Platform: PlatformGemini, Type: AccountTypeOAuth, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
}, },
accountsByID: map[int64]*Account{}, accountsByID: map[int64]*Account{},
......
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log"
"net/http" "net/http"
"regexp" "regexp"
"strconv" "strconv"
...@@ -18,12 +19,23 @@ import ( ...@@ -18,12 +19,23 @@ import (
) )
const ( const (
TierAIPremium = "AI_PREMIUM" // Canonical tier IDs used by sub2api (2026-aligned).
TierGoogleOneStandard = "GOOGLE_ONE_STANDARD" GeminiTierGoogleOneFree = "google_one_free"
TierGoogleOneBasic = "GOOGLE_ONE_BASIC" GeminiTierGoogleAIPro = "google_ai_pro"
TierFree = "FREE" GeminiTierGoogleAIUltra = "google_ai_ultra"
TierGoogleOneUnknown = "GOOGLE_ONE_UNKNOWN" GeminiTierGCPStandard = "gcp_standard"
TierGoogleOneUnlimited = "GOOGLE_ONE_UNLIMITED" GeminiTierGCPEnterprise = "gcp_enterprise"
GeminiTierAIStudioFree = "aistudio_free"
GeminiTierAIStudioPaid = "aistudio_paid"
GeminiTierGoogleOneUnknown = "google_one_unknown"
// Legacy/compat tier IDs that may exist in historical data or upstream responses.
legacyTierAIPremium = "AI_PREMIUM"
legacyTierGoogleOneStandard = "GOOGLE_ONE_STANDARD"
legacyTierGoogleOneBasic = "GOOGLE_ONE_BASIC"
legacyTierFree = "FREE"
legacyTierGoogleOneUnknown = "GOOGLE_ONE_UNKNOWN"
legacyTierGoogleOneUnlimited = "GOOGLE_ONE_UNLIMITED"
) )
const ( const (
...@@ -84,7 +96,7 @@ type GeminiAuthURLResult struct { ...@@ -84,7 +96,7 @@ type GeminiAuthURLResult struct {
State string `json:"state"` State string `json:"state"`
} }
func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI, projectID, oauthType string) (*GeminiAuthURLResult, error) { func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI, projectID, oauthType, tierID string) (*GeminiAuthURLResult, error) {
state, err := geminicli.GenerateState() state, err := geminicli.GenerateState()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to generate state: %w", err) return nil, fmt.Errorf("failed to generate state: %w", err)
...@@ -109,14 +121,14 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64 ...@@ -109,14 +121,14 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
// OAuth client selection: // OAuth client selection:
// - code_assist: always use built-in Gemini CLI OAuth client (public), regardless of configured client_id/secret. // - code_assist: always use built-in Gemini CLI OAuth client (public), regardless of configured client_id/secret.
// - google_one: same as code_assist, uses built-in client for personal Google accounts. // - google_one: uses configured OAuth client when provided; otherwise falls back to built-in client.
// - ai_studio: requires a user-provided OAuth client. // - ai_studio: requires a user-provided OAuth client.
oauthCfg := geminicli.OAuthConfig{ oauthCfg := geminicli.OAuthConfig{
ClientID: s.cfg.Gemini.OAuth.ClientID, ClientID: s.cfg.Gemini.OAuth.ClientID,
ClientSecret: s.cfg.Gemini.OAuth.ClientSecret, ClientSecret: s.cfg.Gemini.OAuth.ClientSecret,
Scopes: s.cfg.Gemini.OAuth.Scopes, Scopes: s.cfg.Gemini.OAuth.Scopes,
} }
if oauthType == "code_assist" || oauthType == "google_one" { if oauthType == "code_assist" {
oauthCfg.ClientID = "" oauthCfg.ClientID = ""
oauthCfg.ClientSecret = "" oauthCfg.ClientSecret = ""
} }
...@@ -127,6 +139,7 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64 ...@@ -127,6 +139,7 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
ProxyURL: proxyURL, ProxyURL: proxyURL,
RedirectURI: redirectURI, RedirectURI: redirectURI,
ProjectID: strings.TrimSpace(projectID), ProjectID: strings.TrimSpace(projectID),
TierID: canonicalGeminiTierIDForOAuthType(oauthType, tierID),
OAuthType: oauthType, OAuthType: oauthType,
CreatedAt: time.Now(), CreatedAt: time.Now(),
} }
...@@ -146,9 +159,9 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64 ...@@ -146,9 +159,9 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
} }
// Redirect URI strategy: // Redirect URI strategy:
// - code_assist: use Gemini CLI redirect URI (codeassist.google.com/authcode) // - built-in Gemini CLI OAuth client: use upstream redirect URI (codeassist.google.com/authcode)
// - ai_studio: use localhost callback for manual copy/paste flow // - custom OAuth client: use localhost callback for manual copy/paste flow
if oauthType == "code_assist" { if isBuiltinClient {
redirectURI = geminicli.GeminiCLIRedirectURI redirectURI = geminicli.GeminiCLIRedirectURI
} else { } else {
redirectURI = geminicli.AIStudioOAuthRedirectURI redirectURI = geminicli.AIStudioOAuthRedirectURI
...@@ -174,6 +187,9 @@ type GeminiExchangeCodeInput struct { ...@@ -174,6 +187,9 @@ type GeminiExchangeCodeInput struct {
Code string Code string
ProxyID *int64 ProxyID *int64
OAuthType string // "code_assist" 或 "ai_studio" OAuthType string // "code_assist" 或 "ai_studio"
// TierID is a user-selected tier to be used when auto detection is unavailable or fails.
// If empty, the service will fall back to the tier stored in the OAuth session (if any).
TierID string
} }
type GeminiTokenInfo struct { type GeminiTokenInfo struct {
...@@ -185,7 +201,7 @@ type GeminiTokenInfo struct { ...@@ -185,7 +201,7 @@ type GeminiTokenInfo struct {
Scope string `json:"scope,omitempty"` Scope string `json:"scope,omitempty"`
ProjectID string `json:"project_id,omitempty"` ProjectID string `json:"project_id,omitempty"`
OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio" OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio"
TierID string `json:"tier_id,omitempty"` // Gemini Code Assist tier: LEGACY/PRO/ULTRA TierID string `json:"tier_id,omitempty"` // Canonical tier id (e.g. google_one_free, gcp_standard, aistudio_free)
Extra map[string]any `json:"extra,omitempty"` // Drive metadata Extra map[string]any `json:"extra,omitempty"` // Drive metadata
} }
...@@ -204,6 +220,90 @@ func validateTierID(tierID string) error { ...@@ -204,6 +220,90 @@ func validateTierID(tierID string) error {
return nil return nil
} }
func canonicalGeminiTierID(raw string) string {
raw = strings.TrimSpace(raw)
if raw == "" {
return ""
}
lower := strings.ToLower(raw)
switch lower {
case GeminiTierGoogleOneFree,
GeminiTierGoogleAIPro,
GeminiTierGoogleAIUltra,
GeminiTierGCPStandard,
GeminiTierGCPEnterprise,
GeminiTierAIStudioFree,
GeminiTierAIStudioPaid,
GeminiTierGoogleOneUnknown:
return lower
}
upper := strings.ToUpper(raw)
switch upper {
// Google One legacy tiers
case legacyTierAIPremium:
return GeminiTierGoogleAIPro
case legacyTierGoogleOneUnlimited:
return GeminiTierGoogleAIUltra
case legacyTierFree, legacyTierGoogleOneBasic, legacyTierGoogleOneStandard:
return GeminiTierGoogleOneFree
case legacyTierGoogleOneUnknown:
return GeminiTierGoogleOneUnknown
// Code Assist legacy tiers
case "STANDARD", "PRO", "LEGACY":
return GeminiTierGCPStandard
case "ENTERPRISE", "ULTRA":
return GeminiTierGCPEnterprise
}
// Some Code Assist responses use kebab-case tier identifiers.
switch lower {
case "standard-tier", "pro-tier":
return GeminiTierGCPStandard
case "ultra-tier":
return GeminiTierGCPEnterprise
}
return ""
}
func canonicalGeminiTierIDForOAuthType(oauthType, tierID string) string {
oauthType = strings.ToLower(strings.TrimSpace(oauthType))
canonical := canonicalGeminiTierID(tierID)
if canonical == "" {
return ""
}
switch oauthType {
case "google_one":
switch canonical {
case GeminiTierGoogleOneFree, GeminiTierGoogleAIPro, GeminiTierGoogleAIUltra:
return canonical
default:
return ""
}
case "code_assist":
switch canonical {
case GeminiTierGCPStandard, GeminiTierGCPEnterprise:
return canonical
default:
return ""
}
case "ai_studio":
switch canonical {
case GeminiTierAIStudioFree, GeminiTierAIStudioPaid:
return canonical
default:
return ""
}
default:
// Unknown oauth type: accept canonical tier.
return canonical
}
}
// extractTierIDFromAllowedTiers extracts tierID from LoadCodeAssist response // extractTierIDFromAllowedTiers extracts tierID from LoadCodeAssist response
// Prioritizes IsDefault tier, falls back to first non-empty tier // Prioritizes IsDefault tier, falls back to first non-empty tier
func extractTierIDFromAllowedTiers(allowedTiers []geminicli.AllowedTier) string { func extractTierIDFromAllowedTiers(allowedTiers []geminicli.AllowedTier) string {
...@@ -229,45 +329,61 @@ func extractTierIDFromAllowedTiers(allowedTiers []geminicli.AllowedTier) string ...@@ -229,45 +329,61 @@ func extractTierIDFromAllowedTiers(allowedTiers []geminicli.AllowedTier) string
// inferGoogleOneTier infers Google One tier from Drive storage limit // inferGoogleOneTier infers Google One tier from Drive storage limit
func inferGoogleOneTier(storageBytes int64) string { func inferGoogleOneTier(storageBytes int64) string {
log.Printf("[GeminiOAuth] inferGoogleOneTier - input: %d bytes (%.2f TB)", storageBytes, float64(storageBytes)/float64(TB))
if storageBytes <= 0 { if storageBytes <= 0 {
return TierGoogleOneUnknown log.Printf("[GeminiOAuth] inferGoogleOneTier - storageBytes <= 0, returning UNKNOWN")
return GeminiTierGoogleOneUnknown
} }
if storageBytes > StorageTierUnlimited { if storageBytes > StorageTierUnlimited {
return TierGoogleOneUnlimited log.Printf("[GeminiOAuth] inferGoogleOneTier - > %d bytes (100TB), returning UNLIMITED", StorageTierUnlimited)
return GeminiTierGoogleAIUltra
} }
if storageBytes >= StorageTierAIPremium { if storageBytes >= StorageTierAIPremium {
return TierAIPremium log.Printf("[GeminiOAuth] inferGoogleOneTier - >= %d bytes (2TB), returning google_ai_pro", StorageTierAIPremium)
} return GeminiTierGoogleAIPro
if storageBytes >= StorageTierStandard {
return TierGoogleOneStandard
}
if storageBytes >= StorageTierBasic {
return TierGoogleOneBasic
} }
if storageBytes >= StorageTierFree { if storageBytes >= StorageTierFree {
return TierFree log.Printf("[GeminiOAuth] inferGoogleOneTier - >= %d bytes (15GB), returning FREE", StorageTierFree)
return GeminiTierGoogleOneFree
} }
return TierGoogleOneUnknown
log.Printf("[GeminiOAuth] inferGoogleOneTier - < %d bytes (15GB), returning UNKNOWN", StorageTierFree)
return GeminiTierGoogleOneUnknown
} }
// fetchGoogleOneTier fetches Google One tier from Drive API // FetchGoogleOneTier fetches Google One tier from Drive API.
// Note: LoadCodeAssist API is NOT called for Google One accounts because:
// 1. It's designed for GCP IAM (enterprise), not personal Google accounts
// 2. Personal accounts will get 403/404 from cloudaicompanion.googleapis.com
// 3. Google consumer (Google One) and enterprise (GCP) systems are physically isolated
func (s *GeminiOAuthService) FetchGoogleOneTier(ctx context.Context, accessToken, proxyURL string) (string, *geminicli.DriveStorageInfo, error) { func (s *GeminiOAuthService) FetchGoogleOneTier(ctx context.Context, accessToken, proxyURL string) (string, *geminicli.DriveStorageInfo, error) {
log.Printf("[GeminiOAuth] Starting FetchGoogleOneTier (Google One personal account)")
// Use Drive API to infer tier from storage quota (requires drive.readonly scope)
log.Printf("[GeminiOAuth] Calling Drive API for storage quota...")
driveClient := geminicli.NewDriveClient() driveClient := geminicli.NewDriveClient()
storageInfo, err := driveClient.GetStorageQuota(ctx, accessToken, proxyURL) storageInfo, err := driveClient.GetStorageQuota(ctx, accessToken, proxyURL)
if err != nil { if err != nil {
// Check if it's a 403 (scope not granted) // Check if it's a 403 (scope not granted)
if strings.Contains(err.Error(), "status 403") { if strings.Contains(err.Error(), "status 403") {
fmt.Printf("[GeminiOAuth] Drive API scope not available: %v\n", err) log.Printf("[GeminiOAuth] Drive API scope not available (403): %v", err)
return TierGoogleOneUnknown, nil, err return GeminiTierGoogleOneUnknown, nil, err
} }
// Other errors // Other errors
fmt.Printf("[GeminiOAuth] Failed to fetch Drive storage: %v\n", err) log.Printf("[GeminiOAuth] Failed to fetch Drive storage: %v", err)
return TierGoogleOneUnknown, nil, err return GeminiTierGoogleOneUnknown, nil, err
} }
log.Printf("[GeminiOAuth] Drive API response - Limit: %d bytes (%.2f TB), Usage: %d bytes (%.2f GB)",
storageInfo.Limit, float64(storageInfo.Limit)/float64(TB),
storageInfo.Usage, float64(storageInfo.Usage)/float64(GB))
tierID := inferGoogleOneTier(storageInfo.Limit) tierID := inferGoogleOneTier(storageInfo.Limit)
log.Printf("[GeminiOAuth] Inferred tier from storage: %s", tierID)
return tierID, storageInfo, nil return tierID, storageInfo, nil
} }
...@@ -326,11 +442,16 @@ func (s *GeminiOAuthService) RefreshAccountGoogleOneTier( ...@@ -326,11 +442,16 @@ func (s *GeminiOAuthService) RefreshAccountGoogleOneTier(
} }
func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) { func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) {
log.Printf("[GeminiOAuth] ========== ExchangeCode START ==========")
log.Printf("[GeminiOAuth] SessionID: %s", input.SessionID)
session, ok := s.sessionStore.Get(input.SessionID) session, ok := s.sessionStore.Get(input.SessionID)
if !ok { if !ok {
log.Printf("[GeminiOAuth] ERROR: Session not found or expired")
return nil, fmt.Errorf("session not found or expired") return nil, fmt.Errorf("session not found or expired")
} }
if strings.TrimSpace(input.State) == "" || input.State != session.State { if strings.TrimSpace(input.State) == "" || input.State != session.State {
log.Printf("[GeminiOAuth] ERROR: Invalid state")
return nil, fmt.Errorf("invalid state") return nil, fmt.Errorf("invalid state")
} }
...@@ -341,6 +462,7 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch ...@@ -341,6 +462,7 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
proxyURL = proxy.URL() proxyURL = proxy.URL()
} }
} }
log.Printf("[GeminiOAuth] ProxyURL: %s", proxyURL)
redirectURI := session.RedirectURI redirectURI := session.RedirectURI
...@@ -349,6 +471,8 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch ...@@ -349,6 +471,8 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
if oauthType == "" { if oauthType == "" {
oauthType = "code_assist" oauthType = "code_assist"
} }
log.Printf("[GeminiOAuth] OAuth Type: %s", oauthType)
log.Printf("[GeminiOAuth] Project ID from session: %s", session.ProjectID)
// If the session was created for AI Studio OAuth, ensure a custom OAuth client is configured. // If the session was created for AI Studio OAuth, ensure a custom OAuth client is configured.
if oauthType == "ai_studio" { if oauthType == "ai_studio" {
...@@ -374,8 +498,13 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch ...@@ -374,8 +498,13 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
tokenResp, err := s.oauthClient.ExchangeCode(ctx, oauthType, input.Code, session.CodeVerifier, redirectURI, proxyURL) tokenResp, err := s.oauthClient.ExchangeCode(ctx, oauthType, input.Code, session.CodeVerifier, redirectURI, proxyURL)
if err != nil { if err != nil {
log.Printf("[GeminiOAuth] ERROR: Failed to exchange code: %v", err)
return nil, fmt.Errorf("failed to exchange code: %w", err) return nil, fmt.Errorf("failed to exchange code: %w", err)
} }
log.Printf("[GeminiOAuth] Token exchange successful")
log.Printf("[GeminiOAuth] Token scope: %s", tokenResp.Scope)
log.Printf("[GeminiOAuth] Token expires_in: %d seconds", tokenResp.ExpiresIn)
sessionProjectID := strings.TrimSpace(session.ProjectID) sessionProjectID := strings.TrimSpace(session.ProjectID)
s.sessionStore.Delete(input.SessionID) s.sessionStore.Delete(input.SessionID)
...@@ -391,43 +520,91 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch ...@@ -391,43 +520,91 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
projectID := sessionProjectID projectID := sessionProjectID
var tierID string var tierID string
fallbackTierID := canonicalGeminiTierIDForOAuthType(oauthType, input.TierID)
if fallbackTierID == "" {
fallbackTierID = canonicalGeminiTierIDForOAuthType(oauthType, session.TierID)
}
log.Printf("[GeminiOAuth] ========== Account Type Detection START ==========")
log.Printf("[GeminiOAuth] OAuth Type: %s", oauthType)
// 对于 code_assist 模式,project_id 是必需的,需要调用 Code Assist API // 对于 code_assist 模式,project_id 是必需的,需要调用 Code Assist API
// 对于 google_one 模式,使用个人 Google 账号,不需要 project_id,配额由 Google 网关自动识别 // 对于 google_one 模式,使用个人 Google 账号,不需要 project_id,配额由 Google 网关自动识别
// 对于 ai_studio 模式,project_id 是可选的(不影响使用 AI Studio API) // 对于 ai_studio 模式,project_id 是可选的(不影响使用 AI Studio API)
switch oauthType { switch oauthType {
case "code_assist": case "code_assist":
log.Printf("[GeminiOAuth] Processing code_assist OAuth type")
if projectID == "" { if projectID == "" {
log.Printf("[GeminiOAuth] No project_id provided, attempting to fetch from LoadCodeAssist API...")
var err error var err error
projectID, tierID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL) projectID, tierID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
if err != nil { if err != nil {
// 记录警告但不阻断流程,允许后续补充 project_id // 记录警告但不阻断流程,允许后续补充 project_id
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err) fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err)
log.Printf("[GeminiOAuth] WARNING: Failed to fetch project_id: %v", err)
} else {
log.Printf("[GeminiOAuth] Successfully fetched project_id: %s, tier_id: %s", projectID, tierID)
} }
} else { } else {
log.Printf("[GeminiOAuth] User provided project_id: %s, fetching tier_id...", projectID)
// 用户手动填了 project_id,仍需调用 LoadCodeAssist 获取 tierID // 用户手动填了 project_id,仍需调用 LoadCodeAssist 获取 tierID
_, fetchedTierID, err := s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL) _, fetchedTierID, err := s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
if err != nil { if err != nil {
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch tierID: %v\n", err) fmt.Printf("[GeminiOAuth] Warning: Failed to fetch tierID: %v\n", err)
log.Printf("[GeminiOAuth] WARNING: Failed to fetch tier_id: %v", err)
} else { } else {
tierID = fetchedTierID tierID = fetchedTierID
log.Printf("[GeminiOAuth] Successfully fetched tier_id: %s", tierID)
} }
} }
if strings.TrimSpace(projectID) == "" { if strings.TrimSpace(projectID) == "" {
log.Printf("[GeminiOAuth] ERROR: Missing project_id for Code Assist OAuth")
return nil, fmt.Errorf("missing project_id for Code Assist OAuth: please fill Project ID (optional field) and regenerate the auth URL, or ensure your Google account has an ACTIVE GCP project") return nil, fmt.Errorf("missing project_id for Code Assist OAuth: please fill Project ID (optional field) and regenerate the auth URL, or ensure your Google account has an ACTIVE GCP project")
} }
// tierID 缺失时使用默认值 // Prefer auto-detected tier; fall back to user-selected tier.
tierID = canonicalGeminiTierIDForOAuthType(oauthType, tierID)
if tierID == "" { if tierID == "" {
tierID = "LEGACY" if fallbackTierID != "" {
tierID = fallbackTierID
log.Printf("[GeminiOAuth] Using fallback tier_id from user/session: %s", tierID)
} else {
tierID = GeminiTierGCPStandard
log.Printf("[GeminiOAuth] Using default tier_id: %s", tierID)
}
} }
log.Printf("[GeminiOAuth] Final code_assist result - project_id: %s, tier_id: %s", projectID, tierID)
case "google_one": case "google_one":
log.Printf("[GeminiOAuth] Processing google_one OAuth type")
log.Printf("[GeminiOAuth] Attempting to fetch Google One tier from Drive API...")
// Attempt to fetch Drive storage tier // Attempt to fetch Drive storage tier
tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, tokenResp.AccessToken, proxyURL) var storageInfo *geminicli.DriveStorageInfo
var err error
tierID, storageInfo, err = s.FetchGoogleOneTier(ctx, tokenResp.AccessToken, proxyURL)
if err != nil { if err != nil {
// Log warning but don't block - use fallback // Log warning but don't block - use fallback
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch Drive tier: %v\n", err) fmt.Printf("[GeminiOAuth] Warning: Failed to fetch Drive tier: %v\n", err)
tierID = TierGoogleOneUnknown log.Printf("[GeminiOAuth] WARNING: Failed to fetch Drive tier: %v", err)
tierID = ""
} else {
log.Printf("[GeminiOAuth] Successfully fetched Drive tier: %s", tierID)
if storageInfo != nil {
log.Printf("[GeminiOAuth] Drive storage - Limit: %d bytes (%.2f TB), Usage: %d bytes (%.2f GB)",
storageInfo.Limit, float64(storageInfo.Limit)/float64(TB),
storageInfo.Usage, float64(storageInfo.Usage)/float64(GB))
}
} }
tierID = canonicalGeminiTierIDForOAuthType(oauthType, tierID)
if tierID == "" || tierID == GeminiTierGoogleOneUnknown {
if fallbackTierID != "" {
tierID = fallbackTierID
log.Printf("[GeminiOAuth] Using fallback tier_id from user/session: %s", tierID)
} else {
tierID = GeminiTierGoogleOneFree
log.Printf("[GeminiOAuth] Using default tier_id: %s", tierID)
}
}
fmt.Printf("[GeminiOAuth] Google One tierID after normalization: %s\n", tierID)
// Store Drive info in extra field for caching // Store Drive info in extra field for caching
if storageInfo != nil { if storageInfo != nil {
...@@ -447,12 +624,25 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch ...@@ -447,12 +624,25 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
"drive_tier_updated_at": time.Now().Format(time.RFC3339), "drive_tier_updated_at": time.Now().Format(time.RFC3339),
}, },
} }
log.Printf("[GeminiOAuth] ========== ExchangeCode END (google_one with storage info) ==========")
return tokenInfo, nil return tokenInfo, nil
} }
case "ai_studio":
// No automatic tier detection for AI Studio OAuth; rely on user selection.
if fallbackTierID != "" {
tierID = fallbackTierID
} else {
tierID = GeminiTierAIStudioFree
}
default:
log.Printf("[GeminiOAuth] Processing %s OAuth type (no tier detection)", oauthType)
} }
// ai_studio 模式不设置 tierID,保持为空
return &GeminiTokenInfo{ log.Printf("[GeminiOAuth] ========== Account Type Detection END ==========")
result := &GeminiTokenInfo{
AccessToken: tokenResp.AccessToken, AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken, RefreshToken: tokenResp.RefreshToken,
TokenType: tokenResp.TokenType, TokenType: tokenResp.TokenType,
...@@ -462,7 +652,10 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch ...@@ -462,7 +652,10 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
ProjectID: projectID, ProjectID: projectID,
TierID: tierID, TierID: tierID,
OAuthType: oauthType, OAuthType: oauthType,
}, nil }
log.Printf("[GeminiOAuth] Final result - OAuth Type: %s, Project ID: %s, Tier ID: %s", result.OAuthType, result.ProjectID, result.TierID)
log.Printf("[GeminiOAuth] ========== ExchangeCode END ==========")
return result, nil
} }
func (s *GeminiOAuthService) RefreshToken(ctx context.Context, oauthType, refreshToken, proxyURL string) (*GeminiTokenInfo, error) { func (s *GeminiOAuthService) RefreshToken(ctx context.Context, oauthType, refreshToken, proxyURL string) (*GeminiTokenInfo, error) {
...@@ -558,6 +751,17 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A ...@@ -558,6 +751,17 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
err = nil err = nil
} }
} }
// Backward compatibility for google_one:
// - New behavior: when a custom OAuth client is configured, google_one will use it.
// - Old behavior: google_one always used the built-in Gemini CLI OAuth client.
// If an existing account was authorized with the built-in client, refreshing with the custom client
// will fail with "unauthorized_client". Retry with the built-in client (code_assist path forces it).
if err != nil && oauthType == "google_one" && strings.Contains(err.Error(), "unauthorized_client") && s.GetOAuthConfig().AIStudioOAuthEnabled {
if alt, altErr := s.RefreshToken(ctx, "code_assist", refreshToken, proxyURL); altErr == nil {
tokenInfo = alt
err = nil
}
}
if err != nil { if err != nil {
// Provide a more actionable error for common OAuth client mismatch issues. // Provide a more actionable error for common OAuth client mismatch issues.
if strings.Contains(err.Error(), "unauthorized_client") { if strings.Contains(err.Error(), "unauthorized_client") {
...@@ -583,13 +787,14 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A ...@@ -583,13 +787,14 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
case "code_assist": case "code_assist":
// 先设置默认值或保留旧值,确保 tier_id 始终有值 // 先设置默认值或保留旧值,确保 tier_id 始终有值
if existingTierID != "" { if existingTierID != "" {
tokenInfo.TierID = existingTierID tokenInfo.TierID = canonicalGeminiTierIDForOAuthType(oauthType, existingTierID)
} else { }
tokenInfo.TierID = "LEGACY" // 默认值 if tokenInfo.TierID == "" {
tokenInfo.TierID = GeminiTierGCPStandard
} }
// 尝试自动探测 project_id 和 tier_id // 尝试自动探测 project_id 和 tier_id
needDetect := strings.TrimSpace(tokenInfo.ProjectID) == "" || existingTierID == "" needDetect := strings.TrimSpace(tokenInfo.ProjectID) == "" || tokenInfo.TierID == ""
if needDetect { if needDetect {
projectID, tierID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL) projectID, tierID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL)
if err != nil { if err != nil {
...@@ -598,9 +803,10 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A ...@@ -598,9 +803,10 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
if strings.TrimSpace(tokenInfo.ProjectID) == "" && projectID != "" { if strings.TrimSpace(tokenInfo.ProjectID) == "" && projectID != "" {
tokenInfo.ProjectID = projectID tokenInfo.ProjectID = projectID
} }
// 只有当原来没有 tier_id 且探测成功时才更新 if tierID != "" {
if existingTierID == "" && tierID != "" { if canonical := canonicalGeminiTierIDForOAuthType(oauthType, tierID); canonical != "" {
tokenInfo.TierID = tierID tokenInfo.TierID = canonical
}
} }
} }
} }
...@@ -609,6 +815,7 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A ...@@ -609,6 +815,7 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
return nil, fmt.Errorf("failed to auto-detect project_id: empty result") return nil, fmt.Errorf("failed to auto-detect project_id: empty result")
} }
case "google_one": case "google_one":
canonicalExistingTier := canonicalGeminiTierIDForOAuthType(oauthType, existingTierID)
// Check if tier cache is stale (> 24 hours) // Check if tier cache is stale (> 24 hours)
needsRefresh := true needsRefresh := true
if account.Extra != nil { if account.Extra != nil {
...@@ -617,32 +824,39 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A ...@@ -617,32 +824,39 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
if time.Since(updatedAt) <= 24*time.Hour { if time.Since(updatedAt) <= 24*time.Hour {
needsRefresh = false needsRefresh = false
// Use cached tier // Use cached tier
if existingTierID != "" { tokenInfo.TierID = canonicalExistingTier
tokenInfo.TierID = existingTierID
}
} }
} }
} }
} }
if tokenInfo.TierID == "" {
tokenInfo.TierID = canonicalExistingTier
}
if needsRefresh { if needsRefresh {
tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, tokenInfo.AccessToken, proxyURL) tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, tokenInfo.AccessToken, proxyURL)
if err == nil && storageInfo != nil { if err == nil {
tokenInfo.TierID = tierID if canonical := canonicalGeminiTierIDForOAuthType(oauthType, tierID); canonical != "" && canonical != GeminiTierGoogleOneUnknown {
tokenInfo.Extra = map[string]any{ tokenInfo.TierID = canonical
"drive_storage_limit": storageInfo.Limit,
"drive_storage_usage": storageInfo.Usage,
"drive_tier_updated_at": time.Now().Format(time.RFC3339),
} }
} else { if storageInfo != nil {
// Fallback to cached or unknown tokenInfo.Extra = map[string]any{
if existingTierID != "" { "drive_storage_limit": storageInfo.Limit,
tokenInfo.TierID = existingTierID "drive_storage_usage": storageInfo.Usage,
} else { "drive_tier_updated_at": time.Now().Format(time.RFC3339),
tokenInfo.TierID = TierGoogleOneUnknown }
} }
} }
} }
if tokenInfo.TierID == "" || tokenInfo.TierID == GeminiTierGoogleOneUnknown {
if canonicalExistingTier != "" {
tokenInfo.TierID = canonicalExistingTier
} else {
tokenInfo.TierID = GeminiTierGoogleOneFree
}
}
} }
return tokenInfo, nil return tokenInfo, nil
...@@ -669,6 +883,9 @@ func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo) ...@@ -669,6 +883,9 @@ func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo)
// Validate tier_id before storing // Validate tier_id before storing
if err := validateTierID(tokenInfo.TierID); err == nil { if err := validateTierID(tokenInfo.TierID); err == nil {
creds["tier_id"] = tokenInfo.TierID creds["tier_id"] = tokenInfo.TierID
fmt.Printf("[GeminiOAuth] Storing tier_id: %s\n", tokenInfo.TierID)
} else {
fmt.Printf("[GeminiOAuth] Invalid tier_id %s: %v\n", tokenInfo.TierID, err)
} }
// Silently skip invalid tier_id (don't block account creation) // Silently skip invalid tier_id (don't block account creation)
} }
...@@ -698,7 +915,13 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr ...@@ -698,7 +915,13 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
// Extract tierID from response (works whether CloudAICompanionProject is set or not) // Extract tierID from response (works whether CloudAICompanionProject is set or not)
tierID := "LEGACY" tierID := "LEGACY"
if loadResp != nil { if loadResp != nil {
tierID = extractTierIDFromAllowedTiers(loadResp.AllowedTiers) // First try to get tier from currentTier/paidTier fields
if tier := loadResp.GetTier(); tier != "" {
tierID = tier
} else {
// Fallback to extracting from allowedTiers
tierID = extractTierIDFromAllowedTiers(loadResp.AllowedTiers)
}
} }
// If LoadCodeAssist returned a project, use it // If LoadCodeAssist returned a project, use it
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment