Commit 2a5ef6d3 authored by shaw's avatar shaw
Browse files

Merge PR #279: feat(计费): 账号计费倍率快照与账号口径费用统计

parents 99cbfa15 ec24a3c3
...@@ -80,6 +80,10 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account ...@@ -80,6 +80,10 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
SetSchedulable(account.Schedulable). SetSchedulable(account.Schedulable).
SetAutoPauseOnExpired(account.AutoPauseOnExpired) SetAutoPauseOnExpired(account.AutoPauseOnExpired)
if account.RateMultiplier != nil {
builder.SetRateMultiplier(*account.RateMultiplier)
}
if account.ProxyID != nil { if account.ProxyID != nil {
builder.SetProxyID(*account.ProxyID) builder.SetProxyID(*account.ProxyID)
} }
...@@ -291,6 +295,10 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account ...@@ -291,6 +295,10 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
SetSchedulable(account.Schedulable). SetSchedulable(account.Schedulable).
SetAutoPauseOnExpired(account.AutoPauseOnExpired) SetAutoPauseOnExpired(account.AutoPauseOnExpired)
if account.RateMultiplier != nil {
builder.SetRateMultiplier(*account.RateMultiplier)
}
if account.ProxyID != nil { if account.ProxyID != nil {
builder.SetProxyID(*account.ProxyID) builder.SetProxyID(*account.ProxyID)
} else { } else {
...@@ -999,6 +1007,11 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates ...@@ -999,6 +1007,11 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
args = append(args, *updates.Priority) args = append(args, *updates.Priority)
idx++ idx++
} }
if updates.RateMultiplier != nil {
setClauses = append(setClauses, "rate_multiplier = $"+itoa(idx))
args = append(args, *updates.RateMultiplier)
idx++
}
if updates.Status != nil { if updates.Status != nil {
setClauses = append(setClauses, "status = $"+itoa(idx)) setClauses = append(setClauses, "status = $"+itoa(idx))
args = append(args, *updates.Status) args = append(args, *updates.Status)
...@@ -1347,6 +1360,8 @@ func accountEntityToService(m *dbent.Account) *service.Account { ...@@ -1347,6 +1360,8 @@ func accountEntityToService(m *dbent.Account) *service.Account {
return nil return nil
} }
rateMultiplier := m.RateMultiplier
return &service.Account{ return &service.Account{
ID: m.ID, ID: m.ID,
Name: m.Name, Name: m.Name,
...@@ -1358,6 +1373,7 @@ func accountEntityToService(m *dbent.Account) *service.Account { ...@@ -1358,6 +1373,7 @@ func accountEntityToService(m *dbent.Account) *service.Account {
ProxyID: m.ProxyID, ProxyID: m.ProxyID,
Concurrency: m.Concurrency, Concurrency: m.Concurrency,
Priority: m.Priority, Priority: m.Priority,
RateMultiplier: &rateMultiplier,
Status: m.Status, Status: m.Status,
ErrorMessage: derefString(m.ErrorMessage), ErrorMessage: derefString(m.ErrorMessage),
LastUsedAt: m.LastUsedAt, LastUsedAt: m.LastUsedAt,
......
...@@ -71,7 +71,9 @@ usage_agg AS ( ...@@ -71,7 +71,9 @@ usage_agg AS (
error_base AS ( error_base AS (
SELECT SELECT
date_trunc('hour', created_at AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' AS bucket_start, date_trunc('hour', created_at AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' AS bucket_start,
platform AS platform, -- platform is NULL for some early-phase errors (e.g. before routing); map to a sentinel
-- value so platform-level GROUPING SETS don't collide with the overall (platform=NULL) row.
COALESCE(platform, 'unknown') AS platform,
group_id AS group_id, group_id AS group_id,
is_business_limited AS is_business_limited, is_business_limited AS is_business_limited,
error_owner AS error_owner, error_owner AS error_owner,
......
...@@ -22,7 +22,7 @@ import ( ...@@ -22,7 +22,7 @@ import (
"github.com/lib/pq" "github.com/lib/pq"
) )
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, created_at" const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, created_at"
type usageLogRepository struct { type usageLogRepository struct {
client *dbent.Client client *dbent.Client
...@@ -105,6 +105,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ...@@ -105,6 +105,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
total_cost, total_cost,
actual_cost, actual_cost,
rate_multiplier, rate_multiplier,
account_rate_multiplier,
billing_type, billing_type,
stream, stream,
duration_ms, duration_ms,
...@@ -120,7 +121,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ...@@ -120,7 +121,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
$8, $9, $10, $11, $8, $9, $10, $11,
$12, $13, $12, $13,
$14, $15, $16, $17, $18, $19, $14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29 $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30
) )
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at RETURNING id, created_at
...@@ -160,6 +161,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ...@@ -160,6 +161,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
log.TotalCost, log.TotalCost,
log.ActualCost, log.ActualCost,
rateMultiplier, rateMultiplier,
log.AccountRateMultiplier,
log.BillingType, log.BillingType,
log.Stream, log.Stream,
duration, duration,
...@@ -835,7 +837,9 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID ...@@ -835,7 +837,9 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
SELECT SELECT
COUNT(*) as requests, COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(actual_cost), 0) as cost COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(total_cost), 0) as standard_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs FROM usage_logs
WHERE account_id = $1 AND created_at >= $2 WHERE account_id = $1 AND created_at >= $2
` `
...@@ -849,6 +853,8 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID ...@@ -849,6 +853,8 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
&stats.Requests, &stats.Requests,
&stats.Tokens, &stats.Tokens,
&stats.Cost, &stats.Cost,
&stats.StandardCost,
&stats.UserCost,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
...@@ -861,7 +867,9 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI ...@@ -861,7 +867,9 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
SELECT SELECT
COUNT(*) as requests, COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(actual_cost), 0) as cost COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(total_cost), 0) as standard_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs FROM usage_logs
WHERE account_id = $1 AND created_at >= $2 WHERE account_id = $1 AND created_at >= $2
` `
...@@ -875,6 +883,8 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI ...@@ -875,6 +883,8 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
&stats.Requests, &stats.Requests,
&stats.Tokens, &stats.Tokens,
&stats.Cost, &stats.Cost,
&stats.StandardCost,
&stats.UserCost,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
...@@ -1454,7 +1464,13 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start ...@@ -1454,7 +1464,13 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
// GetModelStatsWithFilters returns model statistics with optional user/api_key filters // GetModelStatsWithFilters returns model statistics with optional user/api_key filters
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) (results []ModelStat, err error) { func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) (results []ModelStat, err error) {
query := ` actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
// 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。
if accountID > 0 && userID == 0 && apiKeyID == 0 {
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
}
query := fmt.Sprintf(`
SELECT SELECT
model, model,
COUNT(*) as requests, COUNT(*) as requests,
...@@ -1462,10 +1478,10 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start ...@@ -1462,10 +1478,10 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
COALESCE(SUM(output_tokens), 0) as output_tokens, COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost, COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost %s
FROM usage_logs FROM usage_logs
WHERE created_at >= $1 AND created_at < $2 WHERE created_at >= $1 AND created_at < $2
` `, actualCostExpr)
args := []any{startTime, endTime} args := []any{startTime, endTime}
if userID > 0 { if userID > 0 {
...@@ -1587,12 +1603,14 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us ...@@ -1587,12 +1603,14 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens, COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
COALESCE(SUM(total_cost), 0) as total_cost, COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost, COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as total_account_cost,
COALESCE(AVG(duration_ms), 0) as avg_duration_ms COALESCE(AVG(duration_ms), 0) as avg_duration_ms
FROM usage_logs FROM usage_logs
%s %s
`, buildWhere(conditions)) `, buildWhere(conditions))
stats := &UsageStats{} stats := &UsageStats{}
var totalAccountCost float64
if err := scanSingleRow( if err := scanSingleRow(
ctx, ctx,
r.sql, r.sql,
...@@ -1604,10 +1622,14 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us ...@@ -1604,10 +1622,14 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
&stats.TotalCacheTokens, &stats.TotalCacheTokens,
&stats.TotalCost, &stats.TotalCost,
&stats.TotalActualCost, &stats.TotalActualCost,
&totalAccountCost,
&stats.AverageDurationMs, &stats.AverageDurationMs,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
if filters.AccountID > 0 {
stats.TotalAccountCost = &totalAccountCost
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
return stats, nil return stats, nil
} }
...@@ -1634,7 +1656,8 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID ...@@ -1634,7 +1656,8 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
COUNT(*) as requests, COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(total_cost), 0) as cost, COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs FROM usage_logs
WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 WHERE account_id = $1 AND created_at >= $2 AND created_at < $3
GROUP BY date GROUP BY date
...@@ -1661,7 +1684,8 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID ...@@ -1661,7 +1684,8 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
var tokens int64 var tokens int64
var cost float64 var cost float64
var actualCost float64 var actualCost float64
if err = rows.Scan(&date, &requests, &tokens, &cost, &actualCost); err != nil { var userCost float64
if err = rows.Scan(&date, &requests, &tokens, &cost, &actualCost, &userCost); err != nil {
return nil, err return nil, err
} }
t, _ := time.Parse("2006-01-02", date) t, _ := time.Parse("2006-01-02", date)
...@@ -1672,19 +1696,21 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID ...@@ -1672,19 +1696,21 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
Tokens: tokens, Tokens: tokens,
Cost: cost, Cost: cost,
ActualCost: actualCost, ActualCost: actualCost,
UserCost: userCost,
}) })
} }
if err = rows.Err(); err != nil { if err = rows.Err(); err != nil {
return nil, err return nil, err
} }
var totalActualCost, totalStandardCost float64 var totalAccountCost, totalUserCost, totalStandardCost float64
var totalRequests, totalTokens int64 var totalRequests, totalTokens int64
var highestCostDay, highestRequestDay *AccountUsageHistory var highestCostDay, highestRequestDay *AccountUsageHistory
for i := range history { for i := range history {
h := &history[i] h := &history[i]
totalActualCost += h.ActualCost totalAccountCost += h.ActualCost
totalUserCost += h.UserCost
totalStandardCost += h.Cost totalStandardCost += h.Cost
totalRequests += h.Requests totalRequests += h.Requests
totalTokens += h.Tokens totalTokens += h.Tokens
...@@ -1711,11 +1737,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID ...@@ -1711,11 +1737,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
summary := AccountUsageSummary{ summary := AccountUsageSummary{
Days: daysCount, Days: daysCount,
ActualDaysUsed: actualDaysUsed, ActualDaysUsed: actualDaysUsed,
TotalCost: totalActualCost, TotalCost: totalAccountCost,
TotalUserCost: totalUserCost,
TotalStandardCost: totalStandardCost, TotalStandardCost: totalStandardCost,
TotalRequests: totalRequests, TotalRequests: totalRequests,
TotalTokens: totalTokens, TotalTokens: totalTokens,
AvgDailyCost: totalActualCost / float64(actualDaysUsed), AvgDailyCost: totalAccountCost / float64(actualDaysUsed),
AvgDailyUserCost: totalUserCost / float64(actualDaysUsed),
AvgDailyRequests: float64(totalRequests) / float64(actualDaysUsed), AvgDailyRequests: float64(totalRequests) / float64(actualDaysUsed),
AvgDailyTokens: float64(totalTokens) / float64(actualDaysUsed), AvgDailyTokens: float64(totalTokens) / float64(actualDaysUsed),
AvgDurationMs: avgDuration, AvgDurationMs: avgDuration,
...@@ -1727,11 +1755,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID ...@@ -1727,11 +1755,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
summary.Today = &struct { summary.Today = &struct {
Date string `json:"date"` Date string `json:"date"`
Cost float64 `json:"cost"` Cost float64 `json:"cost"`
UserCost float64 `json:"user_cost"`
Requests int64 `json:"requests"` Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"` Tokens int64 `json:"tokens"`
}{ }{
Date: history[i].Date, Date: history[i].Date,
Cost: history[i].ActualCost, Cost: history[i].ActualCost,
UserCost: history[i].UserCost,
Requests: history[i].Requests, Requests: history[i].Requests,
Tokens: history[i].Tokens, Tokens: history[i].Tokens,
} }
...@@ -1744,11 +1774,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID ...@@ -1744,11 +1774,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
Date string `json:"date"` Date string `json:"date"`
Label string `json:"label"` Label string `json:"label"`
Cost float64 `json:"cost"` Cost float64 `json:"cost"`
UserCost float64 `json:"user_cost"`
Requests int64 `json:"requests"` Requests int64 `json:"requests"`
}{ }{
Date: highestCostDay.Date, Date: highestCostDay.Date,
Label: highestCostDay.Label, Label: highestCostDay.Label,
Cost: highestCostDay.ActualCost, Cost: highestCostDay.ActualCost,
UserCost: highestCostDay.UserCost,
Requests: highestCostDay.Requests, Requests: highestCostDay.Requests,
} }
} }
...@@ -1759,11 +1791,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID ...@@ -1759,11 +1791,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
Label string `json:"label"` Label string `json:"label"`
Requests int64 `json:"requests"` Requests int64 `json:"requests"`
Cost float64 `json:"cost"` Cost float64 `json:"cost"`
UserCost float64 `json:"user_cost"`
}{ }{
Date: highestRequestDay.Date, Date: highestRequestDay.Date,
Label: highestRequestDay.Label, Label: highestRequestDay.Label,
Requests: highestRequestDay.Requests, Requests: highestRequestDay.Requests,
Cost: highestRequestDay.ActualCost, Cost: highestRequestDay.ActualCost,
UserCost: highestRequestDay.UserCost,
} }
} }
...@@ -2015,6 +2049,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -2015,6 +2049,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
totalCost float64 totalCost float64
actualCost float64 actualCost float64
rateMultiplier float64 rateMultiplier float64
accountRateMultiplier sql.NullFloat64
billingType int16 billingType int16
stream bool stream bool
durationMs sql.NullInt64 durationMs sql.NullInt64
...@@ -2048,6 +2083,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -2048,6 +2083,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&totalCost, &totalCost,
&actualCost, &actualCost,
&rateMultiplier, &rateMultiplier,
&accountRateMultiplier,
&billingType, &billingType,
&stream, &stream,
&durationMs, &durationMs,
...@@ -2080,6 +2116,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -2080,6 +2116,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
TotalCost: totalCost, TotalCost: totalCost,
ActualCost: actualCost, ActualCost: actualCost,
RateMultiplier: rateMultiplier, RateMultiplier: rateMultiplier,
AccountRateMultiplier: nullFloat64Ptr(accountRateMultiplier),
BillingType: int8(billingType), BillingType: int8(billingType),
Stream: stream, Stream: stream,
ImageCount: imageCount, ImageCount: imageCount,
...@@ -2186,6 +2223,14 @@ func nullInt(v *int) sql.NullInt64 { ...@@ -2186,6 +2223,14 @@ func nullInt(v *int) sql.NullInt64 {
return sql.NullInt64{Int64: int64(*v), Valid: true} return sql.NullInt64{Int64: int64(*v), Valid: true}
} }
func nullFloat64Ptr(v sql.NullFloat64) *float64 {
if !v.Valid {
return nil
}
out := v.Float64
return &out
}
func nullString(v *string) sql.NullString { func nullString(v *string) sql.NullString {
if v == nil || *v == "" { if v == nil || *v == "" {
return sql.NullString{} return sql.NullString{}
......
...@@ -11,6 +11,7 @@ import ( ...@@ -11,6 +11,7 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
...@@ -95,6 +96,34 @@ func (s *UsageLogRepoSuite) TestGetByID_NotFound() { ...@@ -95,6 +96,34 @@ func (s *UsageLogRepoSuite) TestGetByID_NotFound() {
s.Require().Error(err, "expected error for non-existent ID") s.Require().Error(err, "expected error for non-existent ID")
} }
func (s *UsageLogRepoSuite) TestGetByID_ReturnsAccountRateMultiplier() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid-mult@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid-mult", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid-mult"})
m := 0.5
log := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.New().String(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 1.0,
ActualCost: 2.0,
AccountRateMultiplier: &m,
CreatedAt: timezone.Today().Add(2 * time.Hour),
}
_, err := s.repo.Create(s.ctx, log)
s.Require().NoError(err)
got, err := s.repo.GetByID(s.ctx, log.ID)
s.Require().NoError(err)
s.Require().NotNil(got.AccountRateMultiplier)
s.Require().InEpsilon(0.5, *got.AccountRateMultiplier, 0.0001)
}
// --- Delete --- // --- Delete ---
func (s *UsageLogRepoSuite) TestDelete() { func (s *UsageLogRepoSuite) TestDelete() {
...@@ -403,12 +432,49 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() { ...@@ -403,12 +432,49 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-today"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-today"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) createdAt := timezone.Today().Add(1 * time.Hour)
m1 := 1.5
m2 := 0.0
_, err := s.repo.Create(s.ctx, &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.New().String(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 1.0,
ActualCost: 2.0,
AccountRateMultiplier: &m1,
CreatedAt: createdAt,
})
s.Require().NoError(err)
_, err = s.repo.Create(s.ctx, &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.New().String(),
Model: "claude-3",
InputTokens: 5,
OutputTokens: 5,
TotalCost: 0.5,
ActualCost: 1.0,
AccountRateMultiplier: &m2,
CreatedAt: createdAt,
})
s.Require().NoError(err)
stats, err := s.repo.GetAccountTodayStats(s.ctx, account.ID) stats, err := s.repo.GetAccountTodayStats(s.ctx, account.ID)
s.Require().NoError(err, "GetAccountTodayStats") s.Require().NoError(err, "GetAccountTodayStats")
s.Require().Equal(int64(1), stats.Requests) s.Require().Equal(int64(2), stats.Requests)
s.Require().Equal(int64(30), stats.Tokens) s.Require().Equal(int64(40), stats.Tokens)
// account cost = SUM(total_cost * account_rate_multiplier)
s.Require().InEpsilon(1.5, stats.Cost, 0.0001)
// standard cost = SUM(total_cost)
s.Require().InEpsilon(1.5, stats.StandardCost, 0.0001)
// user cost = SUM(actual_cost)
s.Require().InEpsilon(3.0, stats.UserCost, 0.0001)
} }
func (s *UsageLogRepoSuite) TestDashboardAggregationConsistency() { func (s *UsageLogRepoSuite) TestDashboardAggregationConsistency() {
......
...@@ -241,6 +241,7 @@ func TestAPIContracts(t *testing.T) { ...@@ -241,6 +241,7 @@ func TestAPIContracts(t *testing.T) {
"total_cost": 0.5, "total_cost": 0.5,
"actual_cost": 0.5, "actual_cost": 0.5,
"rate_multiplier": 1, "rate_multiplier": 1,
"account_rate_multiplier": null,
"billing_type": 0, "billing_type": 0,
"stream": true, "stream": true,
"duration_ms": 100, "duration_ms": 100,
......
...@@ -19,6 +19,9 @@ type Account struct { ...@@ -19,6 +19,9 @@ type Account struct {
ProxyID *int64 ProxyID *int64
Concurrency int Concurrency int
Priority int Priority int
// RateMultiplier 账号计费倍率(>=0,允许 0 表示该账号计费为 0)。
// 使用指针用于兼容旧版本调度缓存(Redis)中缺字段的情况:nil 表示按 1.0 处理。
RateMultiplier *float64
Status string Status string
ErrorMessage string ErrorMessage string
LastUsedAt *time.Time LastUsedAt *time.Time
...@@ -57,6 +60,20 @@ func (a *Account) IsActive() bool { ...@@ -57,6 +60,20 @@ func (a *Account) IsActive() bool {
return a.Status == StatusActive return a.Status == StatusActive
} }
// BillingRateMultiplier 返回账号计费倍率。
// - nil 表示未配置/旧缓存缺字段,按 1.0 处理
// - 允许 0,表示该账号计费为 0
// - 负数属于非法数据,出于安全考虑按 1.0 处理
func (a *Account) BillingRateMultiplier() float64 {
if a == nil || a.RateMultiplier == nil {
return 1.0
}
if *a.RateMultiplier < 0 {
return 1.0
}
return *a.RateMultiplier
}
func (a *Account) IsSchedulable() bool { func (a *Account) IsSchedulable() bool {
if !a.IsActive() || !a.Schedulable { if !a.IsActive() || !a.Schedulable {
return false return false
......
package service
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
)
func TestAccount_BillingRateMultiplier_DefaultsToOneWhenNil(t *testing.T) {
var a Account
require.NoError(t, json.Unmarshal([]byte(`{"id":1,"name":"acc","status":"active"}`), &a))
require.Nil(t, a.RateMultiplier)
require.Equal(t, 1.0, a.BillingRateMultiplier())
}
func TestAccount_BillingRateMultiplier_AllowsZero(t *testing.T) {
v := 0.0
a := Account{RateMultiplier: &v}
require.Equal(t, 0.0, a.BillingRateMultiplier())
}
func TestAccount_BillingRateMultiplier_NegativeFallsBackToOne(t *testing.T) {
v := -1.0
a := Account{RateMultiplier: &v}
require.Equal(t, 1.0, a.BillingRateMultiplier())
}
...@@ -67,6 +67,7 @@ type AccountBulkUpdate struct { ...@@ -67,6 +67,7 @@ type AccountBulkUpdate struct {
ProxyID *int64 ProxyID *int64
Concurrency *int Concurrency *int
Priority *int Priority *int
RateMultiplier *float64
Status *string Status *string
Schedulable *bool Schedulable *bool
Credentials map[string]any Credentials map[string]any
......
...@@ -96,10 +96,16 @@ func NewUsageCache() *UsageCache { ...@@ -96,10 +96,16 @@ func NewUsageCache() *UsageCache {
} }
// WindowStats 窗口期统计 // WindowStats 窗口期统计
//
// cost: 账号口径费用(total_cost * account_rate_multiplier)
// standard_cost: 标准费用(total_cost,不含倍率)
// user_cost: 用户/API Key 口径费用(actual_cost,受分组倍率影响)
type WindowStats struct { type WindowStats struct {
Requests int64 `json:"requests"` Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"` Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"` Cost float64 `json:"cost"`
StandardCost float64 `json:"standard_cost"`
UserCost float64 `json:"user_cost"`
} }
// UsageProgress 使用量进度 // UsageProgress 使用量进度
...@@ -380,6 +386,8 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou ...@@ -380,6 +386,8 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou
Requests: stats.Requests, Requests: stats.Requests,
Tokens: stats.Tokens, Tokens: stats.Tokens,
Cost: stats.Cost, Cost: stats.Cost,
StandardCost: stats.StandardCost,
UserCost: stats.UserCost,
} }
// 缓存窗口统计(1 分钟) // 缓存窗口统计(1 分钟)
...@@ -406,6 +414,8 @@ func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64 ...@@ -406,6 +414,8 @@ func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64
Requests: stats.Requests, Requests: stats.Requests,
Tokens: stats.Tokens, Tokens: stats.Tokens,
Cost: stats.Cost, Cost: stats.Cost,
StandardCost: stats.StandardCost,
UserCost: stats.UserCost,
}, nil }, nil
} }
......
...@@ -136,6 +136,7 @@ type CreateAccountInput struct { ...@@ -136,6 +136,7 @@ type CreateAccountInput struct {
ProxyID *int64 ProxyID *int64
Concurrency int Concurrency int
Priority int Priority int
RateMultiplier *float64 // 账号计费倍率(>=0,允许 0)
GroupIDs []int64 GroupIDs []int64
ExpiresAt *int64 ExpiresAt *int64
AutoPauseOnExpired *bool AutoPauseOnExpired *bool
...@@ -153,6 +154,7 @@ type UpdateAccountInput struct { ...@@ -153,6 +154,7 @@ type UpdateAccountInput struct {
ProxyID *int64 ProxyID *int64
Concurrency *int // 使用指针区分"未提供"和"设置为0" Concurrency *int // 使用指针区分"未提供"和"设置为0"
Priority *int // 使用指针区分"未提供"和"设置为0" Priority *int // 使用指针区分"未提供"和"设置为0"
RateMultiplier *float64 // 账号计费倍率(>=0,允许 0)
Status string Status string
GroupIDs *[]int64 GroupIDs *[]int64
ExpiresAt *int64 ExpiresAt *int64
...@@ -167,6 +169,7 @@ type BulkUpdateAccountsInput struct { ...@@ -167,6 +169,7 @@ type BulkUpdateAccountsInput struct {
ProxyID *int64 ProxyID *int64
Concurrency *int Concurrency *int
Priority *int Priority *int
RateMultiplier *float64 // 账号计费倍率(>=0,允许 0)
Status string Status string
Schedulable *bool Schedulable *bool
GroupIDs *[]int64 GroupIDs *[]int64
...@@ -817,6 +820,12 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou ...@@ -817,6 +820,12 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
} else { } else {
account.AutoPauseOnExpired = true account.AutoPauseOnExpired = true
} }
if input.RateMultiplier != nil {
if *input.RateMultiplier < 0 {
return nil, errors.New("rate_multiplier must be >= 0")
}
account.RateMultiplier = input.RateMultiplier
}
if err := s.accountRepo.Create(ctx, account); err != nil { if err := s.accountRepo.Create(ctx, account); err != nil {
return nil, err return nil, err
} }
...@@ -869,6 +878,12 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U ...@@ -869,6 +878,12 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
if input.Priority != nil { if input.Priority != nil {
account.Priority = *input.Priority account.Priority = *input.Priority
} }
if input.RateMultiplier != nil {
if *input.RateMultiplier < 0 {
return nil, errors.New("rate_multiplier must be >= 0")
}
account.RateMultiplier = input.RateMultiplier
}
if input.Status != "" { if input.Status != "" {
account.Status = input.Status account.Status = input.Status
} }
...@@ -942,6 +957,12 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp ...@@ -942,6 +957,12 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
} }
} }
if input.RateMultiplier != nil {
if *input.RateMultiplier < 0 {
return nil, errors.New("rate_multiplier must be >= 0")
}
}
// Prepare bulk updates for columns and JSONB fields. // Prepare bulk updates for columns and JSONB fields.
repoUpdates := AccountBulkUpdate{ repoUpdates := AccountBulkUpdate{
Credentials: input.Credentials, Credentials: input.Credentials,
...@@ -959,6 +980,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp ...@@ -959,6 +980,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
if input.Priority != nil { if input.Priority != nil {
repoUpdates.Priority = input.Priority repoUpdates.Priority = input.Priority
} }
if input.RateMultiplier != nil {
repoUpdates.RateMultiplier = input.RateMultiplier
}
if input.Status != "" { if input.Status != "" {
repoUpdates.Status = &input.Status repoUpdates.Status = &input.Status
} }
......
...@@ -1211,6 +1211,72 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { ...@@ -1211,6 +1211,72 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
require.Nil(t, result) require.Nil(t, result)
require.Contains(t, err.Error(), "no available accounts") require.Contains(t, err.Error(), "no available accounts")
}) })
t.Run("过滤不可调度账号-限流账号被跳过", func(t *testing.T) {
now := time.Now()
resetAt := now.Add(10 * time.Minute)
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, RateLimitResetAt: &resetAt},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil,
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "应跳过限流账号,选择可用账号")
})
t.Run("过滤不可调度账号-过载账号被跳过", func(t *testing.T) {
now := time.Now()
overloadUntil := now.Add(10 * time.Minute)
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, OverloadUntil: &overloadUntil},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil,
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "应跳过过载账号,选择可用账号")
})
} }
func TestGatewayService_GroupResolution_ReusesContextGroup(t *testing.T) { func TestGatewayService_GroupResolution_ReusesContextGroup(t *testing.T) {
......
...@@ -511,6 +511,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -511,6 +511,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if isExcluded(acc.ID) { if isExcluded(acc.ID) {
continue continue
} }
// Scheduler snapshots can be temporarily stale (bucket rebuild is throttled);
// re-check schedulability here so recently rate-limited/overloaded accounts
// are not selected again before the bucket is rebuilt.
if !acc.IsSchedulable() {
continue
}
if !s.isAccountAllowedForPlatform(acc, platform, useMixed) { if !s.isAccountAllowedForPlatform(acc, platform, useMixed) {
continue continue
} }
...@@ -893,6 +899,11 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, ...@@ -893,6 +899,11 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if _, excluded := excludedIDs[acc.ID]; excluded { if _, excluded := excludedIDs[acc.ID]; excluded {
continue continue
} }
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
// avoid selecting accounts that were recently rate-limited/overloaded.
if !acc.IsSchedulable() {
continue
}
if !acc.IsSchedulableForModel(requestedModel) { if !acc.IsSchedulableForModel(requestedModel) {
continue continue
} }
...@@ -977,6 +988,11 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g ...@@ -977,6 +988,11 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if _, excluded := excludedIDs[acc.ID]; excluded { if _, excluded := excludedIDs[acc.ID]; excluded {
continue continue
} }
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
// avoid selecting accounts that were recently rate-limited/overloaded.
if !acc.IsSchedulable() {
continue
}
// 过滤:原生平台直接通过,antigravity 需要启用混合调度 // 过滤:原生平台直接通过,antigravity 需要启用混合调度
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue continue
...@@ -2618,6 +2634,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -2618,6 +2634,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
if result.ImageSize != "" { if result.ImageSize != "" {
imageSize = &result.ImageSize imageSize = &result.ImageSize
} }
accountRateMultiplier := account.BillingRateMultiplier()
usageLog := &UsageLog{ usageLog := &UsageLog{
UserID: user.ID, UserID: user.ID,
APIKeyID: apiKey.ID, APIKeyID: apiKey.ID,
...@@ -2635,6 +2652,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -2635,6 +2652,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
TotalCost: cost.TotalCost, TotalCost: cost.TotalCost,
ActualCost: cost.ActualCost, ActualCost: cost.ActualCost,
RateMultiplier: multiplier, RateMultiplier: multiplier,
AccountRateMultiplier: &accountRateMultiplier,
BillingType: billingType, BillingType: billingType,
Stream: result.Stream, Stream: result.Stream,
DurationMs: &durationMs, DurationMs: &durationMs,
......
...@@ -186,6 +186,11 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C ...@@ -186,6 +186,11 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
if _, excluded := excludedIDs[acc.ID]; excluded { if _, excluded := excludedIDs[acc.ID]; excluded {
continue continue
} }
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
// avoid selecting accounts that were recently rate-limited/overloaded.
if !acc.IsSchedulable() {
continue
}
// Check model support // Check model support
if requestedModel != "" && !acc.IsModelSupported(requestedModel) { if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
continue continue
...@@ -332,6 +337,12 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex ...@@ -332,6 +337,12 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if isExcluded(acc.ID) { if isExcluded(acc.ID) {
continue continue
} }
// Scheduler snapshots can be temporarily stale (bucket rebuild is throttled);
// re-check schedulability here so recently rate-limited/overloaded accounts
// are not selected again before the bucket is rebuilt.
if !acc.IsSchedulable() {
continue
}
if requestedModel != "" && !acc.IsModelSupported(requestedModel) { if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
continue continue
} }
...@@ -1432,6 +1443,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec ...@@ -1432,6 +1443,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
// Create usage log // Create usage log
durationMs := int(result.Duration.Milliseconds()) durationMs := int(result.Duration.Milliseconds())
accountRateMultiplier := account.BillingRateMultiplier()
usageLog := &UsageLog{ usageLog := &UsageLog{
UserID: user.ID, UserID: user.ID,
APIKeyID: apiKey.ID, APIKeyID: apiKey.ID,
...@@ -1449,6 +1461,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec ...@@ -1449,6 +1461,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
TotalCost: cost.TotalCost, TotalCost: cost.TotalCost,
ActualCost: cost.ActualCost, ActualCost: cost.ActualCost,
RateMultiplier: multiplier, RateMultiplier: multiplier,
AccountRateMultiplier: &accountRateMultiplier,
BillingType: billingType, BillingType: billingType,
Stream: result.Stream, Stream: result.Stream,
DurationMs: &durationMs, DurationMs: &durationMs,
......
...@@ -3,6 +3,7 @@ package service ...@@ -3,6 +3,7 @@ package service
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"context"
"errors" "errors"
"io" "io"
"net/http" "net/http"
...@@ -15,6 +16,129 @@ import ( ...@@ -15,6 +16,129 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
type stubOpenAIAccountRepo struct {
AccountRepository
accounts []Account
}
func (r stubOpenAIAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
return append([]Account(nil), r.accounts...), nil
}
func (r stubOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
return append([]Account(nil), r.accounts...), nil
}
type stubConcurrencyCache struct {
ConcurrencyCache
}
func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
return true, nil
}
func (c stubConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
return nil
}
func (c stubConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
out := make(map[int64]*AccountLoadInfo, len(accounts))
for _, acc := range accounts {
out[acc.ID] = &AccountLoadInfo{AccountID: acc.ID, LoadRate: 0}
}
return out, nil
}
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) {
now := time.Now()
resetAt := now.Add(10 * time.Minute)
groupID := int64(1)
rateLimited := Account{
ID: 1,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
RateLimitResetAt: &resetAt,
}
available := Account{
ID: 2,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 1,
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{rateLimited, available}},
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
}
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-5.2", nil)
if err != nil {
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
}
if selection == nil || selection.Account == nil {
t.Fatalf("expected selection with account")
}
if selection.Account.ID != available.ID {
t.Fatalf("expected account %d, got %d", available.ID, selection.Account.ID)
}
if selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
}
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulableWhenNoConcurrencyService(t *testing.T) {
now := time.Now()
resetAt := now.Add(10 * time.Minute)
groupID := int64(1)
rateLimited := Account{
ID: 1,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
RateLimitResetAt: &resetAt,
}
available := Account{
ID: 2,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 1,
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{rateLimited, available}},
// concurrencyService is nil, forcing the non-load-batch selection path.
}
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-5.2", nil)
if err != nil {
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
}
if selection == nil || selection.Account == nil {
t.Fatalf("expected selection with account")
}
if selection.Account.ID != available.ID {
t.Fatalf("expected account %d, got %d", available.ID, selection.Account.ID)
}
if selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
}
func TestOpenAIStreamingTimeout(t *testing.T) { func TestOpenAIStreamingTimeout(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
cfg := &config.Config{ cfg := &config.Config{
......
...@@ -33,6 +33,8 @@ type UsageLog struct { ...@@ -33,6 +33,8 @@ type UsageLog struct {
TotalCost float64 TotalCost float64
ActualCost float64 ActualCost float64
RateMultiplier float64 RateMultiplier float64
// AccountRateMultiplier 账号计费倍率快照(nil 表示历史数据,按 1.0 处理)
AccountRateMultiplier *float64
BillingType int8 BillingType int8
Stream bool Stream bool
......
-- Add account billing rate multiplier and per-usage snapshot.
--
-- accounts.rate_multiplier: 账号计费倍率(>=0,允许 0 表示该账号计费为 0)。
-- usage_logs.account_rate_multiplier: 每条 usage log 的账号倍率快照,用于实现
-- “倍率调整仅影响之后请求”,并支持同一天分段倍率加权统计。
--
-- 注意:usage_logs.account_rate_multiplier 不做回填、不设置 NOT NULL。
-- 老数据为 NULL 时,统计口径按 1.0 处理(COALESCE)。
ALTER TABLE IF EXISTS accounts
ADD COLUMN IF NOT EXISTS rate_multiplier DECIMAL(10,4) NOT NULL DEFAULT 1.0;
ALTER TABLE IF EXISTS usage_logs
ADD COLUMN IF NOT EXISTS account_rate_multiplier DECIMAL(10,4);
...@@ -16,6 +16,7 @@ export interface AdminUsageStatsResponse { ...@@ -16,6 +16,7 @@ export interface AdminUsageStatsResponse {
total_tokens: number total_tokens: number
total_cost: number total_cost: number
total_actual_cost: number total_actual_cost: number
total_account_cost?: number
average_duration_ms: number average_duration_ms: number
} }
......
...@@ -73,11 +73,12 @@ ...@@ -73,11 +73,12 @@
</p> </p>
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400"> <p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.stats.accumulatedCost') }} {{ t('admin.accounts.stats.accumulatedCost') }}
<span class="text-gray-400 dark:text-gray-500" <span class="text-gray-400 dark:text-gray-500">
>({{ t('admin.accounts.stats.standardCost') }}: ${{ ({{ t('usage.userBilled') }}: ${{ formatCost(stats.summary.total_user_cost) }} ·
{{ t('admin.accounts.stats.standardCost') }}: ${{
formatCost(stats.summary.total_standard_cost) formatCost(stats.summary.total_standard_cost)
}})</span }})
> </span>
</p> </p>
</div> </div>
...@@ -127,6 +128,9 @@ ...@@ -127,6 +128,9 @@
days: stats.summary.actual_days_used days: stats.summary.actual_days_used
}) })
}} }}
<span class="text-gray-400 dark:text-gray-500">
({{ t('usage.userBilled') }}: ${{ formatCost(stats.summary.avg_daily_user_cost) }})
</span>
</p> </p>
</div> </div>
...@@ -189,13 +193,17 @@ ...@@ -189,13 +193,17 @@
</div> </div>
<div class="space-y-2"> <div class="space-y-2">
<div class="flex items-center justify-between"> <div class="flex items-center justify-between">
<span class="text-xs text-gray-500 dark:text-gray-400">{{ <span class="text-xs text-gray-500 dark:text-gray-400">{{ t('usage.accountBilled') }}</span>
t('admin.accounts.stats.cost')
}}</span>
<span class="text-sm font-semibold text-gray-900 dark:text-white" <span class="text-sm font-semibold text-gray-900 dark:text-white"
>${{ formatCost(stats.summary.today?.cost || 0) }}</span >${{ formatCost(stats.summary.today?.cost || 0) }}</span
> >
</div> </div>
<div class="flex items-center justify-between">
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('usage.userBilled') }}</span>
<span class="text-sm font-semibold text-gray-900 dark:text-white"
>${{ formatCost(stats.summary.today?.user_cost || 0) }}</span
>
</div>
<div class="flex items-center justify-between"> <div class="flex items-center justify-between">
<span class="text-xs text-gray-500 dark:text-gray-400">{{ <span class="text-xs text-gray-500 dark:text-gray-400">{{
t('admin.accounts.stats.requests') t('admin.accounts.stats.requests')
...@@ -240,13 +248,17 @@ ...@@ -240,13 +248,17 @@
}}</span> }}</span>
</div> </div>
<div class="flex items-center justify-between"> <div class="flex items-center justify-between">
<span class="text-xs text-gray-500 dark:text-gray-400">{{ <span class="text-xs text-gray-500 dark:text-gray-400">{{ t('usage.accountBilled') }}</span>
t('admin.accounts.stats.cost')
}}</span>
<span class="text-sm font-semibold text-orange-600 dark:text-orange-400" <span class="text-sm font-semibold text-orange-600 dark:text-orange-400"
>${{ formatCost(stats.summary.highest_cost_day?.cost || 0) }}</span >${{ formatCost(stats.summary.highest_cost_day?.cost || 0) }}</span
> >
</div> </div>
<div class="flex items-center justify-between">
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('usage.userBilled') }}</span>
<span class="text-sm font-semibold text-gray-900 dark:text-white"
>${{ formatCost(stats.summary.highest_cost_day?.user_cost || 0) }}</span
>
</div>
<div class="flex items-center justify-between"> <div class="flex items-center justify-between">
<span class="text-xs text-gray-500 dark:text-gray-400">{{ <span class="text-xs text-gray-500 dark:text-gray-400">{{
t('admin.accounts.stats.requests') t('admin.accounts.stats.requests')
...@@ -291,13 +303,17 @@ ...@@ -291,13 +303,17 @@
}}</span> }}</span>
</div> </div>
<div class="flex items-center justify-between"> <div class="flex items-center justify-between">
<span class="text-xs text-gray-500 dark:text-gray-400">{{ <span class="text-xs text-gray-500 dark:text-gray-400">{{ t('usage.accountBilled') }}</span>
t('admin.accounts.stats.cost')
}}</span>
<span class="text-sm font-semibold text-gray-900 dark:text-white" <span class="text-sm font-semibold text-gray-900 dark:text-white"
>${{ formatCost(stats.summary.highest_request_day?.cost || 0) }}</span >${{ formatCost(stats.summary.highest_request_day?.cost || 0) }}</span
> >
</div> </div>
<div class="flex items-center justify-between">
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('usage.userBilled') }}</span>
<span class="text-sm font-semibold text-gray-900 dark:text-white"
>${{ formatCost(stats.summary.highest_request_day?.user_cost || 0) }}</span
>
</div>
</div> </div>
</div> </div>
</div> </div>
...@@ -397,13 +413,17 @@ ...@@ -397,13 +413,17 @@
}}</span> }}</span>
</div> </div>
<div class="flex items-center justify-between"> <div class="flex items-center justify-between">
<span class="text-xs text-gray-500 dark:text-gray-400">{{ <span class="text-xs text-gray-500 dark:text-gray-400">{{ t('usage.accountBilled') }}</span>
t('admin.accounts.stats.todayCost')
}}</span>
<span class="text-sm font-semibold text-gray-900 dark:text-white" <span class="text-sm font-semibold text-gray-900 dark:text-white"
>${{ formatCost(stats.summary.today?.cost || 0) }}</span >${{ formatCost(stats.summary.today?.cost || 0) }}</span
> >
</div> </div>
<div class="flex items-center justify-between">
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('usage.userBilled') }}</span>
<span class="text-sm font-semibold text-gray-900 dark:text-white"
>${{ formatCost(stats.summary.today?.user_cost || 0) }}</span
>
</div>
</div> </div>
</div> </div>
</div> </div>
...@@ -517,14 +537,24 @@ const trendChartData = computed(() => { ...@@ -517,14 +537,24 @@ const trendChartData = computed(() => {
labels: stats.value.history.map((h) => h.label), labels: stats.value.history.map((h) => h.label),
datasets: [ datasets: [
{ {
label: t('admin.accounts.stats.cost') + ' (USD)', label: t('usage.accountBilled') + ' (USD)',
data: stats.value.history.map((h) => h.cost), data: stats.value.history.map((h) => h.actual_cost),
borderColor: '#3b82f6', borderColor: '#3b82f6',
backgroundColor: 'rgba(59, 130, 246, 0.1)', backgroundColor: 'rgba(59, 130, 246, 0.1)',
fill: true, fill: true,
tension: 0.3, tension: 0.3,
yAxisID: 'y' yAxisID: 'y'
}, },
{
label: t('usage.userBilled') + ' (USD)',
data: stats.value.history.map((h) => h.user_cost),
borderColor: '#10b981',
backgroundColor: 'rgba(16, 185, 129, 0.08)',
fill: false,
tension: 0.3,
borderDash: [5, 5],
yAxisID: 'y'
},
{ {
label: t('admin.accounts.stats.requests'), label: t('admin.accounts.stats.requests'),
data: stats.value.history.map((h) => h.requests), data: stats.value.history.map((h) => h.requests),
...@@ -602,7 +632,7 @@ const lineChartOptions = computed(() => ({ ...@@ -602,7 +632,7 @@ const lineChartOptions = computed(() => ({
}, },
title: { title: {
display: true, display: true,
text: t('admin.accounts.stats.cost') + ' (USD)', text: t('usage.accountBilled') + ' (USD)',
color: '#3b82f6', color: '#3b82f6',
font: { font: {
size: 11 size: 11
......
...@@ -32,15 +32,20 @@ ...@@ -32,15 +32,20 @@
formatTokens(stats.tokens) formatTokens(stats.tokens)
}}</span> }}</span>
</div> </div>
<!-- Cost --> <!-- Cost (Account) -->
<div class="flex items-center gap-1"> <div class="flex items-center gap-1">
<span class="text-gray-500 dark:text-gray-400" <span class="text-gray-500 dark:text-gray-400">{{ t('usage.accountBilled') }}:</span>
>{{ t('admin.accounts.stats.cost') }}:</span
>
<span class="font-medium text-emerald-600 dark:text-emerald-400">{{ <span class="font-medium text-emerald-600 dark:text-emerald-400">{{
formatCurrency(stats.cost) formatCurrency(stats.cost)
}}</span> }}</span>
</div> </div>
<!-- Cost (User/API Key) -->
<div v-if="stats.user_cost != null" class="flex items-center gap-1">
<span class="text-gray-500 dark:text-gray-400">{{ t('usage.userBilled') }}:</span>
<span class="font-medium text-gray-700 dark:text-gray-300">{{
formatCurrency(stats.user_cost)
}}</span>
</div>
</div> </div>
<!-- No data --> <!-- No data -->
......
...@@ -459,7 +459,7 @@ ...@@ -459,7 +459,7 @@
</div> </div>
<!-- Concurrency & Priority --> <!-- Concurrency & Priority -->
<div class="grid grid-cols-2 gap-4 border-t border-gray-200 pt-4 dark:border-dark-600"> <div class="grid grid-cols-2 gap-4 border-t border-gray-200 pt-4 dark:border-dark-600 lg:grid-cols-3">
<div> <div>
<div class="mb-3 flex items-center justify-between"> <div class="mb-3 flex items-center justify-between">
<label <label
...@@ -516,6 +516,36 @@ ...@@ -516,6 +516,36 @@
aria-labelledby="bulk-edit-priority-label" aria-labelledby="bulk-edit-priority-label"
/> />
</div> </div>
<div>
<div class="mb-3 flex items-center justify-between">
<label
id="bulk-edit-rate-multiplier-label"
class="input-label mb-0"
for="bulk-edit-rate-multiplier-enabled"
>
{{ t('admin.accounts.billingRateMultiplier') }}
</label>
<input
v-model="enableRateMultiplier"
id="bulk-edit-rate-multiplier-enabled"
type="checkbox"
aria-controls="bulk-edit-rate-multiplier"
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
/>
</div>
<input
v-model.number="rateMultiplier"
id="bulk-edit-rate-multiplier"
type="number"
min="0"
step="0.01"
:disabled="!enableRateMultiplier"
class="input"
:class="!enableRateMultiplier && 'cursor-not-allowed opacity-50'"
aria-labelledby="bulk-edit-rate-multiplier-label"
/>
<p class="input-hint">{{ t('admin.accounts.billingRateMultiplierHint') }}</p>
</div>
</div> </div>
<!-- Status --> <!-- Status -->
...@@ -655,6 +685,7 @@ const enableInterceptWarmup = ref(false) ...@@ -655,6 +685,7 @@ const enableInterceptWarmup = ref(false)
const enableProxy = ref(false) const enableProxy = ref(false)
const enableConcurrency = ref(false) const enableConcurrency = ref(false)
const enablePriority = ref(false) const enablePriority = ref(false)
const enableRateMultiplier = ref(false)
const enableStatus = ref(false) const enableStatus = ref(false)
const enableGroups = ref(false) const enableGroups = ref(false)
...@@ -670,6 +701,7 @@ const interceptWarmupRequests = ref(false) ...@@ -670,6 +701,7 @@ const interceptWarmupRequests = ref(false)
const proxyId = ref<number | null>(null) const proxyId = ref<number | null>(null)
const concurrency = ref(1) const concurrency = ref(1)
const priority = ref(1) const priority = ref(1)
const rateMultiplier = ref(1)
const status = ref<'active' | 'inactive'>('active') const status = ref<'active' | 'inactive'>('active')
const groupIds = ref<number[]>([]) const groupIds = ref<number[]>([])
...@@ -863,6 +895,10 @@ const buildUpdatePayload = (): Record<string, unknown> | null => { ...@@ -863,6 +895,10 @@ const buildUpdatePayload = (): Record<string, unknown> | null => {
updates.priority = priority.value updates.priority = priority.value
} }
if (enableRateMultiplier.value) {
updates.rate_multiplier = rateMultiplier.value
}
if (enableStatus.value) { if (enableStatus.value) {
updates.status = status.value updates.status = status.value
} }
...@@ -923,6 +959,7 @@ const handleSubmit = async () => { ...@@ -923,6 +959,7 @@ const handleSubmit = async () => {
enableProxy.value || enableProxy.value ||
enableConcurrency.value || enableConcurrency.value ||
enablePriority.value || enablePriority.value ||
enableRateMultiplier.value ||
enableStatus.value || enableStatus.value ||
enableGroups.value enableGroups.value
...@@ -977,6 +1014,7 @@ watch( ...@@ -977,6 +1014,7 @@ watch(
enableProxy.value = false enableProxy.value = false
enableConcurrency.value = false enableConcurrency.value = false
enablePriority.value = false enablePriority.value = false
enableRateMultiplier.value = false
enableStatus.value = false enableStatus.value = false
enableGroups.value = false enableGroups.value = false
...@@ -991,6 +1029,7 @@ watch( ...@@ -991,6 +1029,7 @@ watch(
proxyId.value = null proxyId.value = null
concurrency.value = 1 concurrency.value = 1
priority.value = 1 priority.value = 1
rateMultiplier.value = 1
status.value = 'active' status.value = 'active'
groupIds.value = [] groupIds.value = []
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment