Commit d72ac926 authored by erio's avatar erio
Browse files

feat: image output token billing, channel-mapped billing source, credits balance precheck

- Parse candidatesTokensDetails from Gemini API to separate image/text output tokens
- Add image_output_tokens and image_output_cost to usage_log (migration 089)
- Support per-image-token pricing via output_cost_per_image_token from model pricing data
- Channel pricing ImageOutputPrice override works in token billing mode
- Auto-fill image_output_price in channel pricing form from model defaults
- Add "channel_mapped" billing model source as new default (migration 088)
- Bills by model name after channel mapping, before account mapping
- Fix channel cache error TTL sign error (115s → 5s)
- Fix Update channel only invalidating new groups, not removed groups
- Fix frontend model_mapping clearing sending undefined instead of {}
- Credits balance precheck via shared AccountUsageService cache before injection
- Skip credits injection for accounts with insufficient balance
- Don't mark credits exhausted for "exhausted your capacity on this model" 429s
parent 2555951b
......@@ -139,11 +139,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oauthRefreshAPI, tempUnschedCache)
internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
tlsFingerprintProfileRepository := repository.NewTLSFingerprintProfileRepository(client)
tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient)
tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache)
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache, accountUsageService)
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
......
......@@ -31,7 +31,7 @@ type createChannelRequest struct {
GroupIDs []int64 `json:"group_ids"`
ModelPricing []channelModelPricingRequest `json:"model_pricing"`
ModelMapping map[string]map[string]string `json:"model_mapping"`
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream"`
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
RestrictModels bool `json:"restrict_models"`
}
......@@ -42,7 +42,7 @@ type updateChannelRequest struct {
GroupIDs *[]int64 `json:"group_ids"`
ModelPricing *[]channelModelPricingRequest `json:"model_pricing"`
ModelMapping map[string]map[string]string `json:"model_mapping"`
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream"`
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
RestrictModels *bool `json:"restrict_models"`
}
......@@ -129,7 +129,7 @@ func channelToResponse(ch *service.Channel) *channelResponse {
}
resp.BillingModelSource = ch.BillingModelSource
if resp.BillingModelSource == "" {
resp.BillingModelSource = "requested"
resp.BillingModelSource = "channel_mapped"
}
if resp.GroupIDs == nil {
resp.GroupIDs = []int64{}
......@@ -393,5 +393,6 @@ func (h *ChannelHandler) GetModelDefaultPricing(c *gin.Context) {
"output_price": pricing.OutputPricePerToken,
"cache_write_price": pricing.CacheCreationPricePerToken,
"cache_read_price": pricing.CacheReadPricePerToken,
"image_output_price": pricing.ImageOutputPricePerToken,
})
}
......@@ -106,7 +106,7 @@ func TestChannelToResponse_EmptyDefaults(t *testing.T) {
}
resp := channelToResponse(ch)
require.Equal(t, "requested", resp.BillingModelSource)
require.Equal(t, "channel_mapped", resp.BillingModelSource)
require.NotNil(t, resp.GroupIDs)
require.Empty(t, resp.GroupIDs)
require.NotNil(t, resp.ModelMapping)
......
......@@ -125,6 +125,7 @@ type ClaudeUsage struct {
OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"`
CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"`
ImageOutputTokens int `json:"image_output_tokens,omitempty"`
}
// ClaudeError Claude 错误响应
......
......@@ -149,6 +149,12 @@ type GeminiCandidate struct {
GroundingMetadata *GeminiGroundingMetadata `json:"groundingMetadata,omitempty"`
}
// GeminiTokenDetail Gemini token 详情(按模态分类)
type GeminiTokenDetail struct {
Modality string `json:"modality"`
TokenCount int `json:"tokenCount"`
}
// GeminiUsageMetadata Gemini 用量元数据
type GeminiUsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount,omitempty"`
......@@ -156,6 +162,18 @@ type GeminiUsageMetadata struct {
CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"`
TotalTokenCount int `json:"totalTokenCount,omitempty"`
ThoughtsTokenCount int `json:"thoughtsTokenCount,omitempty"` // thinking tokens(按输出价格计费)
CandidatesTokensDetails []GeminiTokenDetail `json:"candidatesTokensDetails,omitempty"`
PromptTokensDetails []GeminiTokenDetail `json:"promptTokensDetails,omitempty"`
}
// ImageOutputTokens 从 CandidatesTokensDetails 中提取 IMAGE 模态的 token 数
func (m *GeminiUsageMetadata) ImageOutputTokens() int {
for _, d := range m.CandidatesTokensDetails {
if d.Modality == "IMAGE" {
return d.TokenCount
}
}
return 0
}
// GeminiGroundingMetadata Gemini grounding 元数据(Web Search)
......
......@@ -284,6 +284,7 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon
usage.InputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount
usage.CacheReadInputTokens = cached
usage.ImageOutputTokens = geminiResp.UsageMetadata.ImageOutputTokens()
}
// 生成响应 ID
......
......@@ -35,6 +35,7 @@ type StreamingProcessor struct {
inputTokens int
outputTokens int
cacheReadTokens int
imageOutputTokens int
}
// NewStreamingProcessor 创建流式响应处理器
......@@ -45,6 +46,28 @@ func NewStreamingProcessor(originalModel string) *StreamingProcessor {
}
}
// SetUsageMapHook sets an optional hook that modifies usage maps before they are emitted.
func (p *StreamingProcessor) SetUsageMapHook(fn UsageMapHook) {
p.usageMapHook = fn
}
func usageToMap(u ClaudeUsage) map[string]any {
m := map[string]any{
"input_tokens": u.InputTokens,
"output_tokens": u.OutputTokens,
}
if u.CacheCreationInputTokens > 0 {
m["cache_creation_input_tokens"] = u.CacheCreationInputTokens
}
if u.CacheReadInputTokens > 0 {
m["cache_read_input_tokens"] = u.CacheReadInputTokens
}
if u.ImageOutputTokens > 0 {
m["image_output_tokens"] = u.ImageOutputTokens
}
return m
}
// ProcessLine 处理 SSE 行,返回 Claude SSE 事件
func (p *StreamingProcessor) ProcessLine(line string) []byte {
line = strings.TrimSpace(line)
......@@ -87,6 +110,7 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
p.inputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount
p.cacheReadTokens = cached
p.imageOutputTokens = geminiResp.UsageMetadata.ImageOutputTokens()
}
// 处理 parts
......@@ -127,6 +151,7 @@ func (p *StreamingProcessor) Finish() ([]byte, *ClaudeUsage) {
InputTokens: p.inputTokens,
OutputTokens: p.outputTokens,
CacheReadInputTokens: p.cacheReadTokens,
ImageOutputTokens: p.imageOutputTokens,
}
if !p.messageStartSent {
......@@ -158,6 +183,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
usage.InputTokens = v1Resp.Response.UsageMetadata.PromptTokenCount - cached
usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount + v1Resp.Response.UsageMetadata.ThoughtsTokenCount
usage.CacheReadInputTokens = cached
usage.ImageOutputTokens = v1Resp.Response.UsageMetadata.ImageOutputTokens()
}
responseID := v1Resp.ResponseID
......@@ -485,6 +511,7 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
InputTokens: p.inputTokens,
OutputTokens: p.outputTokens,
CacheReadInputTokens: p.cacheReadTokens,
ImageOutputTokens: p.imageOutputTokens,
}
deltaEvent := map[string]any{
......
......@@ -28,7 +28,7 @@ import (
gocache "github.com/patrickmn/go-cache"
)
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, created_at"
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, created_at"
// usageLogInsertArgTypes must stay in the same order as:
// 1. prepareUsageLogInsert().args
......@@ -53,6 +53,8 @@ var usageLogInsertArgTypes = [...]string{
"integer", // cache_read_tokens
"integer", // cache_creation_5m_tokens
"integer", // cache_creation_1h_tokens
"integer", // image_output_tokens
"numeric", // image_output_cost
"numeric", // input_cost
"numeric", // output_cost
"numeric", // cache_creation_cost
......@@ -330,6 +332,8 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost,
output_cost,
cache_creation_cost,
......@@ -363,9 +367,9 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
$1, $2, $3, $4, $5, $6, $7,
$8, $9,
$10, $11, $12, $13,
$14, $15,
$16, $17, $18, $19, $20, $21,
$22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44
$14, $15, $16, $17,
$18, $19, $20, $21, $22, $23,
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
......@@ -766,6 +770,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost,
output_cost,
cache_creation_cost,
......@@ -797,7 +803,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
created_at
) AS (VALUES `)
args := make([]any, 0, len(keys)*45)
args := make([]any, 0, len(keys)*47)
argPos := 1
for idx, key := range keys {
if idx > 0 {
......@@ -841,6 +847,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost,
output_cost,
cache_creation_cost,
......@@ -887,6 +895,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost,
output_cost,
cache_creation_cost,
......@@ -973,6 +983,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost,
output_cost,
cache_creation_cost,
......@@ -1004,7 +1016,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
created_at
) AS (VALUES `)
args := make([]any, 0, len(preparedList)*44)
args := make([]any, 0, len(preparedList)*46)
argPos := 1
for idx, prepared := range preparedList {
if idx > 0 {
......@@ -1045,6 +1057,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost,
output_cost,
cache_creation_cost,
......@@ -1091,6 +1105,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost,
output_cost,
cache_creation_cost,
......@@ -1145,6 +1161,8 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost,
output_cost,
cache_creation_cost,
......@@ -1178,9 +1196,9 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
$1, $2, $3, $4, $5, $6, $7,
$8, $9,
$10, $11, $12, $13,
$14, $15,
$16, $17, $18, $19, $20, $21,
$22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44
$14, $15, $16, $17,
$18, $19, $20, $21, $22, $23,
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
`, prepared.args...)
......@@ -1248,6 +1266,8 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
log.CacheReadTokens,
log.CacheCreation5mTokens,
log.CacheCreation1hTokens,
log.ImageOutputTokens,
log.ImageOutputCost,
log.InputCost,
log.OutputCost,
log.CacheCreationCost,
......@@ -4011,6 +4031,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
cacheReadTokens int
cacheCreation5m int
cacheCreation1h int
imageOutputTokens int
imageOutputCost float64
inputCost float64
outputCost float64
cacheCreationCost float64
......@@ -4059,6 +4081,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&cacheReadTokens,
&cacheCreation5m,
&cacheCreation1h,
&imageOutputTokens,
&imageOutputCost,
&inputCost,
&outputCost,
&cacheCreationCost,
......@@ -4105,6 +4129,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
CacheReadTokens: cacheReadTokens,
CacheCreation5mTokens: cacheCreation5m,
CacheCreation1hTokens: cacheCreation1h,
ImageOutputTokens: imageOutputTokens,
ImageOutputCost: imageOutputCost,
InputCost: inputCost,
OutputCost: outputCost,
CacheCreationCost: cacheCreationCost,
......
......@@ -56,6 +56,8 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
log.CacheReadTokens,
log.CacheCreation5mTokens,
log.CacheCreation1hTokens,
log.ImageOutputTokens,
log.ImageOutputCost,
log.InputCost,
log.OutputCost,
log.CacheCreationCost,
......@@ -133,6 +135,8 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
log.CacheReadTokens,
log.CacheCreation5mTokens,
log.CacheCreation1hTokens,
log.ImageOutputTokens,
log.ImageOutputCost,
log.InputCost,
log.OutputCost,
log.CacheCreationCost,
......@@ -447,6 +451,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
4, // cache_read_tokens
5, // cache_creation_5m_tokens
6, // cache_creation_1h_tokens
0, // image_output_tokens
0.0, // image_output_cost
0.1, // input_cost
0.2, // output_cost
0.3, // cache_creation_cost
......@@ -499,6 +505,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullInt64{},
sql.NullInt64{},
1, 2, 3, 4, 5, 6,
0, 0.0, // image_output_tokens, image_output_cost
0.1, 0.2, 0.3, 0.4, 1.0, 0.9,
1.0,
sql.NullFloat64{},
......@@ -546,6 +553,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullInt64{},
sql.NullInt64{},
1, 2, 3, 4, 5, 6,
0, 0.0, // image_output_tokens, image_output_cost
0.1, 0.2, 0.3, 0.4, 1.0, 0.9,
1.0,
sql.NullFloat64{},
......
......@@ -846,6 +846,15 @@ func (s *AccountUsageService) getAntigravityUsage(ctx context.Context, account *
return usage, nil
}
// GetAntigravityCredits 返回账号的 AI Credits 信息,复用 getAntigravityUsage 的缓存。
// 如果缓存存在且 TTL 充足则直接返回;TTL 不足时自动刷新。
func (s *AccountUsageService) GetAntigravityCredits(ctx context.Context, account *Account) (*UsageInfo, error) {
if account == nil || account.Platform != PlatformAntigravity {
return nil, nil
}
return s.getAntigravityUsage(ctx, account)
}
// recalcAntigravityRemainingSeconds 重新计算 Antigravity UsageInfo 中各窗口的 RemainingSeconds
// 用于从缓存取出时更新倒计时,避免返回过时的剩余秒数
func recalcAntigravityRemainingSeconds(info *UsageInfo) {
......
......@@ -19,6 +19,54 @@ const (
creditsExhaustedDuration = 5 * time.Hour
)
// checkAccountCredits 通过共享的 AccountUsageService 缓存检查账号是否有足够的 AI Credits。
// 缓存 TTL 不足时会自动从 Google loadCodeAssist API 刷新。
// 返回 true 表示积分可用。
func (s *AntigravityGatewayService) checkAccountCredits(
ctx context.Context, account *Account, accessToken, proxyURL string,
) bool {
if account == nil || account.ID == 0 {
return false
}
if s.accountUsageService == nil {
return true // 无 usage service 时不阻断
}
usageInfo, err := s.accountUsageService.GetAntigravityCredits(ctx, account)
if err != nil {
logger.LegacyPrintf("service.antigravity_gateway",
"check_credits: get_credits_failed account=%d err=%v", account.ID, err)
return true // 出错时假设有积分,不阻断
}
if usageInfo == nil || len(usageInfo.AICredits) == 0 {
logger.LegacyPrintf("service.antigravity_gateway",
"check_credits: account=%d has_credits=false amount=0 (no credits field)",
account.ID)
return false
}
for _, credit := range usageInfo.AICredits {
if credit.CreditType == "GOOGLE_ONE_AI" {
minimum := credit.MinimumBalance
if minimum <= 0 {
minimum = 5
}
hasCredits := credit.Amount >= minimum
logger.LegacyPrintf("service.antigravity_gateway",
"check_credits: account=%d has_credits=%t amount=%.0f minimum=%.0f",
account.ID, hasCredits, credit.Amount, minimum)
return hasCredits
}
}
logger.LegacyPrintf("service.antigravity_gateway",
"check_credits: account=%d has_credits=false (no GOOGLE_ONE_AI credit)",
account.ID)
return false
}
type antigravity429Category string
const (
......@@ -141,6 +189,13 @@ func resolveCreditsOveragesModelKey(ctx context.Context, account *Account, upstr
}
// shouldMarkCreditsExhausted 判断一次 credits 请求失败是否应标记为 credits 耗尽。
// 此函数在积分注入后失败时调用(预检查注入 + attemptCreditsOveragesRetry 两条路径)。
// - 429 + 非单模型限流:积分注入后仍 429 → 标记耗尽。
// - 429 + 单模型限流("exhausted your capacity on this model"):该模型免费配额用完,
// 积分注入对此无效,但账号积分对其他模型可能仍可用 → 不标记积分耗尽。
// - 403 等其他 4xx:检查 body 是否包含积分不足的关键词。
//
// clearCreditsExhausted 会在后续成功时自动清除。
func shouldMarkCreditsExhausted(resp *http.Response, respBody []byte, reqErr error) bool {
if reqErr != nil || resp == nil {
return false
......@@ -148,13 +203,16 @@ func shouldMarkCreditsExhausted(resp *http.Response, respBody []byte, reqErr err
if resp.StatusCode >= 500 || resp.StatusCode == http.StatusRequestTimeout {
return false
}
// 注意:不再检查 isURLLevelRateLimit。此函数仅在积分重试失败后调用,
// 如果注入 enabledCreditTypes 后仍返回 "Resource has been exhausted",
// 说明积分也已耗尽,应该标记。clearCreditsExhausted 会在后续成功时自动清除。
if info := parseAntigravitySmartRetryInfo(respBody); info != nil {
bodyLower := strings.ToLower(string(respBody))
// 积分注入后仍 429
if resp.StatusCode == http.StatusTooManyRequests {
// 单模型配额耗尽:积分注入对此无效,不标记整个账号积分耗尽
if strings.Contains(bodyLower, "exhausted your capacity on this model") {
return false
}
bodyLower := strings.ToLower(string(respBody))
return true
}
// 其他 4xx:关键词匹配(如 403 + "Insufficient credits")
for _, keyword := range creditsExhaustedKeywords {
if strings.Contains(bodyLower, keyword) {
return true
......@@ -181,6 +239,16 @@ func (s *AntigravityGatewayService) attemptCreditsOveragesRetry(
if creditsBody == nil {
return &creditsOveragesRetryResult{handled: false}
}
// Check actual credits balance before attempting retry
if !s.checkAccountCredits(p.ctx, p.account, p.accessToken, p.proxyURL) {
s.setCreditsExhausted(p.ctx, p.account)
modelKey := resolveCreditsOveragesModelKey(p.ctx, p.account, modelName, p.requestedModel)
logger.LegacyPrintf("service.antigravity_gateway", "%s credit_overages_no_credits model=%s account=%d (skipping credits retry)",
p.prefix, modelKey, p.account.ID)
return &creditsOveragesRetryResult{handled: true}
}
modelKey := resolveCreditsOveragesModelKey(p.ctx, p.account, modelName, p.requestedModel)
logger.LegacyPrintf("service.antigravity_gateway", "%s status=429 credit_overages_retry model=%s account=%d (injecting enabledCreditTypes)",
p.prefix, modelKey, p.account.ID)
......
......@@ -418,7 +418,13 @@ func TestShouldMarkCreditsExhausted(t *testing.T) {
require.True(t, shouldMarkCreditsExhausted(resp, body, nil))
})
t.Run("结构化限流不标记", func(t *testing.T) {
t.Run("单模型配额耗尽不标记(积分对此无效)", func(t *testing.T) {
resp := &http.Response{StatusCode: http.StatusTooManyRequests}
body := []byte(`{"error":{"code":429,"message":"You have exhausted your capacity on this model. Your quota will reset after 146h11m17s.","status":"RESOURCE_EXHAUSTED"}}`)
require.False(t, shouldMarkCreditsExhausted(resp, body, nil))
})
t.Run("429 结构化限流也标记(积分注入后仍 429 即为耗尽)", func(t *testing.T) {
resp := &http.Response{StatusCode: http.StatusTooManyRequests}
body := []byte(`{"error":{"status":"RESOURCE_EXHAUSTED","details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","reason":"RATE_LIMIT_EXCEEDED"},{"@type":"type.googleapis.com/google.rpc.RetryInfo","retryDelay":"0.5s"}]}}`)
require.False(t, shouldMarkCreditsExhausted(resp, body, nil))
......
......@@ -557,7 +557,13 @@ func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopP
if p.requestedModel != "" && p.account.Platform == PlatformAntigravity &&
p.account.IsOveragesEnabled() && !p.account.isCreditsExhausted() &&
p.account.isModelRateLimitedWithContext(p.ctx, p.requestedModel) {
if creditsBody := injectEnabledCreditTypes(p.body); creditsBody != nil {
// Check actual credits balance before injection
if !s.checkAccountCredits(p.ctx, p.account, p.accessToken, p.proxyURL) {
// No credits available - mark as exhausted and skip injection
s.setCreditsExhausted(p.ctx, p.account)
logger.LegacyPrintf("service.antigravity_gateway", "%s pre_check: no_credits_available account=%d (skipping credits injection)",
p.prefix, p.account.ID)
} else if creditsBody := injectEnabledCreditTypes(p.body); creditsBody != nil {
p.body = creditsBody
overagesInjected = true
logger.LegacyPrintf("service.antigravity_gateway", "%s pre_check: model_rate_limited_credits_inject model=%s account=%d (injecting enabledCreditTypes)",
......@@ -878,6 +884,7 @@ type AntigravityGatewayService struct {
cache GatewayCache // 用于模型级限流时清除粘性会话绑定
schedulerSnapshot *SchedulerSnapshotService
internal500Cache Internal500CounterCache // INTERNAL 500 渐进惩罚计数器
accountUsageService *AccountUsageService // 共享 usage 缓存,用于积分余额检查
}
func NewAntigravityGatewayService(
......@@ -889,6 +896,7 @@ func NewAntigravityGatewayService(
httpUpstream HTTPUpstream,
settingService *SettingService,
internal500Cache Internal500CounterCache,
accountUsageService *AccountUsageService,
) *AntigravityGatewayService {
return &AntigravityGatewayService{
accountRepo: accountRepo,
......@@ -899,6 +907,7 @@ func NewAntigravityGatewayService(
cache: cache,
schedulerSnapshot: schedulerSnapshot,
internal500Cache: internal500Cache,
accountUsageService: accountUsageService,
}
}
......
......@@ -56,6 +56,7 @@ type ModelPricing struct {
LongContextInputThreshold int // 超过阈值后按整次会话提升输入价格
LongContextInputMultiplier float64 // 长上下文整次会话输入倍率
LongContextOutputMultiplier float64 // 长上下文整次会话输出倍率
ImageOutputPricePerToken float64 // 图片输出 token 价格 (USD)
}
const (
......@@ -94,12 +95,14 @@ type UsageTokens struct {
CacheReadTokens int
CacheCreation5mTokens int
CacheCreation1hTokens int
ImageOutputTokens int
}
// CostBreakdown 费用明细
type CostBreakdown struct {
InputCost float64
OutputCost float64
ImageOutputCost float64
CacheCreationCost float64
CacheReadCost float64
TotalCost float64
......@@ -358,6 +361,7 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
LongContextInputThreshold: litellmPricing.LongContextInputTokenThreshold,
LongContextInputMultiplier: litellmPricing.LongContextInputCostMultiplier,
LongContextOutputMultiplier: litellmPricing.LongContextOutputCostMultiplier,
ImageOutputPricePerToken: litellmPricing.OutputCostPerImageToken,
}), nil
}
}
......@@ -399,6 +403,9 @@ func (s *BillingService) GetModelPricingWithChannel(model string, channelPricing
pricing.CacheReadPricePerToken = *channelPricing.CacheReadPrice
pricing.CacheReadPricePerTokenPriority = *channelPricing.CacheReadPrice
}
if channelPricing.ImageOutputPrice != nil {
pricing.ImageOutputPricePerToken = *channelPricing.ImageOutputPrice
}
return pricing, nil
}
......@@ -489,7 +496,22 @@ func (s *BillingService) calculateTokenCost(resolved *ResolvedPricing, input Cos
}
breakdown.InputCost = float64(input.Tokens.InputTokens) * inputPricePerToken
breakdown.OutputCost = float64(input.Tokens.OutputTokens) * outputPricePerToken
// Separate image output tokens from text output tokens
textOutputTokens := input.Tokens.OutputTokens - input.Tokens.ImageOutputTokens
if textOutputTokens < 0 {
textOutputTokens = 0
}
breakdown.OutputCost = float64(textOutputTokens) * outputPricePerToken
// Image output tokens cost (separate rate from text output)
if input.Tokens.ImageOutputTokens > 0 {
imageOutputPrice := pricing.ImageOutputPricePerToken
if imageOutputPrice == 0 {
imageOutputPrice = outputPricePerToken // fallback to regular output price
}
breakdown.ImageOutputCost = float64(input.Tokens.ImageOutputTokens) * imageOutputPrice
}
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) {
if input.Tokens.CacheCreation5mTokens == 0 && input.Tokens.CacheCreation1hTokens == 0 && input.Tokens.CacheCreationTokens > 0 {
......@@ -507,11 +529,12 @@ func (s *BillingService) calculateTokenCost(resolved *ResolvedPricing, input Cos
if tierMultiplier != 1.0 {
breakdown.InputCost *= tierMultiplier
breakdown.OutputCost *= tierMultiplier
breakdown.ImageOutputCost *= tierMultiplier
breakdown.CacheCreationCost *= tierMultiplier
breakdown.CacheReadCost *= tierMultiplier
}
breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost +
breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost + breakdown.ImageOutputCost +
breakdown.CacheCreationCost + breakdown.CacheReadCost
breakdown.ActualCost = breakdown.TotalCost * input.RateMultiplier
......@@ -597,8 +620,21 @@ func (s *BillingService) calculateCostInternal(model string, tokens UsageTokens,
// 计算输入token费用(使用per-token价格)
breakdown.InputCost = float64(tokens.InputTokens) * inputPricePerToken
// 计算输出token费用
breakdown.OutputCost = float64(tokens.OutputTokens) * outputPricePerToken
// 计算输出token费用(分离图片输出token)
textOutputTokens := tokens.OutputTokens - tokens.ImageOutputTokens
if textOutputTokens < 0 {
textOutputTokens = 0
}
breakdown.OutputCost = float64(textOutputTokens) * outputPricePerToken
// 图片输出 token 费用
if tokens.ImageOutputTokens > 0 {
imageOutputPrice := pricing.ImageOutputPricePerToken
if imageOutputPrice == 0 {
imageOutputPrice = outputPricePerToken
}
breakdown.ImageOutputCost = float64(tokens.ImageOutputTokens) * imageOutputPrice
}
// 计算缓存费用
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) {
......@@ -620,12 +656,13 @@ func (s *BillingService) calculateCostInternal(model string, tokens UsageTokens,
if tierMultiplier != 1.0 {
breakdown.InputCost *= tierMultiplier
breakdown.OutputCost *= tierMultiplier
breakdown.ImageOutputCost *= tierMultiplier
breakdown.CacheCreationCost *= tierMultiplier
breakdown.CacheReadCost *= tierMultiplier
}
// 计算总费用
breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost +
breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost + breakdown.ImageOutputCost +
breakdown.CacheCreationCost + breakdown.CacheReadCost
// 应用倍率计算实际费用
......@@ -730,6 +767,7 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage
CacheReadTokens: inRangeCacheTokens,
CacheCreation5mTokens: tokens.CacheCreation5mTokens,
CacheCreation1hTokens: tokens.CacheCreation1hTokens,
ImageOutputTokens: tokens.ImageOutputTokens,
}
inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier)
if err != nil {
......@@ -750,6 +788,7 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage
return &CostBreakdown{
InputCost: inRangeCost.InputCost + outRangeCost.InputCost,
OutputCost: inRangeCost.OutputCost,
ImageOutputCost: inRangeCost.ImageOutputCost,
CacheCreationCost: inRangeCost.CacheCreationCost,
CacheReadCost: inRangeCost.CacheReadCost + outRangeCost.CacheReadCost,
TotalCost: inRangeCost.TotalCost + outRangeCost.TotalCost,
......
......@@ -26,6 +26,7 @@ func (m BillingMode) IsValid() bool {
const (
BillingModelSourceRequested = "requested"
BillingModelSourceUpstream = "upstream"
BillingModelSourceChannelMapped = "channel_mapped"
)
// Channel 渠道实体
......@@ -34,7 +35,7 @@ type Channel struct {
Name string
Description string
Status string
BillingModelSource string // "requested" or "upstream"
BillingModelSource string // "requested", "upstream", or "channel_mapped"
RestrictModels bool // 是否限制模型(仅允许定价列表中的模型)
CreatedAt time.Time
UpdatedAt time.Time
......@@ -180,6 +181,7 @@ func (c *Channel) Clone() *Channel {
type ChannelUsageFields struct {
ChannelID int64 // 渠道 ID(0 = 无渠道)
OriginalModel string // 用户原始请求模型(渠道映射前)
BillingModelSource string // 计费模型来源:"requested" / "upstream"
ChannelMappedModel string // 渠道映射后的模型名(无映射时等于 OriginalModel)
BillingModelSource string // 计费模型来源:"requested" / "upstream" / "channel_mapped"
ModelMappingChain string // 映射链描述,如 "a→b→c"
}
......@@ -97,7 +97,7 @@ type ChannelMappingResult struct {
MappedModel string // 映射后的模型名(无映射时等于原始模型名)
ChannelID int64 // 渠道 ID(0 = 无渠道关联)
Mapped bool // 是否发生了映射
BillingModelSource string // 计费模型来源("requested" / "upstream")
BillingModelSource string // 计费模型来源("requested" / "upstream" / "channel_mapped"
}
// BuildModelMappingChain 根据映射结果和上游实际模型构建映射链描述。
......@@ -119,9 +119,14 @@ func (r ChannelMappingResult) BuildModelMappingChain(reqModel, upstreamModel str
// ToUsageFields 将渠道映射结果转为使用记录字段
func (r ChannelMappingResult) ToUsageFields(reqModel, upstreamModel string) ChannelUsageFields {
channelMappedModel := reqModel
if r.Mapped {
channelMappedModel = r.MappedModel
}
return ChannelUsageFields{
ChannelID: r.ChannelID,
OriginalModel: reqModel,
ChannelMappedModel: channelMappedModel,
BillingModelSource: r.BillingModelSource,
ModelMappingChain: r.BuildModelMappingChain(reqModel, upstreamModel),
}
......@@ -193,7 +198,7 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
channelByGroupID: make(map[int64]*Channel),
groupPlatform: make(map[int64]string),
byID: make(map[int64]*Channel),
loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL
loadedAt: time.Now().Add(-(channelCacheTTL - channelErrorTTL)), // 使剩余 TTL = errorTTL
}
s.cache.Store(errorCache)
return nil, fmt.Errorf("list all channels: %w", err)
......@@ -374,7 +379,7 @@ func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int6
BillingModelSource: ch.BillingModelSource,
}
if result.BillingModelSource == "" {
result.BillingModelSource = BillingModelSourceRequested
result.BillingModelSource = BillingModelSourceChannelMapped
}
platform := cache.groupPlatform[groupID]
......@@ -481,7 +486,7 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
ModelMapping: input.ModelMapping,
}
if channel.BillingModelSource == "" {
channel.BillingModelSource = BillingModelSourceRequested
channel.BillingModelSource = BillingModelSourceChannelMapped
}
if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
......@@ -565,22 +570,38 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
return nil, err
}
// 先获取旧分组,Update 后旧分组关联已删除,无法再查到
var oldGroupIDs []int64
if s.authCacheInvalidator != nil {
var err2 error
oldGroupIDs, err2 = s.repo.GetGroupIDs(ctx, id)
if err2 != nil {
slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", id, "error", err2)
}
}
if err := s.repo.Update(ctx, channel); err != nil {
return nil, fmt.Errorf("update channel: %w", err)
}
s.invalidateCache()
// 失效关联分组的 auth 缓存
// 失效新旧分组的 auth 缓存
if s.authCacheInvalidator != nil {
groupIDs, err := s.repo.GetGroupIDs(ctx, id)
if err != nil {
slog.Warn("failed to get group IDs for cache invalidation", "channel_id", id, "error", err)
seen := make(map[int64]struct{}, len(oldGroupIDs)+len(channel.GroupIDs))
for _, gid := range oldGroupIDs {
if _, ok := seen[gid]; !ok {
seen[gid] = struct{}{}
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
}
for _, gid := range groupIDs {
}
for _, gid := range channel.GroupIDs {
if _, ok := seen[gid]; !ok {
seen[gid] = struct{}{}
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
}
}
}
return s.repo.GetByID(ctx, id)
}
......
......@@ -196,7 +196,6 @@ func newTestChannelServiceWithAuth(repo *mockChannelRepository, auth *mockChanne
return NewChannelService(repo, auth)
}
// makeStandardRepo returns a repo that serves one active channel with anthropic pricing
// for group 1, with the given model pricing and model mapping.
func makeStandardRepo(ch Channel, groupPlatforms map[int64]string) *mockChannelRepository {
......@@ -914,7 +913,7 @@ func TestResolveChannelMapping_DefaultBillingModelSource(t *testing.T) {
svc := newTestChannelService(repo)
result := svc.ResolveChannelMapping(context.Background(), 10, "claude-opus-4")
require.Equal(t, BillingModelSourceRequested, result.BillingModelSource)
require.Equal(t, BillingModelSourceChannelMapped, result.BillingModelSource)
}
func TestResolveChannelMapping_UpstreamBillingModelSource(t *testing.T) {
......@@ -1451,11 +1450,11 @@ func TestCreate_DefaultBillingModelSource(t *testing.T) {
result, err := svc.Create(context.Background(), &CreateChannelInput{
Name: "new-channel",
BillingModelSource: "", // empty, should default to "requested"
BillingModelSource: "", // empty, should default to "channel_mapped"
})
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, BillingModelSourceRequested, result.BillingModelSource)
require.Equal(t, BillingModelSourceChannelMapped, result.BillingModelSource)
}
func TestCreate_InvalidatesCache(t *testing.T) {
......
......@@ -483,6 +483,7 @@ type ClaudeUsage struct {
CacheReadInputTokens int `json:"cache_read_input_tokens"`
CacheCreation5mTokens int // 5分钟缓存创建token(来自嵌套 cache_creation 对象)
CacheCreation1hTokens int // 1小时缓存创建token(来自嵌套 cache_creation 对象)
ImageOutputTokens int `json:"image_output_tokens,omitempty"`
}
// ForwardResult 转发结果
......@@ -7729,6 +7730,9 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
var cost *CostBreakdown
// 确定计费模型
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
if input.BillingModelSource == BillingModelSourceChannelMapped && input.ChannelMappedModel != "" {
billingModel = input.ChannelMappedModel
}
if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" {
billingModel = input.OriginalModel
}
......@@ -7777,6 +7781,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
CacheReadTokens: result.Usage.CacheReadInputTokens,
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
ImageOutputTokens: result.Usage.ImageOutputTokens,
}
var err error
if s.resolver != nil && apiKey.Group != nil {
......@@ -7836,8 +7841,10 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
CacheReadTokens: result.Usage.CacheReadInputTokens,
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
ImageOutputTokens: result.Usage.ImageOutputTokens,
InputCost: cost.InputCost,
OutputCost: cost.OutputCost,
ImageOutputCost: cost.ImageOutputCost,
CacheCreationCost: cost.CacheCreationCost,
CacheReadCost: cost.CacheReadCost,
TotalCost: cost.TotalCost,
......@@ -7976,6 +7983,9 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
var cost *CostBreakdown
// 确定计费模型
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
if input.BillingModelSource == BillingModelSourceChannelMapped && input.ChannelMappedModel != "" {
billingModel = input.ChannelMappedModel
}
if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" {
billingModel = input.OriginalModel
}
......@@ -8007,6 +8017,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
CacheReadTokens: result.Usage.CacheReadInputTokens,
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
ImageOutputTokens: result.Usage.ImageOutputTokens,
}
var err error
// 优先尝试 Resolver + CalculateCostUnified(仅在有渠道定价时使用)
......@@ -8073,8 +8084,10 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
CacheReadTokens: result.Usage.CacheReadInputTokens,
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
ImageOutputTokens: result.Usage.ImageOutputTokens,
InputCost: cost.InputCost,
OutputCost: cost.OutputCost,
ImageOutputCost: cost.ImageOutputCost,
CacheCreationCost: cost.CacheCreationCost,
CacheReadCost: cost.CacheReadCost,
TotalCost: cost.TotalCost,
......
......@@ -2692,12 +2692,27 @@ func extractGeminiUsage(data []byte) *ClaudeUsage {
cand := int(usage.Get("candidatesTokenCount").Int())
cached := int(usage.Get("cachedContentTokenCount").Int())
thoughts := int(usage.Get("thoughtsTokenCount").Int())
// 从 candidatesTokensDetails 提取 IMAGE 模态 token 数
imageTokens := 0
candidateDetails := usage.Get("candidatesTokensDetails")
if candidateDetails.Exists() {
candidateDetails.ForEach(func(_, detail gjson.Result) bool {
if detail.Get("modality").String() == "IMAGE" {
imageTokens = int(detail.Get("tokenCount").Int())
return false
}
return true
})
}
// 注意:Gemini 的 promptTokenCount 包含 cachedContentTokenCount,
// 但 Claude 的 input_tokens 不包含 cache_read_input_tokens,需要减去
return &ClaudeUsage{
InputTokens: prompt - cached,
OutputTokens: cand + thoughts,
CacheReadInputTokens: cached,
ImageOutputTokens: imageTokens,
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment