Commit 1de18b89 authored by Wang Lvyuan's avatar Wang Lvyuan
Browse files

merge: sync upstream/main before PR

parents 882518c1 9f6ab6b8
...@@ -113,15 +113,18 @@ func (s *AccountTestService) validateUpstreamBaseURL(raw string) (string, error) ...@@ -113,15 +113,18 @@ func (s *AccountTestService) validateUpstreamBaseURL(raw string) (string, error)
return normalized, nil return normalized, nil
} }
// generateSessionString generates a Claude Code style session string // generateSessionString generates a Claude Code style session string.
// The output format is determined by the UA version in claude.DefaultHeaders,
// ensuring consistency between the user_id format and the UA sent to upstream.
func generateSessionString() (string, error) { func generateSessionString() (string, error) {
bytes := make([]byte, 32) b := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil { if _, err := rand.Read(b); err != nil {
return "", err return "", err
} }
hex64 := hex.EncodeToString(bytes) hex64 := hex.EncodeToString(b)
sessionUUID := uuid.New().String() sessionUUID := uuid.New().String()
return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID), nil uaVersion := ExtractCLIVersion(claude.DefaultHeaders["User-Agent"])
return FormatMetadataUserID(hex64, "", sessionUUID, uaVersion), nil
} }
// createTestPayload creates a Claude Code style test request payload // createTestPayload creates a Claude Code style test request payload
......
...@@ -49,6 +49,7 @@ type UsageLogRepository interface { ...@@ -49,6 +49,7 @@ type UsageLogRepository interface {
GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error)
GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error)
GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error)
GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error)
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error)
......
...@@ -194,7 +194,7 @@ func (s *groupRepoStubForGroupUpdate) ListActiveByPlatform(context.Context, stri ...@@ -194,7 +194,7 @@ func (s *groupRepoStubForGroupUpdate) ListActiveByPlatform(context.Context, stri
func (s *groupRepoStubForGroupUpdate) ExistsByName(context.Context, string) (bool, error) { func (s *groupRepoStubForGroupUpdate) ExistsByName(context.Context, string) (bool, error) {
panic("unexpected") panic("unexpected")
} }
func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, error) { func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, int64, error) {
panic("unexpected") panic("unexpected")
} }
func (s *groupRepoStubForGroupUpdate) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) { func (s *groupRepoStubForGroupUpdate) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
......
...@@ -160,7 +160,7 @@ func (s *groupRepoStub) ExistsByName(ctx context.Context, name string) (bool, er ...@@ -160,7 +160,7 @@ func (s *groupRepoStub) ExistsByName(ctx context.Context, name string) (bool, er
panic("unexpected ExistsByName call") panic("unexpected ExistsByName call")
} }
func (s *groupRepoStub) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { func (s *groupRepoStub) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
panic("unexpected GetAccountCount call") panic("unexpected GetAccountCount call")
} }
......
...@@ -100,7 +100,7 @@ func (s *groupRepoStubForAdmin) ExistsByName(_ context.Context, _ string) (bool, ...@@ -100,7 +100,7 @@ func (s *groupRepoStubForAdmin) ExistsByName(_ context.Context, _ string) (bool,
panic("unexpected ExistsByName call") panic("unexpected ExistsByName call")
} }
func (s *groupRepoStubForAdmin) GetAccountCount(_ context.Context, _ int64) (int64, error) { func (s *groupRepoStubForAdmin) GetAccountCount(_ context.Context, _ int64) (int64, int64, error) {
panic("unexpected GetAccountCount call") panic("unexpected GetAccountCount call")
} }
...@@ -383,7 +383,7 @@ func (s *groupRepoStubForFallbackCycle) ExistsByName(_ context.Context, _ string ...@@ -383,7 +383,7 @@ func (s *groupRepoStubForFallbackCycle) ExistsByName(_ context.Context, _ string
panic("unexpected ExistsByName call") panic("unexpected ExistsByName call")
} }
func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, error) { func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, int64, error) {
panic("unexpected GetAccountCount call") panic("unexpected GetAccountCount call")
} }
...@@ -458,7 +458,7 @@ func (s *groupRepoStubForInvalidRequestFallback) ExistsByName(_ context.Context, ...@@ -458,7 +458,7 @@ func (s *groupRepoStubForInvalidRequestFallback) ExistsByName(_ context.Context,
panic("unexpected ExistsByName call") panic("unexpected ExistsByName call")
} }
func (s *groupRepoStubForInvalidRequestFallback) GetAccountCount(_ context.Context, _ int64) (int64, error) { func (s *groupRepoStubForInvalidRequestFallback) GetAccountCount(_ context.Context, _ int64) (int64, int64, error) {
panic("unexpected GetAccountCount call") panic("unexpected GetAccountCount call")
} }
......
...@@ -57,16 +57,16 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) { ...@@ -57,16 +57,16 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
expected: "claude-opus-4-6-thinking", expected: "claude-opus-4-6-thinking",
}, },
{ {
name: "默认映射 - claude-haiku-4-5 → claude-sonnet-4-5", name: "默认映射 - claude-haiku-4-5 → claude-sonnet-4-6",
requestedModel: "claude-haiku-4-5", requestedModel: "claude-haiku-4-5",
accountMapping: nil, accountMapping: nil,
expected: "claude-sonnet-4-5", expected: "claude-sonnet-4-6",
}, },
{ {
name: "默认映射 - claude-haiku-4-5-20251001 → claude-sonnet-4-5", name: "默认映射 - claude-haiku-4-5-20251001 → claude-sonnet-4-6",
requestedModel: "claude-haiku-4-5-20251001", requestedModel: "claude-haiku-4-5-20251001",
accountMapping: nil, accountMapping: nil,
expected: "claude-sonnet-4-5", expected: "claude-sonnet-4-6",
}, },
{ {
name: "默认映射 - claude-sonnet-4-5-20250929 → claude-sonnet-4-5", name: "默认映射 - claude-sonnet-4-5-20250929 → claude-sonnet-4-5",
......
...@@ -21,9 +21,6 @@ var ( ...@@ -21,9 +21,6 @@ var (
// 带捕获组的版本提取正则 // 带捕获组的版本提取正则
claudeCodeUAVersionPattern = regexp.MustCompile(`(?i)^claude-cli/(\d+\.\d+\.\d+)`) claudeCodeUAVersionPattern = regexp.MustCompile(`(?i)^claude-cli/(\d+\.\d+\.\d+)`)
// metadata.user_id 格式: user_{64位hex}_account__session_{uuid}
userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account__session_[\w-]+$`)
// System prompt 相似度阈值(默认 0.5,和 claude-relay-service 一致) // System prompt 相似度阈值(默认 0.5,和 claude-relay-service 一致)
systemPromptThreshold = 0.5 systemPromptThreshold = 0.5
) )
...@@ -124,7 +121,7 @@ func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) boo ...@@ -124,7 +121,7 @@ func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) boo
return false return false
} }
if !userIDPattern.MatchString(userID) { if ParseMetadataUserID(userID) == nil {
return false return false
} }
...@@ -278,11 +275,7 @@ func SetClaudeCodeClient(ctx context.Context, isClaudeCode bool) context.Context ...@@ -278,11 +275,7 @@ func SetClaudeCodeClient(ctx context.Context, isClaudeCode bool) context.Context
// ExtractVersion 从 User-Agent 中提取 Claude Code 版本号 // ExtractVersion 从 User-Agent 中提取 Claude Code 版本号
// 返回 "2.1.22" 形式的版本号,如果不匹配返回空字符串 // 返回 "2.1.22" 形式的版本号,如果不匹配返回空字符串
func (v *ClaudeCodeValidator) ExtractVersion(ua string) string { func (v *ClaudeCodeValidator) ExtractVersion(ua string) string {
matches := claudeCodeUAVersionPattern.FindStringSubmatch(ua) return ExtractCLIVersion(ua)
if len(matches) >= 2 {
return matches[1]
}
return ""
} }
// SetClaudeCodeVersion 将 Claude Code 版本号设置到 context 中 // SetClaudeCodeVersion 将 Claude Code 版本号设置到 context 中
......
...@@ -140,6 +140,27 @@ func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTi ...@@ -140,6 +140,27 @@ func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTi
return stats, nil return stats, nil
} }
func (s *DashboardService) GetModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, modelSource string) ([]usagestats.ModelStat, error) {
normalizedSource := usagestats.NormalizeModelSource(modelSource)
if normalizedSource == usagestats.ModelSourceRequested {
return s.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
}
type modelStatsBySourceRepo interface {
GetModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, source string) ([]usagestats.ModelStat, error)
}
if sourceRepo, ok := s.usageRepo.(modelStatsBySourceRepo); ok {
stats, err := sourceRepo.GetModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, normalizedSource)
if err != nil {
return nil, fmt.Errorf("get model stats with filters by source: %w", err)
}
return stats, nil
}
return s.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
}
func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) { func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
stats, err := s.usageRepo.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) stats, err := s.usageRepo.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
if err != nil { if err != nil {
...@@ -148,6 +169,15 @@ func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTi ...@@ -148,6 +169,15 @@ func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTi
return stats, nil return stats, nil
} }
// GetGroupUsageSummary returns today's and cumulative cost for all groups.
func (s *DashboardService) GetGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) {
results, err := s.usageRepo.GetAllGroupUsageSummary(ctx, todayStart)
if err != nil {
return nil, fmt.Errorf("get group usage summary: %w", err)
}
return results, nil
}
func (s *DashboardService) getCachedDashboardStats(ctx context.Context) (*usagestats.DashboardStats, bool, error) { func (s *DashboardService) getCachedDashboardStats(ctx context.Context) (*usagestats.DashboardStats, bool, error) {
data, err := s.cache.GetDashboardStats(ctx) data, err := s.cache.GetDashboardStats(ctx)
if err != nil { if err != nil {
......
...@@ -170,6 +170,13 @@ const ( ...@@ -170,6 +170,13 @@ const (
// SettingKeyOpsRuntimeLogConfig stores JSON config for runtime log settings. // SettingKeyOpsRuntimeLogConfig stores JSON config for runtime log settings.
SettingKeyOpsRuntimeLogConfig = "ops_runtime_log_config" SettingKeyOpsRuntimeLogConfig = "ops_runtime_log_config"
// =========================
// Overload Cooldown (529)
// =========================
// SettingKeyOverloadCooldownSettings stores JSON config for 529 overload cooldown handling.
SettingKeyOverloadCooldownSettings = "overload_cooldown_settings"
// ========================= // =========================
// Stream Timeout Handling // Stream Timeout Handling
// ========================= // =========================
......
...@@ -788,7 +788,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_NonStreamingSuc ...@@ -788,7 +788,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_NonStreamingSuc
rateLimitService: &RateLimitService{}, rateLimitService: &RateLimitService{},
} }
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), body, "claude-3-5-sonnet-latest", false, time.Now()) result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), body, "claude-3-5-sonnet-latest", "claude-3-5-sonnet-latest", false, time.Now())
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.Equal(t, 12, result.Usage.InputTokens) require.Equal(t, 12, result.Usage.InputTokens)
...@@ -815,7 +815,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_InvalidTokenTyp ...@@ -815,7 +815,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_InvalidTokenTyp
} }
svc := &GatewayService{} svc := &GatewayService{}
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{}`), "claude-3-5-sonnet-latest", false, time.Now()) result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{}`), "claude-3-5-sonnet-latest", "claude-3-5-sonnet-latest", false, time.Now())
require.Nil(t, result) require.Nil(t, result)
require.Error(t, err) require.Error(t, err)
require.Contains(t, err.Error(), "requires apikey token") require.Contains(t, err.Error(), "requires apikey token")
...@@ -840,7 +840,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_UpstreamRequest ...@@ -840,7 +840,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_UpstreamRequest
} }
account := newAnthropicAPIKeyAccountForTest() account := newAnthropicAPIKeyAccountForTest()
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{"model":"x"}`), "x", false, time.Now()) result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{"model":"x"}`), "x", "x", false, time.Now())
require.Nil(t, result) require.Nil(t, result)
require.Error(t, err) require.Error(t, err)
require.Contains(t, err.Error(), "upstream request failed") require.Contains(t, err.Error(), "upstream request failed")
...@@ -873,7 +873,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_EmptyResponseBo ...@@ -873,7 +873,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_EmptyResponseBo
httpUpstream: upstream, httpUpstream: upstream,
} }
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), []byte(`{"model":"x"}`), "x", false, time.Now()) result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), []byte(`{"model":"x"}`), "x", "x", false, time.Now())
require.Nil(t, result) require.Nil(t, result)
require.Error(t, err) require.Error(t, err)
require.Contains(t, err.Error(), "empty response") require.Contains(t, err.Error(), "empty response")
......
...@@ -278,8 +278,8 @@ func (m *mockGroupRepoForGateway) ListActiveByPlatform(ctx context.Context, plat ...@@ -278,8 +278,8 @@ func (m *mockGroupRepoForGateway) ListActiveByPlatform(ctx context.Context, plat
func (m *mockGroupRepoForGateway) ExistsByName(ctx context.Context, name string) (bool, error) { func (m *mockGroupRepoForGateway) ExistsByName(ctx context.Context, name string) (bool, error) {
return false, nil return false, nil
} }
func (m *mockGroupRepoForGateway) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { func (m *mockGroupRepoForGateway) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
return 0, nil return 0, 0, nil
} }
func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, nil return 0, nil
......
...@@ -28,6 +28,12 @@ var ( ...@@ -28,6 +28,12 @@ var (
patternEmptyContentSpaced = []byte(`"content": []`) patternEmptyContentSpaced = []byte(`"content": []`)
patternEmptyContentSp1 = []byte(`"content" : []`) patternEmptyContentSp1 = []byte(`"content" : []`)
patternEmptyContentSp2 = []byte(`"content" :[]`) patternEmptyContentSp2 = []byte(`"content" :[]`)
// Fast-path patterns for empty text blocks: {"type":"text","text":""}
patternEmptyText = []byte(`"text":""`)
patternEmptyTextSpaced = []byte(`"text": ""`)
patternEmptyTextSp1 = []byte(`"text" : ""`)
patternEmptyTextSp2 = []byte(`"text" :""`)
) )
// SessionContext 粘性会话上下文,用于区分不同来源的请求。 // SessionContext 粘性会话上下文,用于区分不同来源的请求。
...@@ -233,15 +239,22 @@ func FilterThinkingBlocksForRetry(body []byte) []byte { ...@@ -233,15 +239,22 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
bytes.Contains(body, patternThinkingField) || bytes.Contains(body, patternThinkingField) ||
bytes.Contains(body, patternThinkingFieldSpaced) bytes.Contains(body, patternThinkingFieldSpaced)
// Also check for empty content arrays that need fixing. // Also check for empty content arrays and empty text blocks that need fixing.
// Note: This is a heuristic check; the actual empty content handling is done below. // Note: This is a heuristic check; the actual empty content handling is done below.
hasEmptyContent := bytes.Contains(body, patternEmptyContent) || hasEmptyContent := bytes.Contains(body, patternEmptyContent) ||
bytes.Contains(body, patternEmptyContentSpaced) || bytes.Contains(body, patternEmptyContentSpaced) ||
bytes.Contains(body, patternEmptyContentSp1) || bytes.Contains(body, patternEmptyContentSp1) ||
bytes.Contains(body, patternEmptyContentSp2) bytes.Contains(body, patternEmptyContentSp2)
// Check for empty text blocks: {"type":"text","text":""}
// These cause upstream 400: "text content blocks must be non-empty"
hasEmptyTextBlock := bytes.Contains(body, patternEmptyText) ||
bytes.Contains(body, patternEmptyTextSpaced) ||
bytes.Contains(body, patternEmptyTextSp1) ||
bytes.Contains(body, patternEmptyTextSp2)
// Fast path: nothing to process // Fast path: nothing to process
if !hasThinkingContent && !hasEmptyContent { if !hasThinkingContent && !hasEmptyContent && !hasEmptyTextBlock {
return body return body
} }
...@@ -260,7 +273,7 @@ func FilterThinkingBlocksForRetry(body []byte) []byte { ...@@ -260,7 +273,7 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
bytes.Contains(body, patternTypeRedactedThinking) || bytes.Contains(body, patternTypeRedactedThinking) ||
bytes.Contains(body, patternTypeRedactedSpaced) || bytes.Contains(body, patternTypeRedactedSpaced) ||
bytes.Contains(body, patternThinkingFieldSpaced) bytes.Contains(body, patternThinkingFieldSpaced)
if !hasEmptyContent && !containsThinkingBlocks { if !hasEmptyContent && !hasEmptyTextBlock && !containsThinkingBlocks {
if topThinking := gjson.Get(jsonStr, "thinking"); topThinking.Exists() { if topThinking := gjson.Get(jsonStr, "thinking"); topThinking.Exists() {
if out, err := sjson.DeleteBytes(body, "thinking"); err == nil { if out, err := sjson.DeleteBytes(body, "thinking"); err == nil {
out = removeThinkingDependentContextStrategies(out) out = removeThinkingDependentContextStrategies(out)
...@@ -320,6 +333,16 @@ func FilterThinkingBlocksForRetry(body []byte) []byte { ...@@ -320,6 +333,16 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
blockType, _ := blockMap["type"].(string) blockType, _ := blockMap["type"].(string)
// Strip empty text blocks: {"type":"text","text":""}
// Upstream rejects these with 400: "text content blocks must be non-empty"
if blockType == "text" {
if txt, _ := blockMap["text"].(string); txt == "" {
modifiedThisMsg = true
ensureNewContent(bi)
continue
}
}
// Convert thinking blocks to text (preserve content) and drop redacted_thinking. // Convert thinking blocks to text (preserve content) and drop redacted_thinking.
switch blockType { switch blockType {
case "thinking": case "thinking":
......
...@@ -404,6 +404,51 @@ func TestFilterThinkingBlocksForRetry_EmptyContentGetsPlaceholder(t *testing.T) ...@@ -404,6 +404,51 @@ func TestFilterThinkingBlocksForRetry_EmptyContentGetsPlaceholder(t *testing.T)
require.NotEmpty(t, content0["text"]) require.NotEmpty(t, content0["text"])
} }
func TestFilterThinkingBlocksForRetry_StripsEmptyTextBlocks(t *testing.T) {
// Empty text blocks cause upstream 400: "text content blocks must be non-empty"
input := []byte(`{
"messages":[
{"role":"user","content":[{"type":"text","text":"hello"},{"type":"text","text":""}]},
{"role":"assistant","content":[{"type":"text","text":""}]}
]
}`)
out := FilterThinkingBlocksForRetry(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
msgs, ok := req["messages"].([]any)
require.True(t, ok)
// First message: empty text block stripped, "hello" preserved
msg0 := msgs[0].(map[string]any)
content0 := msg0["content"].([]any)
require.Len(t, content0, 1)
require.Equal(t, "hello", content0[0].(map[string]any)["text"])
// Second message: only had empty text block → gets placeholder
msg1 := msgs[1].(map[string]any)
content1 := msg1["content"].([]any)
require.Len(t, content1, 1)
block1 := content1[0].(map[string]any)
require.Equal(t, "text", block1["type"])
require.NotEmpty(t, block1["text"])
}
func TestFilterThinkingBlocksForRetry_PreservesNonEmptyTextBlocks(t *testing.T) {
// Non-empty text blocks should pass through unchanged
input := []byte(`{
"messages":[
{"role":"user","content":[{"type":"text","text":"hello"},{"type":"text","text":"world"}]}
]
}`)
out := FilterThinkingBlocksForRetry(input)
// Fast path: no thinking content, no empty content, no empty text blocks → unchanged
require.Equal(t, input, out)
}
func TestFilterSignatureSensitiveBlocksForRetry_DowngradesTools(t *testing.T) { func TestFilterSignatureSensitiveBlocksForRetry_DowngradesTools(t *testing.T) {
input := []byte(`{ input := []byte(`{
"thinking":{"type":"enabled","budget_tokens":1024}, "thinking":{"type":"enabled","budget_tokens":1024},
......
...@@ -326,7 +326,6 @@ func isClaudeCodeCredentialScopeError(msg string) bool { ...@@ -326,7 +326,6 @@ func isClaudeCodeCredentialScopeError(msg string) bool {
// Some upstream APIs return non-standard "data:" without space (should be "data: "). // Some upstream APIs return non-standard "data:" without space (should be "data: ").
var ( var (
sseDataRe = regexp.MustCompile(`^data:\s*`) sseDataRe = regexp.MustCompile(`^data:\s*`)
sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`)
claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`) claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`)
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表 // claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
...@@ -491,6 +490,7 @@ type ForwardResult struct { ...@@ -491,6 +490,7 @@ type ForwardResult struct {
RequestID string RequestID string
Usage ClaudeUsage Usage ClaudeUsage
Model string Model string
UpstreamModel string // Actual upstream model after mapping (empty = no mapping)
Stream bool Stream bool
Duration time.Duration Duration time.Duration
FirstTokenMs *int // 首字时间(流式请求) FirstTokenMs *int // 首字时间(流式请求)
...@@ -644,8 +644,8 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { ...@@ -644,8 +644,8 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
// 1. 最高优先级:从 metadata.user_id 提取 session_xxx // 1. 最高优先级:从 metadata.user_id 提取 session_xxx
if parsed.MetadataUserID != "" { if parsed.MetadataUserID != "" {
if match := sessionIDRegex.FindStringSubmatch(parsed.MetadataUserID); len(match) > 1 { if uid := ParseMetadataUserID(parsed.MetadataUserID); uid != nil && uid.SessionID != "" {
return match[1] return uid.SessionID
} }
} }
...@@ -1026,13 +1026,13 @@ func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account ...@@ -1026,13 +1026,13 @@ func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account
sessionID = generateSessionUUID(seed) sessionID = generateSessionUUID(seed)
} }
// Prefer the newer format that includes account_uuid (if present), // 根据指纹 UA 版本选择输出格式
// otherwise fall back to the legacy Claude Code format. var uaVersion string
accountUUID := strings.TrimSpace(account.GetExtraString("account_uuid")) if fp != nil {
if accountUUID != "" { uaVersion = ExtractCLIVersion(fp.UserAgent)
return fmt.Sprintf("user_%s_account_%s_session_%s", userID, accountUUID, sessionID)
} }
return fmt.Sprintf("user_%s_account__session_%s", userID, sessionID) accountUUID := strings.TrimSpace(account.GetExtraString("account_uuid"))
return FormatMetadataUserID(userID, accountUUID, sessionID, uaVersion)
} }
// GenerateSessionUUID creates a deterministic UUID4 from a seed string. // GenerateSessionUUID creates a deterministic UUID4 from a seed string.
...@@ -3989,7 +3989,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -3989,7 +3989,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
passthroughModel = mappedModel passthroughModel = mappedModel
} }
} }
return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody, passthroughModel, parsed.Stream, startTime) return s.forwardAnthropicAPIKeyPassthroughWithInput(ctx, c, account, anthropicPassthroughForwardInput{
Body: passthroughBody,
RequestModel: passthroughModel,
OriginalModel: parsed.Model,
RequestStream: parsed.Stream,
StartTime: startTime,
})
} }
if account != nil && account.IsBedrock() { if account != nil && account.IsBedrock() {
...@@ -4513,6 +4519,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -4513,6 +4519,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
RequestID: resp.Header.Get("x-request-id"), RequestID: resp.Header.Get("x-request-id"),
Usage: *usage, Usage: *usage,
Model: originalModel, // 使用原始模型用于计费和日志 Model: originalModel, // 使用原始模型用于计费和日志
UpstreamModel: mappedModel,
Stream: reqStream, Stream: reqStream,
Duration: time.Since(startTime), Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs, FirstTokenMs: firstTokenMs,
...@@ -4520,14 +4527,38 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -4520,14 +4527,38 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}, nil }, nil
} }
type anthropicPassthroughForwardInput struct {
Body []byte
RequestModel string
OriginalModel string
RequestStream bool
StartTime time.Time
}
func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( func (s *GatewayService) forwardAnthropicAPIKeyPassthrough(
ctx context.Context, ctx context.Context,
c *gin.Context, c *gin.Context,
account *Account, account *Account,
body []byte, body []byte,
reqModel string, reqModel string,
originalModel string,
reqStream bool, reqStream bool,
startTime time.Time, startTime time.Time,
) (*ForwardResult, error) {
return s.forwardAnthropicAPIKeyPassthroughWithInput(ctx, c, account, anthropicPassthroughForwardInput{
Body: body,
RequestModel: reqModel,
OriginalModel: originalModel,
RequestStream: reqStream,
StartTime: startTime,
})
}
func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput(
ctx context.Context,
c *gin.Context,
account *Account,
input anthropicPassthroughForwardInput,
) (*ForwardResult, error) { ) (*ForwardResult, error) {
token, tokenType, err := s.GetAccessToken(ctx, account) token, tokenType, err := s.GetAccessToken(ctx, account)
if err != nil { if err != nil {
...@@ -4543,19 +4574,19 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( ...@@ -4543,19 +4574,19 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough(
} }
logger.LegacyPrintf("service.gateway", "[Anthropic 自动透传] 命中 API Key 透传分支: account=%d name=%s model=%s stream=%v", logger.LegacyPrintf("service.gateway", "[Anthropic 自动透传] 命中 API Key 透传分支: account=%d name=%s model=%s stream=%v",
account.ID, account.Name, reqModel, reqStream) account.ID, account.Name, input.RequestModel, input.RequestStream)
if c != nil { if c != nil {
c.Set("anthropic_passthrough", true) c.Set("anthropic_passthrough", true)
} }
// 重试间复用同一请求体,避免每次 string(body) 产生额外分配。 // 重试间复用同一请求体,避免每次 string(body) 产生额外分配。
setOpsUpstreamRequestBody(c, body) setOpsUpstreamRequestBody(c, input.Body)
var resp *http.Response var resp *http.Response
retryStart := time.Now() retryStart := time.Now()
for attempt := 1; attempt <= maxRetryAttempts; attempt++ { for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, input.RequestStream)
upstreamReq, err := s.buildUpstreamRequestAnthropicAPIKeyPassthrough(upstreamCtx, c, account, body, token) upstreamReq, err := s.buildUpstreamRequestAnthropicAPIKeyPassthrough(upstreamCtx, c, account, input.Body, token)
releaseUpstreamCtx() releaseUpstreamCtx()
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -4713,8 +4744,8 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( ...@@ -4713,8 +4744,8 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough(
var usage *ClaudeUsage var usage *ClaudeUsage
var firstTokenMs *int var firstTokenMs *int
var clientDisconnect bool var clientDisconnect bool
if reqStream { if input.RequestStream {
streamResult, err := s.handleStreamingResponseAnthropicAPIKeyPassthrough(ctx, resp, c, account, startTime, reqModel) streamResult, err := s.handleStreamingResponseAnthropicAPIKeyPassthrough(ctx, resp, c, account, input.StartTime, input.RequestModel)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -4734,9 +4765,10 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( ...@@ -4734,9 +4765,10 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough(
return &ForwardResult{ return &ForwardResult{
RequestID: resp.Header.Get("x-request-id"), RequestID: resp.Header.Get("x-request-id"),
Usage: *usage, Usage: *usage,
Model: reqModel, Model: input.OriginalModel,
Stream: reqStream, UpstreamModel: input.RequestModel,
Duration: time.Since(startTime), Stream: input.RequestStream,
Duration: time.Since(input.StartTime),
FirstTokenMs: firstTokenMs, FirstTokenMs: firstTokenMs,
ClientDisconnect: clientDisconnect, ClientDisconnect: clientDisconnect,
}, nil }, nil
...@@ -5241,6 +5273,7 @@ func (s *GatewayService) forwardBedrock( ...@@ -5241,6 +5273,7 @@ func (s *GatewayService) forwardBedrock(
RequestID: resp.Header.Get("x-amzn-requestid"), RequestID: resp.Header.Get("x-amzn-requestid"),
Usage: *usage, Usage: *usage,
Model: reqModel, Model: reqModel,
UpstreamModel: mappedModel,
Stream: reqStream, Stream: reqStream,
Duration: time.Since(startTime), Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs, FirstTokenMs: firstTokenMs,
...@@ -5533,7 +5566,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex ...@@ -5533,7 +5566,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值 // 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
accountUUID := account.GetExtraString("account_uuid") accountUUID := account.GetExtraString("account_uuid")
if accountUUID != "" && fp.ClientID != "" { if accountUUID != "" && fp.ClientID != "" {
if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 { if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID, fp.UserAgent); err == nil && len(newBody) > 0 {
body = newBody body = newBody
} }
} }
...@@ -6068,9 +6101,11 @@ func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool { ...@@ -6068,9 +6101,11 @@ func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool {
return true return true
} }
// 检测空消息内容错误(可能是过滤 thinking blocks 后导致的) // 检测空消息内容错误(可能是过滤 thinking blocks 后导致的,或客户端发送了空 text block
// 例如: "all messages must have non-empty content" // 例如: "all messages must have non-empty content"
if strings.Contains(msg, "non-empty content") || strings.Contains(msg, "empty content") { // "messages: text content blocks must be non-empty"
if strings.Contains(msg, "non-empty content") || strings.Contains(msg, "empty content") ||
strings.Contains(msg, "content blocks must be non-empty") {
logger.LegacyPrintf("service.gateway", "[SignatureCheck] Detected empty content error") logger.LegacyPrintf("service.gateway", "[SignatureCheck] Detected empty content error")
return true return true
} }
...@@ -7529,6 +7564,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -7529,6 +7564,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
AccountID: account.ID, AccountID: account.ID,
RequestID: requestID, RequestID: requestID,
Model: result.Model, Model: result.Model,
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
ReasoningEffort: result.ReasoningEffort, ReasoningEffort: result.ReasoningEffort,
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint), UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint),
...@@ -7710,6 +7746,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * ...@@ -7710,6 +7746,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
AccountID: account.ID, AccountID: account.ID,
RequestID: requestID, RequestID: requestID,
Model: result.Model, Model: result.Model,
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
ReasoningEffort: result.ReasoningEffort, ReasoningEffort: result.ReasoningEffort,
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint), UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint),
...@@ -8161,7 +8198,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con ...@@ -8161,7 +8198,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
if err == nil { if err == nil {
accountUUID := account.GetExtraString("account_uuid") accountUUID := account.GetExtraString("account_uuid")
if accountUUID != "" && fp.ClientID != "" { if accountUUID != "" && fp.ClientID != "" {
if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 { if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID, fp.UserAgent); err == nil && len(newBody) > 0 {
body = newBody body = newBody
} }
} }
......
...@@ -230,8 +230,8 @@ func (m *mockGroupRepoForGemini) ListActiveByPlatform(ctx context.Context, platf ...@@ -230,8 +230,8 @@ func (m *mockGroupRepoForGemini) ListActiveByPlatform(ctx context.Context, platf
func (m *mockGroupRepoForGemini) ExistsByName(ctx context.Context, name string) (bool, error) { func (m *mockGroupRepoForGemini) ExistsByName(ctx context.Context, name string) (bool, error) {
return false, nil return false, nil
} }
func (m *mockGroupRepoForGemini) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { func (m *mockGroupRepoForGemini) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
return 0, nil return 0, 0, nil
} }
func (m *mockGroupRepoForGemini) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { func (m *mockGroupRepoForGemini) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, nil return 0, nil
......
...@@ -24,7 +24,7 @@ func TestGenerateSessionHash_MetadataHasHighestPriority(t *testing.T) { ...@@ -24,7 +24,7 @@ func TestGenerateSessionHash_MetadataHasHighestPriority(t *testing.T) {
svc := &GatewayService{} svc := &GatewayService{}
parsed := &ParsedRequest{ parsed := &ParsedRequest{
MetadataUserID: "session_123e4567-e89b-12d3-a456-426614174000", MetadataUserID: "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account__session_123e4567-e89b-12d3-a456-426614174000",
System: "You are a helpful assistant.", System: "You are a helpful assistant.",
HasSystem: true, HasSystem: true,
Messages: []any{ Messages: []any{
...@@ -196,7 +196,7 @@ func TestGenerateSessionHash_MetadataOverridesSessionContext(t *testing.T) { ...@@ -196,7 +196,7 @@ func TestGenerateSessionHash_MetadataOverridesSessionContext(t *testing.T) {
svc := &GatewayService{} svc := &GatewayService{}
parsed := &ParsedRequest{ parsed := &ParsedRequest{
MetadataUserID: "session_123e4567-e89b-12d3-a456-426614174000", MetadataUserID: "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account__session_123e4567-e89b-12d3-a456-426614174000",
Messages: []any{ Messages: []any{
map[string]any{"role": "user", "content": "hello"}, map[string]any{"role": "user", "content": "hello"},
}, },
...@@ -212,6 +212,22 @@ func TestGenerateSessionHash_MetadataOverridesSessionContext(t *testing.T) { ...@@ -212,6 +212,22 @@ func TestGenerateSessionHash_MetadataOverridesSessionContext(t *testing.T) {
"metadata session_id should take priority over SessionContext") "metadata session_id should take priority over SessionContext")
} }
func TestGenerateSessionHash_MetadataJSON_HasHighestPriority(t *testing.T) {
svc := &GatewayService{}
parsed := &ParsedRequest{
MetadataUserID: `{"device_id":"a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2","account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`,
System: "You are a helpful assistant.",
HasSystem: true,
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
hash := svc.GenerateSessionHash(parsed)
require.Equal(t, "c72554f2-1234-5678-abcd-123456789abc", hash, "JSON format metadata session_id should have highest priority")
}
func TestGenerateSessionHash_NilSessionContextBackwardCompatible(t *testing.T) { func TestGenerateSessionHash_NilSessionContextBackwardCompatible(t *testing.T) {
svc := &GatewayService{} svc := &GatewayService{}
......
...@@ -66,6 +66,8 @@ type Group struct { ...@@ -66,6 +66,8 @@ type Group struct {
AccountGroups []AccountGroup AccountGroups []AccountGroup
AccountCount int64 AccountCount int64
ActiveAccountCount int64
RateLimitedAccountCount int64
} }
func (g *Group) IsActive() bool { func (g *Group) IsActive() bool {
......
package service
import (
"context"
"time"
)
// GroupCapacitySummary holds aggregated capacity for a single group.
type GroupCapacitySummary struct {
GroupID int64 `json:"group_id"`
ConcurrencyUsed int `json:"concurrency_used"`
ConcurrencyMax int `json:"concurrency_max"`
SessionsUsed int `json:"sessions_used"`
SessionsMax int `json:"sessions_max"`
RPMUsed int `json:"rpm_used"`
RPMMax int `json:"rpm_max"`
}
// GroupCapacityService aggregates per-group capacity from runtime data.
type GroupCapacityService struct {
accountRepo AccountRepository
groupRepo GroupRepository
concurrencyService *ConcurrencyService
sessionLimitCache SessionLimitCache
rpmCache RPMCache
}
// NewGroupCapacityService creates a new GroupCapacityService.
func NewGroupCapacityService(
accountRepo AccountRepository,
groupRepo GroupRepository,
concurrencyService *ConcurrencyService,
sessionLimitCache SessionLimitCache,
rpmCache RPMCache,
) *GroupCapacityService {
return &GroupCapacityService{
accountRepo: accountRepo,
groupRepo: groupRepo,
concurrencyService: concurrencyService,
sessionLimitCache: sessionLimitCache,
rpmCache: rpmCache,
}
}
// GetAllGroupCapacity returns capacity summary for all active groups.
func (s *GroupCapacityService) GetAllGroupCapacity(ctx context.Context) ([]GroupCapacitySummary, error) {
groups, err := s.groupRepo.ListActive(ctx)
if err != nil {
return nil, err
}
results := make([]GroupCapacitySummary, 0, len(groups))
for i := range groups {
cap, err := s.getGroupCapacity(ctx, groups[i].ID)
if err != nil {
// Skip groups with errors, return partial results
continue
}
cap.GroupID = groups[i].ID
results = append(results, cap)
}
return results, nil
}
func (s *GroupCapacityService) getGroupCapacity(ctx context.Context, groupID int64) (GroupCapacitySummary, error) {
accounts, err := s.accountRepo.ListSchedulableByGroupID(ctx, groupID)
if err != nil {
return GroupCapacitySummary{}, err
}
if len(accounts) == 0 {
return GroupCapacitySummary{}, nil
}
// Collect account IDs and config values
accountIDs := make([]int64, 0, len(accounts))
sessionTimeouts := make(map[int64]time.Duration)
var concurrencyMax, sessionsMax, rpmMax int
for i := range accounts {
acc := &accounts[i]
accountIDs = append(accountIDs, acc.ID)
concurrencyMax += acc.Concurrency
if ms := acc.GetMaxSessions(); ms > 0 {
sessionsMax += ms
timeout := time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute
if timeout <= 0 {
timeout = 5 * time.Minute
}
sessionTimeouts[acc.ID] = timeout
}
if rpm := acc.GetBaseRPM(); rpm > 0 {
rpmMax += rpm
}
}
// Batch query runtime data from Redis
concurrencyMap, _ := s.concurrencyService.GetAccountConcurrencyBatch(ctx, accountIDs)
var sessionsMap map[int64]int
if sessionsMax > 0 && s.sessionLimitCache != nil {
sessionsMap, _ = s.sessionLimitCache.GetActiveSessionCountBatch(ctx, accountIDs, sessionTimeouts)
}
var rpmMap map[int64]int
if rpmMax > 0 && s.rpmCache != nil {
rpmMap, _ = s.rpmCache.GetRPMBatch(ctx, accountIDs)
}
// Aggregate
var concurrencyUsed, sessionsUsed, rpmUsed int
for _, id := range accountIDs {
concurrencyUsed += concurrencyMap[id]
if sessionsMap != nil {
sessionsUsed += sessionsMap[id]
}
if rpmMap != nil {
rpmUsed += rpmMap[id]
}
}
return GroupCapacitySummary{
ConcurrencyUsed: concurrencyUsed,
ConcurrencyMax: concurrencyMax,
SessionsUsed: sessionsUsed,
SessionsMax: sessionsMax,
RPMUsed: rpmUsed,
RPMMax: rpmMax,
}, nil
}
...@@ -27,7 +27,7 @@ type GroupRepository interface { ...@@ -27,7 +27,7 @@ type GroupRepository interface {
ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error) ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error)
ExistsByName(ctx context.Context, name string) (bool, error) ExistsByName(ctx context.Context, name string) (bool, error)
GetAccountCount(ctx context.Context, groupID int64) (int64, error) GetAccountCount(ctx context.Context, groupID int64) (total int64, active int64, err error)
DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error)
// GetAccountIDsByGroupIDs 获取多个分组的所有账号 ID(去重) // GetAccountIDsByGroupIDs 获取多个分组的所有账号 ID(去重)
GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error)
...@@ -202,7 +202,7 @@ func (s *GroupService) GetStats(ctx context.Context, id int64) (map[string]any, ...@@ -202,7 +202,7 @@ func (s *GroupService) GetStats(ctx context.Context, id int64) (map[string]any,
} }
// 获取账号数量 // 获取账号数量
accountCount, err := s.groupRepo.GetAccountCount(ctx, id) accountCount, _, err := s.groupRepo.GetAccountCount(ctx, id)
if err != nil { if err != nil {
return nil, fmt.Errorf("get account count: %w", err) return nil, fmt.Errorf("get account count: %w", err)
} }
......
...@@ -19,10 +19,6 @@ import ( ...@@ -19,10 +19,6 @@ import (
// 预编译正则表达式(避免每次调用重新编译) // 预编译正则表达式(避免每次调用重新编译)
var ( var (
// 匹配 user_id 格式:
// 旧格式: user_{64位hex}_account__session_{uuid} (account 后无 UUID)
// 新格式: user_{64位hex}_account_{uuid}_session_{uuid} (account 后有 UUID)
userIDRegex = regexp.MustCompile(`^user_[a-f0-9]{64}_account_([a-f0-9-]*)_session_([a-f0-9-]{36})$`)
// 匹配 User-Agent 版本号: xxx/x.y.z // 匹配 User-Agent 版本号: xxx/x.y.z
userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`) userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`)
) )
...@@ -209,12 +205,12 @@ func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) { ...@@ -209,12 +205,12 @@ func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) {
} }
// RewriteUserID 重写body中的metadata.user_id // RewriteUserID 重写body中的metadata.user_id
// 输入格式:user_{clientId}_account__session_{sessionUUID} // 支持旧拼接格式和新 JSON 格式的 user_id 解析,
// 输出格式:user_{cachedClientID}_account_{accountUUID}_session_{newHash} // 根据 fingerprintUA 版本选择输出格式。
// //
// 重要:此函数使用 json.RawMessage 保留其他字段的原始字节, // 重要:此函数使用 json.RawMessage 保留其他字段的原始字节,
// 避免重新序列化导致 thinking 块等内容被修改。 // 避免重新序列化导致 thinking 块等内容被修改。
func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUID, cachedClientID string) ([]byte, error) { func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUID, cachedClientID, fingerprintUA string) ([]byte, error) {
if len(body) == 0 || accountUUID == "" || cachedClientID == "" { if len(body) == 0 || accountUUID == "" || cachedClientID == "" {
return body, nil return body, nil
} }
...@@ -241,24 +237,21 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI ...@@ -241,24 +237,21 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
return body, nil return body, nil
} }
// 匹配格式: // 解析 user_id(兼容旧拼接格式和新 JSON 格式)
// 旧格式: user_{64位hex}_account__session_{uuid} parsed := ParseMetadataUserID(userID)
// 新格式: user_{64位hex}_account_{uuid}_session_{uuid} if parsed == nil {
matches := userIDRegex.FindStringSubmatch(userID)
if matches == nil {
return body, nil return body, nil
} }
// matches[1] = account UUID (可能为空), matches[2] = session UUID sessionTail := parsed.SessionID // 原始session UUID
sessionTail := matches[2] // 原始session UUID
// 生成新的session hash: SHA256(accountID::sessionTail) -> UUID格式 // 生成新的session hash: SHA256(accountID::sessionTail) -> UUID格式
seed := fmt.Sprintf("%d::%s", accountID, sessionTail) seed := fmt.Sprintf("%d::%s", accountID, sessionTail)
newSessionHash := generateUUIDFromSeed(seed) newSessionHash := generateUUIDFromSeed(seed)
// 构建新的user_id // 根据客户端版本选择输出格式
// 格式: user_{cachedClientID}_account_{account_uuid}_session_{newSessionHash} version := ExtractCLIVersion(fingerprintUA)
newUserID := fmt.Sprintf("user_%s_account_%s_session_%s", cachedClientID, accountUUID, newSessionHash) newUserID := FormatMetadataUserID(cachedClientID, accountUUID, newSessionHash, version)
metadata["user_id"] = newUserID metadata["user_id"] = newUserID
...@@ -278,9 +271,9 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI ...@@ -278,9 +271,9 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
// //
// 重要:此函数使用 json.RawMessage 保留其他字段的原始字节, // 重要:此函数使用 json.RawMessage 保留其他字段的原始字节,
// 避免重新序列化导致 thinking 块等内容被修改。 // 避免重新序列化导致 thinking 块等内容被修改。
func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []byte, account *Account, accountUUID, cachedClientID string) ([]byte, error) { func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []byte, account *Account, accountUUID, cachedClientID, fingerprintUA string) ([]byte, error) {
// 先执行常规的 RewriteUserID 逻辑 // 先执行常规的 RewriteUserID 逻辑
newBody, err := s.RewriteUserID(body, account.ID, accountUUID, cachedClientID) newBody, err := s.RewriteUserID(body, account.ID, accountUUID, cachedClientID, fingerprintUA)
if err != nil { if err != nil {
return newBody, err return newBody, err
} }
...@@ -312,10 +305,9 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b ...@@ -312,10 +305,9 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
return newBody, nil return newBody, nil
} }
// 查找 _session_ 的位置,替换其后的内容 // 解析已重写的 user_id
const sessionMarker = "_session_" uidParsed := ParseMetadataUserID(userID)
idx := strings.LastIndex(userID, sessionMarker) if uidParsed == nil {
if idx == -1 {
return newBody, nil return newBody, nil
} }
...@@ -337,8 +329,9 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b ...@@ -337,8 +329,9 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
logger.LegacyPrintf("service.identity", "Warning: failed to set masked session ID for account %d: %v", account.ID, err) logger.LegacyPrintf("service.identity", "Warning: failed to set masked session ID for account %d: %v", account.ID, err)
} }
// 替换 session 部分:保留 _session_ 之前的内容,替换之后的内容 // 用 FormatMetadataUserID 重建(保持与 RewriteUserID 相同的格式)
newUserID := userID[:idx+len(sessionMarker)] + maskedSessionID version := ExtractCLIVersion(fingerprintUA)
newUserID := FormatMetadataUserID(uidParsed.DeviceID, uidParsed.AccountUUID, maskedSessionID, version)
slog.Debug("session_id_masking_applied", slog.Debug("session_id_masking_applied",
"account_id", account.ID, "account_id", account.ID,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment