Commit ee4bfcbb authored by Elysia's avatar Elysia
Browse files

Merge remote-tracking branch 'origin/main'

parents 32d619a5 cac23020
......@@ -8,6 +8,7 @@ import (
"net/http"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service"
)
......@@ -95,7 +96,8 @@ func (s *claudeUsageService) FetchUsageWithOptions(ctx context.Context, opts *se
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
msg := fmt.Sprintf("API returned status %d: %s", resp.StatusCode, string(body))
return nil, infraerrors.New(http.StatusInternalServerError, "UPSTREAM_ERROR", msg)
}
var usageResp service.ClaudeUsageResponse
......
......@@ -1363,7 +1363,8 @@ func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user
COUNT(*) as requests,
COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens,
COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
......@@ -1401,6 +1402,8 @@ func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64
COUNT(*) as requests,
COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
......@@ -1664,7 +1667,8 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
COUNT(*) as requests,
COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens,
COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
......@@ -1747,7 +1751,8 @@ func (r *usageLogRepository) getUsageTrendFromAggregates(ctx context.Context, st
total_requests as requests,
input_tokens,
output_tokens,
(cache_creation_tokens + cache_read_tokens) as cache_tokens,
cache_creation_tokens,
cache_read_tokens,
(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) as total_tokens,
total_cost as cost,
actual_cost
......@@ -1762,7 +1767,8 @@ func (r *usageLogRepository) getUsageTrendFromAggregates(ctx context.Context, st
total_requests as requests,
input_tokens,
output_tokens,
(cache_creation_tokens + cache_read_tokens) as cache_tokens,
cache_creation_tokens,
cache_read_tokens,
(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) as total_tokens,
total_cost as cost,
actual_cost
......@@ -1806,6 +1812,8 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
COUNT(*) as requests,
COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
%s
......@@ -2622,7 +2630,8 @@ func scanTrendRows(rows *sql.Rows) ([]TrendDataPoint, error) {
&row.Requests,
&row.InputTokens,
&row.OutputTokens,
&row.CacheTokens,
&row.CacheCreationTokens,
&row.CacheReadTokens,
&row.TotalTokens,
&row.Cost,
&row.ActualCost,
......@@ -2646,6 +2655,8 @@ func scanModelStatsRows(rows *sql.Rows) ([]ModelStat, error) {
&row.Requests,
&row.InputTokens,
&row.OutputTokens,
&row.CacheCreationTokens,
&row.CacheReadTokens,
&row.TotalTokens,
&row.Cost,
&row.ActualCost,
......
......@@ -125,7 +125,7 @@ func TestUsageLogRepositoryGetUsageTrendWithFiltersRequestTypePriority(t *testin
mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE\\)\\)").
WithArgs(start, end, requestType).
WillReturnRows(sqlmock.NewRows([]string{"date", "requests", "input_tokens", "output_tokens", "cache_tokens", "total_tokens", "cost", "actual_cost"}))
WillReturnRows(sqlmock.NewRows([]string{"date", "requests", "input_tokens", "output_tokens", "cache_creation_tokens", "cache_read_tokens", "total_tokens", "cost", "actual_cost"}))
trend, err := repo.GetUsageTrendWithFilters(context.Background(), start, end, "day", 0, 0, 0, 0, "", &requestType, &stream, nil)
require.NoError(t, err)
......@@ -144,7 +144,7 @@ func TestUsageLogRepositoryGetModelStatsWithFiltersRequestTypePriority(t *testin
mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)").
WithArgs(start, end, requestType).
WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "total_tokens", "cost", "actual_cost"}))
WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "cache_creation_tokens", "cache_read_tokens", "total_tokens", "cost", "actual_cost"}))
stats, err := repo.GetModelStatsWithFilters(context.Background(), start, end, 0, 0, 0, 0, &requestType, &stream, nil)
require.NoError(t, err)
......
......@@ -1096,6 +1096,14 @@ func (s *stubAccountRepo) UpdateExtra(ctx context.Context, id int64, updates map
return errors.New("not implemented")
}
func (s *stubAccountRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
return errors.New("not implemented")
}
func (s *stubAccountRepo) ResetQuotaUsed(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (s *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
s.bulkUpdateIDs = append([]int64{}, ids...)
return int64(len(ids)), nil
......
......@@ -252,6 +252,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats)
accounts.POST("/today-stats/batch", h.Admin.Account.GetBatchTodayStats)
accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit)
accounts.POST("/:id/reset-quota", h.Admin.Account.ResetQuota)
accounts.GET("/:id/temp-unschedulable", h.Admin.Account.GetTempUnschedulable)
accounts.DELETE("/:id/temp-unschedulable", h.Admin.Account.ClearTempUnschedulable)
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
......
......@@ -28,6 +28,7 @@ type Account struct {
// RateMultiplier 账号计费倍率(>=0,允许 0 表示该账号计费为 0)。
// 使用指针用于兼容旧版本调度缓存(Redis)中缺字段的情况:nil 表示按 1.0 处理。
RateMultiplier *float64
LoadFactor *int // 调度负载因子;nil 表示使用 Concurrency
Status string
ErrorMessage string
LastUsedAt *time.Time
......@@ -88,6 +89,19 @@ func (a *Account) BillingRateMultiplier() float64 {
return *a.RateMultiplier
}
func (a *Account) EffectiveLoadFactor() int {
if a == nil {
return 1
}
if a.LoadFactor != nil && *a.LoadFactor > 0 {
return *a.LoadFactor
}
if a.Concurrency > 0 {
return a.Concurrency
}
return 1
}
func (a *Account) IsSchedulable() bool {
if !a.IsActive() || !a.Schedulable {
return false
......@@ -1117,6 +1131,38 @@ func (a *Account) GetCacheTTLOverrideTarget() string {
return "5m"
}
// GetQuotaLimit 获取 API Key 账号的配额限制(美元)
// 返回 0 表示未启用
func (a *Account) GetQuotaLimit() float64 {
if a.Extra == nil {
return 0
}
if v, ok := a.Extra["quota_limit"]; ok {
return parseExtraFloat64(v)
}
return 0
}
// GetQuotaUsed 获取 API Key 账号的已用配额(美元)
func (a *Account) GetQuotaUsed() float64 {
if a.Extra == nil {
return 0
}
if v, ok := a.Extra["quota_used"]; ok {
return parseExtraFloat64(v)
}
return 0
}
// IsQuotaExceeded 检查 API Key 账号配额是否已超限
func (a *Account) IsQuotaExceeded() bool {
limit := a.GetQuotaLimit()
if limit <= 0 {
return false
}
return a.GetQuotaUsed() >= limit
}
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
// 返回 0 表示未启用
func (a *Account) GetWindowCostLimit() float64 {
......
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func intPtrHelper(v int) *int { return &v }
func TestEffectiveLoadFactor_NilAccount(t *testing.T) {
var a *Account
require.Equal(t, 1, a.EffectiveLoadFactor())
}
func TestEffectiveLoadFactor_NilLoadFactor_PositiveConcurrency(t *testing.T) {
a := &Account{Concurrency: 5}
require.Equal(t, 5, a.EffectiveLoadFactor())
}
func TestEffectiveLoadFactor_NilLoadFactor_ZeroConcurrency(t *testing.T) {
a := &Account{Concurrency: 0}
require.Equal(t, 1, a.EffectiveLoadFactor())
}
func TestEffectiveLoadFactor_PositiveLoadFactor(t *testing.T) {
a := &Account{Concurrency: 5, LoadFactor: intPtrHelper(20)}
require.Equal(t, 20, a.EffectiveLoadFactor())
}
func TestEffectiveLoadFactor_ZeroLoadFactor_FallbackToConcurrency(t *testing.T) {
a := &Account{Concurrency: 5, LoadFactor: intPtrHelper(0)}
require.Equal(t, 5, a.EffectiveLoadFactor())
}
func TestEffectiveLoadFactor_NegativeLoadFactor_FallbackToConcurrency(t *testing.T) {
a := &Account{Concurrency: 3, LoadFactor: intPtrHelper(-1)}
require.Equal(t, 3, a.EffectiveLoadFactor())
}
func TestEffectiveLoadFactor_ZeroLoadFactor_ZeroConcurrency(t *testing.T) {
a := &Account{Concurrency: 0, LoadFactor: intPtrHelper(0)}
require.Equal(t, 1, a.EffectiveLoadFactor())
}
......@@ -68,6 +68,10 @@ type AccountRepository interface {
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error)
// IncrementQuotaUsed 原子递增 API Key 账号的配额用量
IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error
// ResetQuotaUsed 重置 API Key 账号的配额用量为 0
ResetQuotaUsed(ctx context.Context, id int64) error
}
// AccountBulkUpdate describes the fields that can be updated in a bulk operation.
......@@ -78,6 +82,7 @@ type AccountBulkUpdate struct {
Concurrency *int
Priority *int
RateMultiplier *float64
LoadFactor *int
Status *string
Schedulable *bool
Credentials map[string]any
......
......@@ -199,6 +199,14 @@ func (s *accountRepoStub) BulkUpdate(ctx context.Context, ids []int64, updates A
panic("unexpected BulkUpdate call")
}
func (s *accountRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
return nil
}
func (s *accountRepoStub) ResetQuotaUsed(ctx context.Context, id int64) error {
return nil
}
// TestAccountService_Delete_NotFound 测试删除不存在的账号时返回正确的错误。
// 预期行为:
// - ExistsByID 返回 false(账号不存在)
......
......@@ -180,7 +180,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
}
if account.Platform == PlatformAntigravity {
return s.testAntigravityAccountConnection(c, account, modelID)
return s.routeAntigravityTest(c, account, modelID)
}
if account.Platform == PlatformSora {
......@@ -1177,6 +1177,18 @@ func truncateSoraErrorBody(body []byte, max int) string {
return soraerror.TruncateBody(body, max)
}
// routeAntigravityTest 路由 Antigravity 账号的测试请求。
// APIKey 类型走原生协议(与 gateway_handler 路由一致),OAuth/Upstream 走 CRS 中转。
func (s *AccountTestService) routeAntigravityTest(c *gin.Context, account *Account, modelID string) error {
if account.Type == AccountTypeAPIKey {
if strings.HasPrefix(modelID, "gemini-") {
return s.testGeminiAccountConnection(c, account, modelID)
}
return s.testClaudeAccountConnection(c, account, modelID)
}
return s.testAntigravityAccountConnection(c, account, modelID)
}
// testAntigravityAccountConnection tests an Antigravity account's connection
// 支持 Claude 和 Gemini 两种协议,使用非流式请求
func (s *AccountTestService) testAntigravityAccountConnection(c *gin.Context, account *Account, modelID string) error {
......
......@@ -84,6 +84,7 @@ type AdminService interface {
DeleteRedeemCode(ctx context.Context, id int64) error
BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error)
ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error)
ResetAccountQuota(ctx context.Context, id int64) error
}
// CreateUserInput represents input for creating a new user via admin operations.
......@@ -195,6 +196,7 @@ type CreateAccountInput struct {
Concurrency int
Priority int
RateMultiplier *float64 // 账号计费倍率(>=0,允许 0)
LoadFactor *int
GroupIDs []int64
ExpiresAt *int64
AutoPauseOnExpired *bool
......@@ -215,6 +217,7 @@ type UpdateAccountInput struct {
Concurrency *int // 使用指针区分"未提供"和"设置为0"
Priority *int // 使用指针区分"未提供"和"设置为0"
RateMultiplier *float64 // 账号计费倍率(>=0,允许 0)
LoadFactor *int
Status string
GroupIDs *[]int64
ExpiresAt *int64
......@@ -230,6 +233,7 @@ type BulkUpdateAccountsInput struct {
Concurrency *int
Priority *int
RateMultiplier *float64 // 账号计费倍率(>=0,允许 0)
LoadFactor *int
Status string
Schedulable *bool
GroupIDs *[]int64
......@@ -1413,6 +1417,12 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
}
account.RateMultiplier = input.RateMultiplier
}
if input.LoadFactor != nil && *input.LoadFactor > 0 {
if *input.LoadFactor > 10000 {
return nil, errors.New("load_factor must be <= 10000")
}
account.LoadFactor = input.LoadFactor
}
if err := s.accountRepo.Create(ctx, account); err != nil {
return nil, err
}
......@@ -1458,6 +1468,10 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
account.Credentials = input.Credentials
}
if len(input.Extra) > 0 {
// 保留 quota_used,防止编辑账号时意外重置配额用量
if oldQuotaUsed, ok := account.Extra["quota_used"]; ok {
input.Extra["quota_used"] = oldQuotaUsed
}
account.Extra = input.Extra
}
if input.ProxyID != nil {
......@@ -1483,6 +1497,15 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
}
account.RateMultiplier = input.RateMultiplier
}
if input.LoadFactor != nil {
if *input.LoadFactor <= 0 {
account.LoadFactor = nil // 0 或负数表示清除
} else if *input.LoadFactor > 10000 {
return nil, errors.New("load_factor must be <= 10000")
} else {
account.LoadFactor = input.LoadFactor
}
}
if input.Status != "" {
account.Status = input.Status
}
......@@ -1616,6 +1639,15 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
if input.RateMultiplier != nil {
repoUpdates.RateMultiplier = input.RateMultiplier
}
if input.LoadFactor != nil {
if *input.LoadFactor <= 0 {
repoUpdates.LoadFactor = nil // 0 或负数表示清除
} else if *input.LoadFactor > 10000 {
return nil, errors.New("load_factor must be <= 10000")
} else {
repoUpdates.LoadFactor = input.LoadFactor
}
}
if input.Status != "" {
repoUpdates.Status = &input.Status
}
......@@ -2439,3 +2471,7 @@ func (e *MixedChannelError) Error() string {
return fmt.Sprintf("mixed_channel_warning: Group '%s' contains both %s and %s accounts. Using mixed channels in the same context may cause thinking block signature validation issues, which will fallback to non-thinking mode for historical messages.",
e.GroupName, e.CurrentPlatform, e.OtherPlatform)
}
func (s *adminServiceImpl) ResetAccountQuota(ctx context.Context, id int64) error {
return s.accountRepo.ResetQuotaUsed(ctx, id)
}
......@@ -43,15 +43,24 @@ type BillingCache interface {
// ModelPricing 模型价格配置(per-token价格,与LiteLLM格式一致)
type ModelPricing struct {
InputPricePerToken float64 // 每token输入价格 (USD)
OutputPricePerToken float64 // 每token输出价格 (USD)
CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD)
CacheReadPricePerToken float64 // 缓存读取每token价格 (USD)
CacheCreation5mPrice float64 // 5分钟缓存创建每token价格 (USD)
CacheCreation1hPrice float64 // 1小时缓存创建每token价格 (USD)
SupportsCacheBreakdown bool // 是否支持详细的缓存分类
InputPricePerToken float64 // 每token输入价格 (USD)
OutputPricePerToken float64 // 每token输出价格 (USD)
CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD)
CacheReadPricePerToken float64 // 缓存读取每token价格 (USD)
CacheCreation5mPrice float64 // 5分钟缓存创建每token价格 (USD)
CacheCreation1hPrice float64 // 1小时缓存创建每token价格 (USD)
SupportsCacheBreakdown bool // 是否支持详细的缓存分类
LongContextInputThreshold int // 超过阈值后按整次会话提升输入价格
LongContextInputMultiplier float64 // 长上下文整次会话输入倍率
LongContextOutputMultiplier float64 // 长上下文整次会话输出倍率
}
const (
openAIGPT54LongContextInputThreshold = 272000
openAIGPT54LongContextInputMultiplier = 2.0
openAIGPT54LongContextOutputMultiplier = 1.5
)
// UsageTokens 使用的token数量
type UsageTokens struct {
InputTokens int
......@@ -161,6 +170,35 @@ func (s *BillingService) initFallbackPricing() {
CacheReadPricePerToken: 0.2e-6, // $0.20 per MTok
SupportsCacheBreakdown: false,
}
// OpenAI GPT-5.1(本地兜底,防止动态定价不可用时拒绝计费)
s.fallbackPrices["gpt-5.1"] = &ModelPricing{
InputPricePerToken: 1.25e-6, // $1.25 per MTok
OutputPricePerToken: 10e-6, // $10 per MTok
CacheCreationPricePerToken: 1.25e-6, // $1.25 per MTok
CacheReadPricePerToken: 0.125e-6,
SupportsCacheBreakdown: false,
}
// OpenAI GPT-5.4(业务指定价格)
s.fallbackPrices["gpt-5.4"] = &ModelPricing{
InputPricePerToken: 2.5e-6, // $2.5 per MTok
OutputPricePerToken: 15e-6, // $15 per MTok
CacheCreationPricePerToken: 2.5e-6, // $2.5 per MTok
CacheReadPricePerToken: 0.25e-6, // $0.25 per MTok
SupportsCacheBreakdown: false,
LongContextInputThreshold: openAIGPT54LongContextInputThreshold,
LongContextInputMultiplier: openAIGPT54LongContextInputMultiplier,
LongContextOutputMultiplier: openAIGPT54LongContextOutputMultiplier,
}
// Codex 族兜底统一按 GPT-5.1 Codex 价格计费
s.fallbackPrices["gpt-5.1-codex"] = &ModelPricing{
InputPricePerToken: 1.5e-6, // $1.5 per MTok
OutputPricePerToken: 12e-6, // $12 per MTok
CacheCreationPricePerToken: 1.5e-6, // $1.5 per MTok
CacheReadPricePerToken: 0.15e-6,
SupportsCacheBreakdown: false,
}
s.fallbackPrices["gpt-5.3-codex"] = s.fallbackPrices["gpt-5.1-codex"]
}
// getFallbackPricing 根据模型系列获取回退价格
......@@ -189,12 +227,30 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
}
return s.fallbackPrices["claude-3-haiku"]
}
// Claude 未知型号统一回退到 Sonnet,避免计费中断。
if strings.Contains(modelLower, "claude") {
return s.fallbackPrices["claude-sonnet-4"]
}
if strings.Contains(modelLower, "gemini-3.1-pro") || strings.Contains(modelLower, "gemini-3-1-pro") {
return s.fallbackPrices["gemini-3.1-pro"]
}
// 默认使用Sonnet价格
return s.fallbackPrices["claude-sonnet-4"]
// OpenAI 仅匹配已知 GPT-5/Codex 族,避免未知 OpenAI 型号误计价。
if strings.Contains(modelLower, "gpt-5") || strings.Contains(modelLower, "codex") {
normalized := normalizeCodexModel(modelLower)
switch normalized {
case "gpt-5.4":
return s.fallbackPrices["gpt-5.4"]
case "gpt-5.3-codex":
return s.fallbackPrices["gpt-5.3-codex"]
case "gpt-5.1-codex", "gpt-5.1-codex-max", "gpt-5.1-codex-mini", "codex-mini-latest":
return s.fallbackPrices["gpt-5.1-codex"]
case "gpt-5.1":
return s.fallbackPrices["gpt-5.1"]
}
}
return nil
}
// GetModelPricing 获取模型价格配置
......@@ -212,15 +268,18 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
price5m := litellmPricing.CacheCreationInputTokenCost
price1h := litellmPricing.CacheCreationInputTokenCostAbove1hr
enableBreakdown := price1h > 0 && price1h > price5m
return &ModelPricing{
InputPricePerToken: litellmPricing.InputCostPerToken,
OutputPricePerToken: litellmPricing.OutputCostPerToken,
CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost,
CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost,
CacheCreation5mPrice: price5m,
CacheCreation1hPrice: price1h,
SupportsCacheBreakdown: enableBreakdown,
}, nil
return s.applyModelSpecificPricingPolicy(model, &ModelPricing{
InputPricePerToken: litellmPricing.InputCostPerToken,
OutputPricePerToken: litellmPricing.OutputCostPerToken,
CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost,
CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost,
CacheCreation5mPrice: price5m,
CacheCreation1hPrice: price1h,
SupportsCacheBreakdown: enableBreakdown,
LongContextInputThreshold: litellmPricing.LongContextInputTokenThreshold,
LongContextInputMultiplier: litellmPricing.LongContextInputCostMultiplier,
LongContextOutputMultiplier: litellmPricing.LongContextOutputCostMultiplier,
}), nil
}
}
......@@ -228,7 +287,7 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
fallback := s.getFallbackPricing(model)
if fallback != nil {
log.Printf("[Billing] Using fallback pricing for model: %s", model)
return fallback, nil
return s.applyModelSpecificPricingPolicy(model, fallback), nil
}
return nil, fmt.Errorf("pricing not found for model: %s", model)
......@@ -242,12 +301,18 @@ func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMul
}
breakdown := &CostBreakdown{}
inputPricePerToken := pricing.InputPricePerToken
outputPricePerToken := pricing.OutputPricePerToken
if s.shouldApplySessionLongContextPricing(tokens, pricing) {
inputPricePerToken *= pricing.LongContextInputMultiplier
outputPricePerToken *= pricing.LongContextOutputMultiplier
}
// 计算输入token费用(使用per-token价格)
breakdown.InputCost = float64(tokens.InputTokens) * pricing.InputPricePerToken
breakdown.InputCost = float64(tokens.InputTokens) * inputPricePerToken
// 计算输出token费用
breakdown.OutputCost = float64(tokens.OutputTokens) * pricing.OutputPricePerToken
breakdown.OutputCost = float64(tokens.OutputTokens) * outputPricePerToken
// 计算缓存费用
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) {
......@@ -279,6 +344,45 @@ func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMul
return breakdown, nil
}
func (s *BillingService) applyModelSpecificPricingPolicy(model string, pricing *ModelPricing) *ModelPricing {
if pricing == nil {
return nil
}
if !isOpenAIGPT54Model(model) {
return pricing
}
if pricing.LongContextInputThreshold > 0 && pricing.LongContextInputMultiplier > 0 && pricing.LongContextOutputMultiplier > 0 {
return pricing
}
cloned := *pricing
if cloned.LongContextInputThreshold <= 0 {
cloned.LongContextInputThreshold = openAIGPT54LongContextInputThreshold
}
if cloned.LongContextInputMultiplier <= 0 {
cloned.LongContextInputMultiplier = openAIGPT54LongContextInputMultiplier
}
if cloned.LongContextOutputMultiplier <= 0 {
cloned.LongContextOutputMultiplier = openAIGPT54LongContextOutputMultiplier
}
return &cloned
}
func (s *BillingService) shouldApplySessionLongContextPricing(tokens UsageTokens, pricing *ModelPricing) bool {
if pricing == nil || pricing.LongContextInputThreshold <= 0 {
return false
}
if pricing.LongContextInputMultiplier <= 1 && pricing.LongContextOutputMultiplier <= 1 {
return false
}
totalInputTokens := tokens.InputTokens + tokens.CacheReadTokens
return totalInputTokens > pricing.LongContextInputThreshold
}
func isOpenAIGPT54Model(model string) bool {
normalized := normalizeCodexModel(strings.TrimSpace(strings.ToLower(model)))
return normalized == "gpt-5.4"
}
// CalculateCostWithConfig 使用配置中的默认倍率计算费用
func (s *BillingService) CalculateCostWithConfig(model string, tokens UsageTokens) (*CostBreakdown, error) {
multiplier := s.cfg.Default.RateMultiplier
......
......@@ -133,7 +133,7 @@ func TestGetModelPricing_CaseInsensitive(t *testing.T) {
require.Equal(t, p1.InputPricePerToken, p2.InputPricePerToken)
}
func TestGetModelPricing_UnknownModelFallsBackToSonnet(t *testing.T) {
func TestGetModelPricing_UnknownClaudeModelFallsBackToSonnet(t *testing.T) {
svc := newTestBillingService()
// 不包含 opus/sonnet/haiku 关键词的 Claude 模型会走默认 Sonnet 价格
......@@ -142,6 +142,93 @@ func TestGetModelPricing_UnknownModelFallsBackToSonnet(t *testing.T) {
require.InDelta(t, 3e-6, pricing.InputPricePerToken, 1e-12)
}
func TestGetModelPricing_UnknownOpenAIModelReturnsError(t *testing.T) {
svc := newTestBillingService()
pricing, err := svc.GetModelPricing("gpt-unknown-model")
require.Error(t, err)
require.Nil(t, pricing)
require.Contains(t, err.Error(), "pricing not found")
}
func TestGetModelPricing_OpenAIGPT51Fallback(t *testing.T) {
svc := newTestBillingService()
pricing, err := svc.GetModelPricing("gpt-5.1")
require.NoError(t, err)
require.NotNil(t, pricing)
require.InDelta(t, 1.25e-6, pricing.InputPricePerToken, 1e-12)
}
func TestGetModelPricing_OpenAIGPT54Fallback(t *testing.T) {
svc := newTestBillingService()
pricing, err := svc.GetModelPricing("gpt-5.4")
require.NoError(t, err)
require.NotNil(t, pricing)
require.InDelta(t, 2.5e-6, pricing.InputPricePerToken, 1e-12)
require.InDelta(t, 15e-6, pricing.OutputPricePerToken, 1e-12)
require.InDelta(t, 0.25e-6, pricing.CacheReadPricePerToken, 1e-12)
require.Equal(t, 272000, pricing.LongContextInputThreshold)
require.InDelta(t, 2.0, pricing.LongContextInputMultiplier, 1e-12)
require.InDelta(t, 1.5, pricing.LongContextOutputMultiplier, 1e-12)
}
func TestCalculateCost_OpenAIGPT54LongContextAppliesWholeSessionMultipliers(t *testing.T) {
svc := newTestBillingService()
tokens := UsageTokens{
InputTokens: 300000,
OutputTokens: 4000,
}
cost, err := svc.CalculateCost("gpt-5.4-2026-03-05", tokens, 1.0)
require.NoError(t, err)
expectedInput := float64(tokens.InputTokens) * 2.5e-6 * 2.0
expectedOutput := float64(tokens.OutputTokens) * 15e-6 * 1.5
require.InDelta(t, expectedInput, cost.InputCost, 1e-10)
require.InDelta(t, expectedOutput, cost.OutputCost, 1e-10)
require.InDelta(t, expectedInput+expectedOutput, cost.TotalCost, 1e-10)
require.InDelta(t, expectedInput+expectedOutput, cost.ActualCost, 1e-10)
}
func TestGetFallbackPricing_FamilyMatching(t *testing.T) {
svc := newTestBillingService()
tests := []struct {
name string
model string
expectedInput float64
expectNilPricing bool
}{
{name: "empty model", model: " ", expectNilPricing: true},
{name: "claude opus 4.6", model: "claude-opus-4.6-20260201", expectedInput: 5e-6},
{name: "claude opus 4.5 alt separator", model: "claude-opus-4-5-20260101", expectedInput: 5e-6},
{name: "claude generic model fallback sonnet", model: "claude-foo-bar", expectedInput: 3e-6},
{name: "gemini explicit fallback", model: "gemini-3-1-pro", expectedInput: 2e-6},
{name: "gemini unknown no fallback", model: "gemini-2.0-pro", expectNilPricing: true},
{name: "openai gpt5.1", model: "gpt-5.1", expectedInput: 1.25e-6},
{name: "openai gpt5.4", model: "gpt-5.4", expectedInput: 2.5e-6},
{name: "openai gpt5.3 codex", model: "gpt-5.3-codex", expectedInput: 1.5e-6},
{name: "openai gpt5.1 codex max alias", model: "gpt-5.1-codex-max", expectedInput: 1.5e-6},
{name: "openai codex mini latest alias", model: "codex-mini-latest", expectedInput: 1.5e-6},
{name: "openai unknown no fallback", model: "gpt-unknown-model", expectNilPricing: true},
{name: "non supported family", model: "qwen-max", expectNilPricing: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pricing := svc.getFallbackPricing(tt.model)
if tt.expectNilPricing {
require.Nil(t, pricing)
return
}
require.NotNil(t, pricing)
require.InDelta(t, tt.expectedInput, pricing.InputPricePerToken, 1e-12)
})
}
}
func TestCalculateCostWithLongContext_BelowThreshold(t *testing.T) {
svc := newTestBillingService()
......
......@@ -187,6 +187,14 @@ func (m *mockAccountRepoForPlatform) BulkUpdate(ctx context.Context, ids []int64
return 0, nil
}
func (m *mockAccountRepoForPlatform) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
return nil
}
func (m *mockAccountRepoForPlatform) ResetQuotaUsed(ctx context.Context, id int64) error {
return nil
}
// Verify interface implementation
var _ AccountRepository = (*mockAccountRepoForPlatform)(nil)
......
......@@ -1228,6 +1228,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
modelScopeSkippedIDs = append(modelScopeSkippedIDs, account.ID)
continue
}
// 配额检查
if !s.isAccountSchedulableForQuota(account) {
continue
}
// 窗口费用检查(非粘性会话路径)
if !s.isAccountSchedulableForWindowCost(ctx, account, false) {
filteredWindowCost++
......@@ -1260,6 +1264,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) &&
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) &&
s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) &&
s.isAccountSchedulableForQuota(stickyAccount) &&
s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) &&
s.isAccountSchedulableForRPM(ctx, stickyAccount, true) { // 粘性会话窗口费用+RPM 检查
......@@ -1311,7 +1316,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
for _, acc := range routingCandidates {
routingLoads = append(routingLoads, AccountWithConcurrency{
ID: acc.ID,
MaxConcurrency: acc.Concurrency,
MaxConcurrency: acc.EffectiveLoadFactor(),
})
}
routingLoadMap, _ := s.concurrencyService.GetAccountsLoadBatch(ctx, routingLoads)
......@@ -1416,6 +1421,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) &&
s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) &&
s.isAccountSchedulableForQuota(account) &&
s.isAccountSchedulableForWindowCost(ctx, account, true) &&
s.isAccountSchedulableForRPM(ctx, account, true) { // 粘性会话窗口费用+RPM 检查
......@@ -1480,6 +1486,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
continue
}
// 配额检查
if !s.isAccountSchedulableForQuota(acc) {
continue
}
// 窗口费用检查(非粘性会话路径)
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
continue
......@@ -1499,7 +1509,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
for _, acc := range candidates {
accountLoads = append(accountLoads, AccountWithConcurrency{
ID: acc.ID,
MaxConcurrency: acc.Concurrency,
MaxConcurrency: acc.EffectiveLoadFactor(),
})
}
......@@ -2113,6 +2123,15 @@ func (s *GatewayService) withWindowCostPrefetch(ctx context.Context, accounts []
return context.WithValue(ctx, windowCostPrefetchContextKey, costs)
}
// isAccountSchedulableForQuota 检查 API Key 账号是否在配额限制内
// 仅适用于配置了 quota_limit 的 apikey 类型账号
func (s *GatewayService) isAccountSchedulableForQuota(account *Account) bool {
if account.Type != AccountTypeAPIKey {
return true
}
return !account.IsQuotaExceeded()
}
// isAccountSchedulableForWindowCost 检查账号是否可根据窗口费用进行调度
// 仅适用于 Anthropic OAuth/SetupToken 账号
// 返回 true 表示可调度,false 表示不可调度
......@@ -2590,7 +2609,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
}
......@@ -2644,6 +2663,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
continue
}
if !s.isAccountSchedulableForQuota(acc) {
continue
}
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
continue
}
......@@ -2700,7 +2722,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
return account, nil
}
}
......@@ -2743,6 +2765,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
continue
}
if !s.isAccountSchedulableForQuota(acc) {
continue
}
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
continue
}
......@@ -2818,7 +2843,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
......@@ -2874,6 +2899,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
continue
}
if !s.isAccountSchedulableForQuota(acc) {
continue
}
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
continue
}
......@@ -2930,7 +2958,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
return account, nil
}
......@@ -2975,6 +3003,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
continue
}
if !s.isAccountSchedulableForQuota(acc) {
continue
}
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
continue
}
......@@ -3289,6 +3320,10 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
if account.Platform == PlatformSora {
return s.isSoraModelSupportedByAccount(account, requestedModel)
}
// OpenAI 透传模式:仅替换认证,允许所有模型
if account.Platform == PlatformOpenAI && account.IsOpenAIPassthroughEnabled() {
return true
}
// OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID)
if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
requestedModel = claude.NormalizeModelID(requestedModel)
......@@ -6379,6 +6414,89 @@ type APIKeyQuotaUpdater interface {
UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error
}
// postUsageBillingParams 统一扣费所需的参数
type postUsageBillingParams struct {
Cost *CostBreakdown
User *User
APIKey *APIKey
Account *Account
Subscription *UserSubscription
IsSubscriptionBill bool
AccountRateMultiplier float64
APIKeyService APIKeyQuotaUpdater
}
// postUsageBilling 统一处理使用量记录后的扣费逻辑:
// - 订阅/余额扣费
// - API Key 配额更新
// - API Key 限速用量更新
// - 账号配额用量更新(账号口径:TotalCost × 账号计费倍率)
func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps) {
cost := p.Cost
// 1. 订阅 / 余额扣费
if p.IsSubscriptionBill {
if cost.TotalCost > 0 {
if err := deps.userSubRepo.IncrementUsage(ctx, p.Subscription.ID, cost.TotalCost); err != nil {
slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err)
}
deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, cost.TotalCost)
}
} else {
if cost.ActualCost > 0 {
if err := deps.userRepo.DeductBalance(ctx, p.User.ID, cost.ActualCost); err != nil {
slog.Error("deduct balance failed", "user_id", p.User.ID, "error", err)
}
deps.billingCacheService.QueueDeductBalance(p.User.ID, cost.ActualCost)
}
}
// 2. API Key 配额
if cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil {
if err := p.APIKeyService.UpdateQuotaUsed(ctx, p.APIKey.ID, cost.ActualCost); err != nil {
slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err)
}
}
// 3. API Key 限速用量
if cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil {
if err := p.APIKeyService.UpdateRateLimitUsage(ctx, p.APIKey.ID, cost.ActualCost); err != nil {
slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err)
}
deps.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(p.APIKey.ID, cost.ActualCost)
}
// 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率)
if cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.GetQuotaLimit() > 0 {
accountCost := cost.TotalCost * p.AccountRateMultiplier
if err := deps.accountRepo.IncrementQuotaUsed(ctx, p.Account.ID, accountCost); err != nil {
slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err)
}
}
// 5. 更新账号最近使用时间
deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID)
}
// billingDeps 扣费逻辑依赖的服务(由各 gateway service 提供)
type billingDeps struct {
accountRepo AccountRepository
userRepo UserRepository
userSubRepo UserSubscriptionRepository
billingCacheService *BillingCacheService
deferredService *DeferredService
}
func (s *GatewayService) billingDeps() *billingDeps {
return &billingDeps{
accountRepo: s.accountRepo,
userRepo: s.userRepo,
userSubRepo: s.userSubRepo,
billingCacheService: s.billingCacheService,
deferredService: s.deferredService,
}
}
// RecordUsage 记录使用量并扣费(或更新订阅用量)
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
result := input.Result
......@@ -6542,45 +6660,21 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
shouldBill := inserted || err != nil
// 根据计费类型执行扣费
if isSubscriptionBilling {
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
if shouldBill && cost.TotalCost > 0 {
if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
logger.LegacyPrintf("service.gateway", "Increment subscription usage failed: %v", err)
}
// 异步更新订阅缓存
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
}
if shouldBill {
postUsageBilling(ctx, &postUsageBillingParams{
Cost: cost,
User: user,
APIKey: apiKey,
Account: account,
Subscription: subscription,
IsSubscriptionBill: isSubscriptionBilling,
AccountRateMultiplier: accountRateMultiplier,
APIKeyService: input.APIKeyService,
}, s.billingDeps())
} else {
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
if shouldBill && cost.ActualCost > 0 {
if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
logger.LegacyPrintf("service.gateway", "Deduct balance failed: %v", err)
}
// 异步更新余额缓存
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
}
}
// 更新 API Key 配额(如果设置了配额限制)
if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil {
if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil {
logger.LegacyPrintf("service.gateway", "Update API key quota failed: %v", err)
}
}
// Update API Key rate limit usage
if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil {
if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil {
logger.LegacyPrintf("service.gateway", "Update API key rate limit usage failed: %v", err)
}
s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost)
s.deferredService.ScheduleLastUsedUpdate(account.ID)
}
// Schedule batch update for account last_used_at
s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil
}
......@@ -6740,44 +6834,21 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
shouldBill := inserted || err != nil
// 根据计费类型执行扣费
if isSubscriptionBilling {
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
if shouldBill && cost.TotalCost > 0 {
if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
logger.LegacyPrintf("service.gateway", "Increment subscription usage failed: %v", err)
}
// 异步更新订阅缓存
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
}
if shouldBill {
postUsageBilling(ctx, &postUsageBillingParams{
Cost: cost,
User: user,
APIKey: apiKey,
Account: account,
Subscription: subscription,
IsSubscriptionBill: isSubscriptionBilling,
AccountRateMultiplier: accountRateMultiplier,
APIKeyService: input.APIKeyService,
}, s.billingDeps())
} else {
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
if shouldBill && cost.ActualCost > 0 {
if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
logger.LegacyPrintf("service.gateway", "Deduct balance failed: %v", err)
}
// 异步更新余额缓存
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
// API Key 独立配额扣费
if input.APIKeyService != nil && apiKey.Quota > 0 {
if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil {
logger.LegacyPrintf("service.gateway", "Add API key quota used failed: %v", err)
}
}
}
}
// Update API Key rate limit usage
if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil {
if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil {
logger.LegacyPrintf("service.gateway", "Update API key rate limit usage failed: %v", err)
}
s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost)
s.deferredService.ScheduleLastUsedUpdate(account.ID)
}
// Schedule batch update for account last_used_at
s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil
}
......
......@@ -176,6 +176,14 @@ func (m *mockAccountRepoForGemini) BulkUpdate(ctx context.Context, ids []int64,
return 0, nil
}
func (m *mockAccountRepoForGemini) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
return nil
}
func (m *mockAccountRepoForGemini) ResetQuotaUsed(ctx context.Context, id int64) error {
return nil
}
// Verify interface implementation
var _ AccountRepository = (*mockAccountRepoForGemini)(nil)
......
......@@ -342,6 +342,7 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
}
cfg := s.service.schedulingConfig()
// WaitPlan.MaxConcurrency 使用 Concurrency(非 EffectiveLoadFactor),因为 WaitPlan 控制的是 Redis 实际并发槽位等待。
if s.service.concurrencyService != nil {
return &AccountSelectionResult{
Account: account,
......@@ -590,7 +591,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
filtered = append(filtered, account)
loadReq = append(loadReq, AccountWithConcurrency{
ID: account.ID,
MaxConcurrency: account.Concurrency,
MaxConcurrency: account.EffectiveLoadFactor(),
})
}
if len(filtered) == 0 {
......@@ -703,6 +704,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
}
cfg := s.service.schedulingConfig()
// WaitPlan.MaxConcurrency 使用 Concurrency(非 EffectiveLoadFactor),因为 WaitPlan 控制的是 Redis 实际并发槽位等待。
candidate := selectionOrder[0]
return &AccountSelectionResult{
Account: candidate.account,
......
......@@ -9,6 +9,13 @@ import (
var codexCLIInstructions string
var codexModelMap = map[string]string{
"gpt-5.4": "gpt-5.4",
"gpt-5.4-none": "gpt-5.4",
"gpt-5.4-low": "gpt-5.4",
"gpt-5.4-medium": "gpt-5.4",
"gpt-5.4-high": "gpt-5.4",
"gpt-5.4-xhigh": "gpt-5.4",
"gpt-5.4-chat-latest": "gpt-5.4",
"gpt-5.3": "gpt-5.3-codex",
"gpt-5.3-none": "gpt-5.3-codex",
"gpt-5.3-low": "gpt-5.3-codex",
......@@ -154,6 +161,9 @@ func normalizeCodexModel(model string) string {
normalized := strings.ToLower(modelID)
if strings.Contains(normalized, "gpt-5.4") || strings.Contains(normalized, "gpt 5.4") {
return "gpt-5.4"
}
if strings.Contains(normalized, "gpt-5.2-codex") || strings.Contains(normalized, "gpt 5.2 codex") {
return "gpt-5.2-codex"
}
......
......@@ -167,6 +167,10 @@ func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
func TestNormalizeCodexModel_Gpt53(t *testing.T) {
cases := map[string]string{
"gpt-5.4": "gpt-5.4",
"gpt-5.4-high": "gpt-5.4",
"gpt-5.4-chat-latest": "gpt-5.4",
"gpt 5.4": "gpt-5.4",
"gpt-5.3": "gpt-5.3-codex",
"gpt-5.3-codex": "gpt-5.3-codex",
"gpt-5.3-codex-xhigh": "gpt-5.3-codex",
......
......@@ -319,6 +319,16 @@ func NewOpenAIGatewayService(
return svc
}
func (s *OpenAIGatewayService) billingDeps() *billingDeps {
return &billingDeps{
accountRepo: s.accountRepo,
userRepo: s.userRepo,
userSubRepo: s.userSubRepo,
billingCacheService: s.billingCacheService,
deferredService: s.deferredService,
}
}
// CloseOpenAIWSPool 关闭 OpenAI WebSocket 连接池的后台 worker 和空闲连接。
// 应在应用优雅关闭时调用。
func (s *OpenAIGatewayService) CloseOpenAIWSPool() {
......@@ -1242,7 +1252,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
for _, acc := range candidates {
accountLoads = append(accountLoads, AccountWithConcurrency{
ID: acc.ID,
MaxConcurrency: acc.Concurrency,
MaxConcurrency: acc.EffectiveLoadFactor(),
})
}
......@@ -3474,37 +3484,21 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
shouldBill := inserted || err != nil
// Deduct based on billing type
if isSubscriptionBilling {
if shouldBill && cost.TotalCost > 0 {
_ = s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost)
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
}
if shouldBill {
postUsageBilling(ctx, &postUsageBillingParams{
Cost: cost,
User: user,
APIKey: apiKey,
Account: account,
Subscription: subscription,
IsSubscriptionBill: isSubscriptionBilling,
AccountRateMultiplier: accountRateMultiplier,
APIKeyService: input.APIKeyService,
}, s.billingDeps())
} else {
if shouldBill && cost.ActualCost > 0 {
_ = s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost)
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
}
}
// Update API key quota if applicable (only for balance mode with quota set)
if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil {
if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil {
logger.LegacyPrintf("service.openai_gateway", "Update API key quota failed: %v", err)
}
}
// Update API Key rate limit usage
if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil {
if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil {
logger.LegacyPrintf("service.openai_gateway", "Update API key rate limit usage failed: %v", err)
}
s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost)
s.deferredService.ScheduleLastUsedUpdate(account.ID)
}
// Schedule batch update for account last_used_at
s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment