Commit d72ac926 authored by erio's avatar erio
Browse files

feat: image output token billing, channel-mapped billing source, credits balance precheck

- Parse candidatesTokensDetails from Gemini API to separate image/text output tokens
- Add image_output_tokens and image_output_cost to usage_log (migration 089)
- Support per-image-token pricing via output_cost_per_image_token from model pricing data
- Channel pricing ImageOutputPrice override works in token billing mode
- Auto-fill image_output_price in channel pricing form from model defaults
- Add "channel_mapped" billing model source as new default (migration 088)
- Bills by model name after channel mapping, before account mapping
- Fix channel cache error TTL sign error (115s → 5s)
- Fix Update channel only invalidating new groups, not removed groups
- Fix frontend model_mapping clearing sending undefined instead of {}
- Credits balance precheck via shared AccountUsageService cache before injection
- Skip credits injection for accounts with insufficient balance
- Don't mark credits exhausted for "exhausted your capacity on this model" 429s
parent 2555951b
...@@ -139,11 +139,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -139,11 +139,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig) schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oauthRefreshAPI, tempUnschedCache) antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oauthRefreshAPI, tempUnschedCache)
internal500CounterCache := repository.NewInternal500CounterCache(redisClient) internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
tlsFingerprintProfileRepository := repository.NewTLSFingerprintProfileRepository(client) tlsFingerprintProfileRepository := repository.NewTLSFingerprintProfileRepository(client)
tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient) tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient)
tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache) tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache)
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService) accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache, accountUsageService)
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService) accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig) sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
......
...@@ -31,7 +31,7 @@ type createChannelRequest struct { ...@@ -31,7 +31,7 @@ type createChannelRequest struct {
GroupIDs []int64 `json:"group_ids"` GroupIDs []int64 `json:"group_ids"`
ModelPricing []channelModelPricingRequest `json:"model_pricing"` ModelPricing []channelModelPricingRequest `json:"model_pricing"`
ModelMapping map[string]map[string]string `json:"model_mapping"` ModelMapping map[string]map[string]string `json:"model_mapping"`
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream"` BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
RestrictModels bool `json:"restrict_models"` RestrictModels bool `json:"restrict_models"`
} }
...@@ -42,7 +42,7 @@ type updateChannelRequest struct { ...@@ -42,7 +42,7 @@ type updateChannelRequest struct {
GroupIDs *[]int64 `json:"group_ids"` GroupIDs *[]int64 `json:"group_ids"`
ModelPricing *[]channelModelPricingRequest `json:"model_pricing"` ModelPricing *[]channelModelPricingRequest `json:"model_pricing"`
ModelMapping map[string]map[string]string `json:"model_mapping"` ModelMapping map[string]map[string]string `json:"model_mapping"`
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream"` BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
RestrictModels *bool `json:"restrict_models"` RestrictModels *bool `json:"restrict_models"`
} }
...@@ -129,7 +129,7 @@ func channelToResponse(ch *service.Channel) *channelResponse { ...@@ -129,7 +129,7 @@ func channelToResponse(ch *service.Channel) *channelResponse {
} }
resp.BillingModelSource = ch.BillingModelSource resp.BillingModelSource = ch.BillingModelSource
if resp.BillingModelSource == "" { if resp.BillingModelSource == "" {
resp.BillingModelSource = "requested" resp.BillingModelSource = "channel_mapped"
} }
if resp.GroupIDs == nil { if resp.GroupIDs == nil {
resp.GroupIDs = []int64{} resp.GroupIDs = []int64{}
...@@ -393,5 +393,6 @@ func (h *ChannelHandler) GetModelDefaultPricing(c *gin.Context) { ...@@ -393,5 +393,6 @@ func (h *ChannelHandler) GetModelDefaultPricing(c *gin.Context) {
"output_price": pricing.OutputPricePerToken, "output_price": pricing.OutputPricePerToken,
"cache_write_price": pricing.CacheCreationPricePerToken, "cache_write_price": pricing.CacheCreationPricePerToken,
"cache_read_price": pricing.CacheReadPricePerToken, "cache_read_price": pricing.CacheReadPricePerToken,
"image_output_price": pricing.ImageOutputPricePerToken,
}) })
} }
...@@ -106,7 +106,7 @@ func TestChannelToResponse_EmptyDefaults(t *testing.T) { ...@@ -106,7 +106,7 @@ func TestChannelToResponse_EmptyDefaults(t *testing.T) {
} }
resp := channelToResponse(ch) resp := channelToResponse(ch)
require.Equal(t, "requested", resp.BillingModelSource) require.Equal(t, "channel_mapped", resp.BillingModelSource)
require.NotNil(t, resp.GroupIDs) require.NotNil(t, resp.GroupIDs)
require.Empty(t, resp.GroupIDs) require.Empty(t, resp.GroupIDs)
require.NotNil(t, resp.ModelMapping) require.NotNil(t, resp.ModelMapping)
......
...@@ -125,6 +125,7 @@ type ClaudeUsage struct { ...@@ -125,6 +125,7 @@ type ClaudeUsage struct {
OutputTokens int `json:"output_tokens"` OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"` CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"`
CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"` CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"`
ImageOutputTokens int `json:"image_output_tokens,omitempty"`
} }
// ClaudeError Claude 错误响应 // ClaudeError Claude 错误响应
......
...@@ -149,6 +149,12 @@ type GeminiCandidate struct { ...@@ -149,6 +149,12 @@ type GeminiCandidate struct {
GroundingMetadata *GeminiGroundingMetadata `json:"groundingMetadata,omitempty"` GroundingMetadata *GeminiGroundingMetadata `json:"groundingMetadata,omitempty"`
} }
// GeminiTokenDetail Gemini token 详情(按模态分类)
type GeminiTokenDetail struct {
Modality string `json:"modality"`
TokenCount int `json:"tokenCount"`
}
// GeminiUsageMetadata Gemini 用量元数据 // GeminiUsageMetadata Gemini 用量元数据
type GeminiUsageMetadata struct { type GeminiUsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount,omitempty"` PromptTokenCount int `json:"promptTokenCount,omitempty"`
...@@ -156,6 +162,18 @@ type GeminiUsageMetadata struct { ...@@ -156,6 +162,18 @@ type GeminiUsageMetadata struct {
CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"` CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"`
TotalTokenCount int `json:"totalTokenCount,omitempty"` TotalTokenCount int `json:"totalTokenCount,omitempty"`
ThoughtsTokenCount int `json:"thoughtsTokenCount,omitempty"` // thinking tokens(按输出价格计费) ThoughtsTokenCount int `json:"thoughtsTokenCount,omitempty"` // thinking tokens(按输出价格计费)
CandidatesTokensDetails []GeminiTokenDetail `json:"candidatesTokensDetails,omitempty"`
PromptTokensDetails []GeminiTokenDetail `json:"promptTokensDetails,omitempty"`
}
// ImageOutputTokens 从 CandidatesTokensDetails 中提取 IMAGE 模态的 token 数
func (m *GeminiUsageMetadata) ImageOutputTokens() int {
for _, d := range m.CandidatesTokensDetails {
if d.Modality == "IMAGE" {
return d.TokenCount
}
}
return 0
} }
// GeminiGroundingMetadata Gemini grounding 元数据(Web Search) // GeminiGroundingMetadata Gemini grounding 元数据(Web Search)
......
...@@ -284,6 +284,7 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon ...@@ -284,6 +284,7 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon
usage.InputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached usage.InputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount
usage.CacheReadInputTokens = cached usage.CacheReadInputTokens = cached
usage.ImageOutputTokens = geminiResp.UsageMetadata.ImageOutputTokens()
} }
// 生成响应 ID // 生成响应 ID
......
...@@ -35,6 +35,7 @@ type StreamingProcessor struct { ...@@ -35,6 +35,7 @@ type StreamingProcessor struct {
inputTokens int inputTokens int
outputTokens int outputTokens int
cacheReadTokens int cacheReadTokens int
imageOutputTokens int
} }
// NewStreamingProcessor 创建流式响应处理器 // NewStreamingProcessor 创建流式响应处理器
...@@ -45,6 +46,28 @@ func NewStreamingProcessor(originalModel string) *StreamingProcessor { ...@@ -45,6 +46,28 @@ func NewStreamingProcessor(originalModel string) *StreamingProcessor {
} }
} }
// SetUsageMapHook sets an optional hook that modifies usage maps before they are emitted.
func (p *StreamingProcessor) SetUsageMapHook(fn UsageMapHook) {
p.usageMapHook = fn
}
func usageToMap(u ClaudeUsage) map[string]any {
m := map[string]any{
"input_tokens": u.InputTokens,
"output_tokens": u.OutputTokens,
}
if u.CacheCreationInputTokens > 0 {
m["cache_creation_input_tokens"] = u.CacheCreationInputTokens
}
if u.CacheReadInputTokens > 0 {
m["cache_read_input_tokens"] = u.CacheReadInputTokens
}
if u.ImageOutputTokens > 0 {
m["image_output_tokens"] = u.ImageOutputTokens
}
return m
}
// ProcessLine 处理 SSE 行,返回 Claude SSE 事件 // ProcessLine 处理 SSE 行,返回 Claude SSE 事件
func (p *StreamingProcessor) ProcessLine(line string) []byte { func (p *StreamingProcessor) ProcessLine(line string) []byte {
line = strings.TrimSpace(line) line = strings.TrimSpace(line)
...@@ -87,6 +110,7 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte { ...@@ -87,6 +110,7 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
p.inputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached p.inputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount
p.cacheReadTokens = cached p.cacheReadTokens = cached
p.imageOutputTokens = geminiResp.UsageMetadata.ImageOutputTokens()
} }
// 处理 parts // 处理 parts
...@@ -127,6 +151,7 @@ func (p *StreamingProcessor) Finish() ([]byte, *ClaudeUsage) { ...@@ -127,6 +151,7 @@ func (p *StreamingProcessor) Finish() ([]byte, *ClaudeUsage) {
InputTokens: p.inputTokens, InputTokens: p.inputTokens,
OutputTokens: p.outputTokens, OutputTokens: p.outputTokens,
CacheReadInputTokens: p.cacheReadTokens, CacheReadInputTokens: p.cacheReadTokens,
ImageOutputTokens: p.imageOutputTokens,
} }
if !p.messageStartSent { if !p.messageStartSent {
...@@ -158,6 +183,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte ...@@ -158,6 +183,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
usage.InputTokens = v1Resp.Response.UsageMetadata.PromptTokenCount - cached usage.InputTokens = v1Resp.Response.UsageMetadata.PromptTokenCount - cached
usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount + v1Resp.Response.UsageMetadata.ThoughtsTokenCount usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount + v1Resp.Response.UsageMetadata.ThoughtsTokenCount
usage.CacheReadInputTokens = cached usage.CacheReadInputTokens = cached
usage.ImageOutputTokens = v1Resp.Response.UsageMetadata.ImageOutputTokens()
} }
responseID := v1Resp.ResponseID responseID := v1Resp.ResponseID
...@@ -485,6 +511,7 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte { ...@@ -485,6 +511,7 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
InputTokens: p.inputTokens, InputTokens: p.inputTokens,
OutputTokens: p.outputTokens, OutputTokens: p.outputTokens,
CacheReadInputTokens: p.cacheReadTokens, CacheReadInputTokens: p.cacheReadTokens,
ImageOutputTokens: p.imageOutputTokens,
} }
deltaEvent := map[string]any{ deltaEvent := map[string]any{
......
...@@ -28,7 +28,7 @@ import ( ...@@ -28,7 +28,7 @@ import (
gocache "github.com/patrickmn/go-cache" gocache "github.com/patrickmn/go-cache"
) )
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, created_at" const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, created_at"
// usageLogInsertArgTypes must stay in the same order as: // usageLogInsertArgTypes must stay in the same order as:
// 1. prepareUsageLogInsert().args // 1. prepareUsageLogInsert().args
...@@ -53,6 +53,8 @@ var usageLogInsertArgTypes = [...]string{ ...@@ -53,6 +53,8 @@ var usageLogInsertArgTypes = [...]string{
"integer", // cache_read_tokens "integer", // cache_read_tokens
"integer", // cache_creation_5m_tokens "integer", // cache_creation_5m_tokens
"integer", // cache_creation_1h_tokens "integer", // cache_creation_1h_tokens
"integer", // image_output_tokens
"numeric", // image_output_cost
"numeric", // input_cost "numeric", // input_cost
"numeric", // output_cost "numeric", // output_cost
"numeric", // cache_creation_cost "numeric", // cache_creation_cost
...@@ -330,6 +332,8 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, ...@@ -330,6 +332,8 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
cache_read_tokens, cache_read_tokens,
cache_creation_5m_tokens, cache_creation_5m_tokens,
cache_creation_1h_tokens, cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost, input_cost,
output_cost, output_cost,
cache_creation_cost, cache_creation_cost,
...@@ -363,9 +367,9 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, ...@@ -363,9 +367,9 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
$1, $2, $3, $4, $5, $6, $7, $1, $2, $3, $4, $5, $6, $7,
$8, $9, $8, $9,
$10, $11, $12, $13, $10, $11, $12, $13,
$14, $15, $14, $15, $16, $17,
$16, $17, $18, $19, $20, $21, $18, $19, $20, $21, $22, $23,
$22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44 $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
) )
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at RETURNING id, created_at
...@@ -766,6 +770,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage ...@@ -766,6 +770,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
cache_read_tokens, cache_read_tokens,
cache_creation_5m_tokens, cache_creation_5m_tokens,
cache_creation_1h_tokens, cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost, input_cost,
output_cost, output_cost,
cache_creation_cost, cache_creation_cost,
...@@ -797,7 +803,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage ...@@ -797,7 +803,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
created_at created_at
) AS (VALUES `) ) AS (VALUES `)
args := make([]any, 0, len(keys)*45) args := make([]any, 0, len(keys)*47)
argPos := 1 argPos := 1
for idx, key := range keys { for idx, key := range keys {
if idx > 0 { if idx > 0 {
...@@ -841,6 +847,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage ...@@ -841,6 +847,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
cache_read_tokens, cache_read_tokens,
cache_creation_5m_tokens, cache_creation_5m_tokens,
cache_creation_1h_tokens, cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost, input_cost,
output_cost, output_cost,
cache_creation_cost, cache_creation_cost,
...@@ -887,6 +895,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage ...@@ -887,6 +895,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
cache_read_tokens, cache_read_tokens,
cache_creation_5m_tokens, cache_creation_5m_tokens,
cache_creation_1h_tokens, cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost, input_cost,
output_cost, output_cost,
cache_creation_cost, cache_creation_cost,
...@@ -973,6 +983,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ...@@ -973,6 +983,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
cache_read_tokens, cache_read_tokens,
cache_creation_5m_tokens, cache_creation_5m_tokens,
cache_creation_1h_tokens, cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost, input_cost,
output_cost, output_cost,
cache_creation_cost, cache_creation_cost,
...@@ -1004,7 +1016,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ...@@ -1004,7 +1016,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
created_at created_at
) AS (VALUES `) ) AS (VALUES `)
args := make([]any, 0, len(preparedList)*44) args := make([]any, 0, len(preparedList)*46)
argPos := 1 argPos := 1
for idx, prepared := range preparedList { for idx, prepared := range preparedList {
if idx > 0 { if idx > 0 {
...@@ -1045,6 +1057,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ...@@ -1045,6 +1057,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
cache_read_tokens, cache_read_tokens,
cache_creation_5m_tokens, cache_creation_5m_tokens,
cache_creation_1h_tokens, cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost, input_cost,
output_cost, output_cost,
cache_creation_cost, cache_creation_cost,
...@@ -1091,6 +1105,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ...@@ -1091,6 +1105,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
cache_read_tokens, cache_read_tokens,
cache_creation_5m_tokens, cache_creation_5m_tokens,
cache_creation_1h_tokens, cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost, input_cost,
output_cost, output_cost,
cache_creation_cost, cache_creation_cost,
...@@ -1145,6 +1161,8 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared ...@@ -1145,6 +1161,8 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
cache_read_tokens, cache_read_tokens,
cache_creation_5m_tokens, cache_creation_5m_tokens,
cache_creation_1h_tokens, cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost, input_cost,
output_cost, output_cost,
cache_creation_cost, cache_creation_cost,
...@@ -1178,9 +1196,9 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared ...@@ -1178,9 +1196,9 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
$1, $2, $3, $4, $5, $6, $7, $1, $2, $3, $4, $5, $6, $7,
$8, $9, $8, $9,
$10, $11, $12, $13, $10, $11, $12, $13,
$14, $15, $14, $15, $16, $17,
$16, $17, $18, $19, $20, $21, $18, $19, $20, $21, $22, $23,
$22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44 $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
) )
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
`, prepared.args...) `, prepared.args...)
...@@ -1248,6 +1266,8 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { ...@@ -1248,6 +1266,8 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
log.CacheReadTokens, log.CacheReadTokens,
log.CacheCreation5mTokens, log.CacheCreation5mTokens,
log.CacheCreation1hTokens, log.CacheCreation1hTokens,
log.ImageOutputTokens,
log.ImageOutputCost,
log.InputCost, log.InputCost,
log.OutputCost, log.OutputCost,
log.CacheCreationCost, log.CacheCreationCost,
...@@ -4011,6 +4031,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -4011,6 +4031,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
cacheReadTokens int cacheReadTokens int
cacheCreation5m int cacheCreation5m int
cacheCreation1h int cacheCreation1h int
imageOutputTokens int
imageOutputCost float64
inputCost float64 inputCost float64
outputCost float64 outputCost float64
cacheCreationCost float64 cacheCreationCost float64
...@@ -4059,6 +4081,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -4059,6 +4081,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&cacheReadTokens, &cacheReadTokens,
&cacheCreation5m, &cacheCreation5m,
&cacheCreation1h, &cacheCreation1h,
&imageOutputTokens,
&imageOutputCost,
&inputCost, &inputCost,
&outputCost, &outputCost,
&cacheCreationCost, &cacheCreationCost,
...@@ -4105,6 +4129,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -4105,6 +4129,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
CacheReadTokens: cacheReadTokens, CacheReadTokens: cacheReadTokens,
CacheCreation5mTokens: cacheCreation5m, CacheCreation5mTokens: cacheCreation5m,
CacheCreation1hTokens: cacheCreation1h, CacheCreation1hTokens: cacheCreation1h,
ImageOutputTokens: imageOutputTokens,
ImageOutputCost: imageOutputCost,
InputCost: inputCost, InputCost: inputCost,
OutputCost: outputCost, OutputCost: outputCost,
CacheCreationCost: cacheCreationCost, CacheCreationCost: cacheCreationCost,
......
...@@ -56,6 +56,8 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) { ...@@ -56,6 +56,8 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
log.CacheReadTokens, log.CacheReadTokens,
log.CacheCreation5mTokens, log.CacheCreation5mTokens,
log.CacheCreation1hTokens, log.CacheCreation1hTokens,
log.ImageOutputTokens,
log.ImageOutputCost,
log.InputCost, log.InputCost,
log.OutputCost, log.OutputCost,
log.CacheCreationCost, log.CacheCreationCost,
...@@ -133,6 +135,8 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) { ...@@ -133,6 +135,8 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
log.CacheReadTokens, log.CacheReadTokens,
log.CacheCreation5mTokens, log.CacheCreation5mTokens,
log.CacheCreation1hTokens, log.CacheCreation1hTokens,
log.ImageOutputTokens,
log.ImageOutputCost,
log.InputCost, log.InputCost,
log.OutputCost, log.OutputCost,
log.CacheCreationCost, log.CacheCreationCost,
...@@ -447,6 +451,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { ...@@ -447,6 +451,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
4, // cache_read_tokens 4, // cache_read_tokens
5, // cache_creation_5m_tokens 5, // cache_creation_5m_tokens
6, // cache_creation_1h_tokens 6, // cache_creation_1h_tokens
0, // image_output_tokens
0.0, // image_output_cost
0.1, // input_cost 0.1, // input_cost
0.2, // output_cost 0.2, // output_cost
0.3, // cache_creation_cost 0.3, // cache_creation_cost
...@@ -499,6 +505,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { ...@@ -499,6 +505,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullInt64{}, sql.NullInt64{},
sql.NullInt64{}, sql.NullInt64{},
1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6,
0, 0.0, // image_output_tokens, image_output_cost
0.1, 0.2, 0.3, 0.4, 1.0, 0.9, 0.1, 0.2, 0.3, 0.4, 1.0, 0.9,
1.0, 1.0,
sql.NullFloat64{}, sql.NullFloat64{},
...@@ -546,6 +553,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { ...@@ -546,6 +553,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullInt64{}, sql.NullInt64{},
sql.NullInt64{}, sql.NullInt64{},
1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6,
0, 0.0, // image_output_tokens, image_output_cost
0.1, 0.2, 0.3, 0.4, 1.0, 0.9, 0.1, 0.2, 0.3, 0.4, 1.0, 0.9,
1.0, 1.0,
sql.NullFloat64{}, sql.NullFloat64{},
......
...@@ -846,6 +846,15 @@ func (s *AccountUsageService) getAntigravityUsage(ctx context.Context, account * ...@@ -846,6 +846,15 @@ func (s *AccountUsageService) getAntigravityUsage(ctx context.Context, account *
return usage, nil return usage, nil
} }
// GetAntigravityCredits 返回账号的 AI Credits 信息,复用 getAntigravityUsage 的缓存。
// 如果缓存存在且 TTL 充足则直接返回;TTL 不足时自动刷新。
func (s *AccountUsageService) GetAntigravityCredits(ctx context.Context, account *Account) (*UsageInfo, error) {
if account == nil || account.Platform != PlatformAntigravity {
return nil, nil
}
return s.getAntigravityUsage(ctx, account)
}
// recalcAntigravityRemainingSeconds 重新计算 Antigravity UsageInfo 中各窗口的 RemainingSeconds // recalcAntigravityRemainingSeconds 重新计算 Antigravity UsageInfo 中各窗口的 RemainingSeconds
// 用于从缓存取出时更新倒计时,避免返回过时的剩余秒数 // 用于从缓存取出时更新倒计时,避免返回过时的剩余秒数
func recalcAntigravityRemainingSeconds(info *UsageInfo) { func recalcAntigravityRemainingSeconds(info *UsageInfo) {
......
...@@ -19,6 +19,54 @@ const ( ...@@ -19,6 +19,54 @@ const (
creditsExhaustedDuration = 5 * time.Hour creditsExhaustedDuration = 5 * time.Hour
) )
// checkAccountCredits 通过共享的 AccountUsageService 缓存检查账号是否有足够的 AI Credits。
// 缓存 TTL 不足时会自动从 Google loadCodeAssist API 刷新。
// 返回 true 表示积分可用。
func (s *AntigravityGatewayService) checkAccountCredits(
ctx context.Context, account *Account, accessToken, proxyURL string,
) bool {
if account == nil || account.ID == 0 {
return false
}
if s.accountUsageService == nil {
return true // 无 usage service 时不阻断
}
usageInfo, err := s.accountUsageService.GetAntigravityCredits(ctx, account)
if err != nil {
logger.LegacyPrintf("service.antigravity_gateway",
"check_credits: get_credits_failed account=%d err=%v", account.ID, err)
return true // 出错时假设有积分,不阻断
}
if usageInfo == nil || len(usageInfo.AICredits) == 0 {
logger.LegacyPrintf("service.antigravity_gateway",
"check_credits: account=%d has_credits=false amount=0 (no credits field)",
account.ID)
return false
}
for _, credit := range usageInfo.AICredits {
if credit.CreditType == "GOOGLE_ONE_AI" {
minimum := credit.MinimumBalance
if minimum <= 0 {
minimum = 5
}
hasCredits := credit.Amount >= minimum
logger.LegacyPrintf("service.antigravity_gateway",
"check_credits: account=%d has_credits=%t amount=%.0f minimum=%.0f",
account.ID, hasCredits, credit.Amount, minimum)
return hasCredits
}
}
logger.LegacyPrintf("service.antigravity_gateway",
"check_credits: account=%d has_credits=false (no GOOGLE_ONE_AI credit)",
account.ID)
return false
}
type antigravity429Category string type antigravity429Category string
const ( const (
...@@ -141,6 +189,13 @@ func resolveCreditsOveragesModelKey(ctx context.Context, account *Account, upstr ...@@ -141,6 +189,13 @@ func resolveCreditsOveragesModelKey(ctx context.Context, account *Account, upstr
} }
// shouldMarkCreditsExhausted 判断一次 credits 请求失败是否应标记为 credits 耗尽。 // shouldMarkCreditsExhausted 判断一次 credits 请求失败是否应标记为 credits 耗尽。
// 此函数在积分注入后失败时调用(预检查注入 + attemptCreditsOveragesRetry 两条路径)。
// - 429 + 非单模型限流:积分注入后仍 429 → 标记耗尽。
// - 429 + 单模型限流("exhausted your capacity on this model"):该模型免费配额用完,
// 积分注入对此无效,但账号积分对其他模型可能仍可用 → 不标记积分耗尽。
// - 403 等其他 4xx:检查 body 是否包含积分不足的关键词。
//
// clearCreditsExhausted 会在后续成功时自动清除。
func shouldMarkCreditsExhausted(resp *http.Response, respBody []byte, reqErr error) bool { func shouldMarkCreditsExhausted(resp *http.Response, respBody []byte, reqErr error) bool {
if reqErr != nil || resp == nil { if reqErr != nil || resp == nil {
return false return false
...@@ -148,13 +203,16 @@ func shouldMarkCreditsExhausted(resp *http.Response, respBody []byte, reqErr err ...@@ -148,13 +203,16 @@ func shouldMarkCreditsExhausted(resp *http.Response, respBody []byte, reqErr err
if resp.StatusCode >= 500 || resp.StatusCode == http.StatusRequestTimeout { if resp.StatusCode >= 500 || resp.StatusCode == http.StatusRequestTimeout {
return false return false
} }
// 注意:不再检查 isURLLevelRateLimit。此函数仅在积分重试失败后调用, bodyLower := strings.ToLower(string(respBody))
// 如果注入 enabledCreditTypes 后仍返回 "Resource has been exhausted", // 积分注入后仍 429
// 说明积分也已耗尽,应该标记。clearCreditsExhausted 会在后续成功时自动清除。 if resp.StatusCode == http.StatusTooManyRequests {
if info := parseAntigravitySmartRetryInfo(respBody); info != nil { // 单模型配额耗尽:积分注入对此无效,不标记整个账号积分耗尽
if strings.Contains(bodyLower, "exhausted your capacity on this model") {
return false return false
} }
bodyLower := strings.ToLower(string(respBody)) return true
}
// 其他 4xx:关键词匹配(如 403 + "Insufficient credits")
for _, keyword := range creditsExhaustedKeywords { for _, keyword := range creditsExhaustedKeywords {
if strings.Contains(bodyLower, keyword) { if strings.Contains(bodyLower, keyword) {
return true return true
...@@ -181,6 +239,16 @@ func (s *AntigravityGatewayService) attemptCreditsOveragesRetry( ...@@ -181,6 +239,16 @@ func (s *AntigravityGatewayService) attemptCreditsOveragesRetry(
if creditsBody == nil { if creditsBody == nil {
return &creditsOveragesRetryResult{handled: false} return &creditsOveragesRetryResult{handled: false}
} }
// Check actual credits balance before attempting retry
if !s.checkAccountCredits(p.ctx, p.account, p.accessToken, p.proxyURL) {
s.setCreditsExhausted(p.ctx, p.account)
modelKey := resolveCreditsOveragesModelKey(p.ctx, p.account, modelName, p.requestedModel)
logger.LegacyPrintf("service.antigravity_gateway", "%s credit_overages_no_credits model=%s account=%d (skipping credits retry)",
p.prefix, modelKey, p.account.ID)
return &creditsOveragesRetryResult{handled: true}
}
modelKey := resolveCreditsOveragesModelKey(p.ctx, p.account, modelName, p.requestedModel) modelKey := resolveCreditsOveragesModelKey(p.ctx, p.account, modelName, p.requestedModel)
logger.LegacyPrintf("service.antigravity_gateway", "%s status=429 credit_overages_retry model=%s account=%d (injecting enabledCreditTypes)", logger.LegacyPrintf("service.antigravity_gateway", "%s status=429 credit_overages_retry model=%s account=%d (injecting enabledCreditTypes)",
p.prefix, modelKey, p.account.ID) p.prefix, modelKey, p.account.ID)
......
...@@ -418,7 +418,13 @@ func TestShouldMarkCreditsExhausted(t *testing.T) { ...@@ -418,7 +418,13 @@ func TestShouldMarkCreditsExhausted(t *testing.T) {
require.True(t, shouldMarkCreditsExhausted(resp, body, nil)) require.True(t, shouldMarkCreditsExhausted(resp, body, nil))
}) })
t.Run("结构化限流不标记", func(t *testing.T) { t.Run("单模型配额耗尽不标记(积分对此无效)", func(t *testing.T) {
resp := &http.Response{StatusCode: http.StatusTooManyRequests}
body := []byte(`{"error":{"code":429,"message":"You have exhausted your capacity on this model. Your quota will reset after 146h11m17s.","status":"RESOURCE_EXHAUSTED"}}`)
require.False(t, shouldMarkCreditsExhausted(resp, body, nil))
})
t.Run("429 结构化限流也标记(积分注入后仍 429 即为耗尽)", func(t *testing.T) {
resp := &http.Response{StatusCode: http.StatusTooManyRequests} resp := &http.Response{StatusCode: http.StatusTooManyRequests}
body := []byte(`{"error":{"status":"RESOURCE_EXHAUSTED","details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","reason":"RATE_LIMIT_EXCEEDED"},{"@type":"type.googleapis.com/google.rpc.RetryInfo","retryDelay":"0.5s"}]}}`) body := []byte(`{"error":{"status":"RESOURCE_EXHAUSTED","details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","reason":"RATE_LIMIT_EXCEEDED"},{"@type":"type.googleapis.com/google.rpc.RetryInfo","retryDelay":"0.5s"}]}}`)
require.False(t, shouldMarkCreditsExhausted(resp, body, nil)) require.False(t, shouldMarkCreditsExhausted(resp, body, nil))
......
...@@ -557,7 +557,13 @@ func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopP ...@@ -557,7 +557,13 @@ func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopP
if p.requestedModel != "" && p.account.Platform == PlatformAntigravity && if p.requestedModel != "" && p.account.Platform == PlatformAntigravity &&
p.account.IsOveragesEnabled() && !p.account.isCreditsExhausted() && p.account.IsOveragesEnabled() && !p.account.isCreditsExhausted() &&
p.account.isModelRateLimitedWithContext(p.ctx, p.requestedModel) { p.account.isModelRateLimitedWithContext(p.ctx, p.requestedModel) {
if creditsBody := injectEnabledCreditTypes(p.body); creditsBody != nil { // Check actual credits balance before injection
if !s.checkAccountCredits(p.ctx, p.account, p.accessToken, p.proxyURL) {
// No credits available - mark as exhausted and skip injection
s.setCreditsExhausted(p.ctx, p.account)
logger.LegacyPrintf("service.antigravity_gateway", "%s pre_check: no_credits_available account=%d (skipping credits injection)",
p.prefix, p.account.ID)
} else if creditsBody := injectEnabledCreditTypes(p.body); creditsBody != nil {
p.body = creditsBody p.body = creditsBody
overagesInjected = true overagesInjected = true
logger.LegacyPrintf("service.antigravity_gateway", "%s pre_check: model_rate_limited_credits_inject model=%s account=%d (injecting enabledCreditTypes)", logger.LegacyPrintf("service.antigravity_gateway", "%s pre_check: model_rate_limited_credits_inject model=%s account=%d (injecting enabledCreditTypes)",
...@@ -878,6 +884,7 @@ type AntigravityGatewayService struct { ...@@ -878,6 +884,7 @@ type AntigravityGatewayService struct {
cache GatewayCache // 用于模型级限流时清除粘性会话绑定 cache GatewayCache // 用于模型级限流时清除粘性会话绑定
schedulerSnapshot *SchedulerSnapshotService schedulerSnapshot *SchedulerSnapshotService
internal500Cache Internal500CounterCache // INTERNAL 500 渐进惩罚计数器 internal500Cache Internal500CounterCache // INTERNAL 500 渐进惩罚计数器
accountUsageService *AccountUsageService // 共享 usage 缓存,用于积分余额检查
} }
func NewAntigravityGatewayService( func NewAntigravityGatewayService(
...@@ -889,6 +896,7 @@ func NewAntigravityGatewayService( ...@@ -889,6 +896,7 @@ func NewAntigravityGatewayService(
httpUpstream HTTPUpstream, httpUpstream HTTPUpstream,
settingService *SettingService, settingService *SettingService,
internal500Cache Internal500CounterCache, internal500Cache Internal500CounterCache,
accountUsageService *AccountUsageService,
) *AntigravityGatewayService { ) *AntigravityGatewayService {
return &AntigravityGatewayService{ return &AntigravityGatewayService{
accountRepo: accountRepo, accountRepo: accountRepo,
...@@ -899,6 +907,7 @@ func NewAntigravityGatewayService( ...@@ -899,6 +907,7 @@ func NewAntigravityGatewayService(
cache: cache, cache: cache,
schedulerSnapshot: schedulerSnapshot, schedulerSnapshot: schedulerSnapshot,
internal500Cache: internal500Cache, internal500Cache: internal500Cache,
accountUsageService: accountUsageService,
} }
} }
......
...@@ -56,6 +56,7 @@ type ModelPricing struct { ...@@ -56,6 +56,7 @@ type ModelPricing struct {
LongContextInputThreshold int // 超过阈值后按整次会话提升输入价格 LongContextInputThreshold int // 超过阈值后按整次会话提升输入价格
LongContextInputMultiplier float64 // 长上下文整次会话输入倍率 LongContextInputMultiplier float64 // 长上下文整次会话输入倍率
LongContextOutputMultiplier float64 // 长上下文整次会话输出倍率 LongContextOutputMultiplier float64 // 长上下文整次会话输出倍率
ImageOutputPricePerToken float64 // 图片输出 token 价格 (USD)
} }
const ( const (
...@@ -94,12 +95,14 @@ type UsageTokens struct { ...@@ -94,12 +95,14 @@ type UsageTokens struct {
CacheReadTokens int CacheReadTokens int
CacheCreation5mTokens int CacheCreation5mTokens int
CacheCreation1hTokens int CacheCreation1hTokens int
ImageOutputTokens int
} }
// CostBreakdown 费用明细 // CostBreakdown 费用明细
type CostBreakdown struct { type CostBreakdown struct {
InputCost float64 InputCost float64
OutputCost float64 OutputCost float64
ImageOutputCost float64
CacheCreationCost float64 CacheCreationCost float64
CacheReadCost float64 CacheReadCost float64
TotalCost float64 TotalCost float64
...@@ -358,6 +361,7 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) { ...@@ -358,6 +361,7 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
LongContextInputThreshold: litellmPricing.LongContextInputTokenThreshold, LongContextInputThreshold: litellmPricing.LongContextInputTokenThreshold,
LongContextInputMultiplier: litellmPricing.LongContextInputCostMultiplier, LongContextInputMultiplier: litellmPricing.LongContextInputCostMultiplier,
LongContextOutputMultiplier: litellmPricing.LongContextOutputCostMultiplier, LongContextOutputMultiplier: litellmPricing.LongContextOutputCostMultiplier,
ImageOutputPricePerToken: litellmPricing.OutputCostPerImageToken,
}), nil }), nil
} }
} }
...@@ -399,6 +403,9 @@ func (s *BillingService) GetModelPricingWithChannel(model string, channelPricing ...@@ -399,6 +403,9 @@ func (s *BillingService) GetModelPricingWithChannel(model string, channelPricing
pricing.CacheReadPricePerToken = *channelPricing.CacheReadPrice pricing.CacheReadPricePerToken = *channelPricing.CacheReadPrice
pricing.CacheReadPricePerTokenPriority = *channelPricing.CacheReadPrice pricing.CacheReadPricePerTokenPriority = *channelPricing.CacheReadPrice
} }
if channelPricing.ImageOutputPrice != nil {
pricing.ImageOutputPricePerToken = *channelPricing.ImageOutputPrice
}
return pricing, nil return pricing, nil
} }
...@@ -489,7 +496,22 @@ func (s *BillingService) calculateTokenCost(resolved *ResolvedPricing, input Cos ...@@ -489,7 +496,22 @@ func (s *BillingService) calculateTokenCost(resolved *ResolvedPricing, input Cos
} }
breakdown.InputCost = float64(input.Tokens.InputTokens) * inputPricePerToken breakdown.InputCost = float64(input.Tokens.InputTokens) * inputPricePerToken
breakdown.OutputCost = float64(input.Tokens.OutputTokens) * outputPricePerToken
// Separate image output tokens from text output tokens
textOutputTokens := input.Tokens.OutputTokens - input.Tokens.ImageOutputTokens
if textOutputTokens < 0 {
textOutputTokens = 0
}
breakdown.OutputCost = float64(textOutputTokens) * outputPricePerToken
// Image output tokens cost (separate rate from text output)
if input.Tokens.ImageOutputTokens > 0 {
imageOutputPrice := pricing.ImageOutputPricePerToken
if imageOutputPrice == 0 {
imageOutputPrice = outputPricePerToken // fallback to regular output price
}
breakdown.ImageOutputCost = float64(input.Tokens.ImageOutputTokens) * imageOutputPrice
}
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) { if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) {
if input.Tokens.CacheCreation5mTokens == 0 && input.Tokens.CacheCreation1hTokens == 0 && input.Tokens.CacheCreationTokens > 0 { if input.Tokens.CacheCreation5mTokens == 0 && input.Tokens.CacheCreation1hTokens == 0 && input.Tokens.CacheCreationTokens > 0 {
...@@ -507,11 +529,12 @@ func (s *BillingService) calculateTokenCost(resolved *ResolvedPricing, input Cos ...@@ -507,11 +529,12 @@ func (s *BillingService) calculateTokenCost(resolved *ResolvedPricing, input Cos
if tierMultiplier != 1.0 { if tierMultiplier != 1.0 {
breakdown.InputCost *= tierMultiplier breakdown.InputCost *= tierMultiplier
breakdown.OutputCost *= tierMultiplier breakdown.OutputCost *= tierMultiplier
breakdown.ImageOutputCost *= tierMultiplier
breakdown.CacheCreationCost *= tierMultiplier breakdown.CacheCreationCost *= tierMultiplier
breakdown.CacheReadCost *= tierMultiplier breakdown.CacheReadCost *= tierMultiplier
} }
breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost + breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost + breakdown.ImageOutputCost +
breakdown.CacheCreationCost + breakdown.CacheReadCost breakdown.CacheCreationCost + breakdown.CacheReadCost
breakdown.ActualCost = breakdown.TotalCost * input.RateMultiplier breakdown.ActualCost = breakdown.TotalCost * input.RateMultiplier
...@@ -597,8 +620,21 @@ func (s *BillingService) calculateCostInternal(model string, tokens UsageTokens, ...@@ -597,8 +620,21 @@ func (s *BillingService) calculateCostInternal(model string, tokens UsageTokens,
// 计算输入token费用(使用per-token价格) // 计算输入token费用(使用per-token价格)
breakdown.InputCost = float64(tokens.InputTokens) * inputPricePerToken breakdown.InputCost = float64(tokens.InputTokens) * inputPricePerToken
// 计算输出token费用 // 计算输出token费用(分离图片输出token)
breakdown.OutputCost = float64(tokens.OutputTokens) * outputPricePerToken textOutputTokens := tokens.OutputTokens - tokens.ImageOutputTokens
if textOutputTokens < 0 {
textOutputTokens = 0
}
breakdown.OutputCost = float64(textOutputTokens) * outputPricePerToken
// 图片输出 token 费用
if tokens.ImageOutputTokens > 0 {
imageOutputPrice := pricing.ImageOutputPricePerToken
if imageOutputPrice == 0 {
imageOutputPrice = outputPricePerToken
}
breakdown.ImageOutputCost = float64(tokens.ImageOutputTokens) * imageOutputPrice
}
// 计算缓存费用 // 计算缓存费用
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) { if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) {
...@@ -620,12 +656,13 @@ func (s *BillingService) calculateCostInternal(model string, tokens UsageTokens, ...@@ -620,12 +656,13 @@ func (s *BillingService) calculateCostInternal(model string, tokens UsageTokens,
if tierMultiplier != 1.0 { if tierMultiplier != 1.0 {
breakdown.InputCost *= tierMultiplier breakdown.InputCost *= tierMultiplier
breakdown.OutputCost *= tierMultiplier breakdown.OutputCost *= tierMultiplier
breakdown.ImageOutputCost *= tierMultiplier
breakdown.CacheCreationCost *= tierMultiplier breakdown.CacheCreationCost *= tierMultiplier
breakdown.CacheReadCost *= tierMultiplier breakdown.CacheReadCost *= tierMultiplier
} }
// 计算总费用 // 计算总费用
breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost + breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost + breakdown.ImageOutputCost +
breakdown.CacheCreationCost + breakdown.CacheReadCost breakdown.CacheCreationCost + breakdown.CacheReadCost
// 应用倍率计算实际费用 // 应用倍率计算实际费用
...@@ -730,6 +767,7 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage ...@@ -730,6 +767,7 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage
CacheReadTokens: inRangeCacheTokens, CacheReadTokens: inRangeCacheTokens,
CacheCreation5mTokens: tokens.CacheCreation5mTokens, CacheCreation5mTokens: tokens.CacheCreation5mTokens,
CacheCreation1hTokens: tokens.CacheCreation1hTokens, CacheCreation1hTokens: tokens.CacheCreation1hTokens,
ImageOutputTokens: tokens.ImageOutputTokens,
} }
inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier) inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier)
if err != nil { if err != nil {
...@@ -750,6 +788,7 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage ...@@ -750,6 +788,7 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage
return &CostBreakdown{ return &CostBreakdown{
InputCost: inRangeCost.InputCost + outRangeCost.InputCost, InputCost: inRangeCost.InputCost + outRangeCost.InputCost,
OutputCost: inRangeCost.OutputCost, OutputCost: inRangeCost.OutputCost,
ImageOutputCost: inRangeCost.ImageOutputCost,
CacheCreationCost: inRangeCost.CacheCreationCost, CacheCreationCost: inRangeCost.CacheCreationCost,
CacheReadCost: inRangeCost.CacheReadCost + outRangeCost.CacheReadCost, CacheReadCost: inRangeCost.CacheReadCost + outRangeCost.CacheReadCost,
TotalCost: inRangeCost.TotalCost + outRangeCost.TotalCost, TotalCost: inRangeCost.TotalCost + outRangeCost.TotalCost,
......
...@@ -26,6 +26,7 @@ func (m BillingMode) IsValid() bool { ...@@ -26,6 +26,7 @@ func (m BillingMode) IsValid() bool {
const ( const (
BillingModelSourceRequested = "requested" BillingModelSourceRequested = "requested"
BillingModelSourceUpstream = "upstream" BillingModelSourceUpstream = "upstream"
BillingModelSourceChannelMapped = "channel_mapped"
) )
// Channel 渠道实体 // Channel 渠道实体
...@@ -34,7 +35,7 @@ type Channel struct { ...@@ -34,7 +35,7 @@ type Channel struct {
Name string Name string
Description string Description string
Status string Status string
BillingModelSource string // "requested" or "upstream" BillingModelSource string // "requested", "upstream", or "channel_mapped"
RestrictModels bool // 是否限制模型(仅允许定价列表中的模型) RestrictModels bool // 是否限制模型(仅允许定价列表中的模型)
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
...@@ -180,6 +181,7 @@ func (c *Channel) Clone() *Channel { ...@@ -180,6 +181,7 @@ func (c *Channel) Clone() *Channel {
type ChannelUsageFields struct { type ChannelUsageFields struct {
ChannelID int64 // 渠道 ID(0 = 无渠道) ChannelID int64 // 渠道 ID(0 = 无渠道)
OriginalModel string // 用户原始请求模型(渠道映射前) OriginalModel string // 用户原始请求模型(渠道映射前)
BillingModelSource string // 计费模型来源:"requested" / "upstream" ChannelMappedModel string // 渠道映射后的模型名(无映射时等于 OriginalModel)
BillingModelSource string // 计费模型来源:"requested" / "upstream" / "channel_mapped"
ModelMappingChain string // 映射链描述,如 "a→b→c" ModelMappingChain string // 映射链描述,如 "a→b→c"
} }
...@@ -97,7 +97,7 @@ type ChannelMappingResult struct { ...@@ -97,7 +97,7 @@ type ChannelMappingResult struct {
MappedModel string // 映射后的模型名(无映射时等于原始模型名) MappedModel string // 映射后的模型名(无映射时等于原始模型名)
ChannelID int64 // 渠道 ID(0 = 无渠道关联) ChannelID int64 // 渠道 ID(0 = 无渠道关联)
Mapped bool // 是否发生了映射 Mapped bool // 是否发生了映射
BillingModelSource string // 计费模型来源("requested" / "upstream") BillingModelSource string // 计费模型来源("requested" / "upstream" / "channel_mapped"
} }
// BuildModelMappingChain 根据映射结果和上游实际模型构建映射链描述。 // BuildModelMappingChain 根据映射结果和上游实际模型构建映射链描述。
...@@ -119,9 +119,14 @@ func (r ChannelMappingResult) BuildModelMappingChain(reqModel, upstreamModel str ...@@ -119,9 +119,14 @@ func (r ChannelMappingResult) BuildModelMappingChain(reqModel, upstreamModel str
// ToUsageFields 将渠道映射结果转为使用记录字段 // ToUsageFields 将渠道映射结果转为使用记录字段
func (r ChannelMappingResult) ToUsageFields(reqModel, upstreamModel string) ChannelUsageFields { func (r ChannelMappingResult) ToUsageFields(reqModel, upstreamModel string) ChannelUsageFields {
channelMappedModel := reqModel
if r.Mapped {
channelMappedModel = r.MappedModel
}
return ChannelUsageFields{ return ChannelUsageFields{
ChannelID: r.ChannelID, ChannelID: r.ChannelID,
OriginalModel: reqModel, OriginalModel: reqModel,
ChannelMappedModel: channelMappedModel,
BillingModelSource: r.BillingModelSource, BillingModelSource: r.BillingModelSource,
ModelMappingChain: r.BuildModelMappingChain(reqModel, upstreamModel), ModelMappingChain: r.BuildModelMappingChain(reqModel, upstreamModel),
} }
...@@ -193,7 +198,7 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) ...@@ -193,7 +198,7 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
channelByGroupID: make(map[int64]*Channel), channelByGroupID: make(map[int64]*Channel),
groupPlatform: make(map[int64]string), groupPlatform: make(map[int64]string),
byID: make(map[int64]*Channel), byID: make(map[int64]*Channel),
loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL loadedAt: time.Now().Add(-(channelCacheTTL - channelErrorTTL)), // 使剩余 TTL = errorTTL
} }
s.cache.Store(errorCache) s.cache.Store(errorCache)
return nil, fmt.Errorf("list all channels: %w", err) return nil, fmt.Errorf("list all channels: %w", err)
...@@ -374,7 +379,7 @@ func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int6 ...@@ -374,7 +379,7 @@ func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int6
BillingModelSource: ch.BillingModelSource, BillingModelSource: ch.BillingModelSource,
} }
if result.BillingModelSource == "" { if result.BillingModelSource == "" {
result.BillingModelSource = BillingModelSourceRequested result.BillingModelSource = BillingModelSourceChannelMapped
} }
platform := cache.groupPlatform[groupID] platform := cache.groupPlatform[groupID]
...@@ -481,7 +486,7 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput) ...@@ -481,7 +486,7 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
ModelMapping: input.ModelMapping, ModelMapping: input.ModelMapping,
} }
if channel.BillingModelSource == "" { if channel.BillingModelSource == "" {
channel.BillingModelSource = BillingModelSourceRequested channel.BillingModelSource = BillingModelSourceChannelMapped
} }
if err := validateNoConflictingModels(channel.ModelPricing); err != nil { if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
...@@ -565,22 +570,38 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan ...@@ -565,22 +570,38 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
return nil, err return nil, err
} }
// 先获取旧分组,Update 后旧分组关联已删除,无法再查到
var oldGroupIDs []int64
if s.authCacheInvalidator != nil {
var err2 error
oldGroupIDs, err2 = s.repo.GetGroupIDs(ctx, id)
if err2 != nil {
slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", id, "error", err2)
}
}
if err := s.repo.Update(ctx, channel); err != nil { if err := s.repo.Update(ctx, channel); err != nil {
return nil, fmt.Errorf("update channel: %w", err) return nil, fmt.Errorf("update channel: %w", err)
} }
s.invalidateCache() s.invalidateCache()
// 失效关联分组的 auth 缓存 // 失效新旧分组的 auth 缓存
if s.authCacheInvalidator != nil { if s.authCacheInvalidator != nil {
groupIDs, err := s.repo.GetGroupIDs(ctx, id) seen := make(map[int64]struct{}, len(oldGroupIDs)+len(channel.GroupIDs))
if err != nil { for _, gid := range oldGroupIDs {
slog.Warn("failed to get group IDs for cache invalidation", "channel_id", id, "error", err) if _, ok := seen[gid]; !ok {
seen[gid] = struct{}{}
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
} }
for _, gid := range groupIDs { }
for _, gid := range channel.GroupIDs {
if _, ok := seen[gid]; !ok {
seen[gid] = struct{}{}
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid) s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
} }
} }
}
return s.repo.GetByID(ctx, id) return s.repo.GetByID(ctx, id)
} }
......
...@@ -196,7 +196,6 @@ func newTestChannelServiceWithAuth(repo *mockChannelRepository, auth *mockChanne ...@@ -196,7 +196,6 @@ func newTestChannelServiceWithAuth(repo *mockChannelRepository, auth *mockChanne
return NewChannelService(repo, auth) return NewChannelService(repo, auth)
} }
// makeStandardRepo returns a repo that serves one active channel with anthropic pricing // makeStandardRepo returns a repo that serves one active channel with anthropic pricing
// for group 1, with the given model pricing and model mapping. // for group 1, with the given model pricing and model mapping.
func makeStandardRepo(ch Channel, groupPlatforms map[int64]string) *mockChannelRepository { func makeStandardRepo(ch Channel, groupPlatforms map[int64]string) *mockChannelRepository {
...@@ -914,7 +913,7 @@ func TestResolveChannelMapping_DefaultBillingModelSource(t *testing.T) { ...@@ -914,7 +913,7 @@ func TestResolveChannelMapping_DefaultBillingModelSource(t *testing.T) {
svc := newTestChannelService(repo) svc := newTestChannelService(repo)
result := svc.ResolveChannelMapping(context.Background(), 10, "claude-opus-4") result := svc.ResolveChannelMapping(context.Background(), 10, "claude-opus-4")
require.Equal(t, BillingModelSourceRequested, result.BillingModelSource) require.Equal(t, BillingModelSourceChannelMapped, result.BillingModelSource)
} }
func TestResolveChannelMapping_UpstreamBillingModelSource(t *testing.T) { func TestResolveChannelMapping_UpstreamBillingModelSource(t *testing.T) {
...@@ -1451,11 +1450,11 @@ func TestCreate_DefaultBillingModelSource(t *testing.T) { ...@@ -1451,11 +1450,11 @@ func TestCreate_DefaultBillingModelSource(t *testing.T) {
result, err := svc.Create(context.Background(), &CreateChannelInput{ result, err := svc.Create(context.Background(), &CreateChannelInput{
Name: "new-channel", Name: "new-channel",
BillingModelSource: "", // empty, should default to "requested" BillingModelSource: "", // empty, should default to "channel_mapped"
}) })
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.Equal(t, BillingModelSourceRequested, result.BillingModelSource) require.Equal(t, BillingModelSourceChannelMapped, result.BillingModelSource)
} }
func TestCreate_InvalidatesCache(t *testing.T) { func TestCreate_InvalidatesCache(t *testing.T) {
......
...@@ -483,6 +483,7 @@ type ClaudeUsage struct { ...@@ -483,6 +483,7 @@ type ClaudeUsage struct {
CacheReadInputTokens int `json:"cache_read_input_tokens"` CacheReadInputTokens int `json:"cache_read_input_tokens"`
CacheCreation5mTokens int // 5分钟缓存创建token(来自嵌套 cache_creation 对象) CacheCreation5mTokens int // 5分钟缓存创建token(来自嵌套 cache_creation 对象)
CacheCreation1hTokens int // 1小时缓存创建token(来自嵌套 cache_creation 对象) CacheCreation1hTokens int // 1小时缓存创建token(来自嵌套 cache_creation 对象)
ImageOutputTokens int `json:"image_output_tokens,omitempty"`
} }
// ForwardResult 转发结果 // ForwardResult 转发结果
...@@ -7729,6 +7730,9 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -7729,6 +7730,9 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
var cost *CostBreakdown var cost *CostBreakdown
// 确定计费模型 // 确定计费模型
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
if input.BillingModelSource == BillingModelSourceChannelMapped && input.ChannelMappedModel != "" {
billingModel = input.ChannelMappedModel
}
if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" { if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" {
billingModel = input.OriginalModel billingModel = input.OriginalModel
} }
...@@ -7777,6 +7781,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -7777,6 +7781,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
CacheReadTokens: result.Usage.CacheReadInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens,
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
ImageOutputTokens: result.Usage.ImageOutputTokens,
} }
var err error var err error
if s.resolver != nil && apiKey.Group != nil { if s.resolver != nil && apiKey.Group != nil {
...@@ -7836,8 +7841,10 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -7836,8 +7841,10 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
CacheReadTokens: result.Usage.CacheReadInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens,
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
ImageOutputTokens: result.Usage.ImageOutputTokens,
InputCost: cost.InputCost, InputCost: cost.InputCost,
OutputCost: cost.OutputCost, OutputCost: cost.OutputCost,
ImageOutputCost: cost.ImageOutputCost,
CacheCreationCost: cost.CacheCreationCost, CacheCreationCost: cost.CacheCreationCost,
CacheReadCost: cost.CacheReadCost, CacheReadCost: cost.CacheReadCost,
TotalCost: cost.TotalCost, TotalCost: cost.TotalCost,
...@@ -7976,6 +7983,9 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * ...@@ -7976,6 +7983,9 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
var cost *CostBreakdown var cost *CostBreakdown
// 确定计费模型 // 确定计费模型
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
if input.BillingModelSource == BillingModelSourceChannelMapped && input.ChannelMappedModel != "" {
billingModel = input.ChannelMappedModel
}
if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" { if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" {
billingModel = input.OriginalModel billingModel = input.OriginalModel
} }
...@@ -8007,6 +8017,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * ...@@ -8007,6 +8017,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
CacheReadTokens: result.Usage.CacheReadInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens,
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
ImageOutputTokens: result.Usage.ImageOutputTokens,
} }
var err error var err error
// 优先尝试 Resolver + CalculateCostUnified(仅在有渠道定价时使用) // 优先尝试 Resolver + CalculateCostUnified(仅在有渠道定价时使用)
...@@ -8073,8 +8084,10 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * ...@@ -8073,8 +8084,10 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
CacheReadTokens: result.Usage.CacheReadInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens,
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
ImageOutputTokens: result.Usage.ImageOutputTokens,
InputCost: cost.InputCost, InputCost: cost.InputCost,
OutputCost: cost.OutputCost, OutputCost: cost.OutputCost,
ImageOutputCost: cost.ImageOutputCost,
CacheCreationCost: cost.CacheCreationCost, CacheCreationCost: cost.CacheCreationCost,
CacheReadCost: cost.CacheReadCost, CacheReadCost: cost.CacheReadCost,
TotalCost: cost.TotalCost, TotalCost: cost.TotalCost,
......
...@@ -2692,12 +2692,27 @@ func extractGeminiUsage(data []byte) *ClaudeUsage { ...@@ -2692,12 +2692,27 @@ func extractGeminiUsage(data []byte) *ClaudeUsage {
cand := int(usage.Get("candidatesTokenCount").Int()) cand := int(usage.Get("candidatesTokenCount").Int())
cached := int(usage.Get("cachedContentTokenCount").Int()) cached := int(usage.Get("cachedContentTokenCount").Int())
thoughts := int(usage.Get("thoughtsTokenCount").Int()) thoughts := int(usage.Get("thoughtsTokenCount").Int())
// 从 candidatesTokensDetails 提取 IMAGE 模态 token 数
imageTokens := 0
candidateDetails := usage.Get("candidatesTokensDetails")
if candidateDetails.Exists() {
candidateDetails.ForEach(func(_, detail gjson.Result) bool {
if detail.Get("modality").String() == "IMAGE" {
imageTokens = int(detail.Get("tokenCount").Int())
return false
}
return true
})
}
// 注意:Gemini 的 promptTokenCount 包含 cachedContentTokenCount, // 注意:Gemini 的 promptTokenCount 包含 cachedContentTokenCount,
// 但 Claude 的 input_tokens 不包含 cache_read_input_tokens,需要减去 // 但 Claude 的 input_tokens 不包含 cache_read_input_tokens,需要减去
return &ClaudeUsage{ return &ClaudeUsage{
InputTokens: prompt - cached, InputTokens: prompt - cached,
OutputTokens: cand + thoughts, OutputTokens: cand + thoughts,
CacheReadInputTokens: cached, CacheReadInputTokens: cached,
ImageOutputTokens: imageTokens,
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment