Commit cfaac12a authored by Ethan0x0000's avatar Ethan0x0000
Browse files

Merge upstream/main into pr/upstream-model-tracking

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent

)
Co-authored-by: default avatarSisyphus <clio-agent@sisyphuslabs.ai>
parents bd9d2671 21f349c0
...@@ -49,6 +49,7 @@ type UsageLogRepository interface { ...@@ -49,6 +49,7 @@ type UsageLogRepository interface {
GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error)
GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error)
GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error)
GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error)
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error)
......
...@@ -194,7 +194,7 @@ func (s *groupRepoStubForGroupUpdate) ListActiveByPlatform(context.Context, stri ...@@ -194,7 +194,7 @@ func (s *groupRepoStubForGroupUpdate) ListActiveByPlatform(context.Context, stri
func (s *groupRepoStubForGroupUpdate) ExistsByName(context.Context, string) (bool, error) { func (s *groupRepoStubForGroupUpdate) ExistsByName(context.Context, string) (bool, error) {
panic("unexpected") panic("unexpected")
} }
func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, error) { func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, int64, error) {
panic("unexpected") panic("unexpected")
} }
func (s *groupRepoStubForGroupUpdate) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) { func (s *groupRepoStubForGroupUpdate) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
......
...@@ -160,7 +160,7 @@ func (s *groupRepoStub) ExistsByName(ctx context.Context, name string) (bool, er ...@@ -160,7 +160,7 @@ func (s *groupRepoStub) ExistsByName(ctx context.Context, name string) (bool, er
panic("unexpected ExistsByName call") panic("unexpected ExistsByName call")
} }
func (s *groupRepoStub) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { func (s *groupRepoStub) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
panic("unexpected GetAccountCount call") panic("unexpected GetAccountCount call")
} }
......
...@@ -100,7 +100,7 @@ func (s *groupRepoStubForAdmin) ExistsByName(_ context.Context, _ string) (bool, ...@@ -100,7 +100,7 @@ func (s *groupRepoStubForAdmin) ExistsByName(_ context.Context, _ string) (bool,
panic("unexpected ExistsByName call") panic("unexpected ExistsByName call")
} }
func (s *groupRepoStubForAdmin) GetAccountCount(_ context.Context, _ int64) (int64, error) { func (s *groupRepoStubForAdmin) GetAccountCount(_ context.Context, _ int64) (int64, int64, error) {
panic("unexpected GetAccountCount call") panic("unexpected GetAccountCount call")
} }
...@@ -383,7 +383,7 @@ func (s *groupRepoStubForFallbackCycle) ExistsByName(_ context.Context, _ string ...@@ -383,7 +383,7 @@ func (s *groupRepoStubForFallbackCycle) ExistsByName(_ context.Context, _ string
panic("unexpected ExistsByName call") panic("unexpected ExistsByName call")
} }
func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, error) { func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, int64, error) {
panic("unexpected GetAccountCount call") panic("unexpected GetAccountCount call")
} }
...@@ -458,7 +458,7 @@ func (s *groupRepoStubForInvalidRequestFallback) ExistsByName(_ context.Context, ...@@ -458,7 +458,7 @@ func (s *groupRepoStubForInvalidRequestFallback) ExistsByName(_ context.Context,
panic("unexpected ExistsByName call") panic("unexpected ExistsByName call")
} }
func (s *groupRepoStubForInvalidRequestFallback) GetAccountCount(_ context.Context, _ int64) (int64, error) { func (s *groupRepoStubForInvalidRequestFallback) GetAccountCount(_ context.Context, _ int64) (int64, int64, error) {
panic("unexpected GetAccountCount call") panic("unexpected GetAccountCount call")
} }
......
...@@ -21,9 +21,6 @@ var ( ...@@ -21,9 +21,6 @@ var (
// 带捕获组的版本提取正则 // 带捕获组的版本提取正则
claudeCodeUAVersionPattern = regexp.MustCompile(`(?i)^claude-cli/(\d+\.\d+\.\d+)`) claudeCodeUAVersionPattern = regexp.MustCompile(`(?i)^claude-cli/(\d+\.\d+\.\d+)`)
// metadata.user_id 格式: user_{64位hex}_account__session_{uuid}
userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account__session_[\w-]+$`)
// System prompt 相似度阈值(默认 0.5,和 claude-relay-service 一致) // System prompt 相似度阈值(默认 0.5,和 claude-relay-service 一致)
systemPromptThreshold = 0.5 systemPromptThreshold = 0.5
) )
...@@ -124,7 +121,7 @@ func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) boo ...@@ -124,7 +121,7 @@ func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) boo
return false return false
} }
if !userIDPattern.MatchString(userID) { if ParseMetadataUserID(userID) == nil {
return false return false
} }
...@@ -278,11 +275,7 @@ func SetClaudeCodeClient(ctx context.Context, isClaudeCode bool) context.Context ...@@ -278,11 +275,7 @@ func SetClaudeCodeClient(ctx context.Context, isClaudeCode bool) context.Context
// ExtractVersion 从 User-Agent 中提取 Claude Code 版本号 // ExtractVersion 从 User-Agent 中提取 Claude Code 版本号
// 返回 "2.1.22" 形式的版本号,如果不匹配返回空字符串 // 返回 "2.1.22" 形式的版本号,如果不匹配返回空字符串
func (v *ClaudeCodeValidator) ExtractVersion(ua string) string { func (v *ClaudeCodeValidator) ExtractVersion(ua string) string {
matches := claudeCodeUAVersionPattern.FindStringSubmatch(ua) return ExtractCLIVersion(ua)
if len(matches) >= 2 {
return matches[1]
}
return ""
} }
// SetClaudeCodeVersion 将 Claude Code 版本号设置到 context 中 // SetClaudeCodeVersion 将 Claude Code 版本号设置到 context 中
......
...@@ -169,6 +169,15 @@ func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTi ...@@ -169,6 +169,15 @@ func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTi
return stats, nil return stats, nil
} }
// GetGroupUsageSummary returns today's and cumulative cost for all groups.
func (s *DashboardService) GetGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) {
results, err := s.usageRepo.GetAllGroupUsageSummary(ctx, todayStart)
if err != nil {
return nil, fmt.Errorf("get group usage summary: %w", err)
}
return results, nil
}
func (s *DashboardService) getCachedDashboardStats(ctx context.Context) (*usagestats.DashboardStats, bool, error) { func (s *DashboardService) getCachedDashboardStats(ctx context.Context) (*usagestats.DashboardStats, bool, error) {
data, err := s.cache.GetDashboardStats(ctx) data, err := s.cache.GetDashboardStats(ctx)
if err != nil { if err != nil {
......
...@@ -278,8 +278,8 @@ func (m *mockGroupRepoForGateway) ListActiveByPlatform(ctx context.Context, plat ...@@ -278,8 +278,8 @@ func (m *mockGroupRepoForGateway) ListActiveByPlatform(ctx context.Context, plat
func (m *mockGroupRepoForGateway) ExistsByName(ctx context.Context, name string) (bool, error) { func (m *mockGroupRepoForGateway) ExistsByName(ctx context.Context, name string) (bool, error) {
return false, nil return false, nil
} }
func (m *mockGroupRepoForGateway) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { func (m *mockGroupRepoForGateway) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
return 0, nil return 0, 0, nil
} }
func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, nil return 0, nil
......
...@@ -326,7 +326,6 @@ func isClaudeCodeCredentialScopeError(msg string) bool { ...@@ -326,7 +326,6 @@ func isClaudeCodeCredentialScopeError(msg string) bool {
// Some upstream APIs return non-standard "data:" without space (should be "data: "). // Some upstream APIs return non-standard "data:" without space (should be "data: ").
var ( var (
sseDataRe = regexp.MustCompile(`^data:\s*`) sseDataRe = regexp.MustCompile(`^data:\s*`)
sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`)
claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`) claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`)
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表 // claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
...@@ -645,8 +644,8 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { ...@@ -645,8 +644,8 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
// 1. 最高优先级:从 metadata.user_id 提取 session_xxx // 1. 最高优先级:从 metadata.user_id 提取 session_xxx
if parsed.MetadataUserID != "" { if parsed.MetadataUserID != "" {
if match := sessionIDRegex.FindStringSubmatch(parsed.MetadataUserID); len(match) > 1 { if uid := ParseMetadataUserID(parsed.MetadataUserID); uid != nil && uid.SessionID != "" {
return match[1] return uid.SessionID
} }
} }
...@@ -1027,13 +1026,13 @@ func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account ...@@ -1027,13 +1026,13 @@ func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account
sessionID = generateSessionUUID(seed) sessionID = generateSessionUUID(seed)
} }
// Prefer the newer format that includes account_uuid (if present), // 根据指纹 UA 版本选择输出格式
// otherwise fall back to the legacy Claude Code format. var uaVersion string
accountUUID := strings.TrimSpace(account.GetExtraString("account_uuid")) if fp != nil {
if accountUUID != "" { uaVersion = ExtractCLIVersion(fp.UserAgent)
return fmt.Sprintf("user_%s_account_%s_session_%s", userID, accountUUID, sessionID)
} }
return fmt.Sprintf("user_%s_account__session_%s", userID, sessionID) accountUUID := strings.TrimSpace(account.GetExtraString("account_uuid"))
return FormatMetadataUserID(userID, accountUUID, sessionID, uaVersion)
} }
// GenerateSessionUUID creates a deterministic UUID4 from a seed string. // GenerateSessionUUID creates a deterministic UUID4 from a seed string.
...@@ -5567,7 +5566,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex ...@@ -5567,7 +5566,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值 // 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
accountUUID := account.GetExtraString("account_uuid") accountUUID := account.GetExtraString("account_uuid")
if accountUUID != "" && fp.ClientID != "" { if accountUUID != "" && fp.ClientID != "" {
if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 { if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID, fp.UserAgent); err == nil && len(newBody) > 0 {
body = newBody body = newBody
} }
} }
...@@ -8197,7 +8196,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con ...@@ -8197,7 +8196,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
if err == nil { if err == nil {
accountUUID := account.GetExtraString("account_uuid") accountUUID := account.GetExtraString("account_uuid")
if accountUUID != "" && fp.ClientID != "" { if accountUUID != "" && fp.ClientID != "" {
if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 { if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID, fp.UserAgent); err == nil && len(newBody) > 0 {
body = newBody body = newBody
} }
} }
......
...@@ -230,8 +230,8 @@ func (m *mockGroupRepoForGemini) ListActiveByPlatform(ctx context.Context, platf ...@@ -230,8 +230,8 @@ func (m *mockGroupRepoForGemini) ListActiveByPlatform(ctx context.Context, platf
func (m *mockGroupRepoForGemini) ExistsByName(ctx context.Context, name string) (bool, error) { func (m *mockGroupRepoForGemini) ExistsByName(ctx context.Context, name string) (bool, error) {
return false, nil return false, nil
} }
func (m *mockGroupRepoForGemini) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { func (m *mockGroupRepoForGemini) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
return 0, nil return 0, 0, nil
} }
func (m *mockGroupRepoForGemini) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { func (m *mockGroupRepoForGemini) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, nil return 0, nil
......
...@@ -24,7 +24,7 @@ func TestGenerateSessionHash_MetadataHasHighestPriority(t *testing.T) { ...@@ -24,7 +24,7 @@ func TestGenerateSessionHash_MetadataHasHighestPriority(t *testing.T) {
svc := &GatewayService{} svc := &GatewayService{}
parsed := &ParsedRequest{ parsed := &ParsedRequest{
MetadataUserID: "session_123e4567-e89b-12d3-a456-426614174000", MetadataUserID: "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account__session_123e4567-e89b-12d3-a456-426614174000",
System: "You are a helpful assistant.", System: "You are a helpful assistant.",
HasSystem: true, HasSystem: true,
Messages: []any{ Messages: []any{
...@@ -196,7 +196,7 @@ func TestGenerateSessionHash_MetadataOverridesSessionContext(t *testing.T) { ...@@ -196,7 +196,7 @@ func TestGenerateSessionHash_MetadataOverridesSessionContext(t *testing.T) {
svc := &GatewayService{} svc := &GatewayService{}
parsed := &ParsedRequest{ parsed := &ParsedRequest{
MetadataUserID: "session_123e4567-e89b-12d3-a456-426614174000", MetadataUserID: "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account__session_123e4567-e89b-12d3-a456-426614174000",
Messages: []any{ Messages: []any{
map[string]any{"role": "user", "content": "hello"}, map[string]any{"role": "user", "content": "hello"},
}, },
...@@ -212,6 +212,22 @@ func TestGenerateSessionHash_MetadataOverridesSessionContext(t *testing.T) { ...@@ -212,6 +212,22 @@ func TestGenerateSessionHash_MetadataOverridesSessionContext(t *testing.T) {
"metadata session_id should take priority over SessionContext") "metadata session_id should take priority over SessionContext")
} }
func TestGenerateSessionHash_MetadataJSON_HasHighestPriority(t *testing.T) {
svc := &GatewayService{}
parsed := &ParsedRequest{
MetadataUserID: `{"device_id":"a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2","account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`,
System: "You are a helpful assistant.",
HasSystem: true,
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
hash := svc.GenerateSessionHash(parsed)
require.Equal(t, "c72554f2-1234-5678-abcd-123456789abc", hash, "JSON format metadata session_id should have highest priority")
}
func TestGenerateSessionHash_NilSessionContextBackwardCompatible(t *testing.T) { func TestGenerateSessionHash_NilSessionContextBackwardCompatible(t *testing.T) {
svc := &GatewayService{} svc := &GatewayService{}
......
...@@ -66,6 +66,8 @@ type Group struct { ...@@ -66,6 +66,8 @@ type Group struct {
AccountGroups []AccountGroup AccountGroups []AccountGroup
AccountCount int64 AccountCount int64
ActiveAccountCount int64
RateLimitedAccountCount int64
} }
func (g *Group) IsActive() bool { func (g *Group) IsActive() bool {
......
package service
import (
"context"
"time"
)
// GroupCapacitySummary holds aggregated capacity for a single group.
type GroupCapacitySummary struct {
GroupID int64 `json:"group_id"`
ConcurrencyUsed int `json:"concurrency_used"`
ConcurrencyMax int `json:"concurrency_max"`
SessionsUsed int `json:"sessions_used"`
SessionsMax int `json:"sessions_max"`
RPMUsed int `json:"rpm_used"`
RPMMax int `json:"rpm_max"`
}
// GroupCapacityService aggregates per-group capacity from runtime data.
type GroupCapacityService struct {
accountRepo AccountRepository
groupRepo GroupRepository
concurrencyService *ConcurrencyService
sessionLimitCache SessionLimitCache
rpmCache RPMCache
}
// NewGroupCapacityService creates a new GroupCapacityService.
func NewGroupCapacityService(
accountRepo AccountRepository,
groupRepo GroupRepository,
concurrencyService *ConcurrencyService,
sessionLimitCache SessionLimitCache,
rpmCache RPMCache,
) *GroupCapacityService {
return &GroupCapacityService{
accountRepo: accountRepo,
groupRepo: groupRepo,
concurrencyService: concurrencyService,
sessionLimitCache: sessionLimitCache,
rpmCache: rpmCache,
}
}
// GetAllGroupCapacity returns capacity summary for all active groups.
func (s *GroupCapacityService) GetAllGroupCapacity(ctx context.Context) ([]GroupCapacitySummary, error) {
groups, err := s.groupRepo.ListActive(ctx)
if err != nil {
return nil, err
}
results := make([]GroupCapacitySummary, 0, len(groups))
for i := range groups {
cap, err := s.getGroupCapacity(ctx, groups[i].ID)
if err != nil {
// Skip groups with errors, return partial results
continue
}
cap.GroupID = groups[i].ID
results = append(results, cap)
}
return results, nil
}
func (s *GroupCapacityService) getGroupCapacity(ctx context.Context, groupID int64) (GroupCapacitySummary, error) {
accounts, err := s.accountRepo.ListSchedulableByGroupID(ctx, groupID)
if err != nil {
return GroupCapacitySummary{}, err
}
if len(accounts) == 0 {
return GroupCapacitySummary{}, nil
}
// Collect account IDs and config values
accountIDs := make([]int64, 0, len(accounts))
sessionTimeouts := make(map[int64]time.Duration)
var concurrencyMax, sessionsMax, rpmMax int
for i := range accounts {
acc := &accounts[i]
accountIDs = append(accountIDs, acc.ID)
concurrencyMax += acc.Concurrency
if ms := acc.GetMaxSessions(); ms > 0 {
sessionsMax += ms
timeout := time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute
if timeout <= 0 {
timeout = 5 * time.Minute
}
sessionTimeouts[acc.ID] = timeout
}
if rpm := acc.GetBaseRPM(); rpm > 0 {
rpmMax += rpm
}
}
// Batch query runtime data from Redis
concurrencyMap, _ := s.concurrencyService.GetAccountConcurrencyBatch(ctx, accountIDs)
var sessionsMap map[int64]int
if sessionsMax > 0 && s.sessionLimitCache != nil {
sessionsMap, _ = s.sessionLimitCache.GetActiveSessionCountBatch(ctx, accountIDs, sessionTimeouts)
}
var rpmMap map[int64]int
if rpmMax > 0 && s.rpmCache != nil {
rpmMap, _ = s.rpmCache.GetRPMBatch(ctx, accountIDs)
}
// Aggregate
var concurrencyUsed, sessionsUsed, rpmUsed int
for _, id := range accountIDs {
concurrencyUsed += concurrencyMap[id]
if sessionsMap != nil {
sessionsUsed += sessionsMap[id]
}
if rpmMap != nil {
rpmUsed += rpmMap[id]
}
}
return GroupCapacitySummary{
ConcurrencyUsed: concurrencyUsed,
ConcurrencyMax: concurrencyMax,
SessionsUsed: sessionsUsed,
SessionsMax: sessionsMax,
RPMUsed: rpmUsed,
RPMMax: rpmMax,
}, nil
}
...@@ -27,7 +27,7 @@ type GroupRepository interface { ...@@ -27,7 +27,7 @@ type GroupRepository interface {
ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error) ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error)
ExistsByName(ctx context.Context, name string) (bool, error) ExistsByName(ctx context.Context, name string) (bool, error)
GetAccountCount(ctx context.Context, groupID int64) (int64, error) GetAccountCount(ctx context.Context, groupID int64) (total int64, active int64, err error)
DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error)
// GetAccountIDsByGroupIDs 获取多个分组的所有账号 ID(去重) // GetAccountIDsByGroupIDs 获取多个分组的所有账号 ID(去重)
GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error)
...@@ -202,7 +202,7 @@ func (s *GroupService) GetStats(ctx context.Context, id int64) (map[string]any, ...@@ -202,7 +202,7 @@ func (s *GroupService) GetStats(ctx context.Context, id int64) (map[string]any,
} }
// 获取账号数量 // 获取账号数量
accountCount, err := s.groupRepo.GetAccountCount(ctx, id) accountCount, _, err := s.groupRepo.GetAccountCount(ctx, id)
if err != nil { if err != nil {
return nil, fmt.Errorf("get account count: %w", err) return nil, fmt.Errorf("get account count: %w", err)
} }
......
...@@ -19,10 +19,6 @@ import ( ...@@ -19,10 +19,6 @@ import (
// 预编译正则表达式(避免每次调用重新编译) // 预编译正则表达式(避免每次调用重新编译)
var ( var (
// 匹配 user_id 格式:
// 旧格式: user_{64位hex}_account__session_{uuid} (account 后无 UUID)
// 新格式: user_{64位hex}_account_{uuid}_session_{uuid} (account 后有 UUID)
userIDRegex = regexp.MustCompile(`^user_[a-f0-9]{64}_account_([a-f0-9-]*)_session_([a-f0-9-]{36})$`)
// 匹配 User-Agent 版本号: xxx/x.y.z // 匹配 User-Agent 版本号: xxx/x.y.z
userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`) userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`)
) )
...@@ -209,12 +205,12 @@ func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) { ...@@ -209,12 +205,12 @@ func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) {
} }
// RewriteUserID 重写body中的metadata.user_id // RewriteUserID 重写body中的metadata.user_id
// 输入格式:user_{clientId}_account__session_{sessionUUID} // 支持旧拼接格式和新 JSON 格式的 user_id 解析,
// 输出格式:user_{cachedClientID}_account_{accountUUID}_session_{newHash} // 根据 fingerprintUA 版本选择输出格式。
// //
// 重要:此函数使用 json.RawMessage 保留其他字段的原始字节, // 重要:此函数使用 json.RawMessage 保留其他字段的原始字节,
// 避免重新序列化导致 thinking 块等内容被修改。 // 避免重新序列化导致 thinking 块等内容被修改。
func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUID, cachedClientID string) ([]byte, error) { func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUID, cachedClientID, fingerprintUA string) ([]byte, error) {
if len(body) == 0 || accountUUID == "" || cachedClientID == "" { if len(body) == 0 || accountUUID == "" || cachedClientID == "" {
return body, nil return body, nil
} }
...@@ -241,24 +237,21 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI ...@@ -241,24 +237,21 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
return body, nil return body, nil
} }
// 匹配格式: // 解析 user_id(兼容旧拼接格式和新 JSON 格式)
// 旧格式: user_{64位hex}_account__session_{uuid} parsed := ParseMetadataUserID(userID)
// 新格式: user_{64位hex}_account_{uuid}_session_{uuid} if parsed == nil {
matches := userIDRegex.FindStringSubmatch(userID)
if matches == nil {
return body, nil return body, nil
} }
// matches[1] = account UUID (可能为空), matches[2] = session UUID sessionTail := parsed.SessionID // 原始session UUID
sessionTail := matches[2] // 原始session UUID
// 生成新的session hash: SHA256(accountID::sessionTail) -> UUID格式 // 生成新的session hash: SHA256(accountID::sessionTail) -> UUID格式
seed := fmt.Sprintf("%d::%s", accountID, sessionTail) seed := fmt.Sprintf("%d::%s", accountID, sessionTail)
newSessionHash := generateUUIDFromSeed(seed) newSessionHash := generateUUIDFromSeed(seed)
// 构建新的user_id // 根据客户端版本选择输出格式
// 格式: user_{cachedClientID}_account_{account_uuid}_session_{newSessionHash} version := ExtractCLIVersion(fingerprintUA)
newUserID := fmt.Sprintf("user_%s_account_%s_session_%s", cachedClientID, accountUUID, newSessionHash) newUserID := FormatMetadataUserID(cachedClientID, accountUUID, newSessionHash, version)
metadata["user_id"] = newUserID metadata["user_id"] = newUserID
...@@ -278,9 +271,9 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI ...@@ -278,9 +271,9 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
// //
// 重要:此函数使用 json.RawMessage 保留其他字段的原始字节, // 重要:此函数使用 json.RawMessage 保留其他字段的原始字节,
// 避免重新序列化导致 thinking 块等内容被修改。 // 避免重新序列化导致 thinking 块等内容被修改。
func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []byte, account *Account, accountUUID, cachedClientID string) ([]byte, error) { func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []byte, account *Account, accountUUID, cachedClientID, fingerprintUA string) ([]byte, error) {
// 先执行常规的 RewriteUserID 逻辑 // 先执行常规的 RewriteUserID 逻辑
newBody, err := s.RewriteUserID(body, account.ID, accountUUID, cachedClientID) newBody, err := s.RewriteUserID(body, account.ID, accountUUID, cachedClientID, fingerprintUA)
if err != nil { if err != nil {
return newBody, err return newBody, err
} }
...@@ -312,10 +305,9 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b ...@@ -312,10 +305,9 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
return newBody, nil return newBody, nil
} }
// 查找 _session_ 的位置,替换其后的内容 // 解析已重写的 user_id
const sessionMarker = "_session_" uidParsed := ParseMetadataUserID(userID)
idx := strings.LastIndex(userID, sessionMarker) if uidParsed == nil {
if idx == -1 {
return newBody, nil return newBody, nil
} }
...@@ -337,8 +329,9 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b ...@@ -337,8 +329,9 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
logger.LegacyPrintf("service.identity", "Warning: failed to set masked session ID for account %d: %v", account.ID, err) logger.LegacyPrintf("service.identity", "Warning: failed to set masked session ID for account %d: %v", account.ID, err)
} }
// 替换 session 部分:保留 _session_ 之前的内容,替换之后的内容 // 用 FormatMetadataUserID 重建(保持与 RewriteUserID 相同的格式)
newUserID := userID[:idx+len(sessionMarker)] + maskedSessionID version := ExtractCLIVersion(fingerprintUA)
newUserID := FormatMetadataUserID(uidParsed.DeviceID, uidParsed.AccountUUID, maskedSessionID, version)
slog.Debug("session_id_masking_applied", slog.Debug("session_id_masking_applied",
"account_id", account.ID, "account_id", account.ID,
......
package service
import (
"encoding/json"
"regexp"
"strings"
)
// NewMetadataFormatMinVersion is the minimum Claude Code version that uses
// JSON-formatted metadata.user_id instead of the legacy concatenated string.
const NewMetadataFormatMinVersion = "2.1.78"
// ParsedUserID represents the components extracted from a metadata.user_id value.
type ParsedUserID struct {
DeviceID string // 64-char hex (or arbitrary client id)
AccountUUID string // may be empty
SessionID string // UUID
IsNewFormat bool // true if the original was JSON format
}
// legacyUserIDRegex matches the legacy user_id format:
//
// user_{64hex}_account_{optional_uuid}_session_{uuid}
var legacyUserIDRegex = regexp.MustCompile(`^user_([a-fA-F0-9]{64})_account_([a-fA-F0-9-]*)_session_([a-fA-F0-9-]{36})$`)
// jsonUserID is the JSON structure for the new metadata.user_id format.
type jsonUserID struct {
DeviceID string `json:"device_id"`
AccountUUID string `json:"account_uuid"`
SessionID string `json:"session_id"`
}
// ParseMetadataUserID parses a metadata.user_id string in either format.
// Returns nil if the input cannot be parsed.
func ParseMetadataUserID(raw string) *ParsedUserID {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil
}
// Try JSON format first (starts with '{')
if raw[0] == '{' {
var j jsonUserID
if err := json.Unmarshal([]byte(raw), &j); err != nil {
return nil
}
if j.DeviceID == "" || j.SessionID == "" {
return nil
}
return &ParsedUserID{
DeviceID: j.DeviceID,
AccountUUID: j.AccountUUID,
SessionID: j.SessionID,
IsNewFormat: true,
}
}
// Try legacy format
matches := legacyUserIDRegex.FindStringSubmatch(raw)
if matches == nil {
return nil
}
return &ParsedUserID{
DeviceID: matches[1],
AccountUUID: matches[2],
SessionID: matches[3],
IsNewFormat: false,
}
}
// FormatMetadataUserID builds a metadata.user_id string in the format
// appropriate for the given CLI version. Components are the rewritten values
// (not necessarily the originals).
func FormatMetadataUserID(deviceID, accountUUID, sessionID, uaVersion string) string {
if IsNewMetadataFormatVersion(uaVersion) {
b, _ := json.Marshal(jsonUserID{
DeviceID: deviceID,
AccountUUID: accountUUID,
SessionID: sessionID,
})
return string(b)
}
// Legacy format
return "user_" + deviceID + "_account_" + accountUUID + "_session_" + sessionID
}
// IsNewMetadataFormatVersion returns true if the given CLI version uses the
// new JSON metadata.user_id format (>= 2.1.78).
func IsNewMetadataFormatVersion(version string) bool {
if version == "" {
return false
}
return CompareVersions(version, NewMetadataFormatMinVersion) >= 0
}
// ExtractCLIVersion extracts the Claude Code version from a User-Agent string.
// Returns "" if the UA doesn't match the expected pattern.
func ExtractCLIVersion(ua string) string {
matches := claudeCodeUAVersionPattern.FindStringSubmatch(ua)
if len(matches) >= 2 {
return matches[1]
}
return ""
}
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
// ============ ParseMetadataUserID Tests ============
func TestParseMetadataUserID_LegacyFormat_WithoutAccountUUID(t *testing.T) {
raw := "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account__session_123e4567-e89b-12d3-a456-426614174000"
parsed := ParseMetadataUserID(raw)
require.NotNil(t, parsed)
require.Equal(t, "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2", parsed.DeviceID)
require.Equal(t, "", parsed.AccountUUID)
require.Equal(t, "123e4567-e89b-12d3-a456-426614174000", parsed.SessionID)
require.False(t, parsed.IsNewFormat)
}
func TestParseMetadataUserID_LegacyFormat_WithAccountUUID(t *testing.T) {
raw := "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account_550e8400-e29b-41d4-a716-446655440000_session_123e4567-e89b-12d3-a456-426614174000"
parsed := ParseMetadataUserID(raw)
require.NotNil(t, parsed)
require.Equal(t, "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2", parsed.DeviceID)
require.Equal(t, "550e8400-e29b-41d4-a716-446655440000", parsed.AccountUUID)
require.Equal(t, "123e4567-e89b-12d3-a456-426614174000", parsed.SessionID)
require.False(t, parsed.IsNewFormat)
}
func TestParseMetadataUserID_JSONFormat_WithoutAccountUUID(t *testing.T) {
raw := `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`
parsed := ParseMetadataUserID(raw)
require.NotNil(t, parsed)
require.Equal(t, "d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677", parsed.DeviceID)
require.Equal(t, "", parsed.AccountUUID)
require.Equal(t, "c72554f2-1234-5678-abcd-123456789abc", parsed.SessionID)
require.True(t, parsed.IsNewFormat)
}
func TestParseMetadataUserID_JSONFormat_WithAccountUUID(t *testing.T) {
raw := `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":"550e8400-e29b-41d4-a716-446655440000","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`
parsed := ParseMetadataUserID(raw)
require.NotNil(t, parsed)
require.Equal(t, "d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677", parsed.DeviceID)
require.Equal(t, "550e8400-e29b-41d4-a716-446655440000", parsed.AccountUUID)
require.Equal(t, "c72554f2-1234-5678-abcd-123456789abc", parsed.SessionID)
require.True(t, parsed.IsNewFormat)
}
func TestParseMetadataUserID_InvalidInputs(t *testing.T) {
tests := []struct {
name string
raw string
}{
{"empty string", ""},
{"whitespace only", " "},
{"random text", "not-a-valid-user-id"},
{"partial legacy format", "session_123e4567-e89b-12d3-a456-426614174000"},
{"invalid JSON", `{"device_id":}`},
{"JSON missing device_id", `{"account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`},
{"JSON missing session_id", `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":""}`},
{"JSON empty device_id", `{"device_id":"","account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`},
{"JSON empty session_id", `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":"","session_id":""}`},
{"legacy format short hex", "user_a1b2c3d4_account__session_123e4567-e89b-12d3-a456-426614174000"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Nil(t, ParseMetadataUserID(tt.raw), "should return nil for: %s", tt.raw)
})
}
}
func TestParseMetadataUserID_HexCaseInsensitive(t *testing.T) {
// Legacy format should accept both upper and lower case hex
rawUpper := "user_A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2_account__session_123e4567-e89b-12d3-a456-426614174000"
parsed := ParseMetadataUserID(rawUpper)
require.NotNil(t, parsed, "legacy format should accept uppercase hex")
require.Equal(t, "A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2", parsed.DeviceID)
}
// ============ FormatMetadataUserID Tests ============
func TestFormatMetadataUserID_LegacyVersion(t *testing.T) {
result := FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "acc-uuid", "sess-uuid", "2.1.77")
require.Equal(t, "user_deadbeef00112233445566778899aabbccddeeff0011223344556677_account_acc-uuid_session_sess-uuid", result)
}
func TestFormatMetadataUserID_NewVersion(t *testing.T) {
result := FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "acc-uuid", "sess-uuid", "2.1.78")
require.Equal(t, `{"device_id":"deadbeef00112233445566778899aabbccddeeff0011223344556677","account_uuid":"acc-uuid","session_id":"sess-uuid"}`, result)
}
func TestFormatMetadataUserID_EmptyVersion_Legacy(t *testing.T) {
result := FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "", "sess-uuid", "")
require.Equal(t, "user_deadbeef00112233445566778899aabbccddeeff0011223344556677_account__session_sess-uuid", result)
}
func TestFormatMetadataUserID_EmptyAccountUUID(t *testing.T) {
// Legacy format with empty account UUID → double underscore
result := FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "", "sess-uuid", "2.1.22")
require.Contains(t, result, "_account__session_")
// New format with empty account UUID → empty string in JSON
result = FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "", "sess-uuid", "2.1.78")
require.Contains(t, result, `"account_uuid":""`)
}
// ============ IsNewMetadataFormatVersion Tests ============
func TestIsNewMetadataFormatVersion(t *testing.T) {
tests := []struct {
version string
want bool
}{
{"", false},
{"2.1.77", false},
{"2.1.78", true},
{"2.1.79", true},
{"2.2.0", true},
{"3.0.0", true},
{"2.0.100", false},
{"1.9.99", false},
}
for _, tt := range tests {
t.Run(tt.version, func(t *testing.T) {
require.Equal(t, tt.want, IsNewMetadataFormatVersion(tt.version))
})
}
}
// ============ Round-trip Tests ============
func TestParseFormat_RoundTrip_Legacy(t *testing.T) {
deviceID := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
accountUUID := "550e8400-e29b-41d4-a716-446655440000"
sessionID := "123e4567-e89b-12d3-a456-426614174000"
formatted := FormatMetadataUserID(deviceID, accountUUID, sessionID, "2.1.22")
parsed := ParseMetadataUserID(formatted)
require.NotNil(t, parsed)
require.Equal(t, deviceID, parsed.DeviceID)
require.Equal(t, accountUUID, parsed.AccountUUID)
require.Equal(t, sessionID, parsed.SessionID)
require.False(t, parsed.IsNewFormat)
}
func TestParseFormat_RoundTrip_JSON(t *testing.T) {
deviceID := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
accountUUID := "550e8400-e29b-41d4-a716-446655440000"
sessionID := "123e4567-e89b-12d3-a456-426614174000"
formatted := FormatMetadataUserID(deviceID, accountUUID, sessionID, "2.1.78")
parsed := ParseMetadataUserID(formatted)
require.NotNil(t, parsed)
require.Equal(t, deviceID, parsed.DeviceID)
require.Equal(t, accountUUID, parsed.AccountUUID)
require.Equal(t, sessionID, parsed.SessionID)
require.True(t, parsed.IsNewFormat)
}
func TestParseFormat_RoundTrip_EmptyAccountUUID(t *testing.T) {
deviceID := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
sessionID := "123e4567-e89b-12d3-a456-426614174000"
// Legacy round-trip with empty account UUID
formatted := FormatMetadataUserID(deviceID, "", sessionID, "2.1.22")
parsed := ParseMetadataUserID(formatted)
require.NotNil(t, parsed)
require.Equal(t, deviceID, parsed.DeviceID)
require.Equal(t, "", parsed.AccountUUID)
require.Equal(t, sessionID, parsed.SessionID)
// JSON round-trip with empty account UUID
formatted = FormatMetadataUserID(deviceID, "", sessionID, "2.1.78")
parsed = ParseMetadataUserID(formatted)
require.NotNil(t, parsed)
require.Equal(t, deviceID, parsed.DeviceID)
require.Equal(t, "", parsed.AccountUUID)
require.Equal(t, sessionID, parsed.SessionID)
}
...@@ -52,8 +52,8 @@ func (r *stubGroupRepoForQuota) ListActiveByPlatform(context.Context, string) ([ ...@@ -52,8 +52,8 @@ func (r *stubGroupRepoForQuota) ListActiveByPlatform(context.Context, string) ([
func (r *stubGroupRepoForQuota) ExistsByName(context.Context, string) (bool, error) { func (r *stubGroupRepoForQuota) ExistsByName(context.Context, string) (bool, error) {
return false, nil return false, nil
} }
func (r *stubGroupRepoForQuota) GetAccountCount(context.Context, int64) (int64, error) { func (r *stubGroupRepoForQuota) GetAccountCount(context.Context, int64) (int64, int64, error) {
return 0, nil return 0, 0, nil
} }
func (r *stubGroupRepoForQuota) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) { func (r *stubGroupRepoForQuota) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
return 0, nil return 0, nil
......
...@@ -40,7 +40,7 @@ func (groupRepoNoop) ListActiveByPlatform(context.Context, string) ([]Group, err ...@@ -40,7 +40,7 @@ func (groupRepoNoop) ListActiveByPlatform(context.Context, string) ([]Group, err
func (groupRepoNoop) ExistsByName(context.Context, string) (bool, error) { func (groupRepoNoop) ExistsByName(context.Context, string) (bool, error) {
panic("unexpected ExistsByName call") panic("unexpected ExistsByName call")
} }
func (groupRepoNoop) GetAccountCount(context.Context, int64) (int64, error) { func (groupRepoNoop) GetAccountCount(context.Context, int64) (int64, int64, error) {
panic("unexpected GetAccountCount call") panic("unexpected GetAccountCount call")
} }
func (groupRepoNoop) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) { func (groupRepoNoop) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
...@@ -92,7 +92,7 @@ func (userSubRepoNoop) ListActiveByUserID(context.Context, int64) ([]UserSubscri ...@@ -92,7 +92,7 @@ func (userSubRepoNoop) ListActiveByUserID(context.Context, int64) ([]UserSubscri
func (userSubRepoNoop) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error) { func (userSubRepoNoop) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error) {
panic("unexpected ListByGroupID call") panic("unexpected ListByGroupID call")
} }
func (userSubRepoNoop) List(context.Context, pagination.PaginationParams, *int64, *int64, string, string, string) ([]UserSubscription, *pagination.PaginationResult, error) { func (userSubRepoNoop) List(context.Context, pagination.PaginationParams, *int64, *int64, string, string, string, string) ([]UserSubscription, *pagination.PaginationResult, error) {
panic("unexpected List call") panic("unexpected List call")
} }
func (userSubRepoNoop) ExistsByUserIDAndGroupID(context.Context, int64, int64) (bool, error) { func (userSubRepoNoop) ExistsByUserIDAndGroupID(context.Context, int64, int64) (bool, error) {
......
...@@ -634,9 +634,9 @@ func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupI ...@@ -634,9 +634,9 @@ func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupI
} }
// List 获取所有订阅(分页,支持筛选和排序) // List 获取所有订阅(分页,支持筛选和排序)
func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error) { func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
subs, pag, err := s.userSubRepo.List(ctx, params, userID, groupID, status, sortBy, sortOrder) subs, pag, err := s.userSubRepo.List(ctx, params, userID, groupID, status, platform, sortBy, sortOrder)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
......
...@@ -18,7 +18,7 @@ type UserSubscriptionRepository interface { ...@@ -18,7 +18,7 @@ type UserSubscriptionRepository interface {
ListByUserID(ctx context.Context, userID int64) ([]UserSubscription, error) ListByUserID(ctx context.Context, userID int64) ([]UserSubscription, error)
ListActiveByUserID(ctx context.Context, userID int64) ([]UserSubscription, error) ListActiveByUserID(ctx context.Context, userID int64) ([]UserSubscription, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error)
List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error)
ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error)
ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment