Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
0b746501
Commit
0b746501
authored
Apr 16, 2026
by
陈曦
Browse files
1. merge upstream v0.1.113 2.提交migration相关文件
parents
45061102
be7551b9
Changes
225
Hide whitespace changes
Inline
Side-by-side
backend/internal/repository/usage_billing_repo.go
View file @
0b746501
...
...
@@ -113,9 +113,11 @@ func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, t
}
if
cmd
.
BalanceCost
>
0
{
if
err
:=
deductUsageBillingBalance
(
ctx
,
tx
,
cmd
.
UserID
,
cmd
.
BalanceCost
);
err
!=
nil
{
newBalance
,
err
:=
deductUsageBillingBalance
(
ctx
,
tx
,
cmd
.
UserID
,
cmd
.
BalanceCost
)
if
err
!=
nil
{
return
err
}
result
.
NewBalance
=
&
newBalance
}
if
cmd
.
APIKeyQuotaCost
>
0
{
...
...
@@ -133,9 +135,11 @@ func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, t
}
if
cmd
.
AccountQuotaCost
>
0
&&
(
strings
.
EqualFold
(
cmd
.
AccountType
,
service
.
AccountTypeAPIKey
)
||
strings
.
EqualFold
(
cmd
.
AccountType
,
service
.
AccountTypeBedrock
))
{
if
err
:=
incrementUsageBillingAccountQuota
(
ctx
,
tx
,
cmd
.
AccountID
,
cmd
.
AccountQuotaCost
);
err
!=
nil
{
quotaState
,
err
:=
incrementUsageBillingAccountQuota
(
ctx
,
tx
,
cmd
.
AccountID
,
cmd
.
AccountQuotaCost
)
if
err
!=
nil
{
return
err
}
result
.
QuotaState
=
quotaState
}
return
nil
...
...
@@ -169,24 +173,22 @@ func incrementUsageBillingSubscription(ctx context.Context, tx *sql.Tx, subscrip
return
service
.
ErrSubscriptionNotFound
}
func
deductUsageBillingBalance
(
ctx
context
.
Context
,
tx
*
sql
.
Tx
,
userID
int64
,
amount
float64
)
error
{
res
,
err
:=
tx
.
ExecContext
(
ctx
,
`
func
deductUsageBillingBalance
(
ctx
context
.
Context
,
tx
*
sql
.
Tx
,
userID
int64
,
amount
float64
)
(
float64
,
error
)
{
var
newBalance
float64
err
:=
tx
.
QueryRowContext
(
ctx
,
`
UPDATE users
SET balance = balance - $1,
updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL
`
,
amount
,
userID
)
if
err
!=
nil
{
return
err
RETURNING balance
`
,
amount
,
userID
)
.
Scan
(
&
newBalance
)
if
errors
.
Is
(
err
,
sql
.
ErrNoRows
)
{
return
0
,
service
.
ErrUserNotFound
}
affected
,
err
:=
res
.
RowsAffected
()
if
err
!=
nil
{
return
err
}
if
affected
>
0
{
return
nil
return
0
,
err
}
return
service
.
ErrUserNotFound
return
newBalance
,
nil
}
func
incrementUsageBillingAPIKeyQuota
(
ctx
context
.
Context
,
tx
*
sql
.
Tx
,
apiKeyID
int64
,
amount
float64
)
(
bool
,
error
)
{
...
...
@@ -240,7 +242,7 @@ func incrementUsageBillingAPIKeyRateLimit(ctx context.Context, tx *sql.Tx, apiKe
return
nil
}
func
incrementUsageBillingAccountQuota
(
ctx
context
.
Context
,
tx
*
sql
.
Tx
,
accountID
int64
,
amount
float64
)
error
{
func
incrementUsageBillingAccountQuota
(
ctx
context
.
Context
,
tx
*
sql
.
Tx
,
accountID
int64
,
amount
float64
)
(
*
service
.
AccountQuotaState
,
error
)
{
rows
,
err
:=
tx
.
QueryContext
(
ctx
,
`UPDATE accounts SET extra = (
COALESCE(extra, '{}'::jsonb)
...
...
@@ -248,61 +250,71 @@ func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountI
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
jsonb_build_object(
'quota_daily_used',
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
+ '24 hours'::interval <= NOW()
CASE WHEN `
+
dailyExpiredExpr
+
`
THEN $1
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
'quota_daily_start',
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
+ '24 hours'::interval <= NOW()
CASE WHEN `
+
dailyExpiredExpr
+
`
THEN `
+
nowUTC
+
`
ELSE COALESCE(extra->>'quota_daily_start', `
+
nowUTC
+
`) END
)
|| CASE WHEN `
+
dailyExpiredExpr
+
` AND `
+
nextDailyResetAtExpr
+
` IS NOT NULL
THEN jsonb_build_object('quota_daily_reset_at', `
+
nextDailyResetAtExpr
+
`)
ELSE '{}'::jsonb END
ELSE '{}'::jsonb END
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
jsonb_build_object(
'quota_weekly_used',
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
+ '168 hours'::interval <= NOW()
CASE WHEN `
+
weeklyExpiredExpr
+
`
THEN $1
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
'quota_weekly_start',
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
+ '168 hours'::interval <= NOW()
CASE WHEN `
+
weeklyExpiredExpr
+
`
THEN `
+
nowUTC
+
`
ELSE COALESCE(extra->>'quota_weekly_start', `
+
nowUTC
+
`) END
)
|| CASE WHEN `
+
weeklyExpiredExpr
+
` AND `
+
nextWeeklyResetAtExpr
+
` IS NOT NULL
THEN jsonb_build_object('quota_weekly_reset_at', `
+
nextWeeklyResetAtExpr
+
`)
ELSE '{}'::jsonb END
ELSE '{}'::jsonb END
), updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL
RETURNING
COALESCE((extra->>'quota_used')::numeric, 0),
COALESCE((extra->>'quota_limit')::numeric, 0)`
,
COALESCE((extra->>'quota_limit')::numeric, 0),
COALESCE((extra->>'quota_daily_used')::numeric, 0),
COALESCE((extra->>'quota_daily_limit')::numeric, 0),
COALESCE((extra->>'quota_weekly_used')::numeric, 0),
COALESCE((extra->>'quota_weekly_limit')::numeric, 0)`
,
amount
,
accountID
)
if
err
!=
nil
{
return
err
return
nil
,
err
}
defer
func
()
{
_
=
rows
.
Close
()
}()
var
newUsed
,
limit
float64
var
state
service
.
AccountQuotaState
if
rows
.
Next
()
{
if
err
:=
rows
.
Scan
(
&
newUsed
,
&
limit
);
err
!=
nil
{
return
err
if
err
:=
rows
.
Scan
(
&
state
.
TotalUsed
,
&
state
.
TotalLimit
,
&
state
.
DailyUsed
,
&
state
.
DailyLimit
,
&
state
.
WeeklyUsed
,
&
state
.
WeeklyLimit
,
);
err
!=
nil
{
return
nil
,
err
}
}
else
{
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
err
return
nil
,
err
}
return
service
.
ErrAccountNotFound
return
nil
,
service
.
ErrAccountNotFound
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
err
return
nil
,
err
}
if
limit
>
0
&&
newUsed
>=
limit
&&
(
new
Used
-
amount
)
<
l
imit
{
if
state
.
TotalLimit
>
0
&&
state
.
TotalUsed
>=
state
.
TotalLimit
&&
(
state
.
Total
Used
-
amount
)
<
state
.
TotalL
imit
{
if
err
:=
enqueueSchedulerOutbox
(
ctx
,
tx
,
service
.
SchedulerOutboxEventAccountChanged
,
&
accountID
,
nil
,
nil
);
err
!=
nil
{
logger
.
LegacyPrintf
(
"repository.usage_billing"
,
"[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v"
,
accountID
,
err
)
return
err
return
nil
,
err
}
}
return
nil
return
&
state
,
nil
}
backend/internal/repository/usage_log_repo.go
View file @
0b746501
...
...
@@ -28,7 +28,7 @@ import (
gocache
"github.com/patrickmn/go-cache"
)
const
usageLogSelectColumns
=
"id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, created_at"
const
usageLogSelectColumns
=
"id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode,
account_stats_cost,
created_at"
// usageLogInsertArgTypes must stay in the same order as:
// 1. prepareUsageLogInsert().args
...
...
@@ -82,6 +82,7 @@ var usageLogInsertArgTypes = [...]string{
"text"
,
// model_mapping_chain
"text"
,
// billing_tier
"text"
,
// billing_mode
"numeric"
,
// account_stats_cost
"timestamptz"
,
// created_at
}
...
...
@@ -360,6 +361,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
model_mapping_chain,
billing_tier,
billing_mode,
account_stats_cost,
created_at
) VALUES (
$1, $2, $3, $4, $5, $6, $7,
...
...
@@ -367,7 +369,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
$10, $11, $12, $13,
$14, $15, $16, $17,
$18, $19, $20, $21, $22, $23,
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45
, $46
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
...
...
@@ -797,6 +799,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
model_mapping_chain,
billing_tier,
billing_mode,
account_stats_cost,
created_at
) AS (VALUES `
)
...
...
@@ -873,6 +876,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
model_mapping_chain,
billing_tier,
billing_mode,
account_stats_cost,
created_at
)
SELECT
...
...
@@ -920,6 +924,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
model_mapping_chain,
billing_tier,
billing_mode,
account_stats_cost,
created_at
FROM input
ON CONFLICT (request_id, api_key_id) DO NOTHING
...
...
@@ -1007,10 +1012,11 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
model_mapping_chain,
billing_tier,
billing_mode,
account_stats_cost,
created_at
) AS (VALUES `
)
args
:=
make
([]
any
,
0
,
len
(
preparedList
)
*
4
5
)
args
:=
make
([]
any
,
0
,
len
(
preparedList
)
*
4
6
)
argPos
:=
1
for
idx
,
prepared
:=
range
preparedList
{
if
idx
>
0
{
...
...
@@ -1080,6 +1086,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
model_mapping_chain,
billing_tier,
billing_mode,
account_stats_cost,
created_at
)
SELECT
...
...
@@ -1127,6 +1134,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
model_mapping_chain,
billing_tier,
billing_mode,
account_stats_cost,
created_at
FROM input
ON CONFLICT (request_id, api_key_id) DO NOTHING
...
...
@@ -1182,6 +1190,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
model_mapping_chain,
billing_tier,
billing_mode,
account_stats_cost,
created_at
) VALUES (
$1, $2, $3, $4, $5, $6, $7,
...
...
@@ -1189,7 +1198,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
$10, $11, $12, $13,
$14, $15, $16, $17,
$18, $19, $20, $21, $22, $23,
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45
, $46
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
`
,
prepared
.
args
...
)
...
...
@@ -1285,6 +1294,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
modelMappingChain
,
billingTier
,
billingMode
,
log
.
AccountStatsCost
,
// account_stats_cost
createdAt
,
},
}
...
...
@@ -1518,6 +1528,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Conte
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(SUM(account_cost), 0) as total_account_cost,
COALESCE(SUM(total_duration_ms), 0) as total_duration_ms
FROM usage_dashboard_daily
`
...
...
@@ -1534,6 +1545,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Conte
&
stats
.
TotalCacheReadTokens
,
&
stats
.
TotalCost
,
&
stats
.
TotalActualCost
,
&
stats
.
TotalAccountCost
,
&
totalDurationMs
,
);
err
!=
nil
{
return
err
...
...
@@ -1552,6 +1564,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Conte
cache_read_tokens as today_cache_read_tokens,
total_cost as today_cost,
actual_cost as today_actual_cost,
account_cost as today_account_cost,
active_users as active_users
FROM usage_dashboard_daily
WHERE bucket_date = $1::date
...
...
@@ -1568,6 +1581,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Conte
&
stats
.
TodayCacheReadTokens
,
&
stats
.
TodayCost
,
&
stats
.
TodayActualCost
,
&
stats
.
TodayAccountCost
,
&
stats
.
ActiveUsers
,
);
err
!=
nil
{
if
err
!=
sql
.
ErrNoRows
{
...
...
@@ -1603,6 +1617,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co
cache_read_tokens,
total_cost,
actual_cost,
COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1) AS account_cost,
COALESCE(duration_ms, 0) AS duration_ms
FROM usage_logs
WHERE created_at >= LEAST($1::timestamptz, $3::timestamptz)
...
...
@@ -1616,6 +1631,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co
COALESCE(SUM(cache_read_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cache_read_tokens,
COALESCE(SUM(total_cost) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cost,
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_actual_cost,
COALESCE(SUM(account_cost) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_account_cost,
COALESCE(SUM(duration_ms) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_duration_ms,
COUNT(*) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz) AS today_requests,
COALESCE(SUM(input_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_input_tokens,
...
...
@@ -1623,7 +1639,8 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co
COALESCE(SUM(cache_creation_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cache_creation_tokens,
COALESCE(SUM(cache_read_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cache_read_tokens,
COALESCE(SUM(total_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cost,
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_actual_cost
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_actual_cost,
COALESCE(SUM(account_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_account_cost
FROM scoped
`
var
totalDurationMs
int64
...
...
@@ -1639,6 +1656,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co
&
stats
.
TotalCacheReadTokens
,
&
stats
.
TotalCost
,
&
stats
.
TotalActualCost
,
&
stats
.
TotalAccountCost
,
&
totalDurationMs
,
&
stats
.
TodayRequests
,
&
stats
.
TodayInputTokens
,
...
...
@@ -1647,6 +1665,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co
&
stats
.
TodayCacheReadTokens
,
&
stats
.
TodayCost
,
&
stats
.
TodayActualCost
,
&
stats
.
TodayAccountCost
,
);
err
!=
nil
{
return
err
}
...
...
@@ -1959,7 +1978,7 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
SELECT
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(
COALESCE(account_stats_cost,
total_cost
)
* COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(total_cost), 0) as standard_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
...
...
@@ -1989,7 +2008,7 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
SELECT
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(
COALESCE(account_stats_cost,
total_cost
)
* COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(total_cost), 0) as standard_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
...
...
@@ -2026,7 +2045,7 @@ func (r *usageLogRepository) GetAccountWindowStatsBatch(ctx context.Context, acc
account_id,
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(
COALESCE(account_stats_cost,
total_cost
)
* COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(total_cost), 0) as standard_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
...
...
@@ -2585,7 +2604,8 @@ func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
COALESCE(SUM(actual_cost), 0) as actual_cost,
COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as account_cost
FROM usage_logs
WHERE user_id = $1 AND created_at >= $2 AND created_at < $3
GROUP BY model
...
...
@@ -2990,8 +3010,9 @@ func (r *usageLogRepository) getModelStatsWithFiltersBySource(ctx context.Contex
actualCostExpr
:=
"COALESCE(SUM(actual_cost), 0) as actual_cost"
// 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。
if
accountID
>
0
&&
userID
==
0
&&
apiKeyID
==
0
{
actualCostExpr
=
"COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
actualCostExpr
=
"COALESCE(SUM(
COALESCE(account_stats_cost,
total_cost
)
* COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
}
accountCostExpr
:=
"COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as account_cost"
modelExpr
:=
resolveModelDimensionExpression
(
source
)
query
:=
fmt
.
Sprintf
(
`
...
...
@@ -3004,10 +3025,11 @@ func (r *usageLogRepository) getModelStatsWithFiltersBySource(ctx context.Contex
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
%s,
%s
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
`
,
modelExpr
,
actualCostExpr
)
`
,
modelExpr
,
actualCostExpr
,
accountCostExpr
)
args
:=
[]
any
{
startTime
,
endTime
}
if
userID
>
0
{
...
...
@@ -3062,7 +3084,8 @@ func (r *usageLogRepository) GetGroupStatsWithFilters(ctx context.Context, start
COUNT(*) as requests,
COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(ul.total_cost), 0) as cost,
COALESCE(SUM(ul.actual_cost), 0) as actual_cost
COALESCE(SUM(ul.actual_cost), 0) as actual_cost,
COALESCE(SUM(COALESCE(ul.account_stats_cost, ul.total_cost) * COALESCE(ul.account_rate_multiplier, 1)), 0) as account_cost
FROM usage_logs ul
LEFT JOIN groups g ON g.id = ul.group_id
WHERE ul.created_at >= $1 AND ul.created_at < $2
...
...
@@ -3113,6 +3136,7 @@ func (r *usageLogRepository) GetGroupStatsWithFilters(ctx context.Context, start
&
row
.
TotalTokens
,
&
row
.
Cost
,
&
row
.
ActualCost
,
&
row
.
AccountCost
,
);
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -3133,7 +3157,8 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim
COUNT(*) as requests,
COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(ul.total_cost), 0) as cost,
COALESCE(SUM(ul.actual_cost), 0) as actual_cost
COALESCE(SUM(ul.actual_cost), 0) as actual_cost,
COALESCE(SUM(COALESCE(ul.account_stats_cost, ul.total_cost) * COALESCE(ul.account_rate_multiplier, 1)), 0) as account_cost
FROM usage_logs ul
LEFT JOIN users u ON u.id = ul.user_id
WHERE ul.created_at >= $1 AND ul.created_at < $2
...
...
@@ -3204,6 +3229,7 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim
&
row
.
TotalTokens
,
&
row
.
Cost
,
&
row
.
ActualCost
,
&
row
.
AccountCost
,
);
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -3358,7 +3384,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as total_account_cost,
COALESCE(SUM(
COALESCE(account_stats_cost,
total_cost
)
* COALESCE(account_rate_multiplier, 1)), 0) as total_account_cost,
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
FROM usage_logs
%s
...
...
@@ -3382,9 +3408,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
);
err
!=
nil
{
return
nil
,
err
}
if
filters
.
AccountID
>
0
{
stats
.
TotalAccountCost
=
&
totalAccountCost
}
stats
.
TotalAccountCost
=
&
totalAccountCost
stats
.
TotalTokens
=
stats
.
TotalInputTokens
+
stats
.
TotalOutputTokens
+
stats
.
TotalCacheTokens
start
:=
time
.
Unix
(
0
,
0
)
.
UTC
()
...
...
@@ -3433,7 +3457,7 @@ type EndpointStat = usagestats.EndpointStat
func
(
r
*
usageLogRepository
)
getEndpointStatsByColumnWithFilters
(
ctx
context
.
Context
,
endpointColumn
string
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
requestType
*
int16
,
stream
*
bool
,
billingType
*
int8
)
(
results
[]
EndpointStat
,
err
error
)
{
actualCostExpr
:=
"COALESCE(SUM(actual_cost), 0) as actual_cost"
if
accountID
>
0
&&
userID
==
0
&&
apiKeyID
==
0
{
actualCostExpr
=
"COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
actualCostExpr
=
"COALESCE(SUM(
COALESCE(account_stats_cost,
total_cost
)
* COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
}
query
:=
fmt
.
Sprintf
(
`
...
...
@@ -3500,7 +3524,7 @@ func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Con
func
(
r
*
usageLogRepository
)
getEndpointPathStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
requestType
*
int16
,
stream
*
bool
,
billingType
*
int8
)
(
results
[]
EndpointStat
,
err
error
)
{
actualCostExpr
:=
"COALESCE(SUM(actual_cost), 0) as actual_cost"
if
accountID
>
0
&&
userID
==
0
&&
apiKeyID
==
0
{
actualCostExpr
=
"COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
actualCostExpr
=
"COALESCE(SUM(
COALESCE(account_stats_cost,
total_cost
)
* COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
}
query
:=
fmt
.
Sprintf
(
`
...
...
@@ -3591,7 +3615,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost,
COALESCE(SUM(
COALESCE(account_stats_cost,
total_cost
)
* COALESCE(account_rate_multiplier, 1)), 0) as actual_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
WHERE account_id = $1 AND created_at >= $2 AND created_at < $3
...
...
@@ -4069,6 +4093,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
modelMappingChain
sql
.
NullString
billingTier
sql
.
NullString
billingMode
sql
.
NullString
accountStatsCost
sql
.
NullFloat64
createdAt
time
.
Time
)
...
...
@@ -4118,6 +4143,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&
modelMappingChain
,
&
billingTier
,
&
billingMode
,
&
accountStatsCost
,
&
createdAt
,
);
err
!=
nil
{
return
nil
,
err
...
...
@@ -4214,6 +4240,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if
billingMode
.
Valid
{
log
.
BillingMode
=
&
billingMode
.
String
}
if
accountStatsCost
.
Valid
{
log
.
AccountStatsCost
=
&
accountStatsCost
.
Float64
}
return
log
,
nil
}
...
...
@@ -4257,6 +4286,7 @@ func scanModelStatsRows(rows *sql.Rows) ([]ModelStat, error) {
&
row
.
TotalTokens
,
&
row
.
Cost
,
&
row
.
ActualCost
,
&
row
.
AccountCost
,
);
err
!=
nil
{
return
nil
,
err
}
...
...
backend/internal/repository/usage_log_repo_integration_test.go
View file @
0b746501
...
...
@@ -753,8 +753,11 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
s
.
Require
()
.
Equal
(
baseStats
.
TotalTokens
+
int64
(
51
),
stats
.
TotalTokens
,
"TotalTokens mismatch"
)
s
.
Require
()
.
Equal
(
baseStats
.
TotalCost
+
2.3
,
stats
.
TotalCost
,
"TotalCost mismatch"
)
s
.
Require
()
.
Equal
(
baseStats
.
TotalActualCost
+
2.0
,
stats
.
TotalActualCost
,
"TotalActualCost mismatch"
)
// account_cost falls back to total_cost when account_stats_cost is NULL
s
.
Require
()
.
Equal
(
baseStats
.
TotalAccountCost
+
2.3
,
stats
.
TotalAccountCost
,
"TotalAccountCost mismatch"
)
s
.
Require
()
.
GreaterOrEqual
(
stats
.
TodayRequests
,
int64
(
1
),
"expected TodayRequests >= 1"
)
s
.
Require
()
.
GreaterOrEqual
(
stats
.
TodayCost
,
0.0
,
"expected TodayCost >= 0"
)
s
.
Require
()
.
GreaterOrEqual
(
stats
.
TodayAccountCost
,
0.0
,
"expected TodayAccountCost >= 0"
)
wantRpm
,
wantTpm
,
err
:=
s
.
repo
.
getPerformanceStats
(
s
.
ctx
,
0
)
s
.
Require
()
.
NoError
(
err
,
"getPerformanceStats"
)
...
...
@@ -833,6 +836,8 @@ func (s *UsageLogRepoSuite) TestDashboardStatsWithRange_Fallback() {
s
.
Require
()
.
Equal
(
int64
(
45
),
stats
.
TotalTokens
)
s
.
Require
()
.
Equal
(
1.5
,
stats
.
TotalCost
)
s
.
Require
()
.
Equal
(
1.4
,
stats
.
TotalActualCost
)
// account_cost = COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1) = total_cost
s
.
Require
()
.
Equal
(
1.5
,
stats
.
TotalAccountCost
)
s
.
Require
()
.
InEpsilon
(
150.0
,
stats
.
AverageDurationMs
,
0.0001
)
}
...
...
backend/internal/repository/usage_log_repo_request_type_test.go
View file @
0b746501
...
...
@@ -85,6 +85,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
sqlmock
.
AnyArg
(),
// model_mapping_chain
sqlmock
.
AnyArg
(),
// billing_tier
sqlmock
.
AnyArg
(),
// billing_mode
sqlmock
.
AnyArg
(),
// account_stats_cost
createdAt
,
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"id"
,
"created_at"
})
.
AddRow
(
int64
(
99
),
createdAt
))
...
...
@@ -163,6 +164,7 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
sqlmock
.
AnyArg
(),
// model_mapping_chain
sqlmock
.
AnyArg
(),
// billing_tier
sqlmock
.
AnyArg
(),
// billing_mode
sqlmock
.
AnyArg
(),
// account_stats_cost
createdAt
,
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"id"
,
"created_at"
})
.
AddRow
(
int64
(
100
),
createdAt
))
...
...
@@ -299,7 +301,7 @@ func TestUsageLogRepositoryGetModelStatsWithFiltersRequestTypePriority(t *testin
mock
.
ExpectQuery
(
"AND
\\
(request_type =
\\
$3 OR
\\
(request_type = 0 AND openai_ws_mode = TRUE
\\
)
\\
)"
)
.
WithArgs
(
start
,
end
,
requestType
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"model"
,
"requests"
,
"input_tokens"
,
"output_tokens"
,
"cache_creation_tokens"
,
"cache_read_tokens"
,
"total_tokens"
,
"cost"
,
"actual_cost"
}))
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"model"
,
"requests"
,
"input_tokens"
,
"output_tokens"
,
"cache_creation_tokens"
,
"cache_read_tokens"
,
"total_tokens"
,
"cost"
,
"actual_cost"
,
"account_cost"
}))
stats
,
err
:=
repo
.
GetModelStatsWithFilters
(
context
.
Background
(),
start
,
end
,
0
,
0
,
0
,
0
,
&
requestType
,
&
stream
,
nil
)
require
.
NoError
(
t
,
err
)
...
...
@@ -344,6 +346,93 @@ func TestUsageLogRepositoryGetStatsWithFiltersRequestTypePriority(t *testing.T)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
int64
(
1
),
stats
.
TotalRequests
)
require
.
Equal
(
t
,
int64
(
9
),
stats
.
TotalTokens
)
require
.
NotNil
(
t
,
stats
.
TotalAccountCost
,
"TotalAccountCost should always be returned"
)
require
.
Equal
(
t
,
1.2
,
*
stats
.
TotalAccountCost
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageLogRepositoryGetModelStatsAccountCostColumn
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageLogRepository
{
sql
:
db
}
start
:=
time
.
Date
(
2025
,
1
,
1
,
0
,
0
,
0
,
0
,
time
.
UTC
)
end
:=
start
.
Add
(
24
*
time
.
Hour
)
mock
.
ExpectQuery
(
"FROM usage_logs"
)
.
WithArgs
(
start
,
end
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"model"
,
"requests"
,
"input_tokens"
,
"output_tokens"
,
"cache_creation_tokens"
,
"cache_read_tokens"
,
"total_tokens"
,
"cost"
,
"actual_cost"
,
"account_cost"
,
})
.
AddRow
(
"claude-opus-4-6"
,
int64
(
10
),
int64
(
100
),
int64
(
200
),
int64
(
5
),
int64
(
3
),
int64
(
308
),
2.5
,
2.0
,
1.8
)
.
AddRow
(
"claude-sonnet-4-6"
,
int64
(
5
),
int64
(
50
),
int64
(
100
),
int64
(
0
),
int64
(
0
),
int64
(
150
),
1.0
,
0.8
,
0.7
))
results
,
err
:=
repo
.
GetModelStatsWithFilters
(
context
.
Background
(),
start
,
end
,
0
,
0
,
0
,
0
,
nil
,
nil
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
Len
(
t
,
results
,
2
)
require
.
Equal
(
t
,
"claude-opus-4-6"
,
results
[
0
]
.
Model
)
require
.
Equal
(
t
,
2.5
,
results
[
0
]
.
Cost
)
require
.
Equal
(
t
,
2.0
,
results
[
0
]
.
ActualCost
)
require
.
Equal
(
t
,
1.8
,
results
[
0
]
.
AccountCost
)
require
.
Equal
(
t
,
"claude-sonnet-4-6"
,
results
[
1
]
.
Model
)
require
.
Equal
(
t
,
0.7
,
results
[
1
]
.
AccountCost
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageLogRepositoryGetGroupStatsAccountCostColumn
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageLogRepository
{
sql
:
db
}
start
:=
time
.
Date
(
2025
,
1
,
1
,
0
,
0
,
0
,
0
,
time
.
UTC
)
end
:=
start
.
Add
(
24
*
time
.
Hour
)
mock
.
ExpectQuery
(
"FROM usage_logs"
)
.
WithArgs
(
start
,
end
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"group_id"
,
"group_name"
,
"requests"
,
"total_tokens"
,
"cost"
,
"actual_cost"
,
"account_cost"
,
})
.
AddRow
(
int64
(
1
),
"azure-cc"
,
int64
(
100
),
int64
(
5000
),
10.0
,
8.5
,
7.2
)
.
AddRow
(
int64
(
2
),
"max"
,
int64
(
50
),
int64
(
2000
),
5.0
,
4.0
,
3.5
))
results
,
err
:=
repo
.
GetGroupStatsWithFilters
(
context
.
Background
(),
start
,
end
,
0
,
0
,
0
,
0
,
nil
,
nil
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
Len
(
t
,
results
,
2
)
require
.
Equal
(
t
,
int64
(
1
),
results
[
0
]
.
GroupID
)
require
.
Equal
(
t
,
"azure-cc"
,
results
[
0
]
.
GroupName
)
require
.
Equal
(
t
,
10.0
,
results
[
0
]
.
Cost
)
require
.
Equal
(
t
,
8.5
,
results
[
0
]
.
ActualCost
)
require
.
Equal
(
t
,
7.2
,
results
[
0
]
.
AccountCost
)
require
.
Equal
(
t
,
int64
(
2
),
results
[
1
]
.
GroupID
)
require
.
Equal
(
t
,
3.5
,
results
[
1
]
.
AccountCost
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestUsageLogRepositoryGetStatsWithFiltersAlwaysReturnsAccountCost
(
t
*
testing
.
T
)
{
db
,
mock
:=
newSQLMock
(
t
)
repo
:=
&
usageLogRepository
{
sql
:
db
}
// No AccountID filter set - TotalAccountCost should still be returned
filters
:=
usagestats
.
UsageLogFilters
{}
mock
.
ExpectQuery
(
"FROM usage_logs"
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"total_requests"
,
"total_input_tokens"
,
"total_output_tokens"
,
"total_cache_tokens"
,
"total_cost"
,
"total_actual_cost"
,
"total_account_cost"
,
"avg_duration_ms"
,
})
.
AddRow
(
int64
(
50
),
int64
(
1000
),
int64
(
2000
),
int64
(
100
),
15.0
,
12.5
,
11.0
,
100.0
))
mock
.
ExpectQuery
(
"SELECT COALESCE
\\
(NULLIF
\\
(TRIM
\\
(inbound_endpoint
\\
)"
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"endpoint"
,
"requests"
,
"total_tokens"
,
"cost"
,
"actual_cost"
}))
mock
.
ExpectQuery
(
"SELECT COALESCE
\\
(NULLIF
\\
(TRIM
\\
(upstream_endpoint
\\
)"
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"endpoint"
,
"requests"
,
"total_tokens"
,
"cost"
,
"actual_cost"
}))
mock
.
ExpectQuery
(
"SELECT CONCAT
\\
("
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"endpoint"
,
"requests"
,
"total_tokens"
,
"cost"
,
"actual_cost"
}))
stats
,
err
:=
repo
.
GetStatsWithFilters
(
context
.
Background
(),
filters
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
stats
.
TotalAccountCost
,
"TotalAccountCost must always be returned, even without AccountID filter"
)
require
.
Equal
(
t
,
11.0
,
*
stats
.
TotalAccountCost
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
...
...
@@ -483,10 +572,11 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql
.
NullString
{},
sql
.
NullString
{},
false
,
sql
.
NullInt64
{},
// channel_id
sql
.
NullString
{},
// model_mapping_chain
sql
.
NullString
{},
// billing_tier
sql
.
NullString
{},
// billing_mode
sql
.
NullInt64
{},
// channel_id
sql
.
NullString
{},
// model_mapping_chain
sql
.
NullString
{},
// billing_tier
sql
.
NullString
{},
// billing_mode
sql
.
NullFloat64
{},
// account_stats_cost
now
,
}})
require
.
NoError
(
t
,
err
)
...
...
@@ -530,10 +620,11 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql
.
NullString
{},
sql
.
NullString
{},
false
,
sql
.
NullInt64
{},
// channel_id
sql
.
NullString
{},
// model_mapping_chain
sql
.
NullString
{},
// billing_tier
sql
.
NullString
{},
// billing_mode
sql
.
NullInt64
{},
// channel_id
sql
.
NullString
{},
// model_mapping_chain
sql
.
NullString
{},
// billing_tier
sql
.
NullString
{},
// billing_mode
sql
.
NullFloat64
{},
// account_stats_cost
now
,
}})
require
.
NoError
(
t
,
err
)
...
...
@@ -577,10 +668,11 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql
.
NullString
{},
sql
.
NullString
{},
false
,
sql
.
NullInt64
{},
// channel_id
sql
.
NullString
{},
// model_mapping_chain
sql
.
NullString
{},
// billing_tier
sql
.
NullString
{},
// billing_mode
sql
.
NullInt64
{},
// channel_id
sql
.
NullString
{},
// model_mapping_chain
sql
.
NullString
{},
// billing_tier
sql
.
NullString
{},
// billing_mode
sql
.
NullFloat64
{},
// account_stats_cost
now
,
}})
require
.
NoError
(
t
,
err
)
...
...
backend/internal/repository/user_group_rate_repo.go
View file @
0b746501
...
...
@@ -100,7 +100,7 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6
query
:=
`
SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier
FROM user_group_rate_multipliers ugr
JOIN users u ON u.id = ugr.user_id
JOIN users u ON u.id = ugr.user_id
AND u.deleted_at IS NULL
WHERE ugr.group_id = $1
ORDER BY ugr.user_id
`
...
...
backend/internal/repository/user_repo.go
View file @
0b746501
...
...
@@ -137,7 +137,7 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
txClient
=
r
.
client
}
update
d
,
err
:=
txClient
.
User
.
UpdateOneID
(
userIn
.
ID
)
.
update
Op
:=
txClient
.
User
.
UpdateOneID
(
userIn
.
ID
)
.
SetEmail
(
userIn
.
Email
)
.
SetUsername
(
userIn
.
Username
)
.
SetNotes
(
userIn
.
Notes
)
.
...
...
@@ -146,7 +146,15 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
SetBalance
(
userIn
.
Balance
)
.
SetConcurrency
(
userIn
.
Concurrency
)
.
SetStatus
(
userIn
.
Status
)
.
Save
(
ctx
)
SetBalanceNotifyEnabled
(
userIn
.
BalanceNotifyEnabled
)
.
SetBalanceNotifyThresholdType
(
userIn
.
BalanceNotifyThresholdType
)
.
SetNillableBalanceNotifyThreshold
(
userIn
.
BalanceNotifyThreshold
)
.
SetBalanceNotifyExtraEmails
(
marshalExtraEmails
(
userIn
.
BalanceNotifyExtraEmails
))
.
SetTotalRecharged
(
userIn
.
TotalRecharged
)
if
userIn
.
BalanceNotifyThreshold
==
nil
{
updateOp
=
updateOp
.
ClearBalanceNotifyThreshold
()
}
updated
,
err
:=
updateOp
.
Save
(
ctx
)
if
err
!=
nil
{
return
translatePersistenceError
(
err
,
service
.
ErrUserNotFound
,
service
.
ErrEmailExists
)
}
...
...
@@ -382,7 +390,12 @@ func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[
func
(
r
*
userRepository
)
UpdateBalance
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
n
,
err
:=
client
.
User
.
Update
()
.
Where
(
dbuser
.
IDEQ
(
id
))
.
AddBalance
(
amount
)
.
Save
(
ctx
)
update
:=
client
.
User
.
Update
()
.
Where
(
dbuser
.
IDEQ
(
id
))
.
AddBalance
(
amount
)
// Track cumulative recharge amount for percentage-based notifications
if
amount
>
0
{
update
=
update
.
AddTotalRecharged
(
amount
)
}
n
,
err
:=
update
.
Save
(
ctx
)
if
err
!=
nil
{
return
translatePersistenceError
(
err
,
service
.
ErrUserNotFound
,
nil
)
}
...
...
@@ -549,6 +562,11 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
dst
.
UpdatedAt
=
src
.
UpdatedAt
}
// marshalExtraEmails serializes notify email entries to JSON for storage.
func
marshalExtraEmails
(
entries
[]
service
.
NotifyEmailEntry
)
string
{
return
service
.
MarshalNotifyEmails
(
entries
)
}
// UpdateTotpSecret 更新用户的 TOTP 加密密钥
func
(
r
*
userRepository
)
UpdateTotpSecret
(
ctx
context
.
Context
,
userID
int64
,
encryptedSecret
*
string
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
...
...
backend/internal/server/api_contract_test.go
View file @
0b746501
...
...
@@ -58,6 +58,11 @@ func TestAPIContracts(t *testing.T) {
"allowed_groups": null,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z",
"balance_notify_enabled": false,
"balance_notify_threshold_type": "",
"balance_notify_threshold": null,
"balance_notify_extra_emails": null,
"total_recharged": 0,
"run_mode": "standard"
}
}`
,
...
...
@@ -204,11 +209,10 @@ func TestAPIContracts(t *testing.T) {
"image_price_1k": null,
"image_price_2k": null,
"image_price_4k": null,
"claude_code_only": false,
"claude_code_only": false,
"allow_messages_dispatch": false,
"fallback_group_id": null,
"fallback_group_id_on_invalid_request": null,
"allow_messages_dispatch": false,
"require_oauth_only": false,
"require_privacy_set": false,
"created_at": "2025-01-02T03:04:05Z",
...
...
@@ -587,26 +591,34 @@ func TestAPIContracts(t *testing.T) {
"enable_cch_signing": false,
"enable_fingerprint_unification": true,
"enable_metadata_passthrough": false,
"web_search_emulation_enabled": false,
"custom_menu_items": [],
"custom_endpoints": [],
"payment_enabled": false,
"payment_min_amount": 0,
"payment_max_amount": 0,
"payment_daily_limit": 0,
"payment_order_timeout_minutes": 0,
"payment_max_pending_orders": 0,
"payment_enabled_types": null,
"payment_balance_disabled": false,
"payment_balance_recharge_multiplier": 0,
"payment_recharge_fee_rate": 0,
"payment_load_balance_strategy": "",
"payment_product_name_prefix": "",
"payment_product_name_suffix": "",
"payment_help_image_url": "",
"payment_help_text": "",
"payment_enabled_types": null,
"payment_cancel_rate_limit_enabled": false,
"payment_cancel_rate_limit_max": 0,
"payment_cancel_rate_limit_window": 0,
"payment_cancel_rate_limit_unit": "",
"payment_cancel_rate_limit_window_mode": "",
"custom_menu_items": [],
"custom_endpoints": []
"balance_low_notify_enabled": false,
"account_quota_notify_enabled": false,
"balance_low_notify_threshold": 0,
"balance_low_notify_recharge_url": "",
"account_quota_notify_emails": []
}
}`
,
},
...
...
@@ -699,7 +711,7 @@ func newContractDeps(t *testing.T) *contractDeps {
RunMode
:
config
.
RunModeStandard
,
}
userService
:=
service
.
NewUserService
(
userRepo
,
nil
,
nil
)
userService
:=
service
.
NewUserService
(
userRepo
,
nil
,
nil
,
nil
)
apiKeyService
:=
service
.
NewAPIKeyService
(
apiKeyRepo
,
userRepo
,
groupRepo
,
userSubRepo
,
nil
,
apiKeyCache
,
cfg
)
usageRepo
:=
newStubUsageLogRepo
()
...
...
backend/internal/server/http.go
View file @
0b746501
...
...
@@ -2,12 +2,15 @@
package
server
import
(
"context"
"log"
"log/slog"
"net/http"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
middleware2
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
...
...
@@ -56,6 +59,42 @@ func ProvideRouter(
}
}
// Wire up websearch Manager builder so it initializes on startup and rebuilds on config save.
settingService
.
SetWebSearchManagerBuilder
(
context
.
Background
(),
func
(
cfg
*
service
.
WebSearchEmulationConfig
,
proxyURLs
map
[
int64
]
string
)
{
if
cfg
==
nil
||
!
cfg
.
Enabled
||
len
(
cfg
.
Providers
)
==
0
{
service
.
SetWebSearchManager
(
nil
)
return
}
configs
:=
make
([]
websearch
.
ProviderConfig
,
0
,
len
(
cfg
.
Providers
))
for
_
,
p
:=
range
cfg
.
Providers
{
if
p
.
APIKey
==
""
{
continue
}
pc
:=
websearch
.
ProviderConfig
{
Type
:
p
.
Type
,
APIKey
:
p
.
APIKey
,
QuotaLimit
:
derefInt64
(
p
.
QuotaLimit
),
ExpiresAt
:
p
.
ExpiresAt
,
}
if
p
.
SubscribedAt
!=
nil
{
pc
.
SubscribedAt
=
p
.
SubscribedAt
}
if
p
.
ProxyID
!=
nil
{
pc
.
ProxyID
=
*
p
.
ProxyID
if
u
,
ok
:=
proxyURLs
[
*
p
.
ProxyID
];
ok
{
pc
.
ProxyURL
=
u
}
else
{
// Proxy configured but not found — skip this provider to prevent direct connection.
slog
.
Warn
(
"websearch: proxy not found for provider, skipping"
,
"provider"
,
p
.
Type
,
"proxy_id"
,
*
p
.
ProxyID
)
continue
}
}
configs
=
append
(
configs
,
pc
)
}
service
.
SetWebSearchManager
(
websearch
.
NewManager
(
configs
,
redisClient
))
})
return
SetupRouter
(
r
,
handlers
,
jwtAuth
,
adminAuth
,
apiKeyAuth
,
apiKeyService
,
subscriptionService
,
opsService
,
settingService
,
cfg
,
redisClient
)
}
...
...
@@ -102,3 +141,10 @@ func ProvideHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server {
// 不设置 ReadTimeout,因为大请求体可能需要较长时间读取
}
}
func
derefInt64
(
p
*
int64
)
int64
{
if
p
==
nil
{
return
0
}
return
*
p
}
backend/internal/server/middleware/admin_auth_test.go
View file @
0b746501
...
...
@@ -39,7 +39,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
return
&
clone
,
nil
},
}
userService
:=
service
.
NewUserService
(
userRepo
,
nil
,
nil
)
userService
:=
service
.
NewUserService
(
userRepo
,
nil
,
nil
,
nil
)
router
:=
gin
.
New
()
router
.
Use
(
gin
.
HandlerFunc
(
NewAdminAuthMiddleware
(
authService
,
userService
,
nil
)))
...
...
backend/internal/server/middleware/jwt_auth_test.go
View file @
0b746501
...
...
@@ -41,7 +41,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer
userRepo
:=
&
stubJWTUserRepo
{
users
:
users
}
authSvc
:=
service
.
NewAuthService
(
nil
,
userRepo
,
nil
,
nil
,
cfg
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
userSvc
:=
service
.
NewUserService
(
userRepo
,
nil
,
nil
)
userSvc
:=
service
.
NewUserService
(
userRepo
,
nil
,
nil
,
nil
)
mw
:=
NewJWTAuthMiddleware
(
authSvc
,
userSvc
)
r
:=
gin
.
New
()
...
...
backend/internal/server/middleware/security_headers.go
View file @
0b746501
...
...
@@ -18,6 +18,8 @@ const (
NonceTemplate
=
"__CSP_NONCE__"
// CloudflareInsightsDomain is the domain for Cloudflare Web Analytics
CloudflareInsightsDomain
=
"https://static.cloudflareinsights.com"
// StripeDomain is the domain for Stripe.js SDK
StripeDomain
=
"https://*.stripe.com"
)
// GenerateNonce generates a cryptographically secure random nonce.
...
...
@@ -97,8 +99,9 @@ func isAPIRoutePath(c *gin.Context) bool {
strings
.
HasPrefix
(
path
,
"/responses"
)
}
// enhanceCSPPolicy ensures the CSP policy includes nonce support and Cloudflare Insights domain.
// This allows the application to work correctly even if the config file has an older CSP policy.
// enhanceCSPPolicy ensures the CSP policy includes nonce support, Cloudflare Insights,
// and Stripe.js domains. This allows the application to work correctly even if the
// config file has an older CSP policy.
func
enhanceCSPPolicy
(
policy
string
)
string
{
// Add nonce placeholder to script-src if not present
if
!
strings
.
Contains
(
policy
,
NonceTemplate
)
&&
!
strings
.
Contains
(
policy
,
"'nonce-"
)
{
...
...
@@ -110,6 +113,12 @@ func enhanceCSPPolicy(policy string) string {
policy
=
addToDirective
(
policy
,
"script-src"
,
CloudflareInsightsDomain
)
}
// Add Stripe.js domain to script-src and frame-src if not present
if
!
strings
.
Contains
(
policy
,
"stripe.com"
)
{
policy
=
addToDirective
(
policy
,
"script-src"
,
StripeDomain
)
policy
=
addToDirective
(
policy
,
"frame-src"
,
StripeDomain
)
}
return
policy
}
...
...
backend/internal/server/routes/admin.go
View file @
0b746501
...
...
@@ -407,6 +407,11 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
// Beta 策略配置
adminSettings
.
GET
(
"/beta-policy"
,
h
.
Admin
.
Setting
.
GetBetaPolicySettings
)
adminSettings
.
PUT
(
"/beta-policy"
,
h
.
Admin
.
Setting
.
UpdateBetaPolicySettings
)
// Web Search 模拟配置
adminSettings
.
GET
(
"/web-search-emulation"
,
h
.
Admin
.
Setting
.
GetWebSearchEmulationConfig
)
adminSettings
.
PUT
(
"/web-search-emulation"
,
h
.
Admin
.
Setting
.
UpdateWebSearchEmulationConfig
)
adminSettings
.
POST
(
"/web-search-emulation/test"
,
h
.
Admin
.
Setting
.
TestWebSearchEmulation
)
adminSettings
.
POST
(
"/web-search-emulation/reset-usage"
,
h
.
Admin
.
Setting
.
ResetWebSearchUsage
)
}
}
...
...
backend/internal/server/routes/payment.go
View file @
0b746501
...
...
@@ -39,6 +39,7 @@ func RegisterPaymentRoutes(
orders
.
GET
(
"/:id"
,
paymentHandler
.
GetOrder
)
orders
.
POST
(
"/:id/cancel"
,
paymentHandler
.
CancelOrder
)
orders
.
POST
(
"/:id/refund-request"
,
paymentHandler
.
RequestRefund
)
orders
.
GET
(
"/refund-eligible-providers"
,
paymentHandler
.
GetRefundEligibleProviders
)
}
}
...
...
backend/internal/server/routes/user.go
View file @
0b746501
...
...
@@ -26,6 +26,15 @@ func RegisterUserRoutes(
user
.
PUT
(
"/password"
,
h
.
User
.
ChangePassword
)
user
.
PUT
(
""
,
h
.
User
.
UpdateProfile
)
// 通知邮箱管理
notifyEmail
:=
user
.
Group
(
"/notify-email"
)
{
notifyEmail
.
POST
(
"/send-code"
,
h
.
User
.
SendNotifyEmailCode
)
notifyEmail
.
POST
(
"/verify"
,
h
.
User
.
VerifyNotifyEmail
)
notifyEmail
.
PUT
(
"/toggle"
,
h
.
User
.
ToggleNotifyEmail
)
notifyEmail
.
DELETE
(
""
,
h
.
User
.
RemoveNotifyEmail
)
}
// TOTP 双因素认证
totp
:=
user
.
Group
(
"/totp"
)
{
...
...
backend/internal/service/account.go
View file @
0b746501
...
...
@@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"hash/fnv"
"log/slog"
"reflect"
"sort"
"strconv"
...
...
@@ -969,7 +970,7 @@ func (a *Account) IsOveragesEnabled() bool {
return
false
}
// IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用
“
自动透传(仅替换认证)
”
。
// IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用
"
自动透传(仅替换认证)
"
。
//
// 新字段:accounts.extra.openai_passthrough。
// 兼容字段:accounts.extra.openai_oauth_passthrough(历史 OAuth 开关)。
...
...
@@ -1133,7 +1134,7 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri
return
resolvedDefault
}
// IsOpenAIWSForceHTTPEnabled 返回账号级
“
强制 HTTP
”
开关。
// IsOpenAIWSForceHTTPEnabled 返回账号级
"
强制 HTTP
"
开关。
// 字段:accounts.extra.openai_ws_force_http。
func
(
a
*
Account
)
IsOpenAIWSForceHTTPEnabled
()
bool
{
if
a
==
nil
||
!
a
.
IsOpenAI
()
||
a
.
Extra
==
nil
{
...
...
@@ -1158,7 +1159,7 @@ func (a *Account) IsOpenAIOAuthPassthroughEnabled() bool {
return
a
!=
nil
&&
a
.
IsOpenAIOAuth
()
&&
a
.
IsOpenAIPassthroughEnabled
()
}
// IsAnthropicAPIKeyPassthroughEnabled 返回 Anthropic API Key 账号是否启用
“
自动透传(仅替换认证)
”
。
// IsAnthropicAPIKeyPassthroughEnabled 返回 Anthropic API Key 账号是否启用
"
自动透传(仅替换认证)
"
。
// 字段:accounts.extra.anthropic_passthrough。
// 字段缺失或类型不正确时,按 false(关闭)处理。
func
(
a
*
Account
)
IsAnthropicAPIKeyPassthroughEnabled
()
bool
{
...
...
@@ -1169,7 +1170,42 @@ func (a *Account) IsAnthropicAPIKeyPassthroughEnabled() bool {
return
ok
&&
enabled
}
// IsCodexCLIOnlyEnabled 返回 OpenAI OAuth 账号是否启用“仅允许 Codex 官方客户端”。
// WebSearch 模拟三态常量
const
(
WebSearchModeDefault
=
"default"
// 跟随渠道配置
WebSearchModeEnabled
=
"enabled"
// 强制开启
WebSearchModeDisabled
=
"disabled"
// 强制关闭
)
// GetWebSearchEmulationMode 返回账号的 WebSearch 模拟模式。
// 三态:default(跟随渠道)/ enabled(强制开启)/ disabled(强制关闭)。
// 兼容旧 bool 值:true→enabled, false→default(并记录 debug 日志)。
func
(
a
*
Account
)
GetWebSearchEmulationMode
()
string
{
if
a
==
nil
||
a
.
Platform
!=
PlatformAnthropic
||
a
.
Type
!=
AccountTypeAPIKey
||
a
.
Extra
==
nil
{
return
WebSearchModeDefault
}
raw
:=
a
.
Extra
[
featureKeyWebSearchEmulation
]
// Tolerant: legacy bool values (pre-migration or stale writes)
if
b
,
ok
:=
raw
.
(
bool
);
ok
{
slog
.
Debug
(
"legacy bool web_search_emulation value"
,
"account_id"
,
a
.
ID
,
"value"
,
b
)
if
b
{
return
WebSearchModeEnabled
}
return
WebSearchModeDefault
}
mode
,
ok
:=
raw
.
(
string
)
if
!
ok
{
return
WebSearchModeDefault
}
switch
mode
{
case
WebSearchModeEnabled
,
WebSearchModeDisabled
:
return
mode
default
:
return
WebSearchModeDefault
}
}
// IsCodexCLIOnlyEnabled 返回 OpenAI OAuth 账号是否启用"仅允许 Codex 官方客户端"。
// 字段:accounts.extra.codex_cli_only。
// 字段缺失或类型不正确时,按 false(关闭)处理。
func
(
a
*
Account
)
IsCodexCLIOnlyEnabled
()
bool
{
...
...
@@ -1395,6 +1431,19 @@ func (a *Account) getExtraTime(key string) time.Time {
return
time
.
Time
{}
}
// getExtraBool 从 Extra 中读取指定 key 的 bool 值
func
(
a
*
Account
)
getExtraBool
(
key
string
)
bool
{
if
a
.
Extra
==
nil
{
return
false
}
if
v
,
ok
:=
a
.
Extra
[
key
];
ok
{
if
b
,
ok
:=
v
.
(
bool
);
ok
{
return
b
}
}
return
false
}
// getExtraString 从 Extra 中读取指定 key 的字符串值
func
(
a
*
Account
)
getExtraString
(
key
string
)
string
{
if
a
.
Extra
==
nil
{
...
...
@@ -1408,6 +1457,14 @@ func (a *Account) getExtraString(key string) string {
return
""
}
// getExtraStringDefault 从 Extra 中读取指定 key 的字符串值,不存在时返回 defaultVal
func
(
a
*
Account
)
getExtraStringDefault
(
key
,
defaultVal
string
)
string
{
if
v
:=
a
.
getExtraString
(
key
);
v
!=
""
{
return
v
}
return
defaultVal
}
// getExtraInt 从 Extra 中读取指定 key 的 int 值
func
(
a
*
Account
)
getExtraInt
(
key
string
)
int
{
if
a
.
Extra
==
nil
{
...
...
@@ -1464,6 +1521,62 @@ func (a *Account) GetQuotaResetTimezone() string {
return
"UTC"
}
// --- Quota Notification Getters ---
// QuotaNotifyConfig returns the notify configuration for a given quota dimension.
// dim must be one of quotaDimDaily, quotaDimWeekly, quotaDimTotal.
func
(
a
*
Account
)
QuotaNotifyConfig
(
dim
string
)
(
enabled
bool
,
threshold
float64
,
thresholdType
string
)
{
enabled
=
a
.
getExtraBool
(
"quota_notify_"
+
dim
+
"_enabled"
)
threshold
=
a
.
getExtraFloat64
(
"quota_notify_"
+
dim
+
"_threshold"
)
thresholdType
=
a
.
getExtraStringDefault
(
"quota_notify_"
+
dim
+
"_threshold_type"
,
thresholdTypeFixed
)
return
}
func
(
a
*
Account
)
GetQuotaNotifyDailyEnabled
()
bool
{
e
,
_
,
_
:=
a
.
QuotaNotifyConfig
(
quotaDimDaily
)
return
e
}
func
(
a
*
Account
)
GetQuotaNotifyDailyThreshold
()
float64
{
_
,
t
,
_
:=
a
.
QuotaNotifyConfig
(
quotaDimDaily
)
return
t
}
func
(
a
*
Account
)
GetQuotaNotifyDailyThresholdType
()
string
{
_
,
_
,
tt
:=
a
.
QuotaNotifyConfig
(
quotaDimDaily
)
return
tt
}
func
(
a
*
Account
)
GetQuotaNotifyWeeklyEnabled
()
bool
{
e
,
_
,
_
:=
a
.
QuotaNotifyConfig
(
quotaDimWeekly
)
return
e
}
func
(
a
*
Account
)
GetQuotaNotifyWeeklyThreshold
()
float64
{
_
,
t
,
_
:=
a
.
QuotaNotifyConfig
(
quotaDimWeekly
)
return
t
}
func
(
a
*
Account
)
GetQuotaNotifyWeeklyThresholdType
()
string
{
_
,
_
,
tt
:=
a
.
QuotaNotifyConfig
(
quotaDimWeekly
)
return
tt
}
func
(
a
*
Account
)
GetQuotaNotifyTotalEnabled
()
bool
{
e
,
_
,
_
:=
a
.
QuotaNotifyConfig
(
quotaDimTotal
)
return
e
}
func
(
a
*
Account
)
GetQuotaNotifyTotalThreshold
()
float64
{
_
,
t
,
_
:=
a
.
QuotaNotifyConfig
(
quotaDimTotal
)
return
t
}
func
(
a
*
Account
)
GetQuotaNotifyTotalThresholdType
()
string
{
_
,
_
,
tt
:=
a
.
QuotaNotifyConfig
(
quotaDimTotal
)
return
tt
}
// nextFixedDailyReset 计算在 after 之后的下一个每日固定重置时间点
func
nextFixedDailyReset
(
hour
int
,
tz
*
time
.
Location
,
after
time
.
Time
)
time
.
Time
{
t
:=
after
.
In
(
tz
)
...
...
backend/internal/service/account_stats_pricing.go
0 → 100644
View file @
0b746501
package
service
import
(
"context"
"strings"
)
// resolveAccountStatsCost 计算账号统计定价费用。
// 返回 nil 表示不覆盖,使用默认公式(total_cost × account_rate_multiplier)。
//
// 优先级(先命中为准):
// 1. 自定义规则(始终尝试,不依赖 ApplyPricingToAccountStats 开关)
// 2. ApplyPricingToAccountStats 启用时,直接使用本次请求的客户计费(倍率前的 totalCost)
// 3. 模型定价文件(LiteLLM)中上游模型的默认价格
// 4. nil → 走默认公式(total_cost × account_rate_multiplier)
//
// upstreamModel 是最终发往上游的模型 ID。
// totalCost 是本次请求的客户计费(倍率前),用于优先级 2。
func
resolveAccountStatsCost
(
ctx
context
.
Context
,
channelService
*
ChannelService
,
billingService
*
BillingService
,
accountID
int64
,
groupID
int64
,
upstreamModel
string
,
tokens
UsageTokens
,
requestCount
int
,
totalCost
float64
,
)
*
float64
{
if
channelService
==
nil
||
upstreamModel
==
""
{
return
nil
}
channel
,
err
:=
channelService
.
GetChannelForGroup
(
ctx
,
groupID
)
if
err
!=
nil
||
channel
==
nil
{
return
nil
}
platform
:=
channelService
.
GetGroupPlatform
(
ctx
,
groupID
)
// 优先级 1:自定义规则(始终尝试)
if
cost
:=
tryCustomRules
(
channel
,
accountID
,
groupID
,
platform
,
upstreamModel
,
tokens
,
requestCount
);
cost
!=
nil
{
return
cost
}
// 优先级 2:渠道开启"应用模型定价到账号统计"时,直接使用客户计费(倍率前)
if
channel
.
ApplyPricingToAccountStats
{
cost
:=
totalCost
if
cost
<=
0
{
return
nil
}
return
&
cost
}
// 优先级 3:模型定价文件(LiteLLM)默认价格
if
billingService
!=
nil
{
return
tryModelFilePricing
(
billingService
,
upstreamModel
,
tokens
)
}
return
nil
}
// tryModelFilePricing 使用模型定价文件(LiteLLM/fallback)中的标准价格计算费用。
func
tryModelFilePricing
(
billingService
*
BillingService
,
model
string
,
tokens
UsageTokens
)
*
float64
{
pricing
,
err
:=
billingService
.
GetModelPricing
(
model
)
if
err
!=
nil
||
pricing
==
nil
{
return
nil
}
cost
:=
float64
(
tokens
.
InputTokens
)
*
pricing
.
InputPricePerToken
+
float64
(
tokens
.
OutputTokens
)
*
pricing
.
OutputPricePerToken
+
float64
(
tokens
.
CacheCreationTokens
)
*
pricing
.
CacheCreationPricePerToken
+
float64
(
tokens
.
CacheReadTokens
)
*
pricing
.
CacheReadPricePerToken
+
float64
(
tokens
.
ImageOutputTokens
)
*
pricing
.
ImageOutputPricePerToken
if
cost
<=
0
{
return
nil
}
return
&
cost
}
// tryCustomRules 遍历自定义规则,按数组顺序先命中为准。
func
tryCustomRules
(
channel
*
Channel
,
accountID
,
groupID
int64
,
platform
,
model
string
,
tokens
UsageTokens
,
requestCount
int
,
)
*
float64
{
modelLower
:=
strings
.
ToLower
(
model
)
for
_
,
rule
:=
range
channel
.
AccountStatsPricingRules
{
if
!
matchAccountStatsRule
(
&
rule
,
accountID
,
groupID
)
{
continue
}
pricing
:=
findPricingForModel
(
rule
.
Pricing
,
platform
,
modelLower
)
if
pricing
==
nil
{
continue
// 规则匹配但模型不在规则定价中,继续下一条
}
return
calculateStatsCost
(
pricing
,
tokens
,
requestCount
)
}
return
nil
}
// matchAccountStatsRule 检查规则是否匹配指定的 accountID 和 groupID。
// 匹配条件:accountID ∈ rule.AccountIDs 或 groupID ∈ rule.GroupIDs。
// 如果规则的 AccountIDs 和 GroupIDs 都为空,视为不匹配。
func
matchAccountStatsRule
(
rule
*
AccountStatsPricingRule
,
accountID
,
groupID
int64
)
bool
{
if
len
(
rule
.
AccountIDs
)
==
0
&&
len
(
rule
.
GroupIDs
)
==
0
{
return
false
}
for
_
,
id
:=
range
rule
.
AccountIDs
{
if
id
==
accountID
{
return
true
}
}
for
_
,
id
:=
range
rule
.
GroupIDs
{
if
id
==
groupID
{
return
true
}
}
return
false
}
// findPricingForModel 在定价列表中查找匹配的模型定价。
// 先精确匹配,再通配符匹配(按配置顺序,先匹配先使用)。
func
findPricingForModel
(
pricingList
[]
ChannelModelPricing
,
platform
,
modelLower
string
)
*
ChannelModelPricing
{
// 精确匹配优先
for
i
:=
range
pricingList
{
p
:=
&
pricingList
[
i
]
if
!
isPlatformMatch
(
platform
,
p
.
Platform
)
{
continue
}
for
_
,
m
:=
range
p
.
Models
{
if
strings
.
ToLower
(
m
)
==
modelLower
{
return
p
}
}
}
// 通配符匹配:按配置顺序,先匹配先使用
for
i
:=
range
pricingList
{
p
:=
&
pricingList
[
i
]
if
!
isPlatformMatch
(
platform
,
p
.
Platform
)
{
continue
}
for
_
,
m
:=
range
p
.
Models
{
ml
:=
strings
.
ToLower
(
m
)
if
!
strings
.
HasSuffix
(
ml
,
"*"
)
{
continue
}
prefix
:=
strings
.
TrimSuffix
(
ml
,
"*"
)
if
strings
.
HasPrefix
(
modelLower
,
prefix
)
{
return
p
}
}
}
return
nil
}
// isPlatformMatch 判断平台是否匹配(空平台视为不限平台)。
func
isPlatformMatch
(
queryPlatform
,
pricingPlatform
string
)
bool
{
if
queryPlatform
==
""
||
pricingPlatform
==
""
{
return
true
}
return
queryPlatform
==
pricingPlatform
}
// calculateStatsCost 使用给定的定价计算费用(不含任何倍率,原始费用)。
func
calculateStatsCost
(
pricing
*
ChannelModelPricing
,
tokens
UsageTokens
,
requestCount
int
)
*
float64
{
if
pricing
==
nil
{
return
nil
}
switch
pricing
.
BillingMode
{
case
BillingModePerRequest
,
BillingModeImage
:
return
calculatePerRequestStatsCost
(
pricing
,
requestCount
)
default
:
return
calculateTokenStatsCost
(
pricing
,
tokens
)
}
}
// calculatePerRequestStatsCost 按次/图片计费。
func
calculatePerRequestStatsCost
(
pricing
*
ChannelModelPricing
,
requestCount
int
)
*
float64
{
if
pricing
.
PerRequestPrice
==
nil
||
*
pricing
.
PerRequestPrice
<=
0
{
return
nil
}
cost
:=
*
pricing
.
PerRequestPrice
*
float64
(
requestCount
)
return
&
cost
}
// calculateTokenStatsCost Token 计费。
// If the pricing has intervals, find the matching interval by total token count
// and use its prices instead of the flat pricing fields.
func
calculateTokenStatsCost
(
pricing
*
ChannelModelPricing
,
tokens
UsageTokens
)
*
float64
{
p
:=
pricing
if
len
(
pricing
.
Intervals
)
>
0
{
totalTokens
:=
tokens
.
InputTokens
+
tokens
.
OutputTokens
+
tokens
.
CacheCreationTokens
+
tokens
.
CacheReadTokens
if
iv
:=
FindMatchingInterval
(
pricing
.
Intervals
,
totalTokens
);
iv
!=
nil
{
p
=
&
ChannelModelPricing
{
InputPrice
:
iv
.
InputPrice
,
OutputPrice
:
iv
.
OutputPrice
,
CacheWritePrice
:
iv
.
CacheWritePrice
,
CacheReadPrice
:
iv
.
CacheReadPrice
,
PerRequestPrice
:
iv
.
PerRequestPrice
,
}
}
}
deref
:=
func
(
ptr
*
float64
)
float64
{
if
ptr
==
nil
{
return
0
}
return
*
ptr
}
cost
:=
float64
(
tokens
.
InputTokens
)
*
deref
(
p
.
InputPrice
)
+
float64
(
tokens
.
OutputTokens
)
*
deref
(
p
.
OutputPrice
)
+
float64
(
tokens
.
CacheCreationTokens
)
*
deref
(
p
.
CacheWritePrice
)
+
float64
(
tokens
.
CacheReadTokens
)
*
deref
(
p
.
CacheReadPrice
)
+
float64
(
tokens
.
ImageOutputTokens
)
*
deref
(
p
.
ImageOutputPrice
)
if
cost
<=
0
{
return
nil
}
return
&
cost
}
// applyAccountStatsCost resolves the account stats cost for a usage log entry.
// It resolves the upstream model (falling back to the requested model) and calls
// the 4-level priority chain via resolveAccountStatsCost.
func
applyAccountStatsCost
(
ctx
context
.
Context
,
usageLog
*
UsageLog
,
cs
*
ChannelService
,
bs
*
BillingService
,
accountID
int64
,
groupID
int64
,
upstreamModel
,
requestedModel
string
,
tokens
UsageTokens
,
totalCost
float64
,
)
{
model
:=
upstreamModel
if
model
==
""
{
model
=
requestedModel
}
usageLog
.
AccountStatsCost
=
resolveAccountStatsCost
(
ctx
,
cs
,
bs
,
accountID
,
groupID
,
model
,
tokens
,
1
,
totalCost
,
)
}
backend/internal/service/account_stats_pricing_test.go
0 → 100644
View file @
0b746501
//go:build unit
package
service
import
(
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// matchAccountStatsRule
// ---------------------------------------------------------------------------
func
TestMatchAccountStatsRule_BothEmpty_NoMatch
(
t
*
testing
.
T
)
{
rule
:=
&
AccountStatsPricingRule
{}
require
.
False
(
t
,
matchAccountStatsRule
(
rule
,
1
,
10
))
}
func
TestMatchAccountStatsRule_AccountIDMatch
(
t
*
testing
.
T
)
{
rule
:=
&
AccountStatsPricingRule
{
AccountIDs
:
[]
int64
{
1
,
2
,
3
}}
require
.
True
(
t
,
matchAccountStatsRule
(
rule
,
2
,
999
))
}
func
TestMatchAccountStatsRule_GroupIDMatch
(
t
*
testing
.
T
)
{
rule
:=
&
AccountStatsPricingRule
{
GroupIDs
:
[]
int64
{
10
,
20
}}
require
.
True
(
t
,
matchAccountStatsRule
(
rule
,
999
,
20
))
}
func
TestMatchAccountStatsRule_BothConfigured_AccountMatch
(
t
*
testing
.
T
)
{
rule
:=
&
AccountStatsPricingRule
{
AccountIDs
:
[]
int64
{
1
,
2
},
GroupIDs
:
[]
int64
{
10
,
20
},
}
require
.
True
(
t
,
matchAccountStatsRule
(
rule
,
2
,
999
))
}
func
TestMatchAccountStatsRule_BothConfigured_GroupMatch
(
t
*
testing
.
T
)
{
rule
:=
&
AccountStatsPricingRule
{
AccountIDs
:
[]
int64
{
1
,
2
},
GroupIDs
:
[]
int64
{
10
,
20
},
}
require
.
True
(
t
,
matchAccountStatsRule
(
rule
,
999
,
10
))
}
func
TestMatchAccountStatsRule_BothConfigured_NeitherMatch
(
t
*
testing
.
T
)
{
rule
:=
&
AccountStatsPricingRule
{
AccountIDs
:
[]
int64
{
1
,
2
},
GroupIDs
:
[]
int64
{
10
,
20
},
}
require
.
False
(
t
,
matchAccountStatsRule
(
rule
,
999
,
999
))
}
// ---------------------------------------------------------------------------
// findPricingForModel
// ---------------------------------------------------------------------------
func
TestFindPricingForModel
(
t
*
testing
.
T
)
{
exactPricing
:=
ChannelModelPricing
{
ID
:
1
,
Models
:
[]
string
{
"claude-opus-4"
},
}
wildcardPricing
:=
ChannelModelPricing
{
ID
:
2
,
Models
:
[]
string
{
"claude-*"
},
}
platformPricing
:=
ChannelModelPricing
{
ID
:
3
,
Platform
:
"openai"
,
Models
:
[]
string
{
"gpt-4o"
},
}
emptyPlatformPricing
:=
ChannelModelPricing
{
ID
:
4
,
Models
:
[]
string
{
"gemini-2.5-pro"
},
}
tests
:=
[]
struct
{
name
string
list
[]
ChannelModelPricing
platform
string
model
string
wantID
int64
wantNil
bool
}{
{
name
:
"exact match"
,
list
:
[]
ChannelModelPricing
{
exactPricing
},
platform
:
"anthropic"
,
model
:
"claude-opus-4"
,
wantID
:
1
,
},
{
name
:
"exact match case insensitive"
,
list
:
[]
ChannelModelPricing
{{
ID
:
5
,
Models
:
[]
string
{
"Claude-Opus-4"
}}},
platform
:
""
,
model
:
"claude-opus-4"
,
wantID
:
5
,
},
{
name
:
"wildcard match"
,
list
:
[]
ChannelModelPricing
{
wildcardPricing
},
platform
:
"anthropic"
,
model
:
"claude-opus-4"
,
wantID
:
2
,
},
{
name
:
"exact match takes priority over wildcard"
,
list
:
[]
ChannelModelPricing
{
wildcardPricing
,
exactPricing
},
platform
:
"anthropic"
,
model
:
"claude-opus-4"
,
wantID
:
1
,
},
{
name
:
"platform mismatch skipped"
,
list
:
[]
ChannelModelPricing
{
platformPricing
},
platform
:
"anthropic"
,
model
:
"gpt-4o"
,
wantNil
:
true
,
},
{
name
:
"empty platform in pricing matches any"
,
list
:
[]
ChannelModelPricing
{
emptyPlatformPricing
},
platform
:
"gemini"
,
model
:
"gemini-2.5-pro"
,
wantID
:
4
,
},
{
name
:
"empty platform in query matches any pricing platform"
,
list
:
[]
ChannelModelPricing
{
platformPricing
},
platform
:
""
,
model
:
"gpt-4o"
,
wantID
:
3
,
},
{
name
:
"no match at all"
,
list
:
[]
ChannelModelPricing
{
exactPricing
,
wildcardPricing
},
platform
:
"anthropic"
,
model
:
"gpt-4o"
,
wantNil
:
true
,
},
{
name
:
"empty list returns nil"
,
list
:
nil
,
model
:
"claude-opus-4"
,
wantNil
:
true
,
},
{
name
:
"wildcard matches by config order (first match wins)"
,
list
:
[]
ChannelModelPricing
{
{
ID
:
10
,
Models
:
[]
string
{
"claude-*"
}},
{
ID
:
11
,
Models
:
[]
string
{
"claude-opus-*"
}},
},
platform
:
""
,
model
:
"claude-opus-4"
,
wantID
:
10
,
// config order: "claude-*" is first and matches, so it wins
},
{
name
:
"shorter wildcard used when longer does not match"
,
list
:
[]
ChannelModelPricing
{
{
ID
:
10
,
Models
:
[]
string
{
"claude-*"
}},
{
ID
:
11
,
Models
:
[]
string
{
"claude-opus-*"
}},
},
platform
:
""
,
model
:
"claude-sonnet-4"
,
wantID
:
10
,
// only "claude-*" matches
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
result
:=
findPricingForModel
(
tt
.
list
,
tt
.
platform
,
tt
.
model
)
if
tt
.
wantNil
{
require
.
Nil
(
t
,
result
)
return
}
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
tt
.
wantID
,
result
.
ID
)
})
}
}
// ---------------------------------------------------------------------------
// calculateStatsCost
// ---------------------------------------------------------------------------
func
TestCalculateStatsCost_NilPricing
(
t
*
testing
.
T
)
{
result
:=
calculateStatsCost
(
nil
,
UsageTokens
{},
1
)
require
.
Nil
(
t
,
result
)
}
func
TestCalculateStatsCost_TokenBilling
(
t
*
testing
.
T
)
{
pricing
:=
&
ChannelModelPricing
{
BillingMode
:
BillingModeToken
,
InputPrice
:
testPtrFloat64
(
0.001
),
OutputPrice
:
testPtrFloat64
(
0.002
),
}
tokens
:=
UsageTokens
{
InputTokens
:
100
,
OutputTokens
:
50
,
}
result
:=
calculateStatsCost
(
pricing
,
tokens
,
1
)
require
.
NotNil
(
t
,
result
)
// 100*0.001 + 50*0.002 = 0.1 + 0.1 = 0.2
require
.
InDelta
(
t
,
0.2
,
*
result
,
1e-12
)
}
func
TestCalculateStatsCost_TokenBilling_WithCache
(
t
*
testing
.
T
)
{
pricing
:=
&
ChannelModelPricing
{
BillingMode
:
BillingModeToken
,
InputPrice
:
testPtrFloat64
(
0.001
),
OutputPrice
:
testPtrFloat64
(
0.002
),
CacheWritePrice
:
testPtrFloat64
(
0.003
),
CacheReadPrice
:
testPtrFloat64
(
0.0005
),
}
tokens
:=
UsageTokens
{
InputTokens
:
100
,
OutputTokens
:
50
,
CacheCreationTokens
:
200
,
CacheReadTokens
:
300
,
}
result
:=
calculateStatsCost
(
pricing
,
tokens
,
1
)
require
.
NotNil
(
t
,
result
)
// 100*0.001 + 50*0.002 + 200*0.003 + 300*0.0005
// = 0.1 + 0.1 + 0.6 + 0.15 = 0.95
require
.
InDelta
(
t
,
0.95
,
*
result
,
1e-12
)
}
func
TestCalculateStatsCost_TokenBilling_WithImageOutput
(
t
*
testing
.
T
)
{
pricing
:=
&
ChannelModelPricing
{
BillingMode
:
BillingModeToken
,
InputPrice
:
testPtrFloat64
(
0.001
),
OutputPrice
:
testPtrFloat64
(
0.002
),
ImageOutputPrice
:
testPtrFloat64
(
0.01
),
}
tokens
:=
UsageTokens
{
InputTokens
:
100
,
OutputTokens
:
50
,
ImageOutputTokens
:
10
,
}
result
:=
calculateStatsCost
(
pricing
,
tokens
,
1
)
require
.
NotNil
(
t
,
result
)
// 100*0.001 + 50*0.002 + 10*0.01 = 0.1 + 0.1 + 0.1 = 0.3
require
.
InDelta
(
t
,
0.3
,
*
result
,
1e-12
)
}
func
TestCalculateStatsCost_TokenBilling_PartialPricesNil
(
t
*
testing
.
T
)
{
pricing
:=
&
ChannelModelPricing
{
BillingMode
:
BillingModeToken
,
InputPrice
:
testPtrFloat64
(
0.001
),
// OutputPrice, CacheWritePrice, etc. are all nil → treated as 0
}
tokens
:=
UsageTokens
{
InputTokens
:
100
,
OutputTokens
:
50
,
CacheCreationTokens
:
200
,
}
result
:=
calculateStatsCost
(
pricing
,
tokens
,
1
)
require
.
NotNil
(
t
,
result
)
// Only input contributes: 100*0.001 = 0.1
require
.
InDelta
(
t
,
0.1
,
*
result
,
1e-12
)
}
func
TestCalculateStatsCost_TokenBilling_AllTokensZero
(
t
*
testing
.
T
)
{
pricing
:=
&
ChannelModelPricing
{
BillingMode
:
BillingModeToken
,
InputPrice
:
testPtrFloat64
(
0.001
),
OutputPrice
:
testPtrFloat64
(
0.002
),
}
tokens
:=
UsageTokens
{}
// all zeros
result
:=
calculateStatsCost
(
pricing
,
tokens
,
1
)
// totalCost == 0 → returns nil (does not override, falls back to default formula)
require
.
Nil
(
t
,
result
)
}
func
TestCalculateStatsCost_PerRequestBilling
(
t
*
testing
.
T
)
{
pricing
:=
&
ChannelModelPricing
{
BillingMode
:
BillingModePerRequest
,
PerRequestPrice
:
testPtrFloat64
(
0.05
),
}
tokens
:=
UsageTokens
{
InputTokens
:
999
,
OutputTokens
:
999
}
result
:=
calculateStatsCost
(
pricing
,
tokens
,
3
)
require
.
NotNil
(
t
,
result
)
// 0.05 * 3 = 0.15
require
.
InDelta
(
t
,
0.15
,
*
result
,
1e-12
)
}
func
TestCalculateStatsCost_PerRequestBilling_PriceNil
(
t
*
testing
.
T
)
{
pricing
:=
&
ChannelModelPricing
{
BillingMode
:
BillingModePerRequest
,
// PerRequestPrice is nil
}
result
:=
calculateStatsCost
(
pricing
,
UsageTokens
{},
1
)
require
.
Nil
(
t
,
result
)
}
func
TestCalculateStatsCost_PerRequestBilling_PriceZero
(
t
*
testing
.
T
)
{
pricing
:=
&
ChannelModelPricing
{
BillingMode
:
BillingModePerRequest
,
PerRequestPrice
:
testPtrFloat64
(
0
),
}
result
:=
calculateStatsCost
(
pricing
,
UsageTokens
{},
1
)
// price == 0 → condition *pricing.PerRequestPrice > 0 is false → returns nil
require
.
Nil
(
t
,
result
)
}
func
TestCalculateStatsCost_ImageBilling
(
t
*
testing
.
T
)
{
pricing
:=
&
ChannelModelPricing
{
BillingMode
:
BillingModeImage
,
PerRequestPrice
:
testPtrFloat64
(
0.10
),
}
result
:=
calculateStatsCost
(
pricing
,
UsageTokens
{},
2
)
require
.
NotNil
(
t
,
result
)
// 0.10 * 2 = 0.20
require
.
InDelta
(
t
,
0.20
,
*
result
,
1e-12
)
}
func
TestCalculateStatsCost_ImageBilling_PriceNil
(
t
*
testing
.
T
)
{
pricing
:=
&
ChannelModelPricing
{
BillingMode
:
BillingModeImage
,
// PerRequestPrice is nil
}
result
:=
calculateStatsCost
(
pricing
,
UsageTokens
{},
1
)
require
.
Nil
(
t
,
result
)
}
func
TestCalculateStatsCost_DefaultBillingMode_FallsToToken
(
t
*
testing
.
T
)
{
// BillingMode is empty string (default) → falls into token billing
pricing
:=
&
ChannelModelPricing
{
InputPrice
:
testPtrFloat64
(
0.001
),
OutputPrice
:
testPtrFloat64
(
0.002
),
}
tokens
:=
UsageTokens
{
InputTokens
:
100
,
OutputTokens
:
50
,
}
result
:=
calculateStatsCost
(
pricing
,
tokens
,
1
)
require
.
NotNil
(
t
,
result
)
require
.
InDelta
(
t
,
0.2
,
*
result
,
1e-12
)
}
// ---------------------------------------------------------------------------
// tryCustomRules — 多规则顺序测试
// ---------------------------------------------------------------------------
func
TestTryCustomRules_FirstMatchWins
(
t
*
testing
.
T
)
{
channel
:=
&
Channel
{
AccountStatsPricingRules
:
[]
AccountStatsPricingRule
{
{
GroupIDs
:
[]
int64
{
1
},
Pricing
:
[]
ChannelModelPricing
{
{
ID
:
100
,
Models
:
[]
string
{
"claude-opus-4"
},
InputPrice
:
testPtrFloat64
(
0.01
),
OutputPrice
:
testPtrFloat64
(
0.02
)},
},
},
{
GroupIDs
:
[]
int64
{
1
},
Pricing
:
[]
ChannelModelPricing
{
{
ID
:
200
,
Models
:
[]
string
{
"claude-opus-4"
},
InputPrice
:
testPtrFloat64
(
0.99
),
OutputPrice
:
testPtrFloat64
(
0.99
)},
},
},
},
}
tokens
:=
UsageTokens
{
InputTokens
:
100
,
OutputTokens
:
50
}
result
:=
tryCustomRules
(
channel
,
999
,
1
,
""
,
"claude-opus-4"
,
tokens
,
1
)
require
.
NotNil
(
t
,
result
)
// 应使用第一条规则的价格:100*0.01 + 50*0.02 = 2.0
require
.
InDelta
(
t
,
2.0
,
*
result
,
1e-12
)
}
func
TestTryCustomRules_SkipsNonMatchingRules
(
t
*
testing
.
T
)
{
channel
:=
&
Channel
{
AccountStatsPricingRules
:
[]
AccountStatsPricingRule
{
{
AccountIDs
:
[]
int64
{
888
},
// 不匹配
Pricing
:
[]
ChannelModelPricing
{
{
ID
:
100
,
Models
:
[]
string
{
"claude-opus-4"
},
InputPrice
:
testPtrFloat64
(
0.99
)},
},
},
{
GroupIDs
:
[]
int64
{
1
},
// 匹配
Pricing
:
[]
ChannelModelPricing
{
{
ID
:
200
,
Models
:
[]
string
{
"claude-opus-4"
},
InputPrice
:
testPtrFloat64
(
0.05
)},
},
},
},
}
tokens
:=
UsageTokens
{
InputTokens
:
100
}
result
:=
tryCustomRules
(
channel
,
999
,
1
,
""
,
"claude-opus-4"
,
tokens
,
1
)
require
.
NotNil
(
t
,
result
)
// 跳过规则1(账号不匹配),使用规则2:100*0.05 = 5.0
require
.
InDelta
(
t
,
5.0
,
*
result
,
1e-12
)
}
func
TestTryCustomRules_NoMatch_ReturnsNil
(
t
*
testing
.
T
)
{
channel
:=
&
Channel
{
AccountStatsPricingRules
:
[]
AccountStatsPricingRule
{
{
AccountIDs
:
[]
int64
{
888
},
Pricing
:
[]
ChannelModelPricing
{
{
ID
:
100
,
Models
:
[]
string
{
"claude-opus-4"
},
InputPrice
:
testPtrFloat64
(
0.01
)},
},
},
},
}
tokens
:=
UsageTokens
{
InputTokens
:
100
}
result
:=
tryCustomRules
(
channel
,
999
,
2
,
""
,
"claude-opus-4"
,
tokens
,
1
)
require
.
Nil
(
t
,
result
)
// 账号和分组都不匹配
}
func
TestTryCustomRules_RuleMatchesButModelNot_ContinuesToNext
(
t
*
testing
.
T
)
{
channel
:=
&
Channel
{
AccountStatsPricingRules
:
[]
AccountStatsPricingRule
{
{
GroupIDs
:
[]
int64
{
1
},
Pricing
:
[]
ChannelModelPricing
{
{
ID
:
100
,
Models
:
[]
string
{
"gpt-4o"
},
InputPrice
:
testPtrFloat64
(
0.01
)},
// 模型不匹配
},
},
{
GroupIDs
:
[]
int64
{
1
},
Pricing
:
[]
ChannelModelPricing
{
{
ID
:
200
,
Models
:
[]
string
{
"claude-opus-4"
},
InputPrice
:
testPtrFloat64
(
0.05
)},
// 模型匹配
},
},
},
}
tokens
:=
UsageTokens
{
InputTokens
:
100
}
result
:=
tryCustomRules
(
channel
,
999
,
1
,
""
,
"claude-opus-4"
,
tokens
,
1
)
require
.
NotNil
(
t
,
result
)
require
.
InDelta
(
t
,
5.0
,
*
result
,
1e-12
)
// 使用规则2
}
// ---------------------------------------------------------------------------
// tryModelFilePricing
// ---------------------------------------------------------------------------
// newTestBillingServiceWithPrices creates a BillingService with pre-populated
// fallback prices for testing. No config or pricing service is needed.
// The key must match what getFallbackPricing resolves to for a given model name.
// E.g., model "claude-sonnet-4" resolves to key "claude-sonnet-4".
func
newTestBillingServiceWithPrices
(
prices
map
[
string
]
*
ModelPricing
)
*
BillingService
{
return
&
BillingService
{
fallbackPrices
:
prices
,
}
}
func
TestTryModelFilePricing_Success
(
t
*
testing
.
T
)
{
bs
:=
newTestBillingServiceWithPrices
(
map
[
string
]
*
ModelPricing
{
"claude-sonnet-4"
:
{
InputPricePerToken
:
0.001
,
OutputPricePerToken
:
0.002
,
},
})
tokens
:=
UsageTokens
{
InputTokens
:
100
,
OutputTokens
:
50
}
result
:=
tryModelFilePricing
(
bs
,
"claude-sonnet-4"
,
tokens
)
require
.
NotNil
(
t
,
result
)
// 100*0.001 + 50*0.002 = 0.1 + 0.1 = 0.2
require
.
InDelta
(
t
,
0.2
,
*
result
,
1e-12
)
}
func
TestTryModelFilePricing_PricingNotFound
(
t
*
testing
.
T
)
{
// "nonexistent-model" does not match any fallback pattern
bs
:=
newTestBillingServiceWithPrices
(
map
[
string
]
*
ModelPricing
{})
tokens
:=
UsageTokens
{
InputTokens
:
100
,
OutputTokens
:
50
}
result
:=
tryModelFilePricing
(
bs
,
"nonexistent-model"
,
tokens
)
require
.
Nil
(
t
,
result
)
}
func
TestTryModelFilePricing_NilFallback
(
t
*
testing
.
T
)
{
// getFallbackPricing returns nil when key maps to nil
bs
:=
newTestBillingServiceWithPrices
(
map
[
string
]
*
ModelPricing
{
"claude-sonnet-4"
:
nil
,
})
tokens
:=
UsageTokens
{
InputTokens
:
100
}
result
:=
tryModelFilePricing
(
bs
,
"claude-sonnet-4"
,
tokens
)
require
.
Nil
(
t
,
result
)
}
func
TestTryModelFilePricing_ZeroCost
(
t
*
testing
.
T
)
{
bs
:=
newTestBillingServiceWithPrices
(
map
[
string
]
*
ModelPricing
{
"claude-sonnet-4"
:
{
InputPricePerToken
:
0.001
,
OutputPricePerToken
:
0.002
,
},
})
tokens
:=
UsageTokens
{}
// all zero tokens → cost = 0 → nil
result
:=
tryModelFilePricing
(
bs
,
"claude-sonnet-4"
,
tokens
)
require
.
Nil
(
t
,
result
)
}
func
TestTryModelFilePricing_WithImageOutput
(
t
*
testing
.
T
)
{
bs
:=
newTestBillingServiceWithPrices
(
map
[
string
]
*
ModelPricing
{
"claude-sonnet-4"
:
{
InputPricePerToken
:
0.001
,
OutputPricePerToken
:
0.002
,
ImageOutputPricePerToken
:
0.01
,
},
})
tokens
:=
UsageTokens
{
InputTokens
:
100
,
OutputTokens
:
50
,
ImageOutputTokens
:
10
,
}
result
:=
tryModelFilePricing
(
bs
,
"claude-sonnet-4"
,
tokens
)
require
.
NotNil
(
t
,
result
)
// 100*0.001 + 50*0.002 + 10*0.01 = 0.1 + 0.1 + 0.1 = 0.3
require
.
InDelta
(
t
,
0.3
,
*
result
,
1e-12
)
}
func
TestTryModelFilePricing_WithCacheTokens
(
t
*
testing
.
T
)
{
bs
:=
newTestBillingServiceWithPrices
(
map
[
string
]
*
ModelPricing
{
"claude-sonnet-4"
:
{
InputPricePerToken
:
0.001
,
OutputPricePerToken
:
0.002
,
CacheCreationPricePerToken
:
0.003
,
CacheReadPricePerToken
:
0.0005
,
},
})
tokens
:=
UsageTokens
{
InputTokens
:
100
,
OutputTokens
:
50
,
CacheCreationTokens
:
200
,
CacheReadTokens
:
300
,
}
result
:=
tryModelFilePricing
(
bs
,
"claude-sonnet-4"
,
tokens
)
require
.
NotNil
(
t
,
result
)
// 100*0.001 + 50*0.002 + 200*0.003 + 300*0.0005
// = 0.1 + 0.1 + 0.6 + 0.15 = 0.95
require
.
InDelta
(
t
,
0.95
,
*
result
,
1e-12
)
}
// ---------------------------------------------------------------------------
// resolveAccountStatsCost — integration tests covering the 4-level priority chain
// ---------------------------------------------------------------------------
func
TestResolveAccountStatsCost_NilChannelService
(
t
*
testing
.
T
)
{
result
:=
resolveAccountStatsCost
(
context
.
Background
(),
nil
,
// channelService is nil
newTestBillingServiceWithPrices
(
map
[
string
]
*
ModelPricing
{}),
1
,
1
,
"claude-sonnet-4"
,
UsageTokens
{
InputTokens
:
100
},
1
,
0.5
,
)
require
.
Nil
(
t
,
result
)
}
func
TestResolveAccountStatsCost_EmptyUpstreamModel
(
t
*
testing
.
T
)
{
cs
:=
newTestChannelServiceForStats
(
t
,
&
Channel
{
ID
:
1
,
Status
:
StatusActive
,
},
1
,
""
)
result
:=
resolveAccountStatsCost
(
context
.
Background
(),
cs
,
newTestBillingServiceWithPrices
(
map
[
string
]
*
ModelPricing
{}),
1
,
1
,
""
,
// empty upstream model
UsageTokens
{
InputTokens
:
100
},
1
,
0.5
,
)
require
.
Nil
(
t
,
result
)
}
func
TestResolveAccountStatsCost_GetChannelForGroupReturnsNil
(
t
*
testing
.
T
)
{
// Group 99 is NOT in the cache, so GetChannelForGroup returns nil
cs
:=
newTestChannelServiceForStats
(
t
,
&
Channel
{
ID
:
1
,
Status
:
StatusActive
,
},
1
,
""
)
result
:=
resolveAccountStatsCost
(
context
.
Background
(),
cs
,
newTestBillingServiceWithPrices
(
map
[
string
]
*
ModelPricing
{}),
1
,
99
,
"claude-sonnet-4"
,
// groupID 99 has no channel
UsageTokens
{
InputTokens
:
100
},
1
,
0.5
,
)
require
.
Nil
(
t
,
result
)
}
func
TestResolveAccountStatsCost_HitsCustomRule
(
t
*
testing
.
T
)
{
channel
:=
&
Channel
{
ID
:
1
,
Status
:
StatusActive
,
AccountStatsPricingRules
:
[]
AccountStatsPricingRule
{
{
GroupIDs
:
[]
int64
{
10
},
Pricing
:
[]
ChannelModelPricing
{
{
ID
:
100
,
Models
:
[]
string
{
"claude-sonnet-4"
},
InputPrice
:
testPtrFloat64
(
0.01
),
OutputPrice
:
testPtrFloat64
(
0.02
),
},
},
},
},
}
cs
:=
newTestChannelServiceForStats
(
t
,
channel
,
10
,
"anthropic"
)
tokens
:=
UsageTokens
{
InputTokens
:
100
,
OutputTokens
:
50
}
result
:=
resolveAccountStatsCost
(
context
.
Background
(),
cs
,
nil
,
// billingService not needed when custom rule hits
1
,
10
,
"claude-sonnet-4"
,
tokens
,
1
,
999.0
,
// totalCost ignored because custom rule hits
)
require
.
NotNil
(
t
,
result
)
// 100*0.01 + 50*0.02 = 1.0 + 1.0 = 2.0
require
.
InDelta
(
t
,
2.0
,
*
result
,
1e-12
)
}
func
TestResolveAccountStatsCost_ApplyPricingToAccountStats_UsesTotalCost
(
t
*
testing
.
T
)
{
channel
:=
&
Channel
{
ID
:
1
,
Status
:
StatusActive
,
ApplyPricingToAccountStats
:
true
,
// No custom rules
}
cs
:=
newTestChannelServiceForStats
(
t
,
channel
,
10
,
"anthropic"
)
tokens
:=
UsageTokens
{
InputTokens
:
100
,
OutputTokens
:
50
}
result
:=
resolveAccountStatsCost
(
context
.
Background
(),
cs
,
nil
,
1
,
10
,
"claude-sonnet-4"
,
tokens
,
1
,
0.75
,
// totalCost = 0.75
)
require
.
NotNil
(
t
,
result
)
require
.
InDelta
(
t
,
0.75
,
*
result
,
1e-12
)
}
func
TestResolveAccountStatsCost_ApplyPricingToAccountStats_ZeroTotalCost_ReturnsNil
(
t
*
testing
.
T
)
{
channel
:=
&
Channel
{
ID
:
1
,
Status
:
StatusActive
,
ApplyPricingToAccountStats
:
true
,
}
cs
:=
newTestChannelServiceForStats
(
t
,
channel
,
10
,
"anthropic"
)
result
:=
resolveAccountStatsCost
(
context
.
Background
(),
cs
,
nil
,
1
,
10
,
"claude-sonnet-4"
,
UsageTokens
{},
1
,
0.0
,
// totalCost = 0
)
require
.
Nil
(
t
,
result
)
}
func
TestResolveAccountStatsCost_FallsBackToLiteLLM
(
t
*
testing
.
T
)
{
channel
:=
&
Channel
{
ID
:
1
,
Status
:
StatusActive
,
ApplyPricingToAccountStats
:
false
,
// not enabled
// No custom rules
}
cs
:=
newTestChannelServiceForStats
(
t
,
channel
,
10
,
"anthropic"
)
bs
:=
newTestBillingServiceWithPrices
(
map
[
string
]
*
ModelPricing
{
"claude-sonnet-4"
:
{
InputPricePerToken
:
0.001
,
OutputPricePerToken
:
0.002
,
},
})
tokens
:=
UsageTokens
{
InputTokens
:
100
,
OutputTokens
:
50
}
result
:=
resolveAccountStatsCost
(
context
.
Background
(),
cs
,
bs
,
1
,
10
,
"claude-sonnet-4"
,
tokens
,
1
,
999.0
,
// totalCost ignored
)
require
.
NotNil
(
t
,
result
)
// 100*0.001 + 50*0.002 = 0.1 + 0.1 = 0.2
require
.
InDelta
(
t
,
0.2
,
*
result
,
1e-12
)
}
func
TestResolveAccountStatsCost_AllMiss_ReturnsNil
(
t
*
testing
.
T
)
{
channel
:=
&
Channel
{
ID
:
1
,
Status
:
StatusActive
,
ApplyPricingToAccountStats
:
false
,
// No custom rules
}
cs
:=
newTestChannelServiceForStats
(
t
,
channel
,
10
,
"anthropic"
)
// BillingService with no pricing for the model
bs
:=
newTestBillingServiceWithPrices
(
map
[
string
]
*
ModelPricing
{})
tokens
:=
UsageTokens
{
InputTokens
:
100
,
OutputTokens
:
50
}
result
:=
resolveAccountStatsCost
(
context
.
Background
(),
cs
,
bs
,
1
,
10
,
"totally-unknown-model"
,
tokens
,
1
,
0.0
,
)
require
.
Nil
(
t
,
result
)
}
func
TestResolveAccountStatsCost_NilBillingService_SkipsLiteLLM
(
t
*
testing
.
T
)
{
channel
:=
&
Channel
{
ID
:
1
,
Status
:
StatusActive
,
ApplyPricingToAccountStats
:
false
,
}
cs
:=
newTestChannelServiceForStats
(
t
,
channel
,
10
,
"anthropic"
)
result
:=
resolveAccountStatsCost
(
context
.
Background
(),
cs
,
nil
,
// billingService is nil
1
,
10
,
"claude-sonnet-4"
,
UsageTokens
{
InputTokens
:
100
},
1
,
0.0
,
)
require
.
Nil
(
t
,
result
)
}
func
TestResolveAccountStatsCost_CustomRulePriorityOverApplyPricing
(
t
*
testing
.
T
)
{
// Both custom rule and ApplyPricingToAccountStats are configured;
// custom rule should take precedence.
channel
:=
&
Channel
{
ID
:
1
,
Status
:
StatusActive
,
ApplyPricingToAccountStats
:
true
,
AccountStatsPricingRules
:
[]
AccountStatsPricingRule
{
{
GroupIDs
:
[]
int64
{
10
},
Pricing
:
[]
ChannelModelPricing
{
{
ID
:
100
,
Models
:
[]
string
{
"claude-sonnet-4"
},
InputPrice
:
testPtrFloat64
(
0.05
),
},
},
},
},
}
cs
:=
newTestChannelServiceForStats
(
t
,
channel
,
10
,
"anthropic"
)
tokens
:=
UsageTokens
{
InputTokens
:
100
}
result
:=
resolveAccountStatsCost
(
context
.
Background
(),
cs
,
nil
,
1
,
10
,
"claude-sonnet-4"
,
tokens
,
1
,
99.0
,
// totalCost = 99.0 (would be used if ApplyPricing wins)
)
require
.
NotNil
(
t
,
result
)
// Custom rule: 100*0.05 = 5.0 (NOT 99.0 from totalCost)
require
.
InDelta
(
t
,
5.0
,
*
result
,
1e-12
)
}
// ---------------------------------------------------------------------------
// helpers for resolveAccountStatsCost tests
// ---------------------------------------------------------------------------
// newTestChannelServiceForStats creates a ChannelService with a single channel
// mapped to the given groupID, suitable for resolveAccountStatsCost tests.
func
newTestChannelServiceForStats
(
t
*
testing
.
T
,
channel
*
Channel
,
groupID
int64
,
platform
string
)
*
ChannelService
{
t
.
Helper
()
cache
:=
newEmptyChannelCache
()
cache
.
channelByGroupID
[
groupID
]
=
channel
cache
.
groupPlatform
[
groupID
]
=
platform
cs
:=
&
ChannelService
{}
cache
.
loadedAt
=
time
.
Now
()
cs
.
cache
.
Store
(
cache
)
return
cs
}
backend/internal/service/account_test_service.go
View file @
0b746501
...
...
@@ -515,22 +515,10 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
_
=
s
.
accountRepo
.
UpdateExtra
(
ctx
,
account
.
ID
,
updates
)
mergeAccountExtra
(
account
,
updates
)
}
if
snapshot
:=
ParseCodexRateLimitHeaders
(
resp
.
Header
);
snapshot
!=
nil
{
if
resetAt
:=
codexRateLimitResetAtFromSnapshot
(
snapshot
,
time
.
Now
());
resetAt
!=
nil
{
_
=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
*
resetAt
)
account
.
RateLimitResetAt
=
resetAt
}
}
}
if
resp
.
StatusCode
!=
http
.
StatusOK
{
body
,
_
:=
io
.
ReadAll
(
resp
.
Body
)
if
isOAuth
&&
s
.
accountRepo
!=
nil
{
if
resetAt
:=
(
&
RateLimitService
{})
.
calculateOpenAI429ResetTime
(
resp
.
Header
);
resetAt
!=
nil
{
_
=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
*
resetAt
)
account
.
RateLimitResetAt
=
resetAt
}
}
// 401 Unauthorized: 标记账号为永久错误
if
resp
.
StatusCode
==
http
.
StatusUnauthorized
&&
s
.
accountRepo
!=
nil
{
errMsg
:=
fmt
.
Sprintf
(
"Authentication failed (401): %s"
,
string
(
body
))
...
...
backend/internal/service/account_test_service_openai_test.go
View file @
0b746501
...
...
@@ -111,7 +111,7 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.
require
.
Contains
(
t
,
recorder
.
Body
.
String
(),
"test_complete"
)
}
func
TestAccountTestService_OpenAI429PersistsSnapshot
And
RateLimit
(
t
*
testing
.
T
)
{
func
TestAccountTestService_OpenAI429PersistsSnapshot
Without
RateLimit
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
ctx
,
_
:=
newTestContext
()
...
...
@@ -138,10 +138,7 @@ func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimit(t *testing.T)
require
.
Error
(
t
,
err
)
require
.
NotEmpty
(
t
,
repo
.
updatedExtra
)
require
.
Equal
(
t
,
100.0
,
repo
.
updatedExtra
[
"codex_5h_used_percent"
])
require
.
Equal
(
t
,
int64
(
88
),
repo
.
rateLimitedID
)
require
.
NotNil
(
t
,
repo
.
rateLimitedAt
)
require
.
NotNil
(
t
,
account
.
RateLimitResetAt
)
if
account
.
RateLimitResetAt
!=
nil
&&
repo
.
rateLimitedAt
!=
nil
{
require
.
WithinDuration
(
t
,
*
repo
.
rateLimitedAt
,
*
account
.
RateLimitResetAt
,
time
.
Second
)
}
require
.
Zero
(
t
,
repo
.
rateLimitedID
)
require
.
Nil
(
t
,
repo
.
rateLimitedAt
)
require
.
Nil
(
t
,
account
.
RateLimitResetAt
)
}
backend/internal/service/account_usage_service.go
View file @
0b746501
...
...
@@ -499,7 +499,6 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou
if
account
==
nil
{
return
usage
,
nil
}
syncOpenAICodexRateLimitFromExtra
(
ctx
,
s
.
accountRepo
,
account
,
now
)
if
progress
:=
buildCodexUsageProgressFromExtra
(
account
.
Extra
,
"5h"
,
now
);
progress
!=
nil
{
usage
.
FiveHour
=
progress
...
...
@@ -509,11 +508,8 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou
}
if
shouldRefreshOpenAICodexSnapshot
(
account
,
usage
,
now
)
&&
s
.
shouldProbeOpenAICodexSnapshot
(
account
.
ID
,
now
)
{
if
updates
,
resetAt
,
err
:=
s
.
probeOpenAICodexSnapshot
(
ctx
,
account
);
err
==
nil
&&
(
len
(
updates
)
>
0
||
resetAt
!=
nil
)
{
if
updates
,
err
:=
s
.
probeOpenAICodexSnapshot
(
ctx
,
account
);
err
==
nil
&&
len
(
updates
)
>
0
{
mergeAccountExtra
(
account
,
updates
)
if
resetAt
!=
nil
{
account
.
RateLimitResetAt
=
resetAt
}
if
usage
.
UpdatedAt
==
nil
{
usage
.
UpdatedAt
=
&
now
}
...
...
@@ -594,26 +590,26 @@ func (s *AccountUsageService) shouldProbeOpenAICodexSnapshot(accountID int64, no
return
true
}
func
(
s
*
AccountUsageService
)
probeOpenAICodexSnapshot
(
ctx
context
.
Context
,
account
*
Account
)
(
map
[
string
]
any
,
*
time
.
Time
,
error
)
{
func
(
s
*
AccountUsageService
)
probeOpenAICodexSnapshot
(
ctx
context
.
Context
,
account
*
Account
)
(
map
[
string
]
any
,
error
)
{
if
account
==
nil
||
!
account
.
IsOAuth
()
{
return
nil
,
nil
,
nil
return
nil
,
nil
}
accessToken
:=
account
.
GetOpenAIAccessToken
()
if
accessToken
==
""
{
return
nil
,
nil
,
fmt
.
Errorf
(
"no access token available"
)
return
nil
,
fmt
.
Errorf
(
"no access token available"
)
}
modelID
:=
openaipkg
.
DefaultTestModel
payload
:=
createOpenAITestPayload
(
modelID
,
true
)
payloadBytes
,
err
:=
json
.
Marshal
(
payload
)
if
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"marshal openai probe payload: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"marshal openai probe payload: %w"
,
err
)
}
reqCtx
,
cancel
:=
context
.
WithTimeout
(
ctx
,
15
*
time
.
Second
)
defer
cancel
()
req
,
err
:=
http
.
NewRequestWithContext
(
reqCtx
,
http
.
MethodPost
,
chatgptCodexURL
,
bytes
.
NewReader
(
payloadBytes
))
if
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"create openai probe request: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"create openai probe request: %w"
,
err
)
}
req
.
Host
=
"chatgpt.com"
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
...
...
@@ -642,67 +638,51 @@ func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, acco
ResponseHeaderTimeout
:
10
*
time
.
Second
,
})
if
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"build openai probe client: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"build openai probe client: %w"
,
err
)
}
resp
,
err
:=
client
.
Do
(
req
)
if
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"openai codex probe request failed: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"openai codex probe request failed: %w"
,
err
)
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
updates
,
resetAt
,
err
:=
extractOpenAICodexProbe
Snapshot
(
resp
)
updates
,
err
:=
extractOpenAICodexProbe
Updates
(
resp
)
if
err
!=
nil
{
return
nil
,
nil
,
err
return
nil
,
err
}
if
len
(
updates
)
>
0
||
resetAt
!=
nil
{
s
.
persistOpenAICodexProbeSnapshot
(
account
.
ID
,
updates
,
resetAt
)
return
updates
,
resetAt
,
nil
if
len
(
updates
)
>
0
{
s
.
persistOpenAICodexProbeSnapshot
(
account
.
ID
,
updates
)
return
updates
,
nil
}
return
nil
,
nil
,
nil
return
nil
,
nil
}
func
(
s
*
AccountUsageService
)
persistOpenAICodexProbeSnapshot
(
accountID
int64
,
updates
map
[
string
]
any
,
resetAt
*
time
.
Time
)
{
func
(
s
*
AccountUsageService
)
persistOpenAICodexProbeSnapshot
(
accountID
int64
,
updates
map
[
string
]
any
)
{
if
s
==
nil
||
s
.
accountRepo
==
nil
||
accountID
<=
0
{
return
}
if
len
(
updates
)
==
0
&&
resetAt
==
nil
{
if
len
(
updates
)
==
0
{
return
}
go
func
()
{
updateCtx
,
updateCancel
:=
context
.
WithTimeout
(
context
.
Background
(),
5
*
time
.
Second
)
defer
updateCancel
()
if
len
(
updates
)
>
0
{
_
=
s
.
accountRepo
.
UpdateExtra
(
updateCtx
,
accountID
,
updates
)
}
if
resetAt
!=
nil
{
_
=
s
.
accountRepo
.
SetRateLimited
(
updateCtx
,
accountID
,
*
resetAt
)
}
_
=
s
.
accountRepo
.
UpdateExtra
(
updateCtx
,
accountID
,
updates
)
}()
}
func
extractOpenAICodexProbe
Snapshot
(
resp
*
http
.
Response
)
(
map
[
string
]
any
,
*
time
.
Time
,
error
)
{
func
extractOpenAICodexProbe
Updates
(
resp
*
http
.
Response
)
(
map
[
string
]
any
,
error
)
{
if
resp
==
nil
{
return
nil
,
nil
,
nil
return
nil
,
nil
}
if
snapshot
:=
ParseCodexRateLimitHeaders
(
resp
.
Header
);
snapshot
!=
nil
{
baseTime
:=
time
.
Now
()
updates
:=
buildCodexUsageExtraUpdates
(
snapshot
,
baseTime
)
resetAt
:=
codexRateLimitResetAtFromSnapshot
(
snapshot
,
baseTime
)
if
len
(
updates
)
>
0
{
return
updates
,
resetAt
,
nil
}
return
nil
,
resetAt
,
nil
return
buildCodexUsageExtraUpdates
(
snapshot
,
time
.
Now
()),
nil
}
if
resp
.
StatusCode
<
200
||
resp
.
StatusCode
>=
300
{
return
nil
,
nil
,
fmt
.
Errorf
(
"openai codex probe returned status %d"
,
resp
.
StatusCode
)
return
nil
,
fmt
.
Errorf
(
"openai codex probe returned status %d"
,
resp
.
StatusCode
)
}
return
nil
,
nil
,
nil
}
func
extractOpenAICodexProbeUpdates
(
resp
*
http
.
Response
)
(
map
[
string
]
any
,
error
)
{
updates
,
_
,
err
:=
extractOpenAICodexProbeSnapshot
(
resp
)
return
updates
,
err
return
nil
,
nil
}
func
mergeAccountExtra
(
account
*
Account
,
updates
map
[
string
]
any
)
{
...
...
Prev
1
2
3
4
5
6
7
8
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment