Commit cfaac12a authored by Ethan0x0000's avatar Ethan0x0000
Browse files

Merge upstream/main into pr/upstream-model-tracking

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent

)
Co-authored-by: default avatarSisyphus <clio-agent@sisyphuslabs.ai>
parents bd9d2671 21f349c0
......@@ -49,6 +49,7 @@ type UsageLogRepository interface {
GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error)
GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error)
GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error)
GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error)
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error)
......
......@@ -194,7 +194,7 @@ func (s *groupRepoStubForGroupUpdate) ListActiveByPlatform(context.Context, stri
func (s *groupRepoStubForGroupUpdate) ExistsByName(context.Context, string) (bool, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, error) {
func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, int64, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
......
......@@ -160,7 +160,7 @@ func (s *groupRepoStub) ExistsByName(ctx context.Context, name string) (bool, er
panic("unexpected ExistsByName call")
}
func (s *groupRepoStub) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
func (s *groupRepoStub) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
panic("unexpected GetAccountCount call")
}
......
......@@ -100,7 +100,7 @@ func (s *groupRepoStubForAdmin) ExistsByName(_ context.Context, _ string) (bool,
panic("unexpected ExistsByName call")
}
func (s *groupRepoStubForAdmin) GetAccountCount(_ context.Context, _ int64) (int64, error) {
func (s *groupRepoStubForAdmin) GetAccountCount(_ context.Context, _ int64) (int64, int64, error) {
panic("unexpected GetAccountCount call")
}
......@@ -383,7 +383,7 @@ func (s *groupRepoStubForFallbackCycle) ExistsByName(_ context.Context, _ string
panic("unexpected ExistsByName call")
}
func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, error) {
func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, int64, error) {
panic("unexpected GetAccountCount call")
}
......@@ -458,7 +458,7 @@ func (s *groupRepoStubForInvalidRequestFallback) ExistsByName(_ context.Context,
panic("unexpected ExistsByName call")
}
func (s *groupRepoStubForInvalidRequestFallback) GetAccountCount(_ context.Context, _ int64) (int64, error) {
func (s *groupRepoStubForInvalidRequestFallback) GetAccountCount(_ context.Context, _ int64) (int64, int64, error) {
panic("unexpected GetAccountCount call")
}
......
......@@ -21,9 +21,6 @@ var (
// 带捕获组的版本提取正则
claudeCodeUAVersionPattern = regexp.MustCompile(`(?i)^claude-cli/(\d+\.\d+\.\d+)`)
// metadata.user_id 格式: user_{64位hex}_account__session_{uuid}
userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account__session_[\w-]+$`)
// System prompt 相似度阈值(默认 0.5,和 claude-relay-service 一致)
systemPromptThreshold = 0.5
)
......@@ -124,7 +121,7 @@ func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) boo
return false
}
if !userIDPattern.MatchString(userID) {
if ParseMetadataUserID(userID) == nil {
return false
}
......@@ -278,11 +275,7 @@ func SetClaudeCodeClient(ctx context.Context, isClaudeCode bool) context.Context
// ExtractVersion 从 User-Agent 中提取 Claude Code 版本号
// 返回 "2.1.22" 形式的版本号,如果不匹配返回空字符串
func (v *ClaudeCodeValidator) ExtractVersion(ua string) string {
matches := claudeCodeUAVersionPattern.FindStringSubmatch(ua)
if len(matches) >= 2 {
return matches[1]
}
return ""
return ExtractCLIVersion(ua)
}
// SetClaudeCodeVersion 将 Claude Code 版本号设置到 context 中
......
......@@ -169,6 +169,15 @@ func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTi
return stats, nil
}
// GetGroupUsageSummary returns today's and cumulative cost for all groups.
func (s *DashboardService) GetGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) {
results, err := s.usageRepo.GetAllGroupUsageSummary(ctx, todayStart)
if err != nil {
return nil, fmt.Errorf("get group usage summary: %w", err)
}
return results, nil
}
func (s *DashboardService) getCachedDashboardStats(ctx context.Context) (*usagestats.DashboardStats, bool, error) {
data, err := s.cache.GetDashboardStats(ctx)
if err != nil {
......
......@@ -278,8 +278,8 @@ func (m *mockGroupRepoForGateway) ListActiveByPlatform(ctx context.Context, plat
func (m *mockGroupRepoForGateway) ExistsByName(ctx context.Context, name string) (bool, error) {
return false, nil
}
func (m *mockGroupRepoForGateway) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
return 0, nil
func (m *mockGroupRepoForGateway) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
return 0, 0, nil
}
func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, nil
......
......@@ -326,7 +326,6 @@ func isClaudeCodeCredentialScopeError(msg string) bool {
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var (
sseDataRe = regexp.MustCompile(`^data:\s*`)
sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`)
claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`)
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
......@@ -645,8 +644,8 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
// 1. 最高优先级:从 metadata.user_id 提取 session_xxx
if parsed.MetadataUserID != "" {
if match := sessionIDRegex.FindStringSubmatch(parsed.MetadataUserID); len(match) > 1 {
return match[1]
if uid := ParseMetadataUserID(parsed.MetadataUserID); uid != nil && uid.SessionID != "" {
return uid.SessionID
}
}
......@@ -1027,13 +1026,13 @@ func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account
sessionID = generateSessionUUID(seed)
}
// Prefer the newer format that includes account_uuid (if present),
// otherwise fall back to the legacy Claude Code format.
accountUUID := strings.TrimSpace(account.GetExtraString("account_uuid"))
if accountUUID != "" {
return fmt.Sprintf("user_%s_account_%s_session_%s", userID, accountUUID, sessionID)
// 根据指纹 UA 版本选择输出格式
var uaVersion string
if fp != nil {
uaVersion = ExtractCLIVersion(fp.UserAgent)
}
return fmt.Sprintf("user_%s_account__session_%s", userID, sessionID)
accountUUID := strings.TrimSpace(account.GetExtraString("account_uuid"))
return FormatMetadataUserID(userID, accountUUID, sessionID, uaVersion)
}
// GenerateSessionUUID creates a deterministic UUID4 from a seed string.
......@@ -5567,7 +5566,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
accountUUID := account.GetExtraString("account_uuid")
if accountUUID != "" && fp.ClientID != "" {
if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID, fp.UserAgent); err == nil && len(newBody) > 0 {
body = newBody
}
}
......@@ -8197,7 +8196,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
if err == nil {
accountUUID := account.GetExtraString("account_uuid")
if accountUUID != "" && fp.ClientID != "" {
if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID, fp.UserAgent); err == nil && len(newBody) > 0 {
body = newBody
}
}
......
......@@ -230,8 +230,8 @@ func (m *mockGroupRepoForGemini) ListActiveByPlatform(ctx context.Context, platf
func (m *mockGroupRepoForGemini) ExistsByName(ctx context.Context, name string) (bool, error) {
return false, nil
}
func (m *mockGroupRepoForGemini) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
return 0, nil
func (m *mockGroupRepoForGemini) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
return 0, 0, nil
}
func (m *mockGroupRepoForGemini) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, nil
......
......@@ -24,7 +24,7 @@ func TestGenerateSessionHash_MetadataHasHighestPriority(t *testing.T) {
svc := &GatewayService{}
parsed := &ParsedRequest{
MetadataUserID: "session_123e4567-e89b-12d3-a456-426614174000",
MetadataUserID: "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account__session_123e4567-e89b-12d3-a456-426614174000",
System: "You are a helpful assistant.",
HasSystem: true,
Messages: []any{
......@@ -196,7 +196,7 @@ func TestGenerateSessionHash_MetadataOverridesSessionContext(t *testing.T) {
svc := &GatewayService{}
parsed := &ParsedRequest{
MetadataUserID: "session_123e4567-e89b-12d3-a456-426614174000",
MetadataUserID: "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account__session_123e4567-e89b-12d3-a456-426614174000",
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
......@@ -212,6 +212,22 @@ func TestGenerateSessionHash_MetadataOverridesSessionContext(t *testing.T) {
"metadata session_id should take priority over SessionContext")
}
func TestGenerateSessionHash_MetadataJSON_HasHighestPriority(t *testing.T) {
svc := &GatewayService{}
parsed := &ParsedRequest{
MetadataUserID: `{"device_id":"a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2","account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`,
System: "You are a helpful assistant.",
HasSystem: true,
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
hash := svc.GenerateSessionHash(parsed)
require.Equal(t, "c72554f2-1234-5678-abcd-123456789abc", hash, "JSON format metadata session_id should have highest priority")
}
func TestGenerateSessionHash_NilSessionContextBackwardCompatible(t *testing.T) {
svc := &GatewayService{}
......
......@@ -64,8 +64,10 @@ type Group struct {
CreatedAt time.Time
UpdatedAt time.Time
AccountGroups []AccountGroup
AccountCount int64
AccountGroups []AccountGroup
AccountCount int64
ActiveAccountCount int64
RateLimitedAccountCount int64
}
func (g *Group) IsActive() bool {
......
package service
import (
"context"
"time"
)
// GroupCapacitySummary holds aggregated capacity for a single group.
type GroupCapacitySummary struct {
GroupID int64 `json:"group_id"`
ConcurrencyUsed int `json:"concurrency_used"`
ConcurrencyMax int `json:"concurrency_max"`
SessionsUsed int `json:"sessions_used"`
SessionsMax int `json:"sessions_max"`
RPMUsed int `json:"rpm_used"`
RPMMax int `json:"rpm_max"`
}
// GroupCapacityService aggregates per-group capacity from runtime data.
type GroupCapacityService struct {
accountRepo AccountRepository
groupRepo GroupRepository
concurrencyService *ConcurrencyService
sessionLimitCache SessionLimitCache
rpmCache RPMCache
}
// NewGroupCapacityService creates a new GroupCapacityService.
func NewGroupCapacityService(
accountRepo AccountRepository,
groupRepo GroupRepository,
concurrencyService *ConcurrencyService,
sessionLimitCache SessionLimitCache,
rpmCache RPMCache,
) *GroupCapacityService {
return &GroupCapacityService{
accountRepo: accountRepo,
groupRepo: groupRepo,
concurrencyService: concurrencyService,
sessionLimitCache: sessionLimitCache,
rpmCache: rpmCache,
}
}
// GetAllGroupCapacity returns capacity summary for all active groups.
func (s *GroupCapacityService) GetAllGroupCapacity(ctx context.Context) ([]GroupCapacitySummary, error) {
groups, err := s.groupRepo.ListActive(ctx)
if err != nil {
return nil, err
}
results := make([]GroupCapacitySummary, 0, len(groups))
for i := range groups {
cap, err := s.getGroupCapacity(ctx, groups[i].ID)
if err != nil {
// Skip groups with errors, return partial results
continue
}
cap.GroupID = groups[i].ID
results = append(results, cap)
}
return results, nil
}
func (s *GroupCapacityService) getGroupCapacity(ctx context.Context, groupID int64) (GroupCapacitySummary, error) {
accounts, err := s.accountRepo.ListSchedulableByGroupID(ctx, groupID)
if err != nil {
return GroupCapacitySummary{}, err
}
if len(accounts) == 0 {
return GroupCapacitySummary{}, nil
}
// Collect account IDs and config values
accountIDs := make([]int64, 0, len(accounts))
sessionTimeouts := make(map[int64]time.Duration)
var concurrencyMax, sessionsMax, rpmMax int
for i := range accounts {
acc := &accounts[i]
accountIDs = append(accountIDs, acc.ID)
concurrencyMax += acc.Concurrency
if ms := acc.GetMaxSessions(); ms > 0 {
sessionsMax += ms
timeout := time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute
if timeout <= 0 {
timeout = 5 * time.Minute
}
sessionTimeouts[acc.ID] = timeout
}
if rpm := acc.GetBaseRPM(); rpm > 0 {
rpmMax += rpm
}
}
// Batch query runtime data from Redis
concurrencyMap, _ := s.concurrencyService.GetAccountConcurrencyBatch(ctx, accountIDs)
var sessionsMap map[int64]int
if sessionsMax > 0 && s.sessionLimitCache != nil {
sessionsMap, _ = s.sessionLimitCache.GetActiveSessionCountBatch(ctx, accountIDs, sessionTimeouts)
}
var rpmMap map[int64]int
if rpmMax > 0 && s.rpmCache != nil {
rpmMap, _ = s.rpmCache.GetRPMBatch(ctx, accountIDs)
}
// Aggregate
var concurrencyUsed, sessionsUsed, rpmUsed int
for _, id := range accountIDs {
concurrencyUsed += concurrencyMap[id]
if sessionsMap != nil {
sessionsUsed += sessionsMap[id]
}
if rpmMap != nil {
rpmUsed += rpmMap[id]
}
}
return GroupCapacitySummary{
ConcurrencyUsed: concurrencyUsed,
ConcurrencyMax: concurrencyMax,
SessionsUsed: sessionsUsed,
SessionsMax: sessionsMax,
RPMUsed: rpmUsed,
RPMMax: rpmMax,
}, nil
}
......@@ -27,7 +27,7 @@ type GroupRepository interface {
ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error)
ExistsByName(ctx context.Context, name string) (bool, error)
GetAccountCount(ctx context.Context, groupID int64) (int64, error)
GetAccountCount(ctx context.Context, groupID int64) (total int64, active int64, err error)
DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error)
// GetAccountIDsByGroupIDs 获取多个分组的所有账号 ID(去重)
GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error)
......@@ -202,7 +202,7 @@ func (s *GroupService) GetStats(ctx context.Context, id int64) (map[string]any,
}
// 获取账号数量
accountCount, err := s.groupRepo.GetAccountCount(ctx, id)
accountCount, _, err := s.groupRepo.GetAccountCount(ctx, id)
if err != nil {
return nil, fmt.Errorf("get account count: %w", err)
}
......
......@@ -19,10 +19,6 @@ import (
// 预编译正则表达式(避免每次调用重新编译)
var (
// 匹配 user_id 格式:
// 旧格式: user_{64位hex}_account__session_{uuid} (account 后无 UUID)
// 新格式: user_{64位hex}_account_{uuid}_session_{uuid} (account 后有 UUID)
userIDRegex = regexp.MustCompile(`^user_[a-f0-9]{64}_account_([a-f0-9-]*)_session_([a-f0-9-]{36})$`)
// 匹配 User-Agent 版本号: xxx/x.y.z
userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`)
)
......@@ -209,12 +205,12 @@ func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) {
}
// RewriteUserID 重写body中的metadata.user_id
// 输入格式:user_{clientId}_account__session_{sessionUUID}
// 输出格式:user_{cachedClientID}_account_{accountUUID}_session_{newHash}
// 支持旧拼接格式和新 JSON 格式的 user_id 解析,
// 根据 fingerprintUA 版本选择输出格式。
//
// 重要:此函数使用 json.RawMessage 保留其他字段的原始字节,
// 避免重新序列化导致 thinking 块等内容被修改。
func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUID, cachedClientID string) ([]byte, error) {
func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUID, cachedClientID, fingerprintUA string) ([]byte, error) {
if len(body) == 0 || accountUUID == "" || cachedClientID == "" {
return body, nil
}
......@@ -241,24 +237,21 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
return body, nil
}
// 匹配格式:
// 旧格式: user_{64位hex}_account__session_{uuid}
// 新格式: user_{64位hex}_account_{uuid}_session_{uuid}
matches := userIDRegex.FindStringSubmatch(userID)
if matches == nil {
// 解析 user_id(兼容旧拼接格式和新 JSON 格式)
parsed := ParseMetadataUserID(userID)
if parsed == nil {
return body, nil
}
// matches[1] = account UUID (可能为空), matches[2] = session UUID
sessionTail := matches[2] // 原始session UUID
sessionTail := parsed.SessionID // 原始session UUID
// 生成新的session hash: SHA256(accountID::sessionTail) -> UUID格式
seed := fmt.Sprintf("%d::%s", accountID, sessionTail)
newSessionHash := generateUUIDFromSeed(seed)
// 构建新的user_id
// 格式: user_{cachedClientID}_account_{account_uuid}_session_{newSessionHash}
newUserID := fmt.Sprintf("user_%s_account_%s_session_%s", cachedClientID, accountUUID, newSessionHash)
// 根据客户端版本选择输出格式
version := ExtractCLIVersion(fingerprintUA)
newUserID := FormatMetadataUserID(cachedClientID, accountUUID, newSessionHash, version)
metadata["user_id"] = newUserID
......@@ -278,9 +271,9 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
//
// 重要:此函数使用 json.RawMessage 保留其他字段的原始字节,
// 避免重新序列化导致 thinking 块等内容被修改。
func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []byte, account *Account, accountUUID, cachedClientID string) ([]byte, error) {
func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []byte, account *Account, accountUUID, cachedClientID, fingerprintUA string) ([]byte, error) {
// 先执行常规的 RewriteUserID 逻辑
newBody, err := s.RewriteUserID(body, account.ID, accountUUID, cachedClientID)
newBody, err := s.RewriteUserID(body, account.ID, accountUUID, cachedClientID, fingerprintUA)
if err != nil {
return newBody, err
}
......@@ -312,10 +305,9 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
return newBody, nil
}
// 查找 _session_ 的位置,替换其后的内容
const sessionMarker = "_session_"
idx := strings.LastIndex(userID, sessionMarker)
if idx == -1 {
// 解析已重写的 user_id
uidParsed := ParseMetadataUserID(userID)
if uidParsed == nil {
return newBody, nil
}
......@@ -337,8 +329,9 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
logger.LegacyPrintf("service.identity", "Warning: failed to set masked session ID for account %d: %v", account.ID, err)
}
// 替换 session 部分:保留 _session_ 之前的内容,替换之后的内容
newUserID := userID[:idx+len(sessionMarker)] + maskedSessionID
// 用 FormatMetadataUserID 重建(保持与 RewriteUserID 相同的格式)
version := ExtractCLIVersion(fingerprintUA)
newUserID := FormatMetadataUserID(uidParsed.DeviceID, uidParsed.AccountUUID, maskedSessionID, version)
slog.Debug("session_id_masking_applied",
"account_id", account.ID,
......
package service
import (
"encoding/json"
"regexp"
"strings"
)
// NewMetadataFormatMinVersion is the minimum Claude Code version that uses
// JSON-formatted metadata.user_id instead of the legacy concatenated string.
const NewMetadataFormatMinVersion = "2.1.78"
// ParsedUserID represents the components extracted from a metadata.user_id value.
type ParsedUserID struct {
DeviceID string // 64-char hex (or arbitrary client id)
AccountUUID string // may be empty
SessionID string // UUID
IsNewFormat bool // true if the original was JSON format
}
// legacyUserIDRegex matches the legacy user_id format:
//
// user_{64hex}_account_{optional_uuid}_session_{uuid}
var legacyUserIDRegex = regexp.MustCompile(`^user_([a-fA-F0-9]{64})_account_([a-fA-F0-9-]*)_session_([a-fA-F0-9-]{36})$`)
// jsonUserID is the JSON structure for the new metadata.user_id format.
type jsonUserID struct {
DeviceID string `json:"device_id"`
AccountUUID string `json:"account_uuid"`
SessionID string `json:"session_id"`
}
// ParseMetadataUserID parses a metadata.user_id string in either format.
// Returns nil if the input cannot be parsed.
func ParseMetadataUserID(raw string) *ParsedUserID {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil
}
// Try JSON format first (starts with '{')
if raw[0] == '{' {
var j jsonUserID
if err := json.Unmarshal([]byte(raw), &j); err != nil {
return nil
}
if j.DeviceID == "" || j.SessionID == "" {
return nil
}
return &ParsedUserID{
DeviceID: j.DeviceID,
AccountUUID: j.AccountUUID,
SessionID: j.SessionID,
IsNewFormat: true,
}
}
// Try legacy format
matches := legacyUserIDRegex.FindStringSubmatch(raw)
if matches == nil {
return nil
}
return &ParsedUserID{
DeviceID: matches[1],
AccountUUID: matches[2],
SessionID: matches[3],
IsNewFormat: false,
}
}
// FormatMetadataUserID builds a metadata.user_id string in the format
// appropriate for the given CLI version. Components are the rewritten values
// (not necessarily the originals).
func FormatMetadataUserID(deviceID, accountUUID, sessionID, uaVersion string) string {
if IsNewMetadataFormatVersion(uaVersion) {
b, _ := json.Marshal(jsonUserID{
DeviceID: deviceID,
AccountUUID: accountUUID,
SessionID: sessionID,
})
return string(b)
}
// Legacy format
return "user_" + deviceID + "_account_" + accountUUID + "_session_" + sessionID
}
// IsNewMetadataFormatVersion returns true if the given CLI version uses the
// new JSON metadata.user_id format (>= 2.1.78).
func IsNewMetadataFormatVersion(version string) bool {
if version == "" {
return false
}
return CompareVersions(version, NewMetadataFormatMinVersion) >= 0
}
// ExtractCLIVersion extracts the Claude Code version from a User-Agent string.
// Returns "" if the UA doesn't match the expected pattern.
func ExtractCLIVersion(ua string) string {
matches := claudeCodeUAVersionPattern.FindStringSubmatch(ua)
if len(matches) >= 2 {
return matches[1]
}
return ""
}
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
// ============ ParseMetadataUserID Tests ============
func TestParseMetadataUserID_LegacyFormat_WithoutAccountUUID(t *testing.T) {
raw := "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account__session_123e4567-e89b-12d3-a456-426614174000"
parsed := ParseMetadataUserID(raw)
require.NotNil(t, parsed)
require.Equal(t, "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2", parsed.DeviceID)
require.Equal(t, "", parsed.AccountUUID)
require.Equal(t, "123e4567-e89b-12d3-a456-426614174000", parsed.SessionID)
require.False(t, parsed.IsNewFormat)
}
func TestParseMetadataUserID_LegacyFormat_WithAccountUUID(t *testing.T) {
raw := "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account_550e8400-e29b-41d4-a716-446655440000_session_123e4567-e89b-12d3-a456-426614174000"
parsed := ParseMetadataUserID(raw)
require.NotNil(t, parsed)
require.Equal(t, "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2", parsed.DeviceID)
require.Equal(t, "550e8400-e29b-41d4-a716-446655440000", parsed.AccountUUID)
require.Equal(t, "123e4567-e89b-12d3-a456-426614174000", parsed.SessionID)
require.False(t, parsed.IsNewFormat)
}
func TestParseMetadataUserID_JSONFormat_WithoutAccountUUID(t *testing.T) {
raw := `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`
parsed := ParseMetadataUserID(raw)
require.NotNil(t, parsed)
require.Equal(t, "d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677", parsed.DeviceID)
require.Equal(t, "", parsed.AccountUUID)
require.Equal(t, "c72554f2-1234-5678-abcd-123456789abc", parsed.SessionID)
require.True(t, parsed.IsNewFormat)
}
func TestParseMetadataUserID_JSONFormat_WithAccountUUID(t *testing.T) {
raw := `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":"550e8400-e29b-41d4-a716-446655440000","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`
parsed := ParseMetadataUserID(raw)
require.NotNil(t, parsed)
require.Equal(t, "d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677", parsed.DeviceID)
require.Equal(t, "550e8400-e29b-41d4-a716-446655440000", parsed.AccountUUID)
require.Equal(t, "c72554f2-1234-5678-abcd-123456789abc", parsed.SessionID)
require.True(t, parsed.IsNewFormat)
}
func TestParseMetadataUserID_InvalidInputs(t *testing.T) {
tests := []struct {
name string
raw string
}{
{"empty string", ""},
{"whitespace only", " "},
{"random text", "not-a-valid-user-id"},
{"partial legacy format", "session_123e4567-e89b-12d3-a456-426614174000"},
{"invalid JSON", `{"device_id":}`},
{"JSON missing device_id", `{"account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`},
{"JSON missing session_id", `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":""}`},
{"JSON empty device_id", `{"device_id":"","account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`},
{"JSON empty session_id", `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":"","session_id":""}`},
{"legacy format short hex", "user_a1b2c3d4_account__session_123e4567-e89b-12d3-a456-426614174000"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Nil(t, ParseMetadataUserID(tt.raw), "should return nil for: %s", tt.raw)
})
}
}
func TestParseMetadataUserID_HexCaseInsensitive(t *testing.T) {
// Legacy format should accept both upper and lower case hex
rawUpper := "user_A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2_account__session_123e4567-e89b-12d3-a456-426614174000"
parsed := ParseMetadataUserID(rawUpper)
require.NotNil(t, parsed, "legacy format should accept uppercase hex")
require.Equal(t, "A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2", parsed.DeviceID)
}
// ============ FormatMetadataUserID Tests ============
func TestFormatMetadataUserID_LegacyVersion(t *testing.T) {
result := FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "acc-uuid", "sess-uuid", "2.1.77")
require.Equal(t, "user_deadbeef00112233445566778899aabbccddeeff0011223344556677_account_acc-uuid_session_sess-uuid", result)
}
func TestFormatMetadataUserID_NewVersion(t *testing.T) {
result := FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "acc-uuid", "sess-uuid", "2.1.78")
require.Equal(t, `{"device_id":"deadbeef00112233445566778899aabbccddeeff0011223344556677","account_uuid":"acc-uuid","session_id":"sess-uuid"}`, result)
}
func TestFormatMetadataUserID_EmptyVersion_Legacy(t *testing.T) {
result := FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "", "sess-uuid", "")
require.Equal(t, "user_deadbeef00112233445566778899aabbccddeeff0011223344556677_account__session_sess-uuid", result)
}
func TestFormatMetadataUserID_EmptyAccountUUID(t *testing.T) {
// Legacy format with empty account UUID → double underscore
result := FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "", "sess-uuid", "2.1.22")
require.Contains(t, result, "_account__session_")
// New format with empty account UUID → empty string in JSON
result = FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "", "sess-uuid", "2.1.78")
require.Contains(t, result, `"account_uuid":""`)
}
// ============ IsNewMetadataFormatVersion Tests ============
func TestIsNewMetadataFormatVersion(t *testing.T) {
tests := []struct {
version string
want bool
}{
{"", false},
{"2.1.77", false},
{"2.1.78", true},
{"2.1.79", true},
{"2.2.0", true},
{"3.0.0", true},
{"2.0.100", false},
{"1.9.99", false},
}
for _, tt := range tests {
t.Run(tt.version, func(t *testing.T) {
require.Equal(t, tt.want, IsNewMetadataFormatVersion(tt.version))
})
}
}
// ============ Round-trip Tests ============
func TestParseFormat_RoundTrip_Legacy(t *testing.T) {
deviceID := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
accountUUID := "550e8400-e29b-41d4-a716-446655440000"
sessionID := "123e4567-e89b-12d3-a456-426614174000"
formatted := FormatMetadataUserID(deviceID, accountUUID, sessionID, "2.1.22")
parsed := ParseMetadataUserID(formatted)
require.NotNil(t, parsed)
require.Equal(t, deviceID, parsed.DeviceID)
require.Equal(t, accountUUID, parsed.AccountUUID)
require.Equal(t, sessionID, parsed.SessionID)
require.False(t, parsed.IsNewFormat)
}
func TestParseFormat_RoundTrip_JSON(t *testing.T) {
deviceID := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
accountUUID := "550e8400-e29b-41d4-a716-446655440000"
sessionID := "123e4567-e89b-12d3-a456-426614174000"
formatted := FormatMetadataUserID(deviceID, accountUUID, sessionID, "2.1.78")
parsed := ParseMetadataUserID(formatted)
require.NotNil(t, parsed)
require.Equal(t, deviceID, parsed.DeviceID)
require.Equal(t, accountUUID, parsed.AccountUUID)
require.Equal(t, sessionID, parsed.SessionID)
require.True(t, parsed.IsNewFormat)
}
func TestParseFormat_RoundTrip_EmptyAccountUUID(t *testing.T) {
deviceID := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
sessionID := "123e4567-e89b-12d3-a456-426614174000"
// Legacy round-trip with empty account UUID
formatted := FormatMetadataUserID(deviceID, "", sessionID, "2.1.22")
parsed := ParseMetadataUserID(formatted)
require.NotNil(t, parsed)
require.Equal(t, deviceID, parsed.DeviceID)
require.Equal(t, "", parsed.AccountUUID)
require.Equal(t, sessionID, parsed.SessionID)
// JSON round-trip with empty account UUID
formatted = FormatMetadataUserID(deviceID, "", sessionID, "2.1.78")
parsed = ParseMetadataUserID(formatted)
require.NotNil(t, parsed)
require.Equal(t, deviceID, parsed.DeviceID)
require.Equal(t, "", parsed.AccountUUID)
require.Equal(t, sessionID, parsed.SessionID)
}
......@@ -52,8 +52,8 @@ func (r *stubGroupRepoForQuota) ListActiveByPlatform(context.Context, string) ([
func (r *stubGroupRepoForQuota) ExistsByName(context.Context, string) (bool, error) {
return false, nil
}
func (r *stubGroupRepoForQuota) GetAccountCount(context.Context, int64) (int64, error) {
return 0, nil
func (r *stubGroupRepoForQuota) GetAccountCount(context.Context, int64) (int64, int64, error) {
return 0, 0, nil
}
func (r *stubGroupRepoForQuota) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
return 0, nil
......
......@@ -40,7 +40,7 @@ func (groupRepoNoop) ListActiveByPlatform(context.Context, string) ([]Group, err
func (groupRepoNoop) ExistsByName(context.Context, string) (bool, error) {
panic("unexpected ExistsByName call")
}
func (groupRepoNoop) GetAccountCount(context.Context, int64) (int64, error) {
func (groupRepoNoop) GetAccountCount(context.Context, int64) (int64, int64, error) {
panic("unexpected GetAccountCount call")
}
func (groupRepoNoop) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
......@@ -92,7 +92,7 @@ func (userSubRepoNoop) ListActiveByUserID(context.Context, int64) ([]UserSubscri
func (userSubRepoNoop) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error) {
panic("unexpected ListByGroupID call")
}
func (userSubRepoNoop) List(context.Context, pagination.PaginationParams, *int64, *int64, string, string, string) ([]UserSubscription, *pagination.PaginationResult, error) {
func (userSubRepoNoop) List(context.Context, pagination.PaginationParams, *int64, *int64, string, string, string, string) ([]UserSubscription, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
func (userSubRepoNoop) ExistsByUserIDAndGroupID(context.Context, int64, int64) (bool, error) {
......
......@@ -634,9 +634,9 @@ func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupI
}
// List 获取所有订阅(分页,支持筛选和排序)
func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error) {
func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
subs, pag, err := s.userSubRepo.List(ctx, params, userID, groupID, status, sortBy, sortOrder)
subs, pag, err := s.userSubRepo.List(ctx, params, userID, groupID, status, platform, sortBy, sortOrder)
if err != nil {
return nil, nil, err
}
......
......@@ -18,7 +18,7 @@ type UserSubscriptionRepository interface {
ListByUserID(ctx context.Context, userID int64) ([]UserSubscription, error)
ListActiveByUserID(ctx context.Context, userID int64) ([]UserSubscription, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error)
List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error)
List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error)
ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error)
ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment