Commit 7b83d6e7 authored by 陈曦's avatar 陈曦
Browse files

Merge remote-tracking branch 'upstream/main'

parents daa2e6df dbb248df
......@@ -28,7 +28,7 @@ import (
gocache "github.com/patrickmn/go-cache"
)
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at"
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, created_at"
// usageLogInsertArgTypes must stay in the same order as:
// 1. prepareUsageLogInsert().args
......@@ -53,6 +53,8 @@ var usageLogInsertArgTypes = [...]string{
"integer", // cache_read_tokens
"integer", // cache_creation_5m_tokens
"integer", // cache_creation_1h_tokens
"integer", // image_output_tokens
"numeric", // image_output_cost
"numeric", // input_cost
"numeric", // output_cost
"numeric", // cache_creation_cost
......@@ -77,6 +79,10 @@ var usageLogInsertArgTypes = [...]string{
"text", // inbound_endpoint
"text", // upstream_endpoint
"boolean", // cache_ttl_overridden
"bigint", // channel_id
"text", // model_mapping_chain
"text", // billing_tier
"text", // billing_mode
"timestamptz", // created_at
}
......@@ -326,6 +332,8 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost,
output_cost,
cache_creation_cost,
......@@ -350,14 +358,18 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
billing_mode,
created_at
) VALUES (
$1, $2, $3, $4, $5, $6, $7,
$8, $9,
$10, $11, $12, $13,
$14, $15,
$16, $17, $18, $19, $20, $21,
$22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40
$14, $15, $16, $17,
$18, $19, $20, $21, $22, $23,
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
......@@ -758,6 +770,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost,
output_cost,
cache_creation_cost,
......@@ -782,10 +796,14 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
billing_mode,
created_at
) AS (VALUES `)
args := make([]any, 0, len(keys)*39)
args := make([]any, 0, len(keys)*47)
argPos := 1
for idx, key := range keys {
if idx > 0 {
......@@ -829,6 +847,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost,
output_cost,
cache_creation_cost,
......@@ -853,6 +873,10 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
billing_mode,
created_at
)
SELECT
......@@ -871,6 +895,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost,
output_cost,
cache_creation_cost,
......@@ -895,6 +921,10 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
billing_mode,
created_at
FROM input
ON CONFLICT (request_id, api_key_id) DO NOTHING
......@@ -953,6 +983,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost,
output_cost,
cache_creation_cost,
......@@ -977,10 +1009,14 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
billing_mode,
created_at
) AS (VALUES `)
args := make([]any, 0, len(preparedList)*40)
args := make([]any, 0, len(preparedList)*46)
argPos := 1
for idx, prepared := range preparedList {
if idx > 0 {
......@@ -1021,6 +1057,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost,
output_cost,
cache_creation_cost,
......@@ -1045,6 +1083,10 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
billing_mode,
created_at
)
SELECT
......@@ -1063,6 +1105,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost,
output_cost,
cache_creation_cost,
......@@ -1087,6 +1131,10 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
billing_mode,
created_at
FROM input
ON CONFLICT (request_id, api_key_id) DO NOTHING
......@@ -1113,6 +1161,8 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost,
output_cost,
cache_creation_cost,
......@@ -1137,14 +1187,18 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
billing_mode,
created_at
) VALUES (
$1, $2, $3, $4, $5, $6, $7,
$8, $9,
$10, $11, $12, $13,
$14, $15,
$16, $17, $18, $19, $20, $21,
$22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40
$14, $15, $16, $17,
$18, $19, $20, $21, $22, $23,
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
`, prepared.args...)
......@@ -1176,6 +1230,10 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
reasoningEffort := nullString(log.ReasoningEffort)
inboundEndpoint := nullString(log.InboundEndpoint)
upstreamEndpoint := nullString(log.UpstreamEndpoint)
channelID := nullInt64(log.ChannelID)
modelMappingChain := nullString(log.ModelMappingChain)
billingTier := nullString(log.BillingTier)
billingMode := nullString(log.BillingMode)
requestedModel := strings.TrimSpace(log.RequestedModel)
if requestedModel == "" {
requestedModel = strings.TrimSpace(log.Model)
......@@ -1208,6 +1266,8 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
log.CacheReadTokens,
log.CacheCreation5mTokens,
log.CacheCreation1hTokens,
log.ImageOutputTokens,
log.ImageOutputCost,
log.InputCost,
log.OutputCost,
log.CacheCreationCost,
......@@ -1232,6 +1292,10 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
inboundEndpoint,
upstreamEndpoint,
log.CacheTTLOverridden,
channelID,
modelMappingChain,
billingTier,
billingMode,
createdAt,
},
}
......@@ -2564,8 +2628,8 @@ type UsageLogFilters = usagestats.UsageLogFilters
// ListWithFilters lists usage logs with optional filters (for admin)
func (r *usageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
conditions := make([]string, 0, 8)
args := make([]any, 0, 8)
conditions := make([]string, 0, 9)
args := make([]any, 0, 9)
if filters.UserID > 0 {
conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1))
......@@ -2589,6 +2653,10 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
args = append(args, int16(*filters.BillingType))
}
if filters.BillingMode != "" {
conditions = append(conditions, fmt.Sprintf("billing_mode = $%d", len(args)+1))
args = append(args, filters.BillingMode)
}
if filters.StartTime != nil {
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1))
args = append(args, *filters.StartTime)
......@@ -3096,6 +3164,30 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim
query += fmt.Sprintf(" AND %s = $%d", col, len(args)+1)
args = append(args, dim.Endpoint)
}
if dim.UserID > 0 {
query += fmt.Sprintf(" AND ul.user_id = $%d", len(args)+1)
args = append(args, dim.UserID)
}
if dim.APIKeyID > 0 {
query += fmt.Sprintf(" AND ul.api_key_id = $%d", len(args)+1)
args = append(args, dim.APIKeyID)
}
if dim.AccountID > 0 {
query += fmt.Sprintf(" AND ul.account_id = $%d", len(args)+1)
args = append(args, dim.AccountID)
}
if dim.RequestType != nil {
query += fmt.Sprintf(" AND ul.request_type = $%d", len(args)+1)
args = append(args, *dim.RequestType)
}
if dim.Stream != nil {
query += fmt.Sprintf(" AND ul.stream = $%d", len(args)+1)
args = append(args, *dim.Stream)
}
if dim.BillingType != nil {
query += fmt.Sprintf(" AND ul.billing_type = $%d", len(args)+1)
args = append(args, *dim.BillingType)
}
query += " GROUP BY ul.user_id, u.email ORDER BY actual_cost DESC"
if limit > 0 {
......@@ -3256,6 +3348,10 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
args = append(args, int16(*filters.BillingType))
}
if filters.BillingMode != "" {
conditions = append(conditions, fmt.Sprintf("billing_mode = $%d", len(args)+1))
args = append(args, filters.BillingMode)
}
if filters.StartTime != nil {
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1))
args = append(args, *filters.StartTime)
......@@ -3935,6 +4031,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
cacheReadTokens int
cacheCreation5m int
cacheCreation1h int
imageOutputTokens int
imageOutputCost float64
inputCost float64
outputCost float64
cacheCreationCost float64
......@@ -3959,6 +4057,10 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
inboundEndpoint sql.NullString
upstreamEndpoint sql.NullString
cacheTTLOverridden bool
channelID sql.NullInt64
modelMappingChain sql.NullString
billingTier sql.NullString
billingMode sql.NullString
createdAt time.Time
)
......@@ -3979,6 +4081,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&cacheReadTokens,
&cacheCreation5m,
&cacheCreation1h,
&imageOutputTokens,
&imageOutputCost,
&inputCost,
&outputCost,
&cacheCreationCost,
......@@ -4003,6 +4107,10 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&inboundEndpoint,
&upstreamEndpoint,
&cacheTTLOverridden,
&channelID,
&modelMappingChain,
&billingTier,
&billingMode,
&createdAt,
); err != nil {
return nil, err
......@@ -4021,6 +4129,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
CacheReadTokens: cacheReadTokens,
CacheCreation5mTokens: cacheCreation5m,
CacheCreation1hTokens: cacheCreation1h,
ImageOutputTokens: imageOutputTokens,
ImageOutputCost: imageOutputCost,
InputCost: inputCost,
OutputCost: outputCost,
CacheCreationCost: cacheCreationCost,
......@@ -4087,6 +4197,19 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if upstreamModel.Valid {
log.UpstreamModel = &upstreamModel.String
}
if channelID.Valid {
value := channelID.Int64
log.ChannelID = &value
}
if modelMappingChain.Valid {
log.ModelMappingChain = &modelMappingChain.String
}
if billingTier.Valid {
log.BillingTier = &billingTier.String
}
if billingMode.Valid {
log.BillingMode = &billingMode.String
}
return log, nil
}
......
......@@ -56,6 +56,8 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
log.CacheReadTokens,
log.CacheCreation5mTokens,
log.CacheCreation1hTokens,
log.ImageOutputTokens,
log.ImageOutputCost,
log.InputCost,
log.OutputCost,
log.CacheCreationCost,
......@@ -80,6 +82,10 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
sqlmock.AnyArg(), // inbound_endpoint
sqlmock.AnyArg(), // upstream_endpoint
log.CacheTTLOverridden,
sqlmock.AnyArg(), // channel_id
sqlmock.AnyArg(), // model_mapping_chain
sqlmock.AnyArg(), // billing_tier
sqlmock.AnyArg(), // billing_mode
createdAt,
).
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(99), createdAt))
......@@ -129,6 +135,8 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
log.CacheReadTokens,
log.CacheCreation5mTokens,
log.CacheCreation1hTokens,
log.ImageOutputTokens,
log.ImageOutputCost,
log.InputCost,
log.OutputCost,
log.CacheCreationCost,
......@@ -153,6 +161,10 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
sqlmock.AnyArg(),
sqlmock.AnyArg(),
log.CacheTTLOverridden,
sqlmock.AnyArg(), // channel_id
sqlmock.AnyArg(), // model_mapping_chain
sqlmock.AnyArg(), // billing_tier
sqlmock.AnyArg(), // billing_mode
createdAt,
).
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt))
......@@ -439,6 +451,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
4, // cache_read_tokens
5, // cache_creation_5m_tokens
6, // cache_creation_1h_tokens
0, // image_output_tokens
0.0, // image_output_cost
0.1, // input_cost
0.2, // output_cost
0.3, // cache_creation_cost
......@@ -463,6 +477,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
sql.NullString{},
false,
sql.NullInt64{}, // channel_id
sql.NullString{}, // model_mapping_chain
sql.NullString{}, // billing_tier
sql.NullString{}, // billing_mode
now,
}})
require.NoError(t, err)
......@@ -487,6 +505,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullInt64{},
sql.NullInt64{},
1, 2, 3, 4, 5, 6,
0, 0.0, // image_output_tokens, image_output_cost
0.1, 0.2, 0.3, 0.4, 1.0, 0.9,
1.0,
sql.NullFloat64{},
......@@ -506,6 +525,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
sql.NullString{},
false,
sql.NullInt64{}, // channel_id
sql.NullString{}, // model_mapping_chain
sql.NullString{}, // billing_tier
sql.NullString{}, // billing_mode
now,
}})
require.NoError(t, err)
......@@ -530,6 +553,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullInt64{},
sql.NullInt64{},
1, 2, 3, 4, 5, 6,
0, 0.0, // image_output_tokens, image_output_cost
0.1, 0.2, 0.3, 0.4, 1.0, 0.9,
1.0,
sql.NullFloat64{},
......@@ -549,6 +573,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
sql.NullString{},
false,
sql.NullInt64{}, // channel_id
sql.NullString{}, // model_mapping_chain
sql.NullString{}, // billing_tier
sql.NullString{}, // billing_mode
now,
}})
require.NoError(t, err)
......
......@@ -74,6 +74,7 @@ var ProviderSet = wire.NewSet(
NewUserGroupRateRepository,
NewErrorPassthroughRepository,
NewTLSFingerprintProfileRepository,
NewChannelRepository,
// Cache implementations
NewGatewayCache,
......
......@@ -87,6 +87,9 @@ func RegisterAdminRoutes(
// 定时测试计划
registerScheduledTestRoutes(admin, h)
// 渠道管理
registerChannelRoutes(admin, h)
}
}
......@@ -567,3 +570,15 @@ func registerTLSFingerprintProfileRoutes(admin *gin.RouterGroup, h *handler.Hand
profiles.DELETE("/:id", h.Admin.TLSFingerprintProfile.Delete)
}
}
func registerChannelRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
channels := admin.Group("/channels")
{
channels.GET("", h.Admin.Channel.List)
channels.GET("/model-pricing", h.Admin.Channel.GetModelDefaultPricing)
channels.GET("/:id", h.Admin.Channel.GetByID)
channels.POST("", h.Admin.Channel.Create)
channels.PUT("/:id", h.Admin.Channel.Update)
channels.DELETE("/:id", h.Admin.Channel.Delete)
}
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -732,7 +732,7 @@ func TestSelectAccountWithLoadAwareness_StickyReadReuse(t *testing.T) {
modelsListCacheTTL: time.Minute,
}
result, err := svc.SelectAccountWithLoadAwareness(baseCtx, nil, "sess-hash", "", nil, "")
result, err := svc.SelectAccountWithLoadAwareness(baseCtx, nil, "sess-hash", "", nil, "", int64(0))
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
......@@ -754,7 +754,7 @@ func TestSelectAccountWithLoadAwareness_StickyReadReuse(t *testing.T) {
ctx := context.WithValue(baseCtx, ctxkey.PrefetchedStickyAccountID, account.ID)
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(0))
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "")
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "", int64(0))
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
......@@ -776,7 +776,7 @@ func TestSelectAccountWithLoadAwareness_StickyReadReuse(t *testing.T) {
ctx := context.WithValue(baseCtx, ctxkey.PrefetchedStickyAccountID, int64(999))
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(77))
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "")
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "", int64(0))
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
......
......@@ -41,6 +41,8 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo
nil,
nil,
nil,
nil,
nil,
)
}
......
This diff is collapsed.
......@@ -2692,12 +2692,27 @@ func extractGeminiUsage(data []byte) *ClaudeUsage {
cand := int(usage.Get("candidatesTokenCount").Int())
cached := int(usage.Get("cachedContentTokenCount").Int())
thoughts := int(usage.Get("thoughtsTokenCount").Int())
// 从 candidatesTokensDetails 提取 IMAGE 模态 token 数
imageTokens := 0
candidateDetails := usage.Get("candidatesTokensDetails")
if candidateDetails.Exists() {
candidateDetails.ForEach(func(_, detail gjson.Result) bool {
if detail.Get("modality").String() == "IMAGE" {
imageTokens = int(detail.Get("tokenCount").Int())
return false
}
return true
})
}
// 注意:Gemini 的 promptTokenCount 包含 cachedContentTokenCount,
// 但 Claude 的 input_tokens 不包含 cache_read_input_tokens,需要减去
return &ClaudeUsage{
InputTokens: prompt - cached,
OutputTokens: cand + thoughts,
CacheReadInputTokens: cached,
ImageOutputTokens: imageTokens,
}
}
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -10,8 +10,8 @@ import (
const compatPromptCacheKeyPrefix = "compat_cc_"
func shouldAutoInjectPromptCacheKeyForCompat(model string) bool {
switch resolveOpenAIUpstreamModel(strings.TrimSpace(model)) {
case "gpt-5.4", "gpt-5.3-codex", "gpt-5.3-codex-spark":
switch normalizeCodexModel(strings.TrimSpace(model)) {
case "gpt-5.4", "gpt-5.3-codex":
return true
default:
return false
......@@ -23,9 +23,9 @@ func deriveCompatPromptCacheKey(req *apicompat.ChatCompletionsRequest, mappedMod
return ""
}
normalizedModel := resolveOpenAIUpstreamModel(strings.TrimSpace(mappedModel))
normalizedModel := normalizeCodexModel(strings.TrimSpace(mappedModel))
if normalizedModel == "" {
normalizedModel = resolveOpenAIUpstreamModel(strings.TrimSpace(req.Model))
normalizedModel = normalizeCodexModel(strings.TrimSpace(req.Model))
}
if normalizedModel == "" {
normalizedModel = strings.TrimSpace(req.Model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment