Commit 1de18b89 authored by Wang Lvyuan's avatar Wang Lvyuan
Browse files

merge: sync upstream/main before PR

parents 882518c1 9f6ab6b8
...@@ -77,12 +77,13 @@ func (h *SubscriptionHandler) List(c *gin.Context) { ...@@ -77,12 +77,13 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
} }
} }
status := c.Query("status") status := c.Query("status")
platform := c.Query("platform")
// Parse sorting parameters // Parse sorting parameters
sortBy := c.DefaultQuery("sort_by", "created_at") sortBy := c.DefaultQuery("sort_by", "created_at")
sortOrder := c.DefaultQuery("sort_order", "desc") sortOrder := c.DefaultQuery("sort_order", "desc")
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, sortBy, sortOrder) subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, platform, sortBy, sortOrder)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
......
...@@ -135,14 +135,16 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup { ...@@ -135,14 +135,16 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
return nil return nil
} }
out := &AdminGroup{ out := &AdminGroup{
Group: groupFromServiceBase(g), Group: groupFromServiceBase(g),
ModelRouting: g.ModelRouting, ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled, ModelRoutingEnabled: g.ModelRoutingEnabled,
MCPXMLInject: g.MCPXMLInject, MCPXMLInject: g.MCPXMLInject,
DefaultMappedModel: g.DefaultMappedModel, DefaultMappedModel: g.DefaultMappedModel,
SupportedModelScopes: g.SupportedModelScopes, SupportedModelScopes: g.SupportedModelScopes,
AccountCount: g.AccountCount, AccountCount: g.AccountCount,
SortOrder: g.SortOrder, ActiveAccountCount: g.ActiveAccountCount,
RateLimitedAccountCount: g.RateLimitedAccountCount,
SortOrder: g.SortOrder,
} }
if len(g.AccountGroups) > 0 { if len(g.AccountGroups) > 0 {
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups)) out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
...@@ -521,6 +523,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog { ...@@ -521,6 +523,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
AccountID: l.AccountID, AccountID: l.AccountID,
RequestID: l.RequestID, RequestID: l.RequestID,
Model: l.Model, Model: l.Model,
UpstreamModel: l.UpstreamModel,
ServiceTier: l.ServiceTier, ServiceTier: l.ServiceTier,
ReasoningEffort: l.ReasoningEffort, ReasoningEffort: l.ReasoningEffort,
InboundEndpoint: l.InboundEndpoint, InboundEndpoint: l.InboundEndpoint,
......
...@@ -157,6 +157,12 @@ type ListSoraS3ProfilesResponse struct { ...@@ -157,6 +157,12 @@ type ListSoraS3ProfilesResponse struct {
Items []SoraS3Profile `json:"items"` Items []SoraS3Profile `json:"items"`
} }
// OverloadCooldownSettings 529过载冷却配置 DTO
type OverloadCooldownSettings struct {
Enabled bool `json:"enabled"`
CooldownMinutes int `json:"cooldown_minutes"`
}
// StreamTimeoutSettings 流超时处理配置 DTO // StreamTimeoutSettings 流超时处理配置 DTO
type StreamTimeoutSettings struct { type StreamTimeoutSettings struct {
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
......
...@@ -122,9 +122,11 @@ type AdminGroup struct { ...@@ -122,9 +122,11 @@ type AdminGroup struct {
DefaultMappedModel string `json:"default_mapped_model"` DefaultMappedModel string `json:"default_mapped_model"`
// 支持的模型系列(仅 antigravity 平台使用) // 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string `json:"supported_model_scopes"` SupportedModelScopes []string `json:"supported_model_scopes"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"` AccountGroups []AccountGroup `json:"account_groups,omitempty"`
AccountCount int64 `json:"account_count,omitempty"` AccountCount int64 `json:"account_count,omitempty"`
ActiveAccountCount int64 `json:"active_account_count,omitempty"`
RateLimitedAccountCount int64 `json:"rate_limited_account_count,omitempty"`
// 分组排序 // 分组排序
SortOrder int `json:"sort_order"` SortOrder int `json:"sort_order"`
...@@ -332,6 +334,9 @@ type UsageLog struct { ...@@ -332,6 +334,9 @@ type UsageLog struct {
AccountID int64 `json:"account_id"` AccountID int64 `json:"account_id"`
RequestID string `json:"request_id"` RequestID string `json:"request_id"`
Model string `json:"model"` Model string `json:"model"`
// UpstreamModel is the actual model sent to the upstream provider after mapping.
// Omitted when no mapping was applied (requested model was used as-is).
UpstreamModel *string `json:"upstream_model,omitempty"`
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex". // ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
ServiceTier *string `json:"service_tier,omitempty"` ServiceTier *string `json:"service_tier,omitempty"`
// ReasoningEffort is the request's reasoning effort level. // ReasoningEffort is the request's reasoning effort level.
......
...@@ -76,7 +76,7 @@ func (f *fakeGroupRepo) ListActiveByPlatform(context.Context, string) ([]service ...@@ -76,7 +76,7 @@ func (f *fakeGroupRepo) ListActiveByPlatform(context.Context, string) ([]service
return nil, nil return nil, nil
} }
func (f *fakeGroupRepo) ExistsByName(context.Context, string) (bool, error) { return false, nil } func (f *fakeGroupRepo) ExistsByName(context.Context, string) (bool, error) { return false, nil }
func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, error) { return 0, nil } func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, int64, error) { return 0, 0, nil }
func (f *fakeGroupRepo) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) { func (f *fakeGroupRepo) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
return 0, nil return 0, nil
} }
......
...@@ -136,7 +136,7 @@ func validClaudeCodeBodyJSON() []byte { ...@@ -136,7 +136,7 @@ func validClaudeCodeBodyJSON() []byte {
return []byte(`{ return []byte(`{
"model":"claude-3-5-sonnet-20241022", "model":"claude-3-5-sonnet-20241022",
"system":[{"text":"You are Claude Code, Anthropic's official CLI for Claude."}], "system":[{"text":"You are Claude Code, Anthropic's official CLI for Claude."}],
"metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"} "metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"}
}`) }`)
} }
...@@ -190,7 +190,7 @@ func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing ...@@ -190,7 +190,7 @@ func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing
System: []any{ System: []any{
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."}, map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
}, },
MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123", MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa",
} }
// body 非法 JSON,如果函数复用 parsedReq 成功则仍应判定为 Claude Code。 // body 非法 JSON,如果函数复用 parsedReq 成功则仍应判定为 Claude Code。
...@@ -209,7 +209,7 @@ func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing ...@@ -209,7 +209,7 @@ func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing
"system": []any{ "system": []any{
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."}, map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
}, },
"metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"}, "metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"},
}) })
SetClaudeCodeClientContext(c, []byte(`{invalid`), nil) SetClaudeCodeClientContext(c, []byte(`{invalid`), nil)
......
...@@ -273,8 +273,8 @@ func (r *stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform strin ...@@ -273,8 +273,8 @@ func (r *stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform strin
func (r *stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) { func (r *stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) {
return false, nil return false, nil
} }
func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
return 0, nil return 0, 0, nil
} }
func (r *stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, nil return 0, nil
...@@ -348,6 +348,9 @@ func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTi ...@@ -348,6 +348,9 @@ func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTi
func (s *stubUsageLogRepo) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error) { func (s *stubUsageLogRepo) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error) {
return nil, nil return nil, nil
} }
func (s *stubUsageLogRepo) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) { func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
return nil, nil return nil, nil
} }
......
...@@ -3,6 +3,28 @@ package usagestats ...@@ -3,6 +3,28 @@ package usagestats
import "time" import "time"
const (
ModelSourceRequested = "requested"
ModelSourceUpstream = "upstream"
ModelSourceMapping = "mapping"
)
func IsValidModelSource(source string) bool {
switch source {
case ModelSourceRequested, ModelSourceUpstream, ModelSourceMapping:
return true
default:
return false
}
}
func NormalizeModelSource(source string) string {
if IsValidModelSource(source) {
return source
}
return ModelSourceRequested
}
// DashboardStats 仪表盘统计 // DashboardStats 仪表盘统计
type DashboardStats struct { type DashboardStats struct {
// 用户统计 // 用户统计
...@@ -90,6 +112,13 @@ type EndpointStat struct { ...@@ -90,6 +112,13 @@ type EndpointStat struct {
ActualCost float64 `json:"actual_cost"` // 实际扣除 ActualCost float64 `json:"actual_cost"` // 实际扣除
} }
// GroupUsageSummary represents today's and cumulative cost for a single group.
type GroupUsageSummary struct {
GroupID int64 `json:"group_id"`
TodayCost float64 `json:"today_cost"`
TotalCost float64 `json:"total_cost"`
}
// GroupStat represents usage statistics for a single group // GroupStat represents usage statistics for a single group
type GroupStat struct { type GroupStat struct {
GroupID int64 `json:"group_id"` GroupID int64 `json:"group_id"`
...@@ -143,6 +172,7 @@ type UserBreakdownItem struct { ...@@ -143,6 +172,7 @@ type UserBreakdownItem struct {
type UserBreakdownDimension struct { type UserBreakdownDimension struct {
GroupID int64 // filter by group_id (>0 to enable) GroupID int64 // filter by group_id (>0 to enable)
Model string // filter by model name (non-empty to enable) Model string // filter by model name (non-empty to enable)
ModelType string // "requested", "upstream", or "mapping"
Endpoint string // filter by endpoint value (non-empty to enable) Endpoint string // filter by endpoint value (non-empty to enable)
EndpointType string // "inbound", "upstream", or "path" EndpointType string // "inbound", "upstream", or "path"
} }
......
package usagestats
import "testing"
func TestIsValidModelSource(t *testing.T) {
tests := []struct {
name string
source string
want bool
}{
{name: "requested", source: ModelSourceRequested, want: true},
{name: "upstream", source: ModelSourceUpstream, want: true},
{name: "mapping", source: ModelSourceMapping, want: true},
{name: "invalid", source: "foobar", want: false},
{name: "empty", source: "", want: false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := IsValidModelSource(tc.source); got != tc.want {
t.Fatalf("IsValidModelSource(%q)=%v want %v", tc.source, got, tc.want)
}
})
}
}
func TestNormalizeModelSource(t *testing.T) {
tests := []struct {
name string
source string
want string
}{
{name: "requested", source: ModelSourceRequested, want: ModelSourceRequested},
{name: "upstream", source: ModelSourceUpstream, want: ModelSourceUpstream},
{name: "mapping", source: ModelSourceMapping, want: ModelSourceMapping},
{name: "invalid falls back", source: "foobar", want: ModelSourceRequested},
{name: "empty falls back", source: "", want: ModelSourceRequested},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := NormalizeModelSource(tc.source); got != tc.want {
t.Fatalf("NormalizeModelSource(%q)=%q want %q", tc.source, got, tc.want)
}
})
}
}
...@@ -88,8 +88,9 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group ...@@ -88,8 +88,9 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group
if err != nil { if err != nil {
return nil, err return nil, err
} }
count, _ := r.GetAccountCount(ctx, out.ID) total, active, _ := r.GetAccountCount(ctx, out.ID)
out.AccountCount = count out.AccountCount = total
out.ActiveAccountCount = active
return out, nil return out, nil
} }
...@@ -256,7 +257,10 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination ...@@ -256,7 +257,10 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
counts, err := r.loadAccountCounts(ctx, groupIDs) counts, err := r.loadAccountCounts(ctx, groupIDs)
if err == nil { if err == nil {
for i := range outGroups { for i := range outGroups {
outGroups[i].AccountCount = counts[outGroups[i].ID] c := counts[outGroups[i].ID]
outGroups[i].AccountCount = c.Total
outGroups[i].ActiveAccountCount = c.Active
outGroups[i].RateLimitedAccountCount = c.RateLimited
} }
} }
...@@ -283,7 +287,10 @@ func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, erro ...@@ -283,7 +287,10 @@ func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, erro
counts, err := r.loadAccountCounts(ctx, groupIDs) counts, err := r.loadAccountCounts(ctx, groupIDs)
if err == nil { if err == nil {
for i := range outGroups { for i := range outGroups {
outGroups[i].AccountCount = counts[outGroups[i].ID] c := counts[outGroups[i].ID]
outGroups[i].AccountCount = c.Total
outGroups[i].ActiveAccountCount = c.Active
outGroups[i].RateLimitedAccountCount = c.RateLimited
} }
} }
...@@ -310,7 +317,10 @@ func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform str ...@@ -310,7 +317,10 @@ func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform str
counts, err := r.loadAccountCounts(ctx, groupIDs) counts, err := r.loadAccountCounts(ctx, groupIDs)
if err == nil { if err == nil {
for i := range outGroups { for i := range outGroups {
outGroups[i].AccountCount = counts[outGroups[i].ID] c := counts[outGroups[i].ID]
outGroups[i].AccountCount = c.Total
outGroups[i].ActiveAccountCount = c.Active
outGroups[i].RateLimitedAccountCount = c.RateLimited
} }
} }
...@@ -369,12 +379,20 @@ func (r *groupRepository) ExistsByIDs(ctx context.Context, ids []int64) (map[int ...@@ -369,12 +379,20 @@ func (r *groupRepository) ExistsByIDs(ctx context.Context, ids []int64) (map[int
return result, nil return result, nil
} }
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (total int64, active int64, err error) {
var count int64 var rateLimited int64
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", []any{groupID}, &count); err != nil { err = scanSingleRow(ctx, r.sql,
return 0, err `SELECT COUNT(*),
} COUNT(*) FILTER (WHERE a.status = 'active' AND a.schedulable = true),
return count, nil COUNT(*) FILTER (WHERE a.status = 'active' AND (
a.rate_limit_reset_at > NOW() OR
a.overload_until > NOW() OR
a.temp_unschedulable_until > NOW()
))
FROM account_groups ag JOIN accounts a ON a.id = ag.account_id
WHERE ag.group_id = $1`,
[]any{groupID}, &total, &active, &rateLimited)
return
} }
func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
...@@ -500,15 +518,32 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, ...@@ -500,15 +518,32 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
return affectedUserIDs, nil return affectedUserIDs, nil
} }
func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int64) (counts map[int64]int64, err error) { type groupAccountCounts struct {
counts = make(map[int64]int64, len(groupIDs)) Total int64
Active int64
RateLimited int64
}
func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int64) (counts map[int64]groupAccountCounts, err error) {
counts = make(map[int64]groupAccountCounts, len(groupIDs))
if len(groupIDs) == 0 { if len(groupIDs) == 0 {
return counts, nil return counts, nil
} }
rows, err := r.sql.QueryContext( rows, err := r.sql.QueryContext(
ctx, ctx,
"SELECT group_id, COUNT(*) FROM account_groups WHERE group_id = ANY($1) GROUP BY group_id", `SELECT ag.group_id,
COUNT(*) AS total,
COUNT(*) FILTER (WHERE a.status = 'active' AND a.schedulable = true) AS active,
COUNT(*) FILTER (WHERE a.status = 'active' AND (
a.rate_limit_reset_at > NOW() OR
a.overload_until > NOW() OR
a.temp_unschedulable_until > NOW()
)) AS rate_limited
FROM account_groups ag
JOIN accounts a ON a.id = ag.account_id
WHERE ag.group_id = ANY($1)
GROUP BY ag.group_id`,
pq.Array(groupIDs), pq.Array(groupIDs),
) )
if err != nil { if err != nil {
...@@ -523,11 +558,11 @@ func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int6 ...@@ -523,11 +558,11 @@ func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int6
for rows.Next() { for rows.Next() {
var groupID int64 var groupID int64
var count int64 var c groupAccountCounts
if err = rows.Scan(&groupID, &count); err != nil { if err = rows.Scan(&groupID, &c.Total, &c.Active, &c.RateLimited); err != nil {
return nil, err return nil, err
} }
counts[groupID] = count counts[groupID] = c
} }
if err = rows.Err(); err != nil { if err = rows.Err(); err != nil {
return nil, err return nil, err
......
...@@ -603,7 +603,7 @@ func (s *GroupRepoSuite) TestGetAccountCount() { ...@@ -603,7 +603,7 @@ func (s *GroupRepoSuite) TestGetAccountCount() {
_, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a2, group.ID, 2) _, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a2, group.ID, 2)
s.Require().NoError(err) s.Require().NoError(err)
count, err := s.repo.GetAccountCount(s.ctx, group.ID) count, _, err := s.repo.GetAccountCount(s.ctx, group.ID)
s.Require().NoError(err, "GetAccountCount") s.Require().NoError(err, "GetAccountCount")
s.Require().Equal(int64(2), count) s.Require().Equal(int64(2), count)
} }
...@@ -619,7 +619,7 @@ func (s *GroupRepoSuite) TestGetAccountCount_Empty() { ...@@ -619,7 +619,7 @@ func (s *GroupRepoSuite) TestGetAccountCount_Empty() {
} }
s.Require().NoError(s.repo.Create(s.ctx, group)) s.Require().NoError(s.repo.Create(s.ctx, group))
count, err := s.repo.GetAccountCount(s.ctx, group.ID) count, _, err := s.repo.GetAccountCount(s.ctx, group.ID)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Zero(count) s.Require().Zero(count)
} }
...@@ -651,7 +651,7 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() { ...@@ -651,7 +651,7 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
s.Require().NoError(err, "DeleteAccountGroupsByGroupID") s.Require().NoError(err, "DeleteAccountGroupsByGroupID")
s.Require().Equal(int64(1), affected, "expected 1 affected row") s.Require().Equal(int64(1), affected, "expected 1 affected row")
count, err := s.repo.GetAccountCount(s.ctx, g.ID) count, _, err := s.repo.GetAccountCount(s.ctx, g.ID)
s.Require().NoError(err, "GetAccountCount") s.Require().NoError(err, "GetAccountCount")
s.Require().Equal(int64(0), count, "expected 0 account groups") s.Require().Equal(int64(0), count, "expected 0 account groups")
} }
...@@ -692,7 +692,7 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() { ...@@ -692,7 +692,7 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Equal(int64(3), affected) s.Require().Equal(int64(3), affected)
count, _ := s.repo.GetAccountCount(s.ctx, g.ID) count, _, _ := s.repo.GetAccountCount(s.ctx, g.ID)
s.Require().Zero(count) s.Require().Zero(count)
} }
......
...@@ -28,7 +28,7 @@ import ( ...@@ -28,7 +28,7 @@ import (
gocache "github.com/patrickmn/go-cache" gocache "github.com/patrickmn/go-cache"
) )
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at" const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at"
var usageLogInsertArgTypes = [...]string{ var usageLogInsertArgTypes = [...]string{
"bigint", "bigint",
...@@ -36,6 +36,7 @@ var usageLogInsertArgTypes = [...]string{ ...@@ -36,6 +36,7 @@ var usageLogInsertArgTypes = [...]string{
"bigint", "bigint",
"text", "text",
"text", "text",
"text",
"bigint", "bigint",
"bigint", "bigint",
"integer", "integer",
...@@ -277,6 +278,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, ...@@ -277,6 +278,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
account_id, account_id,
request_id, request_id,
model, model,
upstream_model,
group_id, group_id,
subscription_id, subscription_id,
input_tokens, input_tokens,
...@@ -311,12 +313,12 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, ...@@ -311,12 +313,12 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
cache_ttl_overridden, cache_ttl_overridden,
created_at created_at
) VALUES ( ) VALUES (
$1, $2, $3, $4, $5, $1, $2, $3, $4, $5, $6,
$6, $7, $7, $8,
$8, $9, $10, $11, $9, $10, $11, $12,
$12, $13, $13, $14,
$14, $15, $16, $17, $18, $19, $15, $16, $17, $18, $19, $20,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38 $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39
) )
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at RETURNING id, created_at
...@@ -707,6 +709,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage ...@@ -707,6 +709,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
account_id, account_id,
request_id, request_id,
model, model,
upstream_model,
group_id, group_id,
subscription_id, subscription_id,
input_tokens, input_tokens,
...@@ -742,7 +745,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage ...@@ -742,7 +745,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
created_at created_at
) AS (VALUES `) ) AS (VALUES `)
args := make([]any, 0, len(keys)*38) args := make([]any, 0, len(keys)*39)
argPos := 1 argPos := 1
for idx, key := range keys { for idx, key := range keys {
if idx > 0 { if idx > 0 {
...@@ -776,6 +779,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage ...@@ -776,6 +779,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
account_id, account_id,
request_id, request_id,
model, model,
upstream_model,
group_id, group_id,
subscription_id, subscription_id,
input_tokens, input_tokens,
...@@ -816,6 +820,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage ...@@ -816,6 +820,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
account_id, account_id,
request_id, request_id,
model, model,
upstream_model,
group_id, group_id,
subscription_id, subscription_id,
input_tokens, input_tokens,
...@@ -896,6 +901,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ...@@ -896,6 +901,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
account_id, account_id,
request_id, request_id,
model, model,
upstream_model,
group_id, group_id,
subscription_id, subscription_id,
input_tokens, input_tokens,
...@@ -931,7 +937,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ...@@ -931,7 +937,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
created_at created_at
) AS (VALUES `) ) AS (VALUES `)
args := make([]any, 0, len(preparedList)*38) args := make([]any, 0, len(preparedList)*39)
argPos := 1 argPos := 1
for idx, prepared := range preparedList { for idx, prepared := range preparedList {
if idx > 0 { if idx > 0 {
...@@ -962,6 +968,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ...@@ -962,6 +968,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
account_id, account_id,
request_id, request_id,
model, model,
upstream_model,
group_id, group_id,
subscription_id, subscription_id,
input_tokens, input_tokens,
...@@ -1002,6 +1009,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ...@@ -1002,6 +1009,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
account_id, account_id,
request_id, request_id,
model, model,
upstream_model,
group_id, group_id,
subscription_id, subscription_id,
input_tokens, input_tokens,
...@@ -1050,6 +1058,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared ...@@ -1050,6 +1058,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
account_id, account_id,
request_id, request_id,
model, model,
upstream_model,
group_id, group_id,
subscription_id, subscription_id,
input_tokens, input_tokens,
...@@ -1084,12 +1093,12 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared ...@@ -1084,12 +1093,12 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
cache_ttl_overridden, cache_ttl_overridden,
created_at created_at
) VALUES ( ) VALUES (
$1, $2, $3, $4, $5, $1, $2, $3, $4, $5, $6,
$6, $7, $7, $8,
$8, $9, $10, $11, $9, $10, $11, $12,
$12, $13, $13, $14,
$14, $15, $16, $17, $18, $19, $15, $16, $17, $18, $19, $20,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38 $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39
) )
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
`, prepared.args...) `, prepared.args...)
...@@ -1121,6 +1130,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { ...@@ -1121,6 +1130,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
reasoningEffort := nullString(log.ReasoningEffort) reasoningEffort := nullString(log.ReasoningEffort)
inboundEndpoint := nullString(log.InboundEndpoint) inboundEndpoint := nullString(log.InboundEndpoint)
upstreamEndpoint := nullString(log.UpstreamEndpoint) upstreamEndpoint := nullString(log.UpstreamEndpoint)
upstreamModel := nullString(log.UpstreamModel)
var requestIDArg any var requestIDArg any
if requestID != "" { if requestID != "" {
...@@ -1138,6 +1148,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { ...@@ -1138,6 +1148,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
log.AccountID, log.AccountID,
requestIDArg, requestIDArg,
log.Model, log.Model,
upstreamModel,
groupID, groupID,
subscriptionID, subscriptionID,
log.InputTokens, log.InputTokens,
...@@ -2864,15 +2875,26 @@ func (r *usageLogRepository) getUsageTrendFromAggregates(ctx context.Context, st ...@@ -2864,15 +2875,26 @@ func (r *usageLogRepository) getUsageTrendFromAggregates(ctx context.Context, st
// GetModelStatsWithFilters returns model statistics with optional filters // GetModelStatsWithFilters returns model statistics with optional filters
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) (results []ModelStat, err error) { func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) (results []ModelStat, err error) {
return r.getModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, usagestats.ModelSourceRequested)
}
// GetModelStatsWithFiltersBySource returns model statistics with optional filters and model source dimension.
// source: requested | upstream | mapping.
func (r *usageLogRepository) GetModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, source string) (results []ModelStat, err error) {
return r.getModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, source)
}
func (r *usageLogRepository) getModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, source string) (results []ModelStat, err error) {
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost" actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
// 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。 // 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。
if accountID > 0 && userID == 0 && apiKeyID == 0 { if accountID > 0 && userID == 0 && apiKeyID == 0 {
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost" actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
} }
modelExpr := resolveModelDimensionExpression(source)
query := fmt.Sprintf(` query := fmt.Sprintf(`
SELECT SELECT
model, %s as model,
COUNT(*) as requests, COUNT(*) as requests,
COALESCE(SUM(input_tokens), 0) as input_tokens, COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens, COALESCE(SUM(output_tokens), 0) as output_tokens,
...@@ -2883,7 +2905,7 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start ...@@ -2883,7 +2905,7 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
%s %s
FROM usage_logs FROM usage_logs
WHERE created_at >= $1 AND created_at < $2 WHERE created_at >= $1 AND created_at < $2
`, actualCostExpr) `, modelExpr, actualCostExpr)
args := []any{startTime, endTime} args := []any{startTime, endTime}
if userID > 0 { if userID > 0 {
...@@ -2907,7 +2929,7 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start ...@@ -2907,7 +2929,7 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1) query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
args = append(args, int16(*billingType)) args = append(args, int16(*billingType))
} }
query += " GROUP BY model ORDER BY total_tokens DESC" query += fmt.Sprintf(" GROUP BY %s ORDER BY total_tokens DESC", modelExpr)
rows, err := r.sql.QueryContext(ctx, query, args...) rows, err := r.sql.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
...@@ -3021,7 +3043,7 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim ...@@ -3021,7 +3043,7 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim
args = append(args, dim.GroupID) args = append(args, dim.GroupID)
} }
if dim.Model != "" { if dim.Model != "" {
query += fmt.Sprintf(" AND ul.model = $%d", len(args)+1) query += fmt.Sprintf(" AND %s = $%d", resolveModelDimensionExpression(dim.ModelType), len(args)+1)
args = append(args, dim.Model) args = append(args, dim.Model)
} }
if dim.Endpoint != "" { if dim.Endpoint != "" {
...@@ -3067,6 +3089,53 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim ...@@ -3067,6 +3089,53 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim
return results, nil return results, nil
} }
// GetAllGroupUsageSummary returns today's and cumulative actual_cost for every group.
// todayStart is the start-of-day in the caller's timezone (UTC-based).
// TODO(perf): This query scans ALL usage_logs rows for total_cost aggregation.
// When usage_logs exceeds ~1M rows, consider adding a short-lived cache (30s)
// or a materialized view / pre-aggregation table for cumulative costs.
func (r *usageLogRepository) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) {
query := `
SELECT
g.id AS group_id,
COALESCE(SUM(ul.actual_cost), 0) AS total_cost,
COALESCE(SUM(CASE WHEN ul.created_at >= $1 THEN ul.actual_cost ELSE 0 END), 0) AS today_cost
FROM groups g
LEFT JOIN usage_logs ul ON ul.group_id = g.id
GROUP BY g.id
`
rows, err := r.sql.QueryContext(ctx, query, todayStart)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
var results []usagestats.GroupUsageSummary
for rows.Next() {
var row usagestats.GroupUsageSummary
if err := rows.Scan(&row.GroupID, &row.TotalCost, &row.TodayCost); err != nil {
return nil, err
}
results = append(results, row)
}
if err := rows.Err(); err != nil {
return nil, err
}
return results, nil
}
// resolveModelDimensionExpression maps model source type to a safe SQL expression.
func resolveModelDimensionExpression(modelType string) string {
switch usagestats.NormalizeModelSource(modelType) {
case usagestats.ModelSourceUpstream:
return "COALESCE(NULLIF(TRIM(upstream_model), ''), model)"
case usagestats.ModelSourceMapping:
return "(model || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), model))"
default:
return "model"
}
}
// resolveEndpointColumn maps endpoint type to the corresponding DB column name. // resolveEndpointColumn maps endpoint type to the corresponding DB column name.
func resolveEndpointColumn(endpointType string) string { func resolveEndpointColumn(endpointType string) string {
switch endpointType { switch endpointType {
...@@ -3819,6 +3888,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -3819,6 +3888,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
accountID int64 accountID int64
requestID sql.NullString requestID sql.NullString
model string model string
upstreamModel sql.NullString
groupID sql.NullInt64 groupID sql.NullInt64
subscriptionID sql.NullInt64 subscriptionID sql.NullInt64
inputTokens int inputTokens int
...@@ -3861,6 +3931,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -3861,6 +3931,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&accountID, &accountID,
&requestID, &requestID,
&model, &model,
&upstreamModel,
&groupID, &groupID,
&subscriptionID, &subscriptionID,
&inputTokens, &inputTokens,
...@@ -3973,6 +4044,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -3973,6 +4044,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if upstreamEndpoint.Valid { if upstreamEndpoint.Valid {
log.UpstreamEndpoint = &upstreamEndpoint.String log.UpstreamEndpoint = &upstreamEndpoint.String
} }
if upstreamModel.Valid {
log.UpstreamModel = &upstreamModel.String
}
return log, nil return log, nil
} }
......
...@@ -5,6 +5,7 @@ package repository ...@@ -5,6 +5,7 @@ package repository
import ( import (
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
...@@ -16,8 +17,8 @@ func TestResolveEndpointColumn(t *testing.T) { ...@@ -16,8 +17,8 @@ func TestResolveEndpointColumn(t *testing.T) {
{"inbound", "ul.inbound_endpoint"}, {"inbound", "ul.inbound_endpoint"},
{"upstream", "ul.upstream_endpoint"}, {"upstream", "ul.upstream_endpoint"},
{"path", "ul.inbound_endpoint || ' -> ' || ul.upstream_endpoint"}, {"path", "ul.inbound_endpoint || ' -> ' || ul.upstream_endpoint"},
{"", "ul.inbound_endpoint"}, // default {"", "ul.inbound_endpoint"}, // default
{"unknown", "ul.inbound_endpoint"}, // fallback {"unknown", "ul.inbound_endpoint"}, // fallback
} }
for _, tc := range tests { for _, tc := range tests {
...@@ -27,3 +28,23 @@ func TestResolveEndpointColumn(t *testing.T) { ...@@ -27,3 +28,23 @@ func TestResolveEndpointColumn(t *testing.T) {
}) })
} }
} }
func TestResolveModelDimensionExpression(t *testing.T) {
tests := []struct {
modelType string
want string
}{
{usagestats.ModelSourceRequested, "model"},
{usagestats.ModelSourceUpstream, "COALESCE(NULLIF(TRIM(upstream_model), ''), model)"},
{usagestats.ModelSourceMapping, "(model || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), model))"},
{"", "model"},
{"invalid", "model"},
}
for _, tc := range tests {
t.Run(tc.modelType, func(t *testing.T) {
got := resolveModelDimensionExpression(tc.modelType)
require.Equal(t, tc.want, got)
})
}
}
...@@ -44,6 +44,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) { ...@@ -44,6 +44,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
log.AccountID, log.AccountID,
log.RequestID, log.RequestID,
log.Model, log.Model,
sqlmock.AnyArg(), // upstream_model
sqlmock.AnyArg(), // group_id sqlmock.AnyArg(), // group_id
sqlmock.AnyArg(), // subscription_id sqlmock.AnyArg(), // subscription_id
log.InputTokens, log.InputTokens,
...@@ -116,6 +117,7 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) { ...@@ -116,6 +117,7 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
log.Model, log.Model,
sqlmock.AnyArg(), sqlmock.AnyArg(),
sqlmock.AnyArg(), sqlmock.AnyArg(),
sqlmock.AnyArg(),
log.InputTokens, log.InputTokens,
log.OutputTokens, log.OutputTokens,
log.CacheCreationTokens, log.CacheCreationTokens,
...@@ -353,6 +355,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { ...@@ -353,6 +355,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
int64(30), // account_id int64(30), // account_id
sql.NullString{Valid: true, String: "req-1"}, sql.NullString{Valid: true, String: "req-1"},
"gpt-5", // model "gpt-5", // model
sql.NullString{}, // upstream_model
sql.NullInt64{}, // group_id sql.NullInt64{}, // group_id
sql.NullInt64{}, // subscription_id sql.NullInt64{}, // subscription_id
1, // input_tokens 1, // input_tokens
...@@ -404,6 +407,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { ...@@ -404,6 +407,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
int64(31), int64(31),
sql.NullString{Valid: true, String: "req-2"}, sql.NullString{Valid: true, String: "req-2"},
"gpt-5", "gpt-5",
sql.NullString{},
sql.NullInt64{}, sql.NullInt64{},
sql.NullInt64{}, sql.NullInt64{},
1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6,
...@@ -445,6 +449,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { ...@@ -445,6 +449,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
int64(32), int64(32),
sql.NullString{Valid: true, String: "req-3"}, sql.NullString{Valid: true, String: "req-3"},
"gpt-5.4", "gpt-5.4",
sql.NullString{},
sql.NullInt64{}, sql.NullInt64{},
sql.NullInt64{}, sql.NullInt64{},
1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6,
......
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
"time" "time"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
...@@ -190,7 +191,7 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID ...@@ -190,7 +191,7 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil
} }
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) { func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
client := clientFromContext(ctx, r.client) client := clientFromContext(ctx, r.client)
q := client.UserSubscription.Query() q := client.UserSubscription.Query()
if userID != nil { if userID != nil {
...@@ -199,6 +200,9 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination ...@@ -199,6 +200,9 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
if groupID != nil { if groupID != nil {
q = q.Where(usersubscription.GroupIDEQ(*groupID)) q = q.Where(usersubscription.GroupIDEQ(*groupID))
} }
if platform != "" {
q = q.Where(usersubscription.HasGroupWith(group.PlatformEQ(platform)))
}
// Status filtering with real-time expiration check // Status filtering with real-time expiration check
now := time.Now() now := time.Now()
......
...@@ -271,7 +271,7 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() { ...@@ -271,7 +271,7 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
group := s.mustCreateGroup("g-list") group := s.mustCreateGroup("g-list")
s.mustCreateSubscription(user.ID, group.ID, nil) s.mustCreateSubscription(user.ID, group.ID, nil)
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "", "", "") subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "", "", "", "")
s.Require().NoError(err, "List") s.Require().NoError(err, "List")
s.Require().Len(subs, 1) s.Require().Len(subs, 1)
s.Require().Equal(int64(1), page.Total) s.Require().Equal(int64(1), page.Total)
...@@ -285,7 +285,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() { ...@@ -285,7 +285,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
s.mustCreateSubscription(user1.ID, group.ID, nil) s.mustCreateSubscription(user1.ID, group.ID, nil)
s.mustCreateSubscription(user2.ID, group.ID, nil) s.mustCreateSubscription(user2.ID, group.ID, nil)
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "", "", "") subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "", "", "", "")
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(subs, 1) s.Require().Len(subs, 1)
s.Require().Equal(user1.ID, subs[0].UserID) s.Require().Equal(user1.ID, subs[0].UserID)
...@@ -299,7 +299,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() { ...@@ -299,7 +299,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
s.mustCreateSubscription(user.ID, g1.ID, nil) s.mustCreateSubscription(user.ID, g1.ID, nil)
s.mustCreateSubscription(user.ID, g2.ID, nil) s.mustCreateSubscription(user.ID, g2.ID, nil)
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "", "", "") subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "", "", "", "")
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(subs, 1) s.Require().Len(subs, 1)
s.Require().Equal(g1.ID, subs[0].GroupID) s.Require().Equal(g1.ID, subs[0].GroupID)
...@@ -320,7 +320,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() { ...@@ -320,7 +320,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
c.SetExpiresAt(time.Now().Add(-24 * time.Hour)) c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
}) })
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired, "", "") subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired, "", "", "")
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(subs, 1) s.Require().Len(subs, 1)
s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status) s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status)
......
...@@ -924,8 +924,8 @@ func (stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error ...@@ -924,8 +924,8 @@ func (stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error
return false, errors.New("not implemented") return false, errors.New("not implemented")
} }
func (stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { func (stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
return 0, errors.New("not implemented") return 0, 0, errors.New("not implemented")
} }
func (stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { func (stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
...@@ -1289,7 +1289,7 @@ func (r *stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userI ...@@ -1289,7 +1289,7 @@ func (r *stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userI
func (stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) { func (stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) { func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) { func (stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
...@@ -1786,6 +1786,9 @@ func (r *stubUsageLogRepo) GetAccountUsageStats(ctx context.Context, accountID i ...@@ -1786,6 +1786,9 @@ func (r *stubUsageLogRepo) GetAccountUsageStats(ctx context.Context, accountID i
func (r *stubUsageLogRepo) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) { func (r *stubUsageLogRepo) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubUsageLogRepo) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) {
return nil, errors.New("not implemented")
}
type stubSettingRepo struct { type stubSettingRepo struct {
all map[string]string all map[string]string
......
...@@ -135,7 +135,7 @@ func (f fakeGoogleSubscriptionRepo) ListActiveByUserID(ctx context.Context, user ...@@ -135,7 +135,7 @@ func (f fakeGoogleSubscriptionRepo) ListActiveByUserID(ctx context.Context, user
func (f fakeGoogleSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) { func (f fakeGoogleSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (f fakeGoogleSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) { func (f fakeGoogleSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (f fakeGoogleSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) { func (f fakeGoogleSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
......
...@@ -646,7 +646,7 @@ func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID in ...@@ -646,7 +646,7 @@ func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID in
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) { func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
......
...@@ -227,6 +227,8 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -227,6 +227,8 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
{ {
groups.GET("", h.Admin.Group.List) groups.GET("", h.Admin.Group.List)
groups.GET("/all", h.Admin.Group.GetAll) groups.GET("/all", h.Admin.Group.GetAll)
groups.GET("/usage-summary", h.Admin.Group.GetUsageSummary)
groups.GET("/capacity-summary", h.Admin.Group.GetCapacitySummary)
groups.PUT("/sort-order", h.Admin.Group.UpdateSortOrder) groups.PUT("/sort-order", h.Admin.Group.UpdateSortOrder)
groups.GET("/:id", h.Admin.Group.GetByID) groups.GET("/:id", h.Admin.Group.GetByID)
groups.POST("", h.Admin.Group.Create) groups.POST("", h.Admin.Group.Create)
...@@ -400,6 +402,9 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -400,6 +402,9 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminAPIKey) adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminAPIKey)
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminAPIKey) adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminAPIKey)
adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminAPIKey) adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminAPIKey)
// 529过载冷却配置
adminSettings.GET("/overload-cooldown", h.Admin.Setting.GetOverloadCooldownSettings)
adminSettings.PUT("/overload-cooldown", h.Admin.Setting.UpdateOverloadCooldownSettings)
// 流超时处理配置 // 流超时处理配置
adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings) adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings)
adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings) adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment