Commit d72ac926 authored by erio's avatar erio
Browse files

feat: image output token billing, channel-mapped billing source, credits balance precheck

- Parse candidatesTokensDetails from Gemini API to separate image/text output tokens
- Add image_output_tokens and image_output_cost to usage_log (migration 089)
- Support per-image-token pricing via output_cost_per_image_token from model pricing data
- Channel pricing ImageOutputPrice override works in token billing mode
- Auto-fill image_output_price in channel pricing form from model defaults
- Add "channel_mapped" billing model source as new default (migration 088)
- Bills by model name after channel mapping, before account mapping
- Fix channel cache error TTL sign error (115s → 5s)
- Fix Update channel only invalidating new groups, not removed groups
- Fix frontend model_mapping clearing sending undefined instead of {}
- Credits balance precheck via shared AccountUsageService cache before injection
- Skip credits injection for accounts with insufficient balance
- Don't mark credits exhausted for "exhausted your capacity on this model" 429s
parent 2555951b
...@@ -139,11 +139,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -139,11 +139,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig) schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oauthRefreshAPI, tempUnschedCache) antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oauthRefreshAPI, tempUnschedCache)
internal500CounterCache := repository.NewInternal500CounterCache(redisClient) internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
tlsFingerprintProfileRepository := repository.NewTLSFingerprintProfileRepository(client) tlsFingerprintProfileRepository := repository.NewTLSFingerprintProfileRepository(client)
tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient) tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient)
tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache) tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache)
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService) accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache, accountUsageService)
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService) accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig) sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
......
...@@ -31,7 +31,7 @@ type createChannelRequest struct { ...@@ -31,7 +31,7 @@ type createChannelRequest struct {
GroupIDs []int64 `json:"group_ids"` GroupIDs []int64 `json:"group_ids"`
ModelPricing []channelModelPricingRequest `json:"model_pricing"` ModelPricing []channelModelPricingRequest `json:"model_pricing"`
ModelMapping map[string]map[string]string `json:"model_mapping"` ModelMapping map[string]map[string]string `json:"model_mapping"`
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream"` BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
RestrictModels bool `json:"restrict_models"` RestrictModels bool `json:"restrict_models"`
} }
...@@ -42,7 +42,7 @@ type updateChannelRequest struct { ...@@ -42,7 +42,7 @@ type updateChannelRequest struct {
GroupIDs *[]int64 `json:"group_ids"` GroupIDs *[]int64 `json:"group_ids"`
ModelPricing *[]channelModelPricingRequest `json:"model_pricing"` ModelPricing *[]channelModelPricingRequest `json:"model_pricing"`
ModelMapping map[string]map[string]string `json:"model_mapping"` ModelMapping map[string]map[string]string `json:"model_mapping"`
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream"` BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
RestrictModels *bool `json:"restrict_models"` RestrictModels *bool `json:"restrict_models"`
} }
...@@ -129,7 +129,7 @@ func channelToResponse(ch *service.Channel) *channelResponse { ...@@ -129,7 +129,7 @@ func channelToResponse(ch *service.Channel) *channelResponse {
} }
resp.BillingModelSource = ch.BillingModelSource resp.BillingModelSource = ch.BillingModelSource
if resp.BillingModelSource == "" { if resp.BillingModelSource == "" {
resp.BillingModelSource = "requested" resp.BillingModelSource = "channel_mapped"
} }
if resp.GroupIDs == nil { if resp.GroupIDs == nil {
resp.GroupIDs = []int64{} resp.GroupIDs = []int64{}
...@@ -388,10 +388,11 @@ func (h *ChannelHandler) GetModelDefaultPricing(c *gin.Context) { ...@@ -388,10 +388,11 @@ func (h *ChannelHandler) GetModelDefaultPricing(c *gin.Context) {
} }
response.Success(c, gin.H{ response.Success(c, gin.H{
"found": true, "found": true,
"input_price": pricing.InputPricePerToken, "input_price": pricing.InputPricePerToken,
"output_price": pricing.OutputPricePerToken, "output_price": pricing.OutputPricePerToken,
"cache_write_price": pricing.CacheCreationPricePerToken, "cache_write_price": pricing.CacheCreationPricePerToken,
"cache_read_price": pricing.CacheReadPricePerToken, "cache_read_price": pricing.CacheReadPricePerToken,
"image_output_price": pricing.ImageOutputPricePerToken,
}) })
} }
...@@ -36,7 +36,7 @@ func TestChannelToResponse_FullChannel(t *testing.T) { ...@@ -36,7 +36,7 @@ func TestChannelToResponse_FullChannel(t *testing.T) {
RestrictModels: true, RestrictModels: true,
CreatedAt: now, CreatedAt: now,
UpdatedAt: now.Add(time.Hour), UpdatedAt: now.Add(time.Hour),
GroupIDs: []int64{1, 2, 3}, GroupIDs: []int64{1, 2, 3},
ModelPricing: []service.ChannelModelPricing{ ModelPricing: []service.ChannelModelPricing{
{ {
ID: 10, ID: 10,
...@@ -94,8 +94,8 @@ func TestChannelToResponse_EmptyDefaults(t *testing.T) { ...@@ -94,8 +94,8 @@ func TestChannelToResponse_EmptyDefaults(t *testing.T) {
BillingModelSource: "", BillingModelSource: "",
CreatedAt: now, CreatedAt: now,
UpdatedAt: now, UpdatedAt: now,
GroupIDs: nil, GroupIDs: nil,
ModelMapping: nil, ModelMapping: nil,
ModelPricing: []service.ChannelModelPricing{ ModelPricing: []service.ChannelModelPricing{
{ {
Platform: "", Platform: "",
...@@ -106,7 +106,7 @@ func TestChannelToResponse_EmptyDefaults(t *testing.T) { ...@@ -106,7 +106,7 @@ func TestChannelToResponse_EmptyDefaults(t *testing.T) {
} }
resp := channelToResponse(ch) resp := channelToResponse(ch)
require.Equal(t, "requested", resp.BillingModelSource) require.Equal(t, "channel_mapped", resp.BillingModelSource)
require.NotNil(t, resp.GroupIDs) require.NotNil(t, resp.GroupIDs)
require.Empty(t, resp.GroupIDs) require.Empty(t, resp.GroupIDs)
require.NotNil(t, resp.ModelMapping) require.NotNil(t, resp.ModelMapping)
......
...@@ -125,6 +125,7 @@ type ClaudeUsage struct { ...@@ -125,6 +125,7 @@ type ClaudeUsage struct {
OutputTokens int `json:"output_tokens"` OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"` CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"`
CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"` CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"`
ImageOutputTokens int `json:"image_output_tokens,omitempty"`
} }
// ClaudeError Claude 错误响应 // ClaudeError Claude 错误响应
......
...@@ -149,13 +149,31 @@ type GeminiCandidate struct { ...@@ -149,13 +149,31 @@ type GeminiCandidate struct {
GroundingMetadata *GeminiGroundingMetadata `json:"groundingMetadata,omitempty"` GroundingMetadata *GeminiGroundingMetadata `json:"groundingMetadata,omitempty"`
} }
// GeminiTokenDetail Gemini token 详情(按模态分类)
type GeminiTokenDetail struct {
Modality string `json:"modality"`
TokenCount int `json:"tokenCount"`
}
// GeminiUsageMetadata Gemini 用量元数据 // GeminiUsageMetadata Gemini 用量元数据
type GeminiUsageMetadata struct { type GeminiUsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount,omitempty"` PromptTokenCount int `json:"promptTokenCount,omitempty"`
CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"` CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"`
CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"` CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"`
TotalTokenCount int `json:"totalTokenCount,omitempty"` TotalTokenCount int `json:"totalTokenCount,omitempty"`
ThoughtsTokenCount int `json:"thoughtsTokenCount,omitempty"` // thinking tokens(按输出价格计费) ThoughtsTokenCount int `json:"thoughtsTokenCount,omitempty"` // thinking tokens(按输出价格计费)
CandidatesTokensDetails []GeminiTokenDetail `json:"candidatesTokensDetails,omitempty"`
PromptTokensDetails []GeminiTokenDetail `json:"promptTokensDetails,omitempty"`
}
// ImageOutputTokens 从 CandidatesTokensDetails 中提取 IMAGE 模态的 token 数
func (m *GeminiUsageMetadata) ImageOutputTokens() int {
for _, d := range m.CandidatesTokensDetails {
if d.Modality == "IMAGE" {
return d.TokenCount
}
}
return 0
} }
// GeminiGroundingMetadata Gemini grounding 元数据(Web Search) // GeminiGroundingMetadata Gemini grounding 元数据(Web Search)
......
...@@ -284,6 +284,7 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon ...@@ -284,6 +284,7 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon
usage.InputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached usage.InputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount
usage.CacheReadInputTokens = cached usage.CacheReadInputTokens = cached
usage.ImageOutputTokens = geminiResp.UsageMetadata.ImageOutputTokens()
} }
// 生成响应 ID // 生成响应 ID
......
...@@ -32,9 +32,10 @@ type StreamingProcessor struct { ...@@ -32,9 +32,10 @@ type StreamingProcessor struct {
groundingChunks []GeminiGroundingChunk groundingChunks []GeminiGroundingChunk
// 累计 usage // 累计 usage
inputTokens int inputTokens int
outputTokens int outputTokens int
cacheReadTokens int cacheReadTokens int
imageOutputTokens int
} }
// NewStreamingProcessor 创建流式响应处理器 // NewStreamingProcessor 创建流式响应处理器
...@@ -45,6 +46,28 @@ func NewStreamingProcessor(originalModel string) *StreamingProcessor { ...@@ -45,6 +46,28 @@ func NewStreamingProcessor(originalModel string) *StreamingProcessor {
} }
} }
// SetUsageMapHook sets an optional hook that modifies usage maps before they are emitted.
func (p *StreamingProcessor) SetUsageMapHook(fn UsageMapHook) {
p.usageMapHook = fn
}
func usageToMap(u ClaudeUsage) map[string]any {
m := map[string]any{
"input_tokens": u.InputTokens,
"output_tokens": u.OutputTokens,
}
if u.CacheCreationInputTokens > 0 {
m["cache_creation_input_tokens"] = u.CacheCreationInputTokens
}
if u.CacheReadInputTokens > 0 {
m["cache_read_input_tokens"] = u.CacheReadInputTokens
}
if u.ImageOutputTokens > 0 {
m["image_output_tokens"] = u.ImageOutputTokens
}
return m
}
// ProcessLine 处理 SSE 行,返回 Claude SSE 事件 // ProcessLine 处理 SSE 行,返回 Claude SSE 事件
func (p *StreamingProcessor) ProcessLine(line string) []byte { func (p *StreamingProcessor) ProcessLine(line string) []byte {
line = strings.TrimSpace(line) line = strings.TrimSpace(line)
...@@ -87,6 +110,7 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte { ...@@ -87,6 +110,7 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
p.inputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached p.inputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount
p.cacheReadTokens = cached p.cacheReadTokens = cached
p.imageOutputTokens = geminiResp.UsageMetadata.ImageOutputTokens()
} }
// 处理 parts // 处理 parts
...@@ -127,6 +151,7 @@ func (p *StreamingProcessor) Finish() ([]byte, *ClaudeUsage) { ...@@ -127,6 +151,7 @@ func (p *StreamingProcessor) Finish() ([]byte, *ClaudeUsage) {
InputTokens: p.inputTokens, InputTokens: p.inputTokens,
OutputTokens: p.outputTokens, OutputTokens: p.outputTokens,
CacheReadInputTokens: p.cacheReadTokens, CacheReadInputTokens: p.cacheReadTokens,
ImageOutputTokens: p.imageOutputTokens,
} }
if !p.messageStartSent { if !p.messageStartSent {
...@@ -158,6 +183,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte ...@@ -158,6 +183,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
usage.InputTokens = v1Resp.Response.UsageMetadata.PromptTokenCount - cached usage.InputTokens = v1Resp.Response.UsageMetadata.PromptTokenCount - cached
usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount + v1Resp.Response.UsageMetadata.ThoughtsTokenCount usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount + v1Resp.Response.UsageMetadata.ThoughtsTokenCount
usage.CacheReadInputTokens = cached usage.CacheReadInputTokens = cached
usage.ImageOutputTokens = v1Resp.Response.UsageMetadata.ImageOutputTokens()
} }
responseID := v1Resp.ResponseID responseID := v1Resp.ResponseID
...@@ -485,6 +511,7 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte { ...@@ -485,6 +511,7 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
InputTokens: p.inputTokens, InputTokens: p.inputTokens,
OutputTokens: p.outputTokens, OutputTokens: p.outputTokens,
CacheReadInputTokens: p.cacheReadTokens, CacheReadInputTokens: p.cacheReadTokens,
ImageOutputTokens: p.imageOutputTokens,
} }
deltaEvent := map[string]any{ deltaEvent := map[string]any{
......
...@@ -97,7 +97,7 @@ func TestUnmarshalModelMapping(t *testing.T) { ...@@ -97,7 +97,7 @@ func TestUnmarshalModelMapping(t *testing.T) {
wantNil: true, wantNil: true,
}, },
{ {
name: "valid JSON", name: "valid JSON",
input: []byte(`{"openai":{"gpt-4":"gpt-4-turbo"},"anthropic":{"old":"new"}}`), input: []byte(`{"openai":{"gpt-4":"gpt-4-turbo"},"anthropic":{"old":"new"}}`),
want: map[string]map[string]string{ want: map[string]map[string]string{
"openai": {"gpt-4": "gpt-4-turbo"}, "openai": {"gpt-4": "gpt-4-turbo"},
......
...@@ -28,7 +28,7 @@ import ( ...@@ -28,7 +28,7 @@ import (
gocache "github.com/patrickmn/go-cache" gocache "github.com/patrickmn/go-cache"
) )
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, created_at" const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, created_at"
// usageLogInsertArgTypes must stay in the same order as: // usageLogInsertArgTypes must stay in the same order as:
// 1. prepareUsageLogInsert().args // 1. prepareUsageLogInsert().args
...@@ -53,6 +53,8 @@ var usageLogInsertArgTypes = [...]string{ ...@@ -53,6 +53,8 @@ var usageLogInsertArgTypes = [...]string{
"integer", // cache_read_tokens "integer", // cache_read_tokens
"integer", // cache_creation_5m_tokens "integer", // cache_creation_5m_tokens
"integer", // cache_creation_1h_tokens "integer", // cache_creation_1h_tokens
"integer", // image_output_tokens
"numeric", // image_output_cost
"numeric", // input_cost "numeric", // input_cost
"numeric", // output_cost "numeric", // output_cost
"numeric", // cache_creation_cost "numeric", // cache_creation_cost
...@@ -330,6 +332,8 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, ...@@ -330,6 +332,8 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
cache_read_tokens, cache_read_tokens,
cache_creation_5m_tokens, cache_creation_5m_tokens,
cache_creation_1h_tokens, cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost, input_cost,
output_cost, output_cost,
cache_creation_cost, cache_creation_cost,
...@@ -363,9 +367,9 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, ...@@ -363,9 +367,9 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
$1, $2, $3, $4, $5, $6, $7, $1, $2, $3, $4, $5, $6, $7,
$8, $9, $8, $9,
$10, $11, $12, $13, $10, $11, $12, $13,
$14, $15, $14, $15, $16, $17,
$16, $17, $18, $19, $20, $21, $18, $19, $20, $21, $22, $23,
$22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44 $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
) )
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at RETURNING id, created_at
...@@ -766,6 +770,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage ...@@ -766,6 +770,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
cache_read_tokens, cache_read_tokens,
cache_creation_5m_tokens, cache_creation_5m_tokens,
cache_creation_1h_tokens, cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost, input_cost,
output_cost, output_cost,
cache_creation_cost, cache_creation_cost,
...@@ -797,7 +803,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage ...@@ -797,7 +803,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
created_at created_at
) AS (VALUES `) ) AS (VALUES `)
args := make([]any, 0, len(keys)*45) args := make([]any, 0, len(keys)*47)
argPos := 1 argPos := 1
for idx, key := range keys { for idx, key := range keys {
if idx > 0 { if idx > 0 {
...@@ -841,6 +847,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage ...@@ -841,6 +847,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
cache_read_tokens, cache_read_tokens,
cache_creation_5m_tokens, cache_creation_5m_tokens,
cache_creation_1h_tokens, cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost, input_cost,
output_cost, output_cost,
cache_creation_cost, cache_creation_cost,
...@@ -887,6 +895,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage ...@@ -887,6 +895,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
cache_read_tokens, cache_read_tokens,
cache_creation_5m_tokens, cache_creation_5m_tokens,
cache_creation_1h_tokens, cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost, input_cost,
output_cost, output_cost,
cache_creation_cost, cache_creation_cost,
...@@ -973,6 +983,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ...@@ -973,6 +983,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
cache_read_tokens, cache_read_tokens,
cache_creation_5m_tokens, cache_creation_5m_tokens,
cache_creation_1h_tokens, cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost, input_cost,
output_cost, output_cost,
cache_creation_cost, cache_creation_cost,
...@@ -1004,7 +1016,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ...@@ -1004,7 +1016,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
created_at created_at
) AS (VALUES `) ) AS (VALUES `)
args := make([]any, 0, len(preparedList)*44) args := make([]any, 0, len(preparedList)*46)
argPos := 1 argPos := 1
for idx, prepared := range preparedList { for idx, prepared := range preparedList {
if idx > 0 { if idx > 0 {
...@@ -1045,6 +1057,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ...@@ -1045,6 +1057,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
cache_read_tokens, cache_read_tokens,
cache_creation_5m_tokens, cache_creation_5m_tokens,
cache_creation_1h_tokens, cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost, input_cost,
output_cost, output_cost,
cache_creation_cost, cache_creation_cost,
...@@ -1091,6 +1105,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ...@@ -1091,6 +1105,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
cache_read_tokens, cache_read_tokens,
cache_creation_5m_tokens, cache_creation_5m_tokens,
cache_creation_1h_tokens, cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost, input_cost,
output_cost, output_cost,
cache_creation_cost, cache_creation_cost,
...@@ -1145,6 +1161,8 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared ...@@ -1145,6 +1161,8 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
cache_read_tokens, cache_read_tokens,
cache_creation_5m_tokens, cache_creation_5m_tokens,
cache_creation_1h_tokens, cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost, input_cost,
output_cost, output_cost,
cache_creation_cost, cache_creation_cost,
...@@ -1178,9 +1196,9 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared ...@@ -1178,9 +1196,9 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
$1, $2, $3, $4, $5, $6, $7, $1, $2, $3, $4, $5, $6, $7,
$8, $9, $8, $9,
$10, $11, $12, $13, $10, $11, $12, $13,
$14, $15, $14, $15, $16, $17,
$16, $17, $18, $19, $20, $21, $18, $19, $20, $21, $22, $23,
$22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44 $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
) )
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
`, prepared.args...) `, prepared.args...)
...@@ -1248,6 +1266,8 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { ...@@ -1248,6 +1266,8 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
log.CacheReadTokens, log.CacheReadTokens,
log.CacheCreation5mTokens, log.CacheCreation5mTokens,
log.CacheCreation1hTokens, log.CacheCreation1hTokens,
log.ImageOutputTokens,
log.ImageOutputCost,
log.InputCost, log.InputCost,
log.OutputCost, log.OutputCost,
log.CacheCreationCost, log.CacheCreationCost,
...@@ -4011,6 +4031,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -4011,6 +4031,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
cacheReadTokens int cacheReadTokens int
cacheCreation5m int cacheCreation5m int
cacheCreation1h int cacheCreation1h int
imageOutputTokens int
imageOutputCost float64
inputCost float64 inputCost float64
outputCost float64 outputCost float64
cacheCreationCost float64 cacheCreationCost float64
...@@ -4059,6 +4081,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -4059,6 +4081,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&cacheReadTokens, &cacheReadTokens,
&cacheCreation5m, &cacheCreation5m,
&cacheCreation1h, &cacheCreation1h,
&imageOutputTokens,
&imageOutputCost,
&inputCost, &inputCost,
&outputCost, &outputCost,
&cacheCreationCost, &cacheCreationCost,
...@@ -4105,6 +4129,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -4105,6 +4129,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
CacheReadTokens: cacheReadTokens, CacheReadTokens: cacheReadTokens,
CacheCreation5mTokens: cacheCreation5m, CacheCreation5mTokens: cacheCreation5m,
CacheCreation1hTokens: cacheCreation1h, CacheCreation1hTokens: cacheCreation1h,
ImageOutputTokens: imageOutputTokens,
ImageOutputCost: imageOutputCost,
InputCost: inputCost, InputCost: inputCost,
OutputCost: outputCost, OutputCost: outputCost,
CacheCreationCost: cacheCreationCost, CacheCreationCost: cacheCreationCost,
......
...@@ -56,6 +56,8 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) { ...@@ -56,6 +56,8 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
log.CacheReadTokens, log.CacheReadTokens,
log.CacheCreation5mTokens, log.CacheCreation5mTokens,
log.CacheCreation1hTokens, log.CacheCreation1hTokens,
log.ImageOutputTokens,
log.ImageOutputCost,
log.InputCost, log.InputCost,
log.OutputCost, log.OutputCost,
log.CacheCreationCost, log.CacheCreationCost,
...@@ -133,6 +135,8 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) { ...@@ -133,6 +135,8 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
log.CacheReadTokens, log.CacheReadTokens,
log.CacheCreation5mTokens, log.CacheCreation5mTokens,
log.CacheCreation1hTokens, log.CacheCreation1hTokens,
log.ImageOutputTokens,
log.ImageOutputCost,
log.InputCost, log.InputCost,
log.OutputCost, log.OutputCost,
log.CacheCreationCost, log.CacheCreationCost,
...@@ -447,6 +451,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { ...@@ -447,6 +451,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
4, // cache_read_tokens 4, // cache_read_tokens
5, // cache_creation_5m_tokens 5, // cache_creation_5m_tokens
6, // cache_creation_1h_tokens 6, // cache_creation_1h_tokens
0, // image_output_tokens
0.0, // image_output_cost
0.1, // input_cost 0.1, // input_cost
0.2, // output_cost 0.2, // output_cost
0.3, // cache_creation_cost 0.3, // cache_creation_cost
...@@ -499,6 +505,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { ...@@ -499,6 +505,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullInt64{}, sql.NullInt64{},
sql.NullInt64{}, sql.NullInt64{},
1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6,
0, 0.0, // image_output_tokens, image_output_cost
0.1, 0.2, 0.3, 0.4, 1.0, 0.9, 0.1, 0.2, 0.3, 0.4, 1.0, 0.9,
1.0, 1.0,
sql.NullFloat64{}, sql.NullFloat64{},
...@@ -546,6 +553,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { ...@@ -546,6 +553,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullInt64{}, sql.NullInt64{},
sql.NullInt64{}, sql.NullInt64{},
1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6,
0, 0.0, // image_output_tokens, image_output_cost
0.1, 0.2, 0.3, 0.4, 1.0, 0.9, 0.1, 0.2, 0.3, 0.4, 1.0, 0.9,
1.0, 1.0,
sql.NullFloat64{}, sql.NullFloat64{},
......
...@@ -846,6 +846,15 @@ func (s *AccountUsageService) getAntigravityUsage(ctx context.Context, account * ...@@ -846,6 +846,15 @@ func (s *AccountUsageService) getAntigravityUsage(ctx context.Context, account *
return usage, nil return usage, nil
} }
// GetAntigravityCredits 返回账号的 AI Credits 信息,复用 getAntigravityUsage 的缓存。
// 如果缓存存在且 TTL 充足则直接返回;TTL 不足时自动刷新。
func (s *AccountUsageService) GetAntigravityCredits(ctx context.Context, account *Account) (*UsageInfo, error) {
if account == nil || account.Platform != PlatformAntigravity {
return nil, nil
}
return s.getAntigravityUsage(ctx, account)
}
// recalcAntigravityRemainingSeconds 重新计算 Antigravity UsageInfo 中各窗口的 RemainingSeconds // recalcAntigravityRemainingSeconds 重新计算 Antigravity UsageInfo 中各窗口的 RemainingSeconds
// 用于从缓存取出时更新倒计时,避免返回过时的剩余秒数 // 用于从缓存取出时更新倒计时,避免返回过时的剩余秒数
func recalcAntigravityRemainingSeconds(info *UsageInfo) { func recalcAntigravityRemainingSeconds(info *UsageInfo) {
......
...@@ -19,6 +19,54 @@ const ( ...@@ -19,6 +19,54 @@ const (
creditsExhaustedDuration = 5 * time.Hour creditsExhaustedDuration = 5 * time.Hour
) )
// checkAccountCredits 通过共享的 AccountUsageService 缓存检查账号是否有足够的 AI Credits。
// 缓存 TTL 不足时会自动从 Google loadCodeAssist API 刷新。
// 返回 true 表示积分可用。
func (s *AntigravityGatewayService) checkAccountCredits(
ctx context.Context, account *Account, accessToken, proxyURL string,
) bool {
if account == nil || account.ID == 0 {
return false
}
if s.accountUsageService == nil {
return true // 无 usage service 时不阻断
}
usageInfo, err := s.accountUsageService.GetAntigravityCredits(ctx, account)
if err != nil {
logger.LegacyPrintf("service.antigravity_gateway",
"check_credits: get_credits_failed account=%d err=%v", account.ID, err)
return true // 出错时假设有积分,不阻断
}
if usageInfo == nil || len(usageInfo.AICredits) == 0 {
logger.LegacyPrintf("service.antigravity_gateway",
"check_credits: account=%d has_credits=false amount=0 (no credits field)",
account.ID)
return false
}
for _, credit := range usageInfo.AICredits {
if credit.CreditType == "GOOGLE_ONE_AI" {
minimum := credit.MinimumBalance
if minimum <= 0 {
minimum = 5
}
hasCredits := credit.Amount >= minimum
logger.LegacyPrintf("service.antigravity_gateway",
"check_credits: account=%d has_credits=%t amount=%.0f minimum=%.0f",
account.ID, hasCredits, credit.Amount, minimum)
return hasCredits
}
}
logger.LegacyPrintf("service.antigravity_gateway",
"check_credits: account=%d has_credits=false (no GOOGLE_ONE_AI credit)",
account.ID)
return false
}
type antigravity429Category string type antigravity429Category string
const ( const (
...@@ -141,6 +189,13 @@ func resolveCreditsOveragesModelKey(ctx context.Context, account *Account, upstr ...@@ -141,6 +189,13 @@ func resolveCreditsOveragesModelKey(ctx context.Context, account *Account, upstr
} }
// shouldMarkCreditsExhausted 判断一次 credits 请求失败是否应标记为 credits 耗尽。 // shouldMarkCreditsExhausted 判断一次 credits 请求失败是否应标记为 credits 耗尽。
// 此函数在积分注入后失败时调用(预检查注入 + attemptCreditsOveragesRetry 两条路径)。
// - 429 + 非单模型限流:积分注入后仍 429 → 标记耗尽。
// - 429 + 单模型限流("exhausted your capacity on this model"):该模型免费配额用完,
// 积分注入对此无效,但账号积分对其他模型可能仍可用 → 不标记积分耗尽。
// - 403 等其他 4xx:检查 body 是否包含积分不足的关键词。
//
// clearCreditsExhausted 会在后续成功时自动清除。
func shouldMarkCreditsExhausted(resp *http.Response, respBody []byte, reqErr error) bool { func shouldMarkCreditsExhausted(resp *http.Response, respBody []byte, reqErr error) bool {
if reqErr != nil || resp == nil { if reqErr != nil || resp == nil {
return false return false
...@@ -148,13 +203,16 @@ func shouldMarkCreditsExhausted(resp *http.Response, respBody []byte, reqErr err ...@@ -148,13 +203,16 @@ func shouldMarkCreditsExhausted(resp *http.Response, respBody []byte, reqErr err
if resp.StatusCode >= 500 || resp.StatusCode == http.StatusRequestTimeout { if resp.StatusCode >= 500 || resp.StatusCode == http.StatusRequestTimeout {
return false return false
} }
// 注意:不再检查 isURLLevelRateLimit。此函数仅在积分重试失败后调用,
// 如果注入 enabledCreditTypes 后仍返回 "Resource has been exhausted",
// 说明积分也已耗尽,应该标记。clearCreditsExhausted 会在后续成功时自动清除。
if info := parseAntigravitySmartRetryInfo(respBody); info != nil {
return false
}
bodyLower := strings.ToLower(string(respBody)) bodyLower := strings.ToLower(string(respBody))
// 积分注入后仍 429
if resp.StatusCode == http.StatusTooManyRequests {
// 单模型配额耗尽:积分注入对此无效,不标记整个账号积分耗尽
if strings.Contains(bodyLower, "exhausted your capacity on this model") {
return false
}
return true
}
// 其他 4xx:关键词匹配(如 403 + "Insufficient credits")
for _, keyword := range creditsExhaustedKeywords { for _, keyword := range creditsExhaustedKeywords {
if strings.Contains(bodyLower, keyword) { if strings.Contains(bodyLower, keyword) {
return true return true
...@@ -181,6 +239,16 @@ func (s *AntigravityGatewayService) attemptCreditsOveragesRetry( ...@@ -181,6 +239,16 @@ func (s *AntigravityGatewayService) attemptCreditsOveragesRetry(
if creditsBody == nil { if creditsBody == nil {
return &creditsOveragesRetryResult{handled: false} return &creditsOveragesRetryResult{handled: false}
} }
// Check actual credits balance before attempting retry
if !s.checkAccountCredits(p.ctx, p.account, p.accessToken, p.proxyURL) {
s.setCreditsExhausted(p.ctx, p.account)
modelKey := resolveCreditsOveragesModelKey(p.ctx, p.account, modelName, p.requestedModel)
logger.LegacyPrintf("service.antigravity_gateway", "%s credit_overages_no_credits model=%s account=%d (skipping credits retry)",
p.prefix, modelKey, p.account.ID)
return &creditsOveragesRetryResult{handled: true}
}
modelKey := resolveCreditsOveragesModelKey(p.ctx, p.account, modelName, p.requestedModel) modelKey := resolveCreditsOveragesModelKey(p.ctx, p.account, modelName, p.requestedModel)
logger.LegacyPrintf("service.antigravity_gateway", "%s status=429 credit_overages_retry model=%s account=%d (injecting enabledCreditTypes)", logger.LegacyPrintf("service.antigravity_gateway", "%s status=429 credit_overages_retry model=%s account=%d (injecting enabledCreditTypes)",
p.prefix, modelKey, p.account.ID) p.prefix, modelKey, p.account.ID)
......
...@@ -418,7 +418,13 @@ func TestShouldMarkCreditsExhausted(t *testing.T) { ...@@ -418,7 +418,13 @@ func TestShouldMarkCreditsExhausted(t *testing.T) {
require.True(t, shouldMarkCreditsExhausted(resp, body, nil)) require.True(t, shouldMarkCreditsExhausted(resp, body, nil))
}) })
t.Run("结构化限流不标记", func(t *testing.T) { t.Run("单模型配额耗尽不标记(积分对此无效)", func(t *testing.T) {
resp := &http.Response{StatusCode: http.StatusTooManyRequests}
body := []byte(`{"error":{"code":429,"message":"You have exhausted your capacity on this model. Your quota will reset after 146h11m17s.","status":"RESOURCE_EXHAUSTED"}}`)
require.False(t, shouldMarkCreditsExhausted(resp, body, nil))
})
t.Run("429 结构化限流也标记(积分注入后仍 429 即为耗尽)", func(t *testing.T) {
resp := &http.Response{StatusCode: http.StatusTooManyRequests} resp := &http.Response{StatusCode: http.StatusTooManyRequests}
body := []byte(`{"error":{"status":"RESOURCE_EXHAUSTED","details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","reason":"RATE_LIMIT_EXCEEDED"},{"@type":"type.googleapis.com/google.rpc.RetryInfo","retryDelay":"0.5s"}]}}`) body := []byte(`{"error":{"status":"RESOURCE_EXHAUSTED","details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","reason":"RATE_LIMIT_EXCEEDED"},{"@type":"type.googleapis.com/google.rpc.RetryInfo","retryDelay":"0.5s"}]}}`)
require.False(t, shouldMarkCreditsExhausted(resp, body, nil)) require.False(t, shouldMarkCreditsExhausted(resp, body, nil))
......
...@@ -557,7 +557,13 @@ func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopP ...@@ -557,7 +557,13 @@ func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopP
if p.requestedModel != "" && p.account.Platform == PlatformAntigravity && if p.requestedModel != "" && p.account.Platform == PlatformAntigravity &&
p.account.IsOveragesEnabled() && !p.account.isCreditsExhausted() && p.account.IsOveragesEnabled() && !p.account.isCreditsExhausted() &&
p.account.isModelRateLimitedWithContext(p.ctx, p.requestedModel) { p.account.isModelRateLimitedWithContext(p.ctx, p.requestedModel) {
if creditsBody := injectEnabledCreditTypes(p.body); creditsBody != nil { // Check actual credits balance before injection
if !s.checkAccountCredits(p.ctx, p.account, p.accessToken, p.proxyURL) {
// No credits available - mark as exhausted and skip injection
s.setCreditsExhausted(p.ctx, p.account)
logger.LegacyPrintf("service.antigravity_gateway", "%s pre_check: no_credits_available account=%d (skipping credits injection)",
p.prefix, p.account.ID)
} else if creditsBody := injectEnabledCreditTypes(p.body); creditsBody != nil {
p.body = creditsBody p.body = creditsBody
overagesInjected = true overagesInjected = true
logger.LegacyPrintf("service.antigravity_gateway", "%s pre_check: model_rate_limited_credits_inject model=%s account=%d (injecting enabledCreditTypes)", logger.LegacyPrintf("service.antigravity_gateway", "%s pre_check: model_rate_limited_credits_inject model=%s account=%d (injecting enabledCreditTypes)",
...@@ -870,14 +876,15 @@ func logPrefix(sessionID, accountName string) string { ...@@ -870,14 +876,15 @@ func logPrefix(sessionID, accountName string) string {
// AntigravityGatewayService 处理 Antigravity 平台的 API 转发 // AntigravityGatewayService 处理 Antigravity 平台的 API 转发
type AntigravityGatewayService struct { type AntigravityGatewayService struct {
accountRepo AccountRepository accountRepo AccountRepository
tokenProvider *AntigravityTokenProvider tokenProvider *AntigravityTokenProvider
rateLimitService *RateLimitService rateLimitService *RateLimitService
httpUpstream HTTPUpstream httpUpstream HTTPUpstream
settingService *SettingService settingService *SettingService
cache GatewayCache // 用于模型级限流时清除粘性会话绑定 cache GatewayCache // 用于模型级限流时清除粘性会话绑定
schedulerSnapshot *SchedulerSnapshotService schedulerSnapshot *SchedulerSnapshotService
internal500Cache Internal500CounterCache // INTERNAL 500 渐进惩罚计数器 internal500Cache Internal500CounterCache // INTERNAL 500 渐进惩罚计数器
accountUsageService *AccountUsageService // 共享 usage 缓存,用于积分余额检查
} }
func NewAntigravityGatewayService( func NewAntigravityGatewayService(
...@@ -889,16 +896,18 @@ func NewAntigravityGatewayService( ...@@ -889,16 +896,18 @@ func NewAntigravityGatewayService(
httpUpstream HTTPUpstream, httpUpstream HTTPUpstream,
settingService *SettingService, settingService *SettingService,
internal500Cache Internal500CounterCache, internal500Cache Internal500CounterCache,
accountUsageService *AccountUsageService,
) *AntigravityGatewayService { ) *AntigravityGatewayService {
return &AntigravityGatewayService{ return &AntigravityGatewayService{
accountRepo: accountRepo, accountRepo: accountRepo,
tokenProvider: tokenProvider, tokenProvider: tokenProvider,
rateLimitService: rateLimitService, rateLimitService: rateLimitService,
httpUpstream: httpUpstream, httpUpstream: httpUpstream,
settingService: settingService, settingService: settingService,
cache: cache, cache: cache,
schedulerSnapshot: schedulerSnapshot, schedulerSnapshot: schedulerSnapshot,
internal500Cache: internal500Cache, internal500Cache: internal500Cache,
accountUsageService: accountUsageService,
} }
} }
......
...@@ -56,6 +56,7 @@ type ModelPricing struct { ...@@ -56,6 +56,7 @@ type ModelPricing struct {
LongContextInputThreshold int // 超过阈值后按整次会话提升输入价格 LongContextInputThreshold int // 超过阈值后按整次会话提升输入价格
LongContextInputMultiplier float64 // 长上下文整次会话输入倍率 LongContextInputMultiplier float64 // 长上下文整次会话输入倍率
LongContextOutputMultiplier float64 // 长上下文整次会话输出倍率 LongContextOutputMultiplier float64 // 长上下文整次会话输出倍率
ImageOutputPricePerToken float64 // 图片输出 token 价格 (USD)
} }
const ( const (
...@@ -94,12 +95,14 @@ type UsageTokens struct { ...@@ -94,12 +95,14 @@ type UsageTokens struct {
CacheReadTokens int CacheReadTokens int
CacheCreation5mTokens int CacheCreation5mTokens int
CacheCreation1hTokens int CacheCreation1hTokens int
ImageOutputTokens int
} }
// CostBreakdown 费用明细 // CostBreakdown 费用明细
type CostBreakdown struct { type CostBreakdown struct {
InputCost float64 InputCost float64
OutputCost float64 OutputCost float64
ImageOutputCost float64
CacheCreationCost float64 CacheCreationCost float64
CacheReadCost float64 CacheReadCost float64
TotalCost float64 TotalCost float64
...@@ -358,6 +361,7 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) { ...@@ -358,6 +361,7 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
LongContextInputThreshold: litellmPricing.LongContextInputTokenThreshold, LongContextInputThreshold: litellmPricing.LongContextInputTokenThreshold,
LongContextInputMultiplier: litellmPricing.LongContextInputCostMultiplier, LongContextInputMultiplier: litellmPricing.LongContextInputCostMultiplier,
LongContextOutputMultiplier: litellmPricing.LongContextOutputCostMultiplier, LongContextOutputMultiplier: litellmPricing.LongContextOutputCostMultiplier,
ImageOutputPricePerToken: litellmPricing.OutputCostPerImageToken,
}), nil }), nil
} }
} }
...@@ -399,6 +403,9 @@ func (s *BillingService) GetModelPricingWithChannel(model string, channelPricing ...@@ -399,6 +403,9 @@ func (s *BillingService) GetModelPricingWithChannel(model string, channelPricing
pricing.CacheReadPricePerToken = *channelPricing.CacheReadPrice pricing.CacheReadPricePerToken = *channelPricing.CacheReadPrice
pricing.CacheReadPricePerTokenPriority = *channelPricing.CacheReadPrice pricing.CacheReadPricePerTokenPriority = *channelPricing.CacheReadPrice
} }
if channelPricing.ImageOutputPrice != nil {
pricing.ImageOutputPricePerToken = *channelPricing.ImageOutputPrice
}
return pricing, nil return pricing, nil
} }
...@@ -489,7 +496,22 @@ func (s *BillingService) calculateTokenCost(resolved *ResolvedPricing, input Cos ...@@ -489,7 +496,22 @@ func (s *BillingService) calculateTokenCost(resolved *ResolvedPricing, input Cos
} }
breakdown.InputCost = float64(input.Tokens.InputTokens) * inputPricePerToken breakdown.InputCost = float64(input.Tokens.InputTokens) * inputPricePerToken
breakdown.OutputCost = float64(input.Tokens.OutputTokens) * outputPricePerToken
// Separate image output tokens from text output tokens
textOutputTokens := input.Tokens.OutputTokens - input.Tokens.ImageOutputTokens
if textOutputTokens < 0 {
textOutputTokens = 0
}
breakdown.OutputCost = float64(textOutputTokens) * outputPricePerToken
// Image output tokens cost (separate rate from text output)
if input.Tokens.ImageOutputTokens > 0 {
imageOutputPrice := pricing.ImageOutputPricePerToken
if imageOutputPrice == 0 {
imageOutputPrice = outputPricePerToken // fallback to regular output price
}
breakdown.ImageOutputCost = float64(input.Tokens.ImageOutputTokens) * imageOutputPrice
}
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) { if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) {
if input.Tokens.CacheCreation5mTokens == 0 && input.Tokens.CacheCreation1hTokens == 0 && input.Tokens.CacheCreationTokens > 0 { if input.Tokens.CacheCreation5mTokens == 0 && input.Tokens.CacheCreation1hTokens == 0 && input.Tokens.CacheCreationTokens > 0 {
...@@ -507,11 +529,12 @@ func (s *BillingService) calculateTokenCost(resolved *ResolvedPricing, input Cos ...@@ -507,11 +529,12 @@ func (s *BillingService) calculateTokenCost(resolved *ResolvedPricing, input Cos
if tierMultiplier != 1.0 { if tierMultiplier != 1.0 {
breakdown.InputCost *= tierMultiplier breakdown.InputCost *= tierMultiplier
breakdown.OutputCost *= tierMultiplier breakdown.OutputCost *= tierMultiplier
breakdown.ImageOutputCost *= tierMultiplier
breakdown.CacheCreationCost *= tierMultiplier breakdown.CacheCreationCost *= tierMultiplier
breakdown.CacheReadCost *= tierMultiplier breakdown.CacheReadCost *= tierMultiplier
} }
breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost + breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost + breakdown.ImageOutputCost +
breakdown.CacheCreationCost + breakdown.CacheReadCost breakdown.CacheCreationCost + breakdown.CacheReadCost
breakdown.ActualCost = breakdown.TotalCost * input.RateMultiplier breakdown.ActualCost = breakdown.TotalCost * input.RateMultiplier
...@@ -597,8 +620,21 @@ func (s *BillingService) calculateCostInternal(model string, tokens UsageTokens, ...@@ -597,8 +620,21 @@ func (s *BillingService) calculateCostInternal(model string, tokens UsageTokens,
// 计算输入token费用(使用per-token价格) // 计算输入token费用(使用per-token价格)
breakdown.InputCost = float64(tokens.InputTokens) * inputPricePerToken breakdown.InputCost = float64(tokens.InputTokens) * inputPricePerToken
// 计算输出token费用 // 计算输出token费用(分离图片输出token)
breakdown.OutputCost = float64(tokens.OutputTokens) * outputPricePerToken textOutputTokens := tokens.OutputTokens - tokens.ImageOutputTokens
if textOutputTokens < 0 {
textOutputTokens = 0
}
breakdown.OutputCost = float64(textOutputTokens) * outputPricePerToken
// 图片输出 token 费用
if tokens.ImageOutputTokens > 0 {
imageOutputPrice := pricing.ImageOutputPricePerToken
if imageOutputPrice == 0 {
imageOutputPrice = outputPricePerToken
}
breakdown.ImageOutputCost = float64(tokens.ImageOutputTokens) * imageOutputPrice
}
// 计算缓存费用 // 计算缓存费用
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) { if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) {
...@@ -620,12 +656,13 @@ func (s *BillingService) calculateCostInternal(model string, tokens UsageTokens, ...@@ -620,12 +656,13 @@ func (s *BillingService) calculateCostInternal(model string, tokens UsageTokens,
if tierMultiplier != 1.0 { if tierMultiplier != 1.0 {
breakdown.InputCost *= tierMultiplier breakdown.InputCost *= tierMultiplier
breakdown.OutputCost *= tierMultiplier breakdown.OutputCost *= tierMultiplier
breakdown.ImageOutputCost *= tierMultiplier
breakdown.CacheCreationCost *= tierMultiplier breakdown.CacheCreationCost *= tierMultiplier
breakdown.CacheReadCost *= tierMultiplier breakdown.CacheReadCost *= tierMultiplier
} }
// 计算总费用 // 计算总费用
breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost + breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost + breakdown.ImageOutputCost +
breakdown.CacheCreationCost + breakdown.CacheReadCost breakdown.CacheCreationCost + breakdown.CacheReadCost
// 应用倍率计算实际费用 // 应用倍率计算实际费用
...@@ -730,6 +767,7 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage ...@@ -730,6 +767,7 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage
CacheReadTokens: inRangeCacheTokens, CacheReadTokens: inRangeCacheTokens,
CacheCreation5mTokens: tokens.CacheCreation5mTokens, CacheCreation5mTokens: tokens.CacheCreation5mTokens,
CacheCreation1hTokens: tokens.CacheCreation1hTokens, CacheCreation1hTokens: tokens.CacheCreation1hTokens,
ImageOutputTokens: tokens.ImageOutputTokens,
} }
inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier) inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier)
if err != nil { if err != nil {
...@@ -750,6 +788,7 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage ...@@ -750,6 +788,7 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage
return &CostBreakdown{ return &CostBreakdown{
InputCost: inRangeCost.InputCost + outRangeCost.InputCost, InputCost: inRangeCost.InputCost + outRangeCost.InputCost,
OutputCost: inRangeCost.OutputCost, OutputCost: inRangeCost.OutputCost,
ImageOutputCost: inRangeCost.ImageOutputCost,
CacheCreationCost: inRangeCost.CacheCreationCost, CacheCreationCost: inRangeCost.CacheCreationCost,
CacheReadCost: inRangeCost.CacheReadCost + outRangeCost.CacheReadCost, CacheReadCost: inRangeCost.CacheReadCost + outRangeCost.CacheReadCost,
TotalCost: inRangeCost.TotalCost + outRangeCost.TotalCost, TotalCost: inRangeCost.TotalCost + outRangeCost.TotalCost,
......
...@@ -24,8 +24,9 @@ func (m BillingMode) IsValid() bool { ...@@ -24,8 +24,9 @@ func (m BillingMode) IsValid() bool {
} }
const ( const (
BillingModelSourceRequested = "requested" BillingModelSourceRequested = "requested"
BillingModelSourceUpstream = "upstream" BillingModelSourceUpstream = "upstream"
BillingModelSourceChannelMapped = "channel_mapped"
) )
// Channel 渠道实体 // Channel 渠道实体
...@@ -34,7 +35,7 @@ type Channel struct { ...@@ -34,7 +35,7 @@ type Channel struct {
Name string Name string
Description string Description string
Status string Status string
BillingModelSource string // "requested" or "upstream" BillingModelSource string // "requested", "upstream", or "channel_mapped"
RestrictModels bool // 是否限制模型(仅允许定价列表中的模型) RestrictModels bool // 是否限制模型(仅允许定价列表中的模型)
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
...@@ -180,6 +181,7 @@ func (c *Channel) Clone() *Channel { ...@@ -180,6 +181,7 @@ func (c *Channel) Clone() *Channel {
type ChannelUsageFields struct { type ChannelUsageFields struct {
ChannelID int64 // 渠道 ID(0 = 无渠道) ChannelID int64 // 渠道 ID(0 = 无渠道)
OriginalModel string // 用户原始请求模型(渠道映射前) OriginalModel string // 用户原始请求模型(渠道映射前)
BillingModelSource string // 计费模型来源:"requested" / "upstream" ChannelMappedModel string // 渠道映射后的模型名(无映射时等于 OriginalModel)
BillingModelSource string // 计费模型来源:"requested" / "upstream" / "channel_mapped"
ModelMappingChain string // 映射链描述,如 "a→b→c" ModelMappingChain string // 映射链描述,如 "a→b→c"
} }
...@@ -97,7 +97,7 @@ type ChannelMappingResult struct { ...@@ -97,7 +97,7 @@ type ChannelMappingResult struct {
MappedModel string // 映射后的模型名(无映射时等于原始模型名) MappedModel string // 映射后的模型名(无映射时等于原始模型名)
ChannelID int64 // 渠道 ID(0 = 无渠道关联) ChannelID int64 // 渠道 ID(0 = 无渠道关联)
Mapped bool // 是否发生了映射 Mapped bool // 是否发生了映射
BillingModelSource string // 计费模型来源("requested" / "upstream") BillingModelSource string // 计费模型来源("requested" / "upstream" / "channel_mapped"
} }
// BuildModelMappingChain 根据映射结果和上游实际模型构建映射链描述。 // BuildModelMappingChain 根据映射结果和上游实际模型构建映射链描述。
...@@ -119,9 +119,14 @@ func (r ChannelMappingResult) BuildModelMappingChain(reqModel, upstreamModel str ...@@ -119,9 +119,14 @@ func (r ChannelMappingResult) BuildModelMappingChain(reqModel, upstreamModel str
// ToUsageFields 将渠道映射结果转为使用记录字段 // ToUsageFields 将渠道映射结果转为使用记录字段
func (r ChannelMappingResult) ToUsageFields(reqModel, upstreamModel string) ChannelUsageFields { func (r ChannelMappingResult) ToUsageFields(reqModel, upstreamModel string) ChannelUsageFields {
channelMappedModel := reqModel
if r.Mapped {
channelMappedModel = r.MappedModel
}
return ChannelUsageFields{ return ChannelUsageFields{
ChannelID: r.ChannelID, ChannelID: r.ChannelID,
OriginalModel: reqModel, OriginalModel: reqModel,
ChannelMappedModel: channelMappedModel,
BillingModelSource: r.BillingModelSource, BillingModelSource: r.BillingModelSource,
ModelMappingChain: r.BuildModelMappingChain(reqModel, upstreamModel), ModelMappingChain: r.BuildModelMappingChain(reqModel, upstreamModel),
} }
...@@ -193,7 +198,7 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) ...@@ -193,7 +198,7 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
channelByGroupID: make(map[int64]*Channel), channelByGroupID: make(map[int64]*Channel),
groupPlatform: make(map[int64]string), groupPlatform: make(map[int64]string),
byID: make(map[int64]*Channel), byID: make(map[int64]*Channel),
loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL loadedAt: time.Now().Add(-(channelCacheTTL - channelErrorTTL)), // 使剩余 TTL = errorTTL
} }
s.cache.Store(errorCache) s.cache.Store(errorCache)
return nil, fmt.Errorf("list all channels: %w", err) return nil, fmt.Errorf("list all channels: %w", err)
...@@ -374,7 +379,7 @@ func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int6 ...@@ -374,7 +379,7 @@ func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int6
BillingModelSource: ch.BillingModelSource, BillingModelSource: ch.BillingModelSource,
} }
if result.BillingModelSource == "" { if result.BillingModelSource == "" {
result.BillingModelSource = BillingModelSourceRequested result.BillingModelSource = BillingModelSourceChannelMapped
} }
platform := cache.groupPlatform[groupID] platform := cache.groupPlatform[groupID]
...@@ -481,7 +486,7 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput) ...@@ -481,7 +486,7 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
ModelMapping: input.ModelMapping, ModelMapping: input.ModelMapping,
} }
if channel.BillingModelSource == "" { if channel.BillingModelSource == "" {
channel.BillingModelSource = BillingModelSourceRequested channel.BillingModelSource = BillingModelSourceChannelMapped
} }
if err := validateNoConflictingModels(channel.ModelPricing); err != nil { if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
...@@ -565,20 +570,36 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan ...@@ -565,20 +570,36 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
return nil, err return nil, err
} }
// 先获取旧分组,Update 后旧分组关联已删除,无法再查到
var oldGroupIDs []int64
if s.authCacheInvalidator != nil {
var err2 error
oldGroupIDs, err2 = s.repo.GetGroupIDs(ctx, id)
if err2 != nil {
slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", id, "error", err2)
}
}
if err := s.repo.Update(ctx, channel); err != nil { if err := s.repo.Update(ctx, channel); err != nil {
return nil, fmt.Errorf("update channel: %w", err) return nil, fmt.Errorf("update channel: %w", err)
} }
s.invalidateCache() s.invalidateCache()
// 失效关联分组的 auth 缓存 // 失效新旧分组的 auth 缓存
if s.authCacheInvalidator != nil { if s.authCacheInvalidator != nil {
groupIDs, err := s.repo.GetGroupIDs(ctx, id) seen := make(map[int64]struct{}, len(oldGroupIDs)+len(channel.GroupIDs))
if err != nil { for _, gid := range oldGroupIDs {
slog.Warn("failed to get group IDs for cache invalidation", "channel_id", id, "error", err) if _, ok := seen[gid]; !ok {
seen[gid] = struct{}{}
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
}
} }
for _, gid := range groupIDs { for _, gid := range channel.GroupIDs {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid) if _, ok := seen[gid]; !ok {
seen[gid] = struct{}{}
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
}
} }
} }
......
...@@ -16,24 +16,24 @@ import ( ...@@ -16,24 +16,24 @@ import (
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
type mockChannelRepository struct { type mockChannelRepository struct {
listAllFn func(ctx context.Context) ([]Channel, error) listAllFn func(ctx context.Context) ([]Channel, error)
getGroupPlatformsFn func(ctx context.Context, groupIDs []int64) (map[int64]string, error) getGroupPlatformsFn func(ctx context.Context, groupIDs []int64) (map[int64]string, error)
createFn func(ctx context.Context, channel *Channel) error createFn func(ctx context.Context, channel *Channel) error
getByIDFn func(ctx context.Context, id int64) (*Channel, error) getByIDFn func(ctx context.Context, id int64) (*Channel, error)
updateFn func(ctx context.Context, channel *Channel) error updateFn func(ctx context.Context, channel *Channel) error
deleteFn func(ctx context.Context, id int64) error deleteFn func(ctx context.Context, id int64) error
listFn func(ctx context.Context, params pagination.PaginationParams, status, search string) ([]Channel, *pagination.PaginationResult, error) listFn func(ctx context.Context, params pagination.PaginationParams, status, search string) ([]Channel, *pagination.PaginationResult, error)
existsByNameFn func(ctx context.Context, name string) (bool, error) existsByNameFn func(ctx context.Context, name string) (bool, error)
existsByNameExcludingFn func(ctx context.Context, name string, excludeID int64) (bool, error) existsByNameExcludingFn func(ctx context.Context, name string, excludeID int64) (bool, error)
getGroupIDsFn func(ctx context.Context, channelID int64) ([]int64, error) getGroupIDsFn func(ctx context.Context, channelID int64) ([]int64, error)
setGroupIDsFn func(ctx context.Context, channelID int64, groupIDs []int64) error setGroupIDsFn func(ctx context.Context, channelID int64, groupIDs []int64) error
getChannelIDByGroupIDFn func(ctx context.Context, groupID int64) (int64, error) getChannelIDByGroupIDFn func(ctx context.Context, groupID int64) (int64, error)
getGroupsInOtherChannelsFn func(ctx context.Context, channelID int64, groupIDs []int64) ([]int64, error) getGroupsInOtherChannelsFn func(ctx context.Context, channelID int64, groupIDs []int64) ([]int64, error)
listModelPricingFn func(ctx context.Context, channelID int64) ([]ChannelModelPricing, error) listModelPricingFn func(ctx context.Context, channelID int64) ([]ChannelModelPricing, error)
createModelPricingFn func(ctx context.Context, pricing *ChannelModelPricing) error createModelPricingFn func(ctx context.Context, pricing *ChannelModelPricing) error
updateModelPricingFn func(ctx context.Context, pricing *ChannelModelPricing) error updateModelPricingFn func(ctx context.Context, pricing *ChannelModelPricing) error
deleteModelPricingFn func(ctx context.Context, id int64) error deleteModelPricingFn func(ctx context.Context, id int64) error
replaceModelPricingFn func(ctx context.Context, channelID int64, pricingList []ChannelModelPricing) error replaceModelPricingFn func(ctx context.Context, channelID int64, pricingList []ChannelModelPricing) error
} }
func (m *mockChannelRepository) Create(ctx context.Context, channel *Channel) error { func (m *mockChannelRepository) Create(ctx context.Context, channel *Channel) error {
...@@ -196,7 +196,6 @@ func newTestChannelServiceWithAuth(repo *mockChannelRepository, auth *mockChanne ...@@ -196,7 +196,6 @@ func newTestChannelServiceWithAuth(repo *mockChannelRepository, auth *mockChanne
return NewChannelService(repo, auth) return NewChannelService(repo, auth)
} }
// makeStandardRepo returns a repo that serves one active channel with anthropic pricing // makeStandardRepo returns a repo that serves one active channel with anthropic pricing
// for group 1, with the given model pricing and model mapping. // for group 1, with the given model pricing and model mapping.
func makeStandardRepo(ch Channel, groupPlatforms map[int64]string) *mockChannelRepository { func makeStandardRepo(ch Channel, groupPlatforms map[int64]string) *mockChannelRepository {
...@@ -907,21 +906,21 @@ func TestResolveChannelMapping_DefaultBillingModelSource(t *testing.T) { ...@@ -907,21 +906,21 @@ func TestResolveChannelMapping_DefaultBillingModelSource(t *testing.T) {
ch := Channel{ ch := Channel{
ID: 1, ID: 1,
Status: StatusActive, Status: StatusActive,
GroupIDs: []int64{10}, GroupIDs: []int64{10},
BillingModelSource: "", // empty BillingModelSource: "", // empty
} }
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
svc := newTestChannelService(repo) svc := newTestChannelService(repo)
result := svc.ResolveChannelMapping(context.Background(), 10, "claude-opus-4") result := svc.ResolveChannelMapping(context.Background(), 10, "claude-opus-4")
require.Equal(t, BillingModelSourceRequested, result.BillingModelSource) require.Equal(t, BillingModelSourceChannelMapped, result.BillingModelSource)
} }
func TestResolveChannelMapping_UpstreamBillingModelSource(t *testing.T) { func TestResolveChannelMapping_UpstreamBillingModelSource(t *testing.T) {
ch := Channel{ ch := Channel{
ID: 1, ID: 1,
Status: StatusActive, Status: StatusActive,
GroupIDs: []int64{10}, GroupIDs: []int64{10},
BillingModelSource: BillingModelSourceUpstream, BillingModelSource: BillingModelSourceUpstream,
} }
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
...@@ -957,7 +956,7 @@ func TestIsModelRestricted_NoChannel(t *testing.T) { ...@@ -957,7 +956,7 @@ func TestIsModelRestricted_NoChannel(t *testing.T) {
ch := Channel{ ch := Channel{
ID: 1, ID: 1,
Status: StatusActive, Status: StatusActive,
GroupIDs: []int64{10}, GroupIDs: []int64{10},
RestrictModels: true, RestrictModels: true,
} }
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
...@@ -972,7 +971,7 @@ func TestIsModelRestricted_RestrictDisabled(t *testing.T) { ...@@ -972,7 +971,7 @@ func TestIsModelRestricted_RestrictDisabled(t *testing.T) {
ch := Channel{ ch := Channel{
ID: 1, ID: 1,
Status: StatusActive, Status: StatusActive,
GroupIDs: []int64{10}, GroupIDs: []int64{10},
RestrictModels: false, RestrictModels: false,
ModelPricing: []ChannelModelPricing{ ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-opus-4"}}, {Platform: "anthropic", Models: []string{"claude-opus-4"}},
...@@ -990,7 +989,7 @@ func TestIsModelRestricted_InactiveChannel(t *testing.T) { ...@@ -990,7 +989,7 @@ func TestIsModelRestricted_InactiveChannel(t *testing.T) {
ch := Channel{ ch := Channel{
ID: 1, ID: 1,
Status: StatusDisabled, Status: StatusDisabled,
GroupIDs: []int64{10}, GroupIDs: []int64{10},
RestrictModels: true, RestrictModels: true,
} }
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
...@@ -1004,7 +1003,7 @@ func TestIsModelRestricted_ModelInPricing(t *testing.T) { ...@@ -1004,7 +1003,7 @@ func TestIsModelRestricted_ModelInPricing(t *testing.T) {
ch := Channel{ ch := Channel{
ID: 1, ID: 1,
Status: StatusActive, Status: StatusActive,
GroupIDs: []int64{10}, GroupIDs: []int64{10},
RestrictModels: true, RestrictModels: true,
ModelPricing: []ChannelModelPricing{ ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-opus-4", "claude-sonnet-4"}}, {Platform: "anthropic", Models: []string{"claude-opus-4", "claude-sonnet-4"}},
...@@ -1021,7 +1020,7 @@ func TestIsModelRestricted_ModelInWildcard(t *testing.T) { ...@@ -1021,7 +1020,7 @@ func TestIsModelRestricted_ModelInWildcard(t *testing.T) {
ch := Channel{ ch := Channel{
ID: 1, ID: 1,
Status: StatusActive, Status: StatusActive,
GroupIDs: []int64{10}, GroupIDs: []int64{10},
RestrictModels: true, RestrictModels: true,
ModelPricing: []ChannelModelPricing{ ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-*"}}, {Platform: "anthropic", Models: []string{"claude-*"}},
...@@ -1038,7 +1037,7 @@ func TestIsModelRestricted_ModelNotFound(t *testing.T) { ...@@ -1038,7 +1037,7 @@ func TestIsModelRestricted_ModelNotFound(t *testing.T) {
ch := Channel{ ch := Channel{
ID: 1, ID: 1,
Status: StatusActive, Status: StatusActive,
GroupIDs: []int64{10}, GroupIDs: []int64{10},
RestrictModels: true, RestrictModels: true,
ModelPricing: []ChannelModelPricing{ ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-opus-4"}}, {Platform: "anthropic", Models: []string{"claude-opus-4"}},
...@@ -1055,7 +1054,7 @@ func TestIsModelRestricted_CaseInsensitive(t *testing.T) { ...@@ -1055,7 +1054,7 @@ func TestIsModelRestricted_CaseInsensitive(t *testing.T) {
ch := Channel{ ch := Channel{
ID: 1, ID: 1,
Status: StatusActive, Status: StatusActive,
GroupIDs: []int64{10}, GroupIDs: []int64{10},
RestrictModels: true, RestrictModels: true,
ModelPricing: []ChannelModelPricing{ ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-opus-4"}}, {Platform: "anthropic", Models: []string{"claude-opus-4"}},
...@@ -1088,7 +1087,7 @@ func TestResolveChannelMappingAndRestrict_ModelInPricing_WithMapping(t *testing. ...@@ -1088,7 +1087,7 @@ func TestResolveChannelMappingAndRestrict_ModelInPricing_WithMapping(t *testing.
ch := Channel{ ch := Channel{
ID: 1, ID: 1,
Status: StatusActive, Status: StatusActive,
GroupIDs: []int64{10}, GroupIDs: []int64{10},
RestrictModels: true, RestrictModels: true,
ModelPricing: []ChannelModelPricing{ ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-sonnet-4"}}, {Platform: "anthropic", Models: []string{"claude-sonnet-4"}},
...@@ -1117,7 +1116,7 @@ func TestResolveChannelMappingAndRestrict_ModelNotInPricing_WithMapping(t *testi ...@@ -1117,7 +1116,7 @@ func TestResolveChannelMappingAndRestrict_ModelNotInPricing_WithMapping(t *testi
ch := Channel{ ch := Channel{
ID: 1, ID: 1,
Status: StatusActive, Status: StatusActive,
GroupIDs: []int64{10}, GroupIDs: []int64{10},
RestrictModels: true, RestrictModels: true,
ModelPricing: []ChannelModelPricing{ ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-sonnet-4"}}, {Platform: "anthropic", Models: []string{"claude-sonnet-4"}},
...@@ -1142,7 +1141,7 @@ func TestResolveChannelMappingAndRestrict_ModelNotInPricing_NoMapping(t *testing ...@@ -1142,7 +1141,7 @@ func TestResolveChannelMappingAndRestrict_ModelNotInPricing_NoMapping(t *testing
ch := Channel{ ch := Channel{
ID: 1, ID: 1,
Status: StatusActive, Status: StatusActive,
GroupIDs: []int64{10}, GroupIDs: []int64{10},
RestrictModels: true, RestrictModels: true,
ModelPricing: []ChannelModelPricing{ ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-sonnet-4"}}, {Platform: "anthropic", Models: []string{"claude-sonnet-4"}},
...@@ -1451,11 +1450,11 @@ func TestCreate_DefaultBillingModelSource(t *testing.T) { ...@@ -1451,11 +1450,11 @@ func TestCreate_DefaultBillingModelSource(t *testing.T) {
result, err := svc.Create(context.Background(), &CreateChannelInput{ result, err := svc.Create(context.Background(), &CreateChannelInput{
Name: "new-channel", Name: "new-channel",
BillingModelSource: "", // empty, should default to "requested" BillingModelSource: "", // empty, should default to "channel_mapped"
}) })
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.Equal(t, BillingModelSourceRequested, result.BillingModelSource) require.Equal(t, BillingModelSourceChannelMapped, result.BillingModelSource)
} }
func TestCreate_InvalidatesCache(t *testing.T) { func TestCreate_InvalidatesCache(t *testing.T) {
......
...@@ -483,6 +483,7 @@ type ClaudeUsage struct { ...@@ -483,6 +483,7 @@ type ClaudeUsage struct {
CacheReadInputTokens int `json:"cache_read_input_tokens"` CacheReadInputTokens int `json:"cache_read_input_tokens"`
CacheCreation5mTokens int // 5分钟缓存创建token(来自嵌套 cache_creation 对象) CacheCreation5mTokens int // 5分钟缓存创建token(来自嵌套 cache_creation 对象)
CacheCreation1hTokens int // 1小时缓存创建token(来自嵌套 cache_creation 对象) CacheCreation1hTokens int // 1小时缓存创建token(来自嵌套 cache_creation 对象)
ImageOutputTokens int `json:"image_output_tokens,omitempty"`
} }
// ForwardResult 转发结果 // ForwardResult 转发结果
...@@ -7729,6 +7730,9 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -7729,6 +7730,9 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
var cost *CostBreakdown var cost *CostBreakdown
// 确定计费模型 // 确定计费模型
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
if input.BillingModelSource == BillingModelSourceChannelMapped && input.ChannelMappedModel != "" {
billingModel = input.ChannelMappedModel
}
if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" { if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" {
billingModel = input.OriginalModel billingModel = input.OriginalModel
} }
...@@ -7777,6 +7781,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -7777,6 +7781,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
CacheReadTokens: result.Usage.CacheReadInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens,
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
ImageOutputTokens: result.Usage.ImageOutputTokens,
} }
var err error var err error
if s.resolver != nil && apiKey.Group != nil { if s.resolver != nil && apiKey.Group != nil {
...@@ -7836,8 +7841,10 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -7836,8 +7841,10 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
CacheReadTokens: result.Usage.CacheReadInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens,
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
ImageOutputTokens: result.Usage.ImageOutputTokens,
InputCost: cost.InputCost, InputCost: cost.InputCost,
OutputCost: cost.OutputCost, OutputCost: cost.OutputCost,
ImageOutputCost: cost.ImageOutputCost,
CacheCreationCost: cost.CacheCreationCost, CacheCreationCost: cost.CacheCreationCost,
CacheReadCost: cost.CacheReadCost, CacheReadCost: cost.CacheReadCost,
TotalCost: cost.TotalCost, TotalCost: cost.TotalCost,
...@@ -7976,6 +7983,9 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * ...@@ -7976,6 +7983,9 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
var cost *CostBreakdown var cost *CostBreakdown
// 确定计费模型 // 确定计费模型
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
if input.BillingModelSource == BillingModelSourceChannelMapped && input.ChannelMappedModel != "" {
billingModel = input.ChannelMappedModel
}
if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" { if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" {
billingModel = input.OriginalModel billingModel = input.OriginalModel
} }
...@@ -8007,6 +8017,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * ...@@ -8007,6 +8017,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
CacheReadTokens: result.Usage.CacheReadInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens,
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
ImageOutputTokens: result.Usage.ImageOutputTokens,
} }
var err error var err error
// 优先尝试 Resolver + CalculateCostUnified(仅在有渠道定价时使用) // 优先尝试 Resolver + CalculateCostUnified(仅在有渠道定价时使用)
...@@ -8073,8 +8084,10 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * ...@@ -8073,8 +8084,10 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
CacheReadTokens: result.Usage.CacheReadInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens,
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
ImageOutputTokens: result.Usage.ImageOutputTokens,
InputCost: cost.InputCost, InputCost: cost.InputCost,
OutputCost: cost.OutputCost, OutputCost: cost.OutputCost,
ImageOutputCost: cost.ImageOutputCost,
CacheCreationCost: cost.CacheCreationCost, CacheCreationCost: cost.CacheCreationCost,
CacheReadCost: cost.CacheReadCost, CacheReadCost: cost.CacheReadCost,
TotalCost: cost.TotalCost, TotalCost: cost.TotalCost,
......
...@@ -2692,12 +2692,27 @@ func extractGeminiUsage(data []byte) *ClaudeUsage { ...@@ -2692,12 +2692,27 @@ func extractGeminiUsage(data []byte) *ClaudeUsage {
cand := int(usage.Get("candidatesTokenCount").Int()) cand := int(usage.Get("candidatesTokenCount").Int())
cached := int(usage.Get("cachedContentTokenCount").Int()) cached := int(usage.Get("cachedContentTokenCount").Int())
thoughts := int(usage.Get("thoughtsTokenCount").Int()) thoughts := int(usage.Get("thoughtsTokenCount").Int())
// 从 candidatesTokensDetails 提取 IMAGE 模态 token 数
imageTokens := 0
candidateDetails := usage.Get("candidatesTokensDetails")
if candidateDetails.Exists() {
candidateDetails.ForEach(func(_, detail gjson.Result) bool {
if detail.Get("modality").String() == "IMAGE" {
imageTokens = int(detail.Get("tokenCount").Int())
return false
}
return true
})
}
// 注意:Gemini 的 promptTokenCount 包含 cachedContentTokenCount, // 注意:Gemini 的 promptTokenCount 包含 cachedContentTokenCount,
// 但 Claude 的 input_tokens 不包含 cache_read_input_tokens,需要减去 // 但 Claude 的 input_tokens 不包含 cache_read_input_tokens,需要减去
return &ClaudeUsage{ return &ClaudeUsage{
InputTokens: prompt - cached, InputTokens: prompt - cached,
OutputTokens: cand + thoughts, OutputTokens: cand + thoughts,
CacheReadInputTokens: cached, CacheReadInputTokens: cached,
ImageOutputTokens: imageTokens,
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment