Commit 2a5ef6d3 authored by shaw's avatar shaw
Browse files

Merge PR #279: feat(计费): 账号计费倍率快照与账号口径费用统计

parents 99cbfa15 ec24a3c3
......@@ -80,6 +80,10 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
SetSchedulable(account.Schedulable).
SetAutoPauseOnExpired(account.AutoPauseOnExpired)
if account.RateMultiplier != nil {
builder.SetRateMultiplier(*account.RateMultiplier)
}
if account.ProxyID != nil {
builder.SetProxyID(*account.ProxyID)
}
......@@ -291,6 +295,10 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
SetSchedulable(account.Schedulable).
SetAutoPauseOnExpired(account.AutoPauseOnExpired)
if account.RateMultiplier != nil {
builder.SetRateMultiplier(*account.RateMultiplier)
}
if account.ProxyID != nil {
builder.SetProxyID(*account.ProxyID)
} else {
......@@ -999,6 +1007,11 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
args = append(args, *updates.Priority)
idx++
}
if updates.RateMultiplier != nil {
setClauses = append(setClauses, "rate_multiplier = $"+itoa(idx))
args = append(args, *updates.RateMultiplier)
idx++
}
if updates.Status != nil {
setClauses = append(setClauses, "status = $"+itoa(idx))
args = append(args, *updates.Status)
......@@ -1347,6 +1360,8 @@ func accountEntityToService(m *dbent.Account) *service.Account {
return nil
}
rateMultiplier := m.RateMultiplier
return &service.Account{
ID: m.ID,
Name: m.Name,
......@@ -1358,6 +1373,7 @@ func accountEntityToService(m *dbent.Account) *service.Account {
ProxyID: m.ProxyID,
Concurrency: m.Concurrency,
Priority: m.Priority,
RateMultiplier: &rateMultiplier,
Status: m.Status,
ErrorMessage: derefString(m.ErrorMessage),
LastUsedAt: m.LastUsedAt,
......
......@@ -71,7 +71,9 @@ usage_agg AS (
error_base AS (
SELECT
date_trunc('hour', created_at AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' AS bucket_start,
platform AS platform,
-- platform is NULL for some early-phase errors (e.g. before routing); map to a sentinel
-- value so platform-level GROUPING SETS don't collide with the overall (platform=NULL) row.
COALESCE(platform, 'unknown') AS platform,
group_id AS group_id,
is_business_limited AS is_business_limited,
error_owner AS error_owner,
......
......@@ -22,7 +22,7 @@ import (
"github.com/lib/pq"
)
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, created_at"
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, created_at"
type usageLogRepository struct {
client *dbent.Client
......@@ -105,6 +105,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
total_cost,
actual_cost,
rate_multiplier,
account_rate_multiplier,
billing_type,
stream,
duration_ms,
......@@ -120,7 +121,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
$8, $9, $10, $11,
$12, $13,
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
......@@ -160,6 +161,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
log.TotalCost,
log.ActualCost,
rateMultiplier,
log.AccountRateMultiplier,
log.BillingType,
log.Stream,
duration,
......@@ -835,7 +837,9 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
SELECT
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(actual_cost), 0) as cost
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(total_cost), 0) as standard_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
WHERE account_id = $1 AND created_at >= $2
`
......@@ -849,6 +853,8 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
&stats.Requests,
&stats.Tokens,
&stats.Cost,
&stats.StandardCost,
&stats.UserCost,
); err != nil {
return nil, err
}
......@@ -861,7 +867,9 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
SELECT
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(actual_cost), 0) as cost
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(total_cost), 0) as standard_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
WHERE account_id = $1 AND created_at >= $2
`
......@@ -875,6 +883,8 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
&stats.Requests,
&stats.Tokens,
&stats.Cost,
&stats.StandardCost,
&stats.UserCost,
); err != nil {
return nil, err
}
......@@ -1454,7 +1464,13 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
// GetModelStatsWithFilters returns model statistics with optional user/api_key filters
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) (results []ModelStat, err error) {
query := `
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
// 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。
if accountID > 0 && userID == 0 && apiKeyID == 0 {
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
}
query := fmt.Sprintf(`
SELECT
model,
COUNT(*) as requests,
......@@ -1462,10 +1478,10 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
%s
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
`
`, actualCostExpr)
args := []any{startTime, endTime}
if userID > 0 {
......@@ -1587,12 +1603,14 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as total_account_cost,
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
FROM usage_logs
%s
`, buildWhere(conditions))
stats := &UsageStats{}
var totalAccountCost float64
if err := scanSingleRow(
ctx,
r.sql,
......@@ -1604,10 +1622,14 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
&stats.TotalCacheTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&totalAccountCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err
}
if filters.AccountID > 0 {
stats.TotalAccountCost = &totalAccountCost
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
return stats, nil
}
......@@ -1634,7 +1656,8 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
WHERE account_id = $1 AND created_at >= $2 AND created_at < $3
GROUP BY date
......@@ -1661,7 +1684,8 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
var tokens int64
var cost float64
var actualCost float64
if err = rows.Scan(&date, &requests, &tokens, &cost, &actualCost); err != nil {
var userCost float64
if err = rows.Scan(&date, &requests, &tokens, &cost, &actualCost, &userCost); err != nil {
return nil, err
}
t, _ := time.Parse("2006-01-02", date)
......@@ -1672,19 +1696,21 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
Tokens: tokens,
Cost: cost,
ActualCost: actualCost,
UserCost: userCost,
})
}
if err = rows.Err(); err != nil {
return nil, err
}
var totalActualCost, totalStandardCost float64
var totalAccountCost, totalUserCost, totalStandardCost float64
var totalRequests, totalTokens int64
var highestCostDay, highestRequestDay *AccountUsageHistory
for i := range history {
h := &history[i]
totalActualCost += h.ActualCost
totalAccountCost += h.ActualCost
totalUserCost += h.UserCost
totalStandardCost += h.Cost
totalRequests += h.Requests
totalTokens += h.Tokens
......@@ -1711,11 +1737,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
summary := AccountUsageSummary{
Days: daysCount,
ActualDaysUsed: actualDaysUsed,
TotalCost: totalActualCost,
TotalCost: totalAccountCost,
TotalUserCost: totalUserCost,
TotalStandardCost: totalStandardCost,
TotalRequests: totalRequests,
TotalTokens: totalTokens,
AvgDailyCost: totalActualCost / float64(actualDaysUsed),
AvgDailyCost: totalAccountCost / float64(actualDaysUsed),
AvgDailyUserCost: totalUserCost / float64(actualDaysUsed),
AvgDailyRequests: float64(totalRequests) / float64(actualDaysUsed),
AvgDailyTokens: float64(totalTokens) / float64(actualDaysUsed),
AvgDurationMs: avgDuration,
......@@ -1727,11 +1755,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
summary.Today = &struct {
Date string `json:"date"`
Cost float64 `json:"cost"`
UserCost float64 `json:"user_cost"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
}{
Date: history[i].Date,
Cost: history[i].ActualCost,
UserCost: history[i].UserCost,
Requests: history[i].Requests,
Tokens: history[i].Tokens,
}
......@@ -1744,11 +1774,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
Date string `json:"date"`
Label string `json:"label"`
Cost float64 `json:"cost"`
UserCost float64 `json:"user_cost"`
Requests int64 `json:"requests"`
}{
Date: highestCostDay.Date,
Label: highestCostDay.Label,
Cost: highestCostDay.ActualCost,
UserCost: highestCostDay.UserCost,
Requests: highestCostDay.Requests,
}
}
......@@ -1759,11 +1791,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
Label string `json:"label"`
Requests int64 `json:"requests"`
Cost float64 `json:"cost"`
UserCost float64 `json:"user_cost"`
}{
Date: highestRequestDay.Date,
Label: highestRequestDay.Label,
Requests: highestRequestDay.Requests,
Cost: highestRequestDay.ActualCost,
UserCost: highestRequestDay.UserCost,
}
}
......@@ -2015,6 +2049,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
totalCost float64
actualCost float64
rateMultiplier float64
accountRateMultiplier sql.NullFloat64
billingType int16
stream bool
durationMs sql.NullInt64
......@@ -2048,6 +2083,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&totalCost,
&actualCost,
&rateMultiplier,
&accountRateMultiplier,
&billingType,
&stream,
&durationMs,
......@@ -2080,6 +2116,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
TotalCost: totalCost,
ActualCost: actualCost,
RateMultiplier: rateMultiplier,
AccountRateMultiplier: nullFloat64Ptr(accountRateMultiplier),
BillingType: int8(billingType),
Stream: stream,
ImageCount: imageCount,
......@@ -2186,6 +2223,14 @@ func nullInt(v *int) sql.NullInt64 {
return sql.NullInt64{Int64: int64(*v), Valid: true}
}
func nullFloat64Ptr(v sql.NullFloat64) *float64 {
if !v.Valid {
return nil
}
out := v.Float64
return &out
}
func nullString(v *string) sql.NullString {
if v == nil || *v == "" {
return sql.NullString{}
......
......@@ -11,6 +11,7 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
......@@ -95,6 +96,34 @@ func (s *UsageLogRepoSuite) TestGetByID_NotFound() {
s.Require().Error(err, "expected error for non-existent ID")
}
func (s *UsageLogRepoSuite) TestGetByID_ReturnsAccountRateMultiplier() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid-mult@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid-mult", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid-mult"})
m := 0.5
log := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.New().String(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 1.0,
ActualCost: 2.0,
AccountRateMultiplier: &m,
CreatedAt: timezone.Today().Add(2 * time.Hour),
}
_, err := s.repo.Create(s.ctx, log)
s.Require().NoError(err)
got, err := s.repo.GetByID(s.ctx, log.ID)
s.Require().NoError(err)
s.Require().NotNil(got.AccountRateMultiplier)
s.Require().InEpsilon(0.5, *got.AccountRateMultiplier, 0.0001)
}
// --- Delete ---
func (s *UsageLogRepoSuite) TestDelete() {
......@@ -403,12 +432,49 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-today"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
createdAt := timezone.Today().Add(1 * time.Hour)
m1 := 1.5
m2 := 0.0
_, err := s.repo.Create(s.ctx, &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.New().String(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 1.0,
ActualCost: 2.0,
AccountRateMultiplier: &m1,
CreatedAt: createdAt,
})
s.Require().NoError(err)
_, err = s.repo.Create(s.ctx, &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.New().String(),
Model: "claude-3",
InputTokens: 5,
OutputTokens: 5,
TotalCost: 0.5,
ActualCost: 1.0,
AccountRateMultiplier: &m2,
CreatedAt: createdAt,
})
s.Require().NoError(err)
stats, err := s.repo.GetAccountTodayStats(s.ctx, account.ID)
s.Require().NoError(err, "GetAccountTodayStats")
s.Require().Equal(int64(1), stats.Requests)
s.Require().Equal(int64(30), stats.Tokens)
s.Require().Equal(int64(2), stats.Requests)
s.Require().Equal(int64(40), stats.Tokens)
// account cost = SUM(total_cost * account_rate_multiplier)
s.Require().InEpsilon(1.5, stats.Cost, 0.0001)
// standard cost = SUM(total_cost)
s.Require().InEpsilon(1.5, stats.StandardCost, 0.0001)
// user cost = SUM(actual_cost)
s.Require().InEpsilon(3.0, stats.UserCost, 0.0001)
}
func (s *UsageLogRepoSuite) TestDashboardAggregationConsistency() {
......
......@@ -241,6 +241,7 @@ func TestAPIContracts(t *testing.T) {
"total_cost": 0.5,
"actual_cost": 0.5,
"rate_multiplier": 1,
"account_rate_multiplier": null,
"billing_type": 0,
"stream": true,
"duration_ms": 100,
......
......@@ -19,6 +19,9 @@ type Account struct {
ProxyID *int64
Concurrency int
Priority int
// RateMultiplier 账号计费倍率(>=0,允许 0 表示该账号计费为 0)。
// 使用指针用于兼容旧版本调度缓存(Redis)中缺字段的情况:nil 表示按 1.0 处理。
RateMultiplier *float64
Status string
ErrorMessage string
LastUsedAt *time.Time
......@@ -57,6 +60,20 @@ func (a *Account) IsActive() bool {
return a.Status == StatusActive
}
// BillingRateMultiplier 返回账号计费倍率。
// - nil 表示未配置/旧缓存缺字段,按 1.0 处理
// - 允许 0,表示该账号计费为 0
// - 负数属于非法数据,出于安全考虑按 1.0 处理
func (a *Account) BillingRateMultiplier() float64 {
if a == nil || a.RateMultiplier == nil {
return 1.0
}
if *a.RateMultiplier < 0 {
return 1.0
}
return *a.RateMultiplier
}
func (a *Account) IsSchedulable() bool {
if !a.IsActive() || !a.Schedulable {
return false
......
package service
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
)
func TestAccount_BillingRateMultiplier_DefaultsToOneWhenNil(t *testing.T) {
var a Account
require.NoError(t, json.Unmarshal([]byte(`{"id":1,"name":"acc","status":"active"}`), &a))
require.Nil(t, a.RateMultiplier)
require.Equal(t, 1.0, a.BillingRateMultiplier())
}
func TestAccount_BillingRateMultiplier_AllowsZero(t *testing.T) {
v := 0.0
a := Account{RateMultiplier: &v}
require.Equal(t, 0.0, a.BillingRateMultiplier())
}
func TestAccount_BillingRateMultiplier_NegativeFallsBackToOne(t *testing.T) {
v := -1.0
a := Account{RateMultiplier: &v}
require.Equal(t, 1.0, a.BillingRateMultiplier())
}
......@@ -67,6 +67,7 @@ type AccountBulkUpdate struct {
ProxyID *int64
Concurrency *int
Priority *int
RateMultiplier *float64
Status *string
Schedulable *bool
Credentials map[string]any
......
......@@ -96,10 +96,16 @@ func NewUsageCache() *UsageCache {
}
// WindowStats 窗口期统计
//
// cost: 账号口径费用(total_cost * account_rate_multiplier)
// standard_cost: 标准费用(total_cost,不含倍率)
// user_cost: 用户/API Key 口径费用(actual_cost,受分组倍率影响)
type WindowStats struct {
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"`
StandardCost float64 `json:"standard_cost"`
UserCost float64 `json:"user_cost"`
}
// UsageProgress 使用量进度
......@@ -380,6 +386,8 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou
Requests: stats.Requests,
Tokens: stats.Tokens,
Cost: stats.Cost,
StandardCost: stats.StandardCost,
UserCost: stats.UserCost,
}
// 缓存窗口统计(1 分钟)
......@@ -406,6 +414,8 @@ func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64
Requests: stats.Requests,
Tokens: stats.Tokens,
Cost: stats.Cost,
StandardCost: stats.StandardCost,
UserCost: stats.UserCost,
}, nil
}
......
......@@ -136,6 +136,7 @@ type CreateAccountInput struct {
ProxyID *int64
Concurrency int
Priority int
RateMultiplier *float64 // 账号计费倍率(>=0,允许 0)
GroupIDs []int64
ExpiresAt *int64
AutoPauseOnExpired *bool
......@@ -153,6 +154,7 @@ type UpdateAccountInput struct {
ProxyID *int64
Concurrency *int // 使用指针区分"未提供"和"设置为0"
Priority *int // 使用指针区分"未提供"和"设置为0"
RateMultiplier *float64 // 账号计费倍率(>=0,允许 0)
Status string
GroupIDs *[]int64
ExpiresAt *int64
......@@ -167,6 +169,7 @@ type BulkUpdateAccountsInput struct {
ProxyID *int64
Concurrency *int
Priority *int
RateMultiplier *float64 // 账号计费倍率(>=0,允许 0)
Status string
Schedulable *bool
GroupIDs *[]int64
......@@ -817,6 +820,12 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
} else {
account.AutoPauseOnExpired = true
}
if input.RateMultiplier != nil {
if *input.RateMultiplier < 0 {
return nil, errors.New("rate_multiplier must be >= 0")
}
account.RateMultiplier = input.RateMultiplier
}
if err := s.accountRepo.Create(ctx, account); err != nil {
return nil, err
}
......@@ -869,6 +878,12 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
if input.Priority != nil {
account.Priority = *input.Priority
}
if input.RateMultiplier != nil {
if *input.RateMultiplier < 0 {
return nil, errors.New("rate_multiplier must be >= 0")
}
account.RateMultiplier = input.RateMultiplier
}
if input.Status != "" {
account.Status = input.Status
}
......@@ -942,6 +957,12 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
}
}
if input.RateMultiplier != nil {
if *input.RateMultiplier < 0 {
return nil, errors.New("rate_multiplier must be >= 0")
}
}
// Prepare bulk updates for columns and JSONB fields.
repoUpdates := AccountBulkUpdate{
Credentials: input.Credentials,
......@@ -959,6 +980,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
if input.Priority != nil {
repoUpdates.Priority = input.Priority
}
if input.RateMultiplier != nil {
repoUpdates.RateMultiplier = input.RateMultiplier
}
if input.Status != "" {
repoUpdates.Status = &input.Status
}
......
......@@ -1211,6 +1211,72 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
require.Nil(t, result)
require.Contains(t, err.Error(), "no available accounts")
})
t.Run("过滤不可调度账号-限流账号被跳过", func(t *testing.T) {
now := time.Now()
resetAt := now.Add(10 * time.Minute)
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, RateLimitResetAt: &resetAt},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil,
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "应跳过限流账号,选择可用账号")
})
t.Run("过滤不可调度账号-过载账号被跳过", func(t *testing.T) {
now := time.Now()
overloadUntil := now.Add(10 * time.Minute)
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, OverloadUntil: &overloadUntil},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil,
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "应跳过过载账号,选择可用账号")
})
}
func TestGatewayService_GroupResolution_ReusesContextGroup(t *testing.T) {
......
......@@ -511,6 +511,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if isExcluded(acc.ID) {
continue
}
// Scheduler snapshots can be temporarily stale (bucket rebuild is throttled);
// re-check schedulability here so recently rate-limited/overloaded accounts
// are not selected again before the bucket is rebuilt.
if !acc.IsSchedulable() {
continue
}
if !s.isAccountAllowedForPlatform(acc, platform, useMixed) {
continue
}
......@@ -893,6 +899,11 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
// avoid selecting accounts that were recently rate-limited/overloaded.
if !acc.IsSchedulable() {
continue
}
if !acc.IsSchedulableForModel(requestedModel) {
continue
}
......@@ -977,6 +988,11 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
// avoid selecting accounts that were recently rate-limited/overloaded.
if !acc.IsSchedulable() {
continue
}
// 过滤:原生平台直接通过,antigravity 需要启用混合调度
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
......@@ -2618,6 +2634,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
if result.ImageSize != "" {
imageSize = &result.ImageSize
}
accountRateMultiplier := account.BillingRateMultiplier()
usageLog := &UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
......@@ -2635,6 +2652,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
TotalCost: cost.TotalCost,
ActualCost: cost.ActualCost,
RateMultiplier: multiplier,
AccountRateMultiplier: &accountRateMultiplier,
BillingType: billingType,
Stream: result.Stream,
DurationMs: &durationMs,
......
......@@ -186,6 +186,11 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
// avoid selecting accounts that were recently rate-limited/overloaded.
if !acc.IsSchedulable() {
continue
}
// Check model support
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
continue
......@@ -332,6 +337,12 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if isExcluded(acc.ID) {
continue
}
// Scheduler snapshots can be temporarily stale (bucket rebuild is throttled);
// re-check schedulability here so recently rate-limited/overloaded accounts
// are not selected again before the bucket is rebuilt.
if !acc.IsSchedulable() {
continue
}
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
continue
}
......@@ -1432,6 +1443,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
// Create usage log
durationMs := int(result.Duration.Milliseconds())
accountRateMultiplier := account.BillingRateMultiplier()
usageLog := &UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
......@@ -1449,6 +1461,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
TotalCost: cost.TotalCost,
ActualCost: cost.ActualCost,
RateMultiplier: multiplier,
AccountRateMultiplier: &accountRateMultiplier,
BillingType: billingType,
Stream: result.Stream,
DurationMs: &durationMs,
......
......@@ -3,6 +3,7 @@ package service
import (
"bufio"
"bytes"
"context"
"errors"
"io"
"net/http"
......@@ -15,6 +16,129 @@ import (
"github.com/gin-gonic/gin"
)
type stubOpenAIAccountRepo struct {
AccountRepository
accounts []Account
}
func (r stubOpenAIAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
return append([]Account(nil), r.accounts...), nil
}
func (r stubOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
return append([]Account(nil), r.accounts...), nil
}
type stubConcurrencyCache struct {
ConcurrencyCache
}
func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
return true, nil
}
func (c stubConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
return nil
}
func (c stubConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
out := make(map[int64]*AccountLoadInfo, len(accounts))
for _, acc := range accounts {
out[acc.ID] = &AccountLoadInfo{AccountID: acc.ID, LoadRate: 0}
}
return out, nil
}
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) {
now := time.Now()
resetAt := now.Add(10 * time.Minute)
groupID := int64(1)
rateLimited := Account{
ID: 1,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
RateLimitResetAt: &resetAt,
}
available := Account{
ID: 2,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 1,
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{rateLimited, available}},
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
}
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-5.2", nil)
if err != nil {
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
}
if selection == nil || selection.Account == nil {
t.Fatalf("expected selection with account")
}
if selection.Account.ID != available.ID {
t.Fatalf("expected account %d, got %d", available.ID, selection.Account.ID)
}
if selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
}
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulableWhenNoConcurrencyService(t *testing.T) {
now := time.Now()
resetAt := now.Add(10 * time.Minute)
groupID := int64(1)
rateLimited := Account{
ID: 1,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
RateLimitResetAt: &resetAt,
}
available := Account{
ID: 2,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 1,
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{rateLimited, available}},
// concurrencyService is nil, forcing the non-load-batch selection path.
}
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-5.2", nil)
if err != nil {
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
}
if selection == nil || selection.Account == nil {
t.Fatalf("expected selection with account")
}
if selection.Account.ID != available.ID {
t.Fatalf("expected account %d, got %d", available.ID, selection.Account.ID)
}
if selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
}
func TestOpenAIStreamingTimeout(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
......
......@@ -33,6 +33,8 @@ type UsageLog struct {
TotalCost float64
ActualCost float64
RateMultiplier float64
// AccountRateMultiplier 账号计费倍率快照(nil 表示历史数据,按 1.0 处理)
AccountRateMultiplier *float64
BillingType int8
Stream bool
......
-- Add account billing rate multiplier and per-usage snapshot.
--
-- accounts.rate_multiplier: 账号计费倍率(>=0,允许 0 表示该账号计费为 0)。
-- usage_logs.account_rate_multiplier: 每条 usage log 的账号倍率快照,用于实现
-- “倍率调整仅影响之后请求”,并支持同一天分段倍率加权统计。
--
-- 注意:usage_logs.account_rate_multiplier 不做回填、不设置 NOT NULL。
-- 老数据为 NULL 时,统计口径按 1.0 处理(COALESCE)。
ALTER TABLE IF EXISTS accounts
ADD COLUMN IF NOT EXISTS rate_multiplier DECIMAL(10,4) NOT NULL DEFAULT 1.0;
ALTER TABLE IF EXISTS usage_logs
ADD COLUMN IF NOT EXISTS account_rate_multiplier DECIMAL(10,4);
......@@ -16,6 +16,7 @@ export interface AdminUsageStatsResponse {
total_tokens: number
total_cost: number
total_actual_cost: number
total_account_cost?: number
average_duration_ms: number
}
......
......@@ -73,11 +73,12 @@
</p>
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.stats.accumulatedCost') }}
<span class="text-gray-400 dark:text-gray-500"
>({{ t('admin.accounts.stats.standardCost') }}: ${{
<span class="text-gray-400 dark:text-gray-500">
({{ t('usage.userBilled') }}: ${{ formatCost(stats.summary.total_user_cost) }} ·
{{ t('admin.accounts.stats.standardCost') }}: ${{
formatCost(stats.summary.total_standard_cost)
}})</span
>
}})
</span>
</p>
</div>
......@@ -127,6 +128,9 @@
days: stats.summary.actual_days_used
})
}}
<span class="text-gray-400 dark:text-gray-500">
({{ t('usage.userBilled') }}: ${{ formatCost(stats.summary.avg_daily_user_cost) }})
</span>
</p>
</div>
......@@ -189,13 +193,17 @@
</div>
<div class="space-y-2">
<div class="flex items-center justify-between">
<span class="text-xs text-gray-500 dark:text-gray-400">{{
t('admin.accounts.stats.cost')
}}</span>
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('usage.accountBilled') }}</span>
<span class="text-sm font-semibold text-gray-900 dark:text-white"
>${{ formatCost(stats.summary.today?.cost || 0) }}</span
>
</div>
<div class="flex items-center justify-between">
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('usage.userBilled') }}</span>
<span class="text-sm font-semibold text-gray-900 dark:text-white"
>${{ formatCost(stats.summary.today?.user_cost || 0) }}</span
>
</div>
<div class="flex items-center justify-between">
<span class="text-xs text-gray-500 dark:text-gray-400">{{
t('admin.accounts.stats.requests')
......@@ -240,13 +248,17 @@
}}</span>
</div>
<div class="flex items-center justify-between">
<span class="text-xs text-gray-500 dark:text-gray-400">{{
t('admin.accounts.stats.cost')
}}</span>
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('usage.accountBilled') }}</span>
<span class="text-sm font-semibold text-orange-600 dark:text-orange-400"
>${{ formatCost(stats.summary.highest_cost_day?.cost || 0) }}</span
>
</div>
<div class="flex items-center justify-between">
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('usage.userBilled') }}</span>
<span class="text-sm font-semibold text-gray-900 dark:text-white"
>${{ formatCost(stats.summary.highest_cost_day?.user_cost || 0) }}</span
>
</div>
<div class="flex items-center justify-between">
<span class="text-xs text-gray-500 dark:text-gray-400">{{
t('admin.accounts.stats.requests')
......@@ -291,13 +303,17 @@
}}</span>
</div>
<div class="flex items-center justify-between">
<span class="text-xs text-gray-500 dark:text-gray-400">{{
t('admin.accounts.stats.cost')
}}</span>
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('usage.accountBilled') }}</span>
<span class="text-sm font-semibold text-gray-900 dark:text-white"
>${{ formatCost(stats.summary.highest_request_day?.cost || 0) }}</span
>
</div>
<div class="flex items-center justify-between">
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('usage.userBilled') }}</span>
<span class="text-sm font-semibold text-gray-900 dark:text-white"
>${{ formatCost(stats.summary.highest_request_day?.user_cost || 0) }}</span
>
</div>
</div>
</div>
</div>
......@@ -397,13 +413,17 @@
}}</span>
</div>
<div class="flex items-center justify-between">
<span class="text-xs text-gray-500 dark:text-gray-400">{{
t('admin.accounts.stats.todayCost')
}}</span>
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('usage.accountBilled') }}</span>
<span class="text-sm font-semibold text-gray-900 dark:text-white"
>${{ formatCost(stats.summary.today?.cost || 0) }}</span
>
</div>
<div class="flex items-center justify-between">
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('usage.userBilled') }}</span>
<span class="text-sm font-semibold text-gray-900 dark:text-white"
>${{ formatCost(stats.summary.today?.user_cost || 0) }}</span
>
</div>
</div>
</div>
</div>
......@@ -517,14 +537,24 @@ const trendChartData = computed(() => {
labels: stats.value.history.map((h) => h.label),
datasets: [
{
label: t('admin.accounts.stats.cost') + ' (USD)',
data: stats.value.history.map((h) => h.cost),
label: t('usage.accountBilled') + ' (USD)',
data: stats.value.history.map((h) => h.actual_cost),
borderColor: '#3b82f6',
backgroundColor: 'rgba(59, 130, 246, 0.1)',
fill: true,
tension: 0.3,
yAxisID: 'y'
},
{
label: t('usage.userBilled') + ' (USD)',
data: stats.value.history.map((h) => h.user_cost),
borderColor: '#10b981',
backgroundColor: 'rgba(16, 185, 129, 0.08)',
fill: false,
tension: 0.3,
borderDash: [5, 5],
yAxisID: 'y'
},
{
label: t('admin.accounts.stats.requests'),
data: stats.value.history.map((h) => h.requests),
......@@ -602,7 +632,7 @@ const lineChartOptions = computed(() => ({
},
title: {
display: true,
text: t('admin.accounts.stats.cost') + ' (USD)',
text: t('usage.accountBilled') + ' (USD)',
color: '#3b82f6',
font: {
size: 11
......
......@@ -32,15 +32,20 @@
formatTokens(stats.tokens)
}}</span>
</div>
<!-- Cost -->
<!-- Cost (Account) -->
<div class="flex items-center gap-1">
<span class="text-gray-500 dark:text-gray-400"
>{{ t('admin.accounts.stats.cost') }}:</span
>
<span class="text-gray-500 dark:text-gray-400">{{ t('usage.accountBilled') }}:</span>
<span class="font-medium text-emerald-600 dark:text-emerald-400">{{
formatCurrency(stats.cost)
}}</span>
</div>
<!-- Cost (User/API Key) -->
<div v-if="stats.user_cost != null" class="flex items-center gap-1">
<span class="text-gray-500 dark:text-gray-400">{{ t('usage.userBilled') }}:</span>
<span class="font-medium text-gray-700 dark:text-gray-300">{{
formatCurrency(stats.user_cost)
}}</span>
</div>
</div>
<!-- No data -->
......
......@@ -459,7 +459,7 @@
</div>
<!-- Concurrency & Priority -->
<div class="grid grid-cols-2 gap-4 border-t border-gray-200 pt-4 dark:border-dark-600">
<div class="grid grid-cols-2 gap-4 border-t border-gray-200 pt-4 dark:border-dark-600 lg:grid-cols-3">
<div>
<div class="mb-3 flex items-center justify-between">
<label
......@@ -516,6 +516,36 @@
aria-labelledby="bulk-edit-priority-label"
/>
</div>
<div>
<div class="mb-3 flex items-center justify-between">
<label
id="bulk-edit-rate-multiplier-label"
class="input-label mb-0"
for="bulk-edit-rate-multiplier-enabled"
>
{{ t('admin.accounts.billingRateMultiplier') }}
</label>
<input
v-model="enableRateMultiplier"
id="bulk-edit-rate-multiplier-enabled"
type="checkbox"
aria-controls="bulk-edit-rate-multiplier"
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
/>
</div>
<input
v-model.number="rateMultiplier"
id="bulk-edit-rate-multiplier"
type="number"
min="0"
step="0.01"
:disabled="!enableRateMultiplier"
class="input"
:class="!enableRateMultiplier && 'cursor-not-allowed opacity-50'"
aria-labelledby="bulk-edit-rate-multiplier-label"
/>
<p class="input-hint">{{ t('admin.accounts.billingRateMultiplierHint') }}</p>
</div>
</div>
<!-- Status -->
......@@ -655,6 +685,7 @@ const enableInterceptWarmup = ref(false)
const enableProxy = ref(false)
const enableConcurrency = ref(false)
const enablePriority = ref(false)
const enableRateMultiplier = ref(false)
const enableStatus = ref(false)
const enableGroups = ref(false)
......@@ -670,6 +701,7 @@ const interceptWarmupRequests = ref(false)
const proxyId = ref<number | null>(null)
const concurrency = ref(1)
const priority = ref(1)
const rateMultiplier = ref(1)
const status = ref<'active' | 'inactive'>('active')
const groupIds = ref<number[]>([])
......@@ -863,6 +895,10 @@ const buildUpdatePayload = (): Record<string, unknown> | null => {
updates.priority = priority.value
}
if (enableRateMultiplier.value) {
updates.rate_multiplier = rateMultiplier.value
}
if (enableStatus.value) {
updates.status = status.value
}
......@@ -923,6 +959,7 @@ const handleSubmit = async () => {
enableProxy.value ||
enableConcurrency.value ||
enablePriority.value ||
enableRateMultiplier.value ||
enableStatus.value ||
enableGroups.value
......@@ -977,6 +1014,7 @@ watch(
enableProxy.value = false
enableConcurrency.value = false
enablePriority.value = false
enableRateMultiplier.value = false
enableStatus.value = false
enableGroups.value = false
......@@ -991,6 +1029,7 @@ watch(
proxyId.value = null
concurrency.value = 1
priority.value = 1
rateMultiplier.value = 1
status.value = 'active'
groupIds.value = []
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment