Commit 08c4e514 authored by InCerry's avatar InCerry
Browse files

Merge branch 'main' of github.com:InCerryGit/sub2api

# Conflicts:
#	backend/internal/service/billing_service.go
parents 73708da6 995bee14
......@@ -46,9 +46,10 @@ func AnthropicToResponses(req *AnthropicRequest) (*ResponsesRequest, error) {
}
// Determine reasoning effort: only output_config.effort controls the
// level; thinking.type is ignored. Default is xhigh when unset.
// Anthropic levels map to OpenAI: low→low, medium→high, high→xhigh.
effort := "high" // default → maps to xhigh
// level; thinking.type is ignored. Default is high when unset (both
// Anthropic and OpenAI default to high).
// Anthropic levels map 1:1 to OpenAI: low→low, medium→medium, high→high, max→xhigh.
effort := "high" // default → both sides' default
if req.OutputConfig != nil && req.OutputConfig.Effort != "" {
effort = req.OutputConfig.Effort
}
......@@ -380,18 +381,19 @@ func extractAnthropicTextFromBlocks(blocks []AnthropicContentBlock) string {
// mapAnthropicEffortToResponses converts Anthropic reasoning effort levels to
// OpenAI Responses API effort levels.
//
// Both APIs default to "high". The mapping is 1:1 for shared levels;
// only Anthropic's "max" (Opus 4.6 exclusive) maps to OpenAI's "xhigh"
// (GPT-5.2+ exclusive) as both represent the highest reasoning tier.
//
// low → low
// medium → high
// high → xhigh
// medium → medium
// high → high
// max → xhigh
func mapAnthropicEffortToResponses(effort string) string {
switch effort {
case "medium":
return "high"
case "high":
if effort == "max" {
return "xhigh"
default:
return effort // "low" and any unknown values pass through unchanged
}
return effort // low→low, medium→medium, high→high, unknown→passthrough
}
// convertAnthropicToolsToResponses maps Anthropic tool definitions to
......
......@@ -181,6 +181,35 @@ func TestChatCompletionsToResponses_ImageURL(t *testing.T) {
assert.Equal(t, "data:image/png;base64,abc123", parts[1].ImageURL)
}
func TestChatCompletionsToResponses_SystemArrayContent(t *testing.T) {
req := &ChatCompletionsRequest{
Model: "gpt-4o",
Messages: []ChatMessage{
{Role: "system", Content: json.RawMessage(`[{"type":"text","text":"You are a careful visual assistant."}]`)},
{Role: "user", Content: json.RawMessage(`[{"type":"text","text":"Describe this image"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc123"}}]`)},
},
}
resp, err := ChatCompletionsToResponses(req)
require.NoError(t, err)
var items []ResponsesInputItem
require.NoError(t, json.Unmarshal(resp.Input, &items))
require.Len(t, items, 2)
var systemParts []ResponsesContentPart
require.NoError(t, json.Unmarshal(items[0].Content, &systemParts))
require.Len(t, systemParts, 1)
assert.Equal(t, "input_text", systemParts[0].Type)
assert.Equal(t, "You are a careful visual assistant.", systemParts[0].Text)
var userParts []ResponsesContentPart
require.NoError(t, json.Unmarshal(items[1].Content, &userParts))
require.Len(t, userParts, 2)
assert.Equal(t, "input_image", userParts[1].Type)
assert.Equal(t, "data:image/png;base64,abc123", userParts[1].ImageURL)
}
func TestChatCompletionsToResponses_LegacyFunctions(t *testing.T) {
req := &ChatCompletionsRequest{
Model: "gpt-4o",
......@@ -398,6 +427,45 @@ func TestResponsesToChatCompletions_Reasoning(t *testing.T) {
assert.Equal(t, "I thought about it.", chat.Choices[0].Message.ReasoningContent)
}
func TestChatCompletionsToResponses_ToolArrayContent(t *testing.T) {
req := &ChatCompletionsRequest{
Model: "gpt-4o",
Messages: []ChatMessage{
{Role: "user", Content: json.RawMessage(`"Use the tool"`)},
{
Role: "assistant",
ToolCalls: []ChatToolCall{
{
ID: "call_1",
Type: "function",
Function: ChatFunctionCall{
Name: "inspect_image",
Arguments: `{}`,
},
},
},
},
{
Role: "tool",
ToolCallID: "call_1",
Content: json.RawMessage(
`[{"type":"text","text":"image width: 100"},{"type":"image_url","image_url":{"url":"data:image/png;base64,ignored"}},{"type":"text","text":"; image height: 200"}]`,
),
},
},
}
resp, err := ChatCompletionsToResponses(req)
require.NoError(t, err)
var items []ResponsesInputItem
require.NoError(t, json.Unmarshal(resp.Input, &items))
require.Len(t, items, 3)
assert.Equal(t, "function_call_output", items[2].Type)
assert.Equal(t, "call_1", items[2].CallID)
assert.Equal(t, "image width: 100; image height: 200", items[2].Output)
}
func TestResponsesToChatCompletions_Incomplete(t *testing.T) {
resp := &ResponsesResponse{
ID: "resp_inc",
......
......@@ -6,6 +6,11 @@ import (
"strings"
)
type chatMessageContent struct {
Text *string
Parts []ChatContentPart
}
// ChatCompletionsToResponses converts a Chat Completions request into a
// Responses API request. The upstream always streams, so Stream is forced to
// true. store is always false and reasoning.encrypted_content is always
......@@ -113,11 +118,11 @@ func chatMessageToResponsesItems(m ChatMessage) ([]ResponsesInputItem, error) {
// chatSystemToResponses converts a system message.
func chatSystemToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
text, err := parseChatContent(m.Content)
parsed, err := parseChatMessageContent(m.Content)
if err != nil {
return nil, err
}
content, err := json.Marshal(text)
content, err := marshalChatInputContent(parsed)
if err != nil {
return nil, err
}
......@@ -127,39 +132,11 @@ func chatSystemToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
// chatUserToResponses converts a user message, handling both plain strings and
// multi-modal content arrays.
func chatUserToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
// Try plain string first.
var s string
if err := json.Unmarshal(m.Content, &s); err == nil {
content, _ := json.Marshal(s)
return []ResponsesInputItem{{Role: "user", Content: content}}, nil
}
var parts []ChatContentPart
if err := json.Unmarshal(m.Content, &parts); err != nil {
parsed, err := parseChatMessageContent(m.Content)
if err != nil {
return nil, fmt.Errorf("parse user content: %w", err)
}
var responseParts []ResponsesContentPart
for _, p := range parts {
switch p.Type {
case "text":
if p.Text != "" {
responseParts = append(responseParts, ResponsesContentPart{
Type: "input_text",
Text: p.Text,
})
}
case "image_url":
if p.ImageURL != nil && p.ImageURL.URL != "" {
responseParts = append(responseParts, ResponsesContentPart{
Type: "input_image",
ImageURL: p.ImageURL.URL,
})
}
}
}
content, err := json.Marshal(responseParts)
content, err := marshalChatInputContent(parsed)
if err != nil {
return nil, err
}
......@@ -312,16 +289,79 @@ func chatFunctionToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
}
// parseChatContent returns the string value of a ChatMessage Content field.
// Content must be a JSON string. Returns "" if content is null or empty.
// Content can be a JSON string or an array of typed parts. Array content is
// flattened to text by concatenating text parts and ignoring non-text parts.
func parseChatContent(raw json.RawMessage) (string, error) {
parsed, err := parseChatMessageContent(raw)
if err != nil {
return "", err
}
if parsed.Text != nil {
return *parsed.Text, nil
}
return flattenChatContentParts(parsed.Parts), nil
}
func parseChatMessageContent(raw json.RawMessage) (chatMessageContent, error) {
if len(raw) == 0 {
return "", nil
return chatMessageContent{Text: stringPtr("")}, nil
}
var s string
if err := json.Unmarshal(raw, &s); err != nil {
return "", fmt.Errorf("parse content as string: %w", err)
if err := json.Unmarshal(raw, &s); err == nil {
return chatMessageContent{Text: &s}, nil
}
return s, nil
var parts []ChatContentPart
if err := json.Unmarshal(raw, &parts); err == nil {
return chatMessageContent{Parts: parts}, nil
}
return chatMessageContent{}, fmt.Errorf("parse content as string or parts array")
}
func marshalChatInputContent(content chatMessageContent) (json.RawMessage, error) {
if content.Text != nil {
return json.Marshal(*content.Text)
}
return json.Marshal(convertChatContentPartsToResponses(content.Parts))
}
func convertChatContentPartsToResponses(parts []ChatContentPart) []ResponsesContentPart {
var responseParts []ResponsesContentPart
for _, p := range parts {
switch p.Type {
case "text":
if p.Text != "" {
responseParts = append(responseParts, ResponsesContentPart{
Type: "input_text",
Text: p.Text,
})
}
case "image_url":
if p.ImageURL != nil && p.ImageURL.URL != "" {
responseParts = append(responseParts, ResponsesContentPart{
Type: "input_image",
ImageURL: p.ImageURL.URL,
})
}
}
}
return responseParts
}
func flattenChatContentParts(parts []ChatContentPart) string {
var textParts []string
for _, p := range parts {
if p.Type == "text" && p.Text != "" {
textParts = append(textParts, p.Text)
}
}
return strings.Join(textParts, "")
}
func stringPtr(s string) *string {
return &s
}
// convertChatToolsToResponses maps Chat Completions tool definitions and legacy
......
......@@ -28,50 +28,64 @@ import (
gocache "github.com/patrickmn/go-cache"
)
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at"
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at"
// usageLogInsertArgTypes must stay in the same order as:
// 1. prepareUsageLogInsert().args
// 2. every INSERT/CTE VALUES column list in this file
// 3. execUsageLogInsertNoResult placeholder positions
// 4. scanUsageLog selected column order (via usageLogSelectColumns)
//
// When adding a usage_logs column, update all of those call sites together.
var usageLogInsertArgTypes = [...]string{
"bigint",
"bigint",
"bigint",
"text",
"text",
"text",
"bigint",
"bigint",
"integer",
"integer",
"integer",
"integer",
"integer",
"integer",
"numeric",
"numeric",
"numeric",
"numeric",
"numeric",
"numeric",
"numeric",
"numeric",
"smallint",
"smallint",
"boolean",
"boolean",
"integer",
"integer",
"text",
"text",
"integer",
"text",
"text",
"text",
"text",
"text",
"text",
"boolean",
"timestamptz",
"bigint", // user_id
"bigint", // api_key_id
"bigint", // account_id
"text", // request_id
"text", // model
"text", // requested_model
"text", // upstream_model
"bigint", // group_id
"bigint", // subscription_id
"integer", // input_tokens
"integer", // output_tokens
"integer", // cache_creation_tokens
"integer", // cache_read_tokens
"integer", // cache_creation_5m_tokens
"integer", // cache_creation_1h_tokens
"numeric", // input_cost
"numeric", // output_cost
"numeric", // cache_creation_cost
"numeric", // cache_read_cost
"numeric", // total_cost
"numeric", // actual_cost
"numeric", // rate_multiplier
"numeric", // account_rate_multiplier
"smallint", // billing_type
"smallint", // request_type
"boolean", // stream
"boolean", // openai_ws_mode
"integer", // duration_ms
"integer", // first_token_ms
"text", // user_agent
"text", // ip_address
"integer", // image_count
"text", // image_size
"text", // media_type
"text", // service_tier
"text", // reasoning_effort
"text", // inbound_endpoint
"text", // upstream_endpoint
"boolean", // cache_ttl_overridden
"timestamptz", // created_at
}
const rawUsageLogModelColumn = "model"
// rawUsageLogModelColumn preserves the exact stored usage_logs.model semantics for direct filters.
// Historical rows may contain upstream/billing model values, while newer rows store requested_model.
// Requested/upstream/mapping analytics must use resolveModelDimensionExpression instead.
// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL
var dateFormatWhitelist = map[string]string{
"hour": "YYYY-MM-DD HH24:00",
......@@ -88,6 +102,30 @@ func safeDateFormat(granularity string) string {
return "YYYY-MM-DD"
}
// appendRawUsageLogModelWhereCondition keeps direct model filters on the raw model column for backward
// compatibility with historical rows. Requested/upstream analytics must use
// resolveModelDimensionExpression instead.
func appendRawUsageLogModelWhereCondition(conditions []string, args []any, model string) ([]string, []any) {
if strings.TrimSpace(model) == "" {
return conditions, args
}
conditions = append(conditions, fmt.Sprintf("%s = $%d", rawUsageLogModelColumn, len(args)+1))
args = append(args, model)
return conditions, args
}
// appendRawUsageLogModelQueryFilter keeps direct model filters on the raw model column for backward
// compatibility with historical rows. Requested/upstream analytics must use
// resolveModelDimensionExpression instead.
func appendRawUsageLogModelQueryFilter(query string, args []any, model string) (string, []any) {
if strings.TrimSpace(model) == "" {
return query, args
}
query += fmt.Sprintf(" AND %s = $%d", rawUsageLogModelColumn, len(args)+1)
args = append(args, model)
return query, args
}
type usageLogRepository struct {
client *dbent.Client
sql sqlExecutor
......@@ -278,6 +316,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
account_id,
request_id,
model,
requested_model,
upstream_model,
group_id,
subscription_id,
......@@ -313,12 +352,12 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
cache_ttl_overridden,
created_at
) VALUES (
$1, $2, $3, $4, $5, $6,
$7, $8,
$9, $10, $11, $12,
$13, $14,
$15, $16, $17, $18, $19, $20,
$21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39
$1, $2, $3, $4, $5, $6, $7,
$8, $9,
$10, $11, $12, $13,
$14, $15,
$16, $17, $18, $19, $20, $21,
$22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
......@@ -709,6 +748,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
account_id,
request_id,
model,
requested_model,
upstream_model,
group_id,
subscription_id,
......@@ -779,6 +819,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
account_id,
request_id,
model,
requested_model,
upstream_model,
group_id,
subscription_id,
......@@ -820,6 +861,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
account_id,
request_id,
model,
requested_model,
upstream_model,
group_id,
subscription_id,
......@@ -901,6 +943,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
account_id,
request_id,
model,
requested_model,
upstream_model,
group_id,
subscription_id,
......@@ -937,7 +980,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
created_at
) AS (VALUES `)
args := make([]any, 0, len(preparedList)*39)
args := make([]any, 0, len(preparedList)*40)
argPos := 1
for idx, prepared := range preparedList {
if idx > 0 {
......@@ -968,6 +1011,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
account_id,
request_id,
model,
requested_model,
upstream_model,
group_id,
subscription_id,
......@@ -1009,6 +1053,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
account_id,
request_id,
model,
requested_model,
upstream_model,
group_id,
subscription_id,
......@@ -1058,6 +1103,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
account_id,
request_id,
model,
requested_model,
upstream_model,
group_id,
subscription_id,
......@@ -1093,12 +1139,12 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
cache_ttl_overridden,
created_at
) VALUES (
$1, $2, $3, $4, $5, $6,
$7, $8,
$9, $10, $11, $12,
$13, $14,
$15, $16, $17, $18, $19, $20,
$21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39
$1, $2, $3, $4, $5, $6, $7,
$8, $9,
$10, $11, $12, $13,
$14, $15,
$16, $17, $18, $19, $20, $21,
$22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
`, prepared.args...)
......@@ -1130,6 +1176,10 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
reasoningEffort := nullString(log.ReasoningEffort)
inboundEndpoint := nullString(log.InboundEndpoint)
upstreamEndpoint := nullString(log.UpstreamEndpoint)
requestedModel := strings.TrimSpace(log.RequestedModel)
if requestedModel == "" {
requestedModel = strings.TrimSpace(log.Model)
}
upstreamModel := nullString(log.UpstreamModel)
var requestIDArg any
......@@ -1148,6 +1198,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
log.AccountID,
requestIDArg,
log.Model,
nullString(&requestedModel),
upstreamModel,
groupID,
subscriptionID,
......@@ -1702,7 +1753,7 @@ func (r *usageLogRepository) GetAccountStatsAggregated(ctx context.Context, acco
// GetModelStatsAggregated 使用 SQL 聚合统计模型使用数据
// 性能优化:数据库层聚合计算,避免应用层循环统计
func (r *usageLogRepository) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
query := `
query := fmt.Sprintf(`
SELECT
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
......@@ -1712,8 +1763,8 @@ func (r *usageLogRepository) GetModelStatsAggregated(ctx context.Context, modelN
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms
FROM usage_logs
WHERE model = $1 AND created_at >= $2 AND created_at < $3
`
WHERE %s = $1 AND created_at >= $2 AND created_at < $3
`, rawUsageLogModelColumn)
var stats usagestats.UsageStats
if err := scanSingleRow(
......@@ -1837,7 +1888,7 @@ func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, acco
}
func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE model = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000"
query := fmt.Sprintf("SELECT %s FROM usage_logs WHERE %s = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000", usageLogSelectColumns, rawUsageLogModelColumn)
logs, err := r.queryUsageLogs(ctx, query, modelName, startTime, endTime)
return logs, nil, err
}
......@@ -2532,10 +2583,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
conditions = append(conditions, fmt.Sprintf("group_id = $%d", len(args)+1))
args = append(args, filters.GroupID)
}
if filters.Model != "" {
conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1))
args = append(args, filters.Model)
}
conditions, args = appendRawUsageLogModelWhereCondition(conditions, args, filters.Model)
conditions, args = appendRequestTypeOrStreamWhereCondition(conditions, args, filters.RequestType, filters.Stream)
if filters.BillingType != nil {
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
......@@ -2768,10 +2816,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
args = append(args, groupID)
}
if model != "" {
query += fmt.Sprintf(" AND model = $%d", len(args)+1)
args = append(args, model)
}
query, args = appendRawUsageLogModelQueryFilter(query, args, model)
query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream)
if billingType != nil {
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
......@@ -3126,13 +3171,14 @@ func (r *usageLogRepository) GetAllGroupUsageSummary(ctx context.Context, todayS
// resolveModelDimensionExpression maps model source type to a safe SQL expression.
func resolveModelDimensionExpression(modelType string) string {
requestedExpr := "COALESCE(NULLIF(TRIM(requested_model), ''), model)"
switch usagestats.NormalizeModelSource(modelType) {
case usagestats.ModelSourceUpstream:
return "COALESCE(NULLIF(TRIM(upstream_model), ''), model)"
return fmt.Sprintf("COALESCE(NULLIF(TRIM(upstream_model), ''), %s)", requestedExpr)
case usagestats.ModelSourceMapping:
return "(model || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), model))"
return fmt.Sprintf("(%s || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), %s))", requestedExpr, requestedExpr)
default:
return "model"
return requestedExpr
}
}
......@@ -3204,10 +3250,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
conditions = append(conditions, fmt.Sprintf("group_id = $%d", len(args)+1))
args = append(args, filters.GroupID)
}
if filters.Model != "" {
conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1))
args = append(args, filters.Model)
}
conditions, args = appendRawUsageLogModelWhereCondition(conditions, args, filters.Model)
conditions, args = appendRequestTypeOrStreamWhereCondition(conditions, args, filters.RequestType, filters.Stream)
if filters.BillingType != nil {
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
......@@ -3336,10 +3379,7 @@ func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Con
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
args = append(args, groupID)
}
if model != "" {
query += fmt.Sprintf(" AND model = $%d", len(args)+1)
args = append(args, model)
}
query, args = appendRawUsageLogModelQueryFilter(query, args, model)
query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream)
if billingType != nil {
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
......@@ -3410,10 +3450,7 @@ func (r *usageLogRepository) getEndpointPathStatsWithFilters(ctx context.Context
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
args = append(args, groupID)
}
if model != "" {
query += fmt.Sprintf(" AND model = $%d", len(args)+1)
args = append(args, model)
}
query, args = appendRawUsageLogModelQueryFilter(query, args, model)
query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream)
if billingType != nil {
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
......@@ -3888,6 +3925,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
accountID int64
requestID sql.NullString
model string
requestedModel sql.NullString
upstreamModel sql.NullString
groupID sql.NullInt64
subscriptionID sql.NullInt64
......@@ -3931,6 +3969,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&accountID,
&requestID,
&model,
&requestedModel,
&upstreamModel,
&groupID,
&subscriptionID,
......@@ -3975,6 +4014,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
APIKeyID: apiKeyID,
AccountID: accountID,
Model: model,
RequestedModel: coalesceTrimmedString(requestedModel, model),
InputTokens: inputTokens,
OutputTokens: outputTokens,
CacheCreationTokens: cacheCreationTokens,
......@@ -4181,6 +4221,13 @@ func nullString(v *string) sql.NullString {
return sql.NullString{String: *v, Valid: true}
}
func coalesceTrimmedString(v sql.NullString, fallback string) string {
if v.Valid && strings.TrimSpace(v.String) != "" {
return v.String
}
return fallback
}
func setToSlice(set map[int64]struct{}) []int64 {
out := make([]int64, 0, len(set))
for id := range set {
......
......@@ -34,11 +34,11 @@ func TestResolveModelDimensionExpression(t *testing.T) {
modelType string
want string
}{
{usagestats.ModelSourceRequested, "model"},
{usagestats.ModelSourceUpstream, "COALESCE(NULLIF(TRIM(upstream_model), ''), model)"},
{usagestats.ModelSourceMapping, "(model || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), model))"},
{"", "model"},
{"invalid", "model"},
{usagestats.ModelSourceRequested, "COALESCE(NULLIF(TRIM(requested_model), ''), model)"},
{usagestats.ModelSourceUpstream, "COALESCE(NULLIF(TRIM(upstream_model), ''), COALESCE(NULLIF(TRIM(requested_model), ''), model))"},
{usagestats.ModelSourceMapping, "(COALESCE(NULLIF(TRIM(requested_model), ''), model) || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), COALESCE(NULLIF(TRIM(requested_model), ''), model)))"},
{"", "COALESCE(NULLIF(TRIM(requested_model), ''), model)"},
{"invalid", "COALESCE(NULLIF(TRIM(requested_model), ''), model)"},
}
for _, tc := range tests {
......
......@@ -3,6 +3,7 @@ package repository
import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"reflect"
"testing"
......@@ -26,6 +27,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
AccountID: 3,
RequestID: "req-1",
Model: "gpt-5",
RequestedModel: "gpt-5",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 1,
......@@ -44,6 +46,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
log.AccountID,
log.RequestID,
log.Model,
log.RequestedModel,
sqlmock.AnyArg(), // upstream_model
sqlmock.AnyArg(), // group_id
sqlmock.AnyArg(), // subscription_id
......@@ -104,6 +107,7 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
AccountID: 3,
RequestID: "req-service-tier",
Model: "gpt-5.4",
RequestedModel: "gpt-5.4",
ServiceTier: &serviceTier,
CreatedAt: createdAt,
}
......@@ -115,6 +119,7 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
log.AccountID,
log.RequestID,
log.Model,
log.RequestedModel,
sqlmock.AnyArg(),
sqlmock.AnyArg(),
sqlmock.AnyArg(),
......@@ -158,6 +163,75 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
require.NoError(t, mock.ExpectationsWereMet())
}
func TestBuildUsageLogBestEffortInsertQuery_IncludesRequestedModelColumn(t *testing.T) {
prepared := prepareUsageLogInsert(&service.UsageLog{
UserID: 1,
APIKeyID: 2,
AccountID: 3,
RequestID: "req-best-effort-query",
Model: "gpt-5",
RequestedModel: "gpt-5",
CreatedAt: time.Date(2025, 1, 3, 12, 0, 0, 0, time.UTC),
})
query, args := buildUsageLogBestEffortInsertQuery([]usageLogInsertPrepared{prepared})
require.Contains(t, query, "INSERT INTO usage_logs (")
require.Contains(t, query, "\n\t\t\tmodel,\n\t\t\trequested_model,\n\t\t\tupstream_model,")
require.Contains(t, query, "\n\t\t\trequest_id,\n\t\t\tmodel,\n\t\t\trequested_model,\n\t\t\tupstream_model,")
require.Len(t, args, len(prepared.args))
require.Equal(t, prepared.args[5], args[5])
}
func TestExecUsageLogInsertNoResult_PersistsRequestedModel(t *testing.T) {
db, mock := newSQLMock(t)
prepared := prepareUsageLogInsert(&service.UsageLog{
UserID: 1,
APIKeyID: 2,
AccountID: 3,
RequestID: "req-best-effort-exec",
Model: "gpt-5",
RequestedModel: "gpt-5",
CreatedAt: time.Date(2025, 1, 4, 12, 0, 0, 0, time.UTC),
})
mock.ExpectExec("INSERT INTO usage_logs").
WithArgs(anySliceToDriverValues(prepared.args)...).
WillReturnResult(sqlmock.NewResult(0, 1))
err := execUsageLogInsertNoResult(context.Background(), db, prepared)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestPrepareUsageLogInsert_ArgCountMatchesTypes(t *testing.T) {
prepared := prepareUsageLogInsert(&service.UsageLog{
UserID: 1,
APIKeyID: 2,
AccountID: 3,
RequestID: "req-arg-count",
Model: "gpt-5",
RequestedModel: "gpt-5",
CreatedAt: time.Date(2025, 1, 5, 12, 0, 0, 0, time.UTC),
})
require.Len(t, prepared.args, len(usageLogInsertArgTypes))
}
func TestCoalesceTrimmedString(t *testing.T) {
require.Equal(t, "fallback", coalesceTrimmedString(sql.NullString{}, "fallback"))
require.Equal(t, "fallback", coalesceTrimmedString(sql.NullString{Valid: true, String: " "}, "fallback"))
require.Equal(t, "value", coalesceTrimmedString(sql.NullString{Valid: true, String: "value"}, "fallback"))
}
func anySliceToDriverValues(values []any) []driver.Value {
out := make([]driver.Value, 0, len(values))
for _, value := range values {
out = append(out, value)
}
return out
}
func TestUsageLogRepositoryListWithFiltersRequestTypePriority(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageLogRepository{sql: db}
......@@ -355,6 +429,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
int64(30), // account_id
sql.NullString{Valid: true, String: "req-1"},
"gpt-5", // model
sql.NullString{Valid: true, String: "gpt-5"}, // requested_model
sql.NullString{}, // upstream_model
sql.NullInt64{}, // group_id
sql.NullInt64{}, // subscription_id
......@@ -407,6 +482,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
int64(31),
sql.NullString{Valid: true, String: "req-2"},
"gpt-5",
sql.NullString{Valid: true, String: "gpt-5"},
sql.NullString{},
sql.NullInt64{},
sql.NullInt64{},
......@@ -449,6 +525,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
int64(32),
sql.NullString{Valid: true, String: "req-3"},
"gpt-5.4",
sql.NullString{Valid: true, String: "gpt-5.4"},
sql.NullString{},
sql.NullInt64{},
sql.NullInt64{},
......
......@@ -540,7 +540,8 @@ func TestAPIContracts(t *testing.T) {
"max_claude_code_version": "",
"allow_ungrouped_key_scheduling": false,
"backend_mode_enabled": false,
"custom_menu_items": []
"custom_menu_items": [],
"custom_endpoints": []
}
}`,
},
......
......@@ -1543,6 +1543,24 @@ func isPeriodExpired(periodStart time.Time, dur time.Duration) bool {
return time.Since(periodStart) >= dur
}
// IsDailyQuotaPeriodExpired 检查日配额周期是否已过期(用于显示层判断是否需要将 used 归零)
func (a *Account) IsDailyQuotaPeriodExpired() bool {
start := a.getExtraTime("quota_daily_start")
if a.GetQuotaDailyResetMode() == "fixed" {
return a.isFixedDailyPeriodExpired(start)
}
return isPeriodExpired(start, 24*time.Hour)
}
// IsWeeklyQuotaPeriodExpired 检查周配额周期是否已过期(用于显示层判断是否需要将 used 归零)
func (a *Account) IsWeeklyQuotaPeriodExpired() bool {
start := a.getExtraTime("quota_weekly_start")
if a.GetQuotaWeeklyResetMode() == "fixed" {
return a.isFixedWeeklyPeriodExpired(start)
}
return isPeriodExpired(start, 7*24*time.Hour)
}
// IsQuotaExceeded 检查 API Key 账号配额是否已超限(任一维度超限即返回 true)
func (a *Account) IsQuotaExceeded() bool {
// 总额度
......
......@@ -1742,7 +1742,8 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return &ForwardResult{
RequestID: requestID,
Usage: *usage,
Model: billingModel, // 使用映射模型用于计费和日志
Model: originalModel,
UpstreamModel: billingModel,
Stream: claudeReq.Stream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
......@@ -2435,7 +2436,8 @@ handleSuccess:
return &ForwardResult{
RequestID: requestID,
Usage: *usage,
Model: billingModel,
Model: originalModel,
UpstreamModel: billingModel,
Stream: stream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
......
......@@ -542,7 +542,8 @@ func TestAntigravityGatewayService_Forward_BillsWithMappedModel(t *testing.T) {
result, err := svc.Forward(context.Background(), c, account, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, mappedModel, result.Model)
require.Equal(t, "claude-sonnet-4-5", result.Model)
require.Equal(t, mappedModel, result.UpstreamModel)
}
// TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel
......@@ -594,7 +595,8 @@ func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing
result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", true, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, mappedModel, result.Model)
require.Equal(t, "gemini-2.5-flash", result.Model)
require.Equal(t, mappedModel, result.UpstreamModel)
}
func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignature(t *testing.T) {
......@@ -664,7 +666,8 @@ func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignatur
result, err := svc.ForwardGemini(context.Background(), c, account, originalModel, "streamGenerateContent", true, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, mappedModel, result.Model)
require.Equal(t, originalModel, result.Model)
require.Equal(t, mappedModel, result.UpstreamModel)
require.Len(t, upstream.requestBodies, 2, "signature error should trigger exactly one retry")
firstReq := string(upstream.requestBodies[0])
......
......@@ -119,6 +119,7 @@ const (
SettingKeyPurchaseSubscriptionEnabled = "purchase_subscription_enabled" // 是否展示"购买订阅"页面入口
SettingKeyPurchaseSubscriptionURL = "purchase_subscription_url" // "购买订阅"页面 URL(作为 iframe src)
SettingKeyCustomMenuItems = "custom_menu_items" // 自定义菜单项(JSON 数组)
SettingKeyCustomEndpoints = "custom_endpoints" // 自定义端点列表(JSON 数组)
// 默认配置
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
......
......@@ -162,6 +162,32 @@ func TestGatewayServiceRecordUsage_BillingFingerprintFallsBackToContextRequestID
require.Equal(t, "local:req-local-123", billingRepo.lastCmd.RequestPayloadHash)
}
func TestGatewayServiceRecordUsage_PreservesRequestedAndUpstreamModels(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
svc := newGatewayRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
mappedModel := "claude-sonnet-4-20250514"
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
Result: &ForwardResult{
RequestID: "gateway_models_split",
Usage: ClaudeUsage{InputTokens: 10, OutputTokens: 6},
Model: "claude-sonnet-4",
UpstreamModel: mappedModel,
Duration: time.Second,
},
APIKey: &APIKey{ID: 501, Quota: 100},
User: &User{ID: 601},
Account: &Account{ID: 701},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "claude-sonnet-4", usageRepo.lastLog.Model)
require.Equal(t, "claude-sonnet-4", usageRepo.lastLog.RequestedModel)
require.NotNil(t, usageRepo.lastLog.UpstreamModel)
require.Equal(t, mappedModel, *usageRepo.lastLog.UpstreamModel)
}
func TestGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)}
userRepo := &openAIRecordUsageUserRepoStub{}
......
......@@ -485,7 +485,9 @@ type ForwardResult struct {
RequestID string
Usage ClaudeUsage
Model string
UpstreamModel string // Actual upstream model after mapping (empty = no mapping)
// UpstreamModel is the actual upstream model after mapping.
// Prefer empty when it is identical to Model; persistence normalizes equal values away as no-op mappings.
UpstreamModel string
Stream bool
Duration time.Duration
FirstTokenMs *int // 首字时间(流式请求)
......@@ -4197,7 +4199,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
resp.Body = io.NopCloser(bytes.NewReader(respBody))
break
}
logger.LegacyPrintf("service.gateway", "Account %d: detected thinking block signature error, retrying with filtered thinking blocks", account.ID)
logger.LegacyPrintf("service.gateway", "[warn] Account %d: thinking blocks have invalid signature, retrying with filtered blocks", account.ID)
// Conservative two-stage fallback:
// 1) Disable thinking + thinking->text (preserve content)
......@@ -4212,7 +4214,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if retryErr == nil {
if retryResp.StatusCode < 400 {
logger.LegacyPrintf("service.gateway", "Account %d: signature error retry succeeded (thinking downgraded)", account.ID)
logger.LegacyPrintf("service.gateway", "Account %d: thinking block retry succeeded (blocks downgraded)", account.ID)
resp = retryResp
break
}
......@@ -6102,13 +6104,9 @@ func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool {
return false
}
// Log for debugging
logger.LegacyPrintf("service.gateway", "[SignatureCheck] Checking error message: %s", msg)
// 检测signature相关的错误(更宽松的匹配)
// 例如: "Invalid `signature` in `thinking` block", "***.signature" 等
if strings.Contains(msg, "signature") {
logger.LegacyPrintf("service.gateway", "[SignatureCheck] Detected signature error")
return true
}
......@@ -7516,6 +7514,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
}
var cost *CostBreakdown
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
// 根据请求类型选择计费方式
if result.MediaType == "image" || result.MediaType == "video" {
......@@ -7531,7 +7530,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
if result.MediaType == "image" {
cost = s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier)
} else {
cost = s.billingService.CalculateSoraVideoCost(result.Model, soraConfig, multiplier)
cost = s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier)
}
} else if result.MediaType == "prompt" {
cost = &CostBreakdown{}
......@@ -7545,7 +7544,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
Price4K: apiKey.Group.ImagePrice4K,
}
}
cost = s.billingService.CalculateImageCost(result.Model, result.ImageSize, result.ImageCount, groupConfig, multiplier)
cost = s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier)
} else {
// Token 计费
tokens := UsageTokens{
......@@ -7557,7 +7556,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
}
var err error
cost, err = s.billingService.CalculateCost(result.Model, tokens, multiplier)
cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier)
if err != nil {
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
cost = &CostBreakdown{ActualCost: 0}
......@@ -7589,6 +7588,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
AccountID: account.ID,
RequestID: requestID,
Model: result.Model,
RequestedModel: result.Model,
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
ReasoningEffort: result.ReasoningEffort,
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
......@@ -7719,6 +7719,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
}
var cost *CostBreakdown
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
// 根据请求类型选择计费方式
if result.ImageCount > 0 {
......@@ -7731,7 +7732,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
Price4K: apiKey.Group.ImagePrice4K,
}
}
cost = s.billingService.CalculateImageCost(result.Model, result.ImageSize, result.ImageCount, groupConfig, multiplier)
cost = s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier)
} else {
// Token 计费(使用长上下文计费方法)
tokens := UsageTokens{
......@@ -7743,7 +7744,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
}
var err error
cost, err = s.billingService.CalculateCostWithLongContext(result.Model, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier)
cost, err = s.billingService.CalculateCostWithLongContext(billingModel, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier)
if err != nil {
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
cost = &CostBreakdown{ActualCost: 0}
......@@ -7771,6 +7772,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
AccountID: account.ID,
RequestID: requestID,
Model: result.Model,
RequestedModel: result.Model,
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
ReasoningEffort: result.ReasoningEffort,
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
......
......@@ -1031,6 +1031,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
RequestID: requestID,
Usage: *usage,
Model: originalModel,
UpstreamModel: mappedModel,
Stream: req.Stream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
......@@ -1244,6 +1245,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
RequestID: "",
Usage: ClaudeUsage{},
Model: originalModel,
UpstreamModel: mappedModel,
Stream: false,
Duration: time.Since(startTime),
FirstTokenMs: nil,
......@@ -1313,6 +1315,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
RequestID: "",
Usage: ClaudeUsage{},
Model: originalModel,
UpstreamModel: mappedModel,
Stream: false,
Duration: time.Since(startTime),
FirstTokenMs: nil,
......@@ -1353,6 +1356,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
RequestID: requestID,
Usage: ClaudeUsage{},
Model: originalModel,
UpstreamModel: mappedModel,
Stream: false,
Duration: time.Since(startTime),
FirstTokenMs: nil,
......@@ -1530,6 +1534,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
RequestID: requestID,
Usage: *usage,
Model: originalModel,
UpstreamModel: mappedModel,
Stream: stream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
......
package service
import (
"context"
"encoding/json"
"fmt"
"io"
......@@ -15,6 +16,30 @@ import (
"github.com/stretchr/testify/require"
)
type geminiCompatHTTPUpstreamStub struct {
response *http.Response
err error
calls int
lastReq *http.Request
}
func (s *geminiCompatHTTPUpstreamStub) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
s.calls++
s.lastReq = req
if s.err != nil {
return nil, s.err
}
if s.response == nil {
return nil, fmt.Errorf("missing stub response")
}
resp := *s.response
return &resp, nil
}
func (s *geminiCompatHTTPUpstreamStub) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
return s.Do(req, proxyURL, accountID, accountConcurrency)
}
// TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换
func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
tests := []struct {
......@@ -170,6 +195,42 @@ func TestGeminiHandleNativeNonStreamingResponse_DebugDisabledDoesNotEmitHeaderLo
require.False(t, logSink.ContainsMessage("[GeminiAPI]"), "debug 关闭时不应输出 Gemini 响应头日志")
}
func TestGeminiMessagesCompatServiceForward_PreservesRequestedModelAndMappedUpstreamModel(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
httpStub := &geminiCompatHTTPUpstreamStub{
response: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"x-request-id": []string{"gemini-req-1"}},
Body: io.NopCloser(strings.NewReader(`{"candidates":[{"content":{"parts":[{"text":"hello"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5}}`)),
},
}
svc := &GeminiMessagesCompatService{httpUpstream: httpStub, cfg: &config.Config{}}
account := &Account{
ID: 1,
Type: AccountTypeAPIKey,
Credentials: map[string]any{
"api_key": "test-key",
"model_mapping": map[string]any{
"claude-sonnet-4": "claude-sonnet-4-20250514",
},
},
}
body := []byte(`{"model":"claude-sonnet-4","max_tokens":16,"messages":[{"role":"user","content":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "claude-sonnet-4", result.Model)
require.Equal(t, "claude-sonnet-4-20250514", result.UpstreamModel)
require.Equal(t, 1, httpStub.calls)
require.NotNil(t, httpStub.lastReq)
require.Contains(t, httpStub.lastReq.URL.String(), "/models/claude-sonnet-4-20250514:")
}
func TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse(t *testing.T) {
claudeReq := map[string]any{
"model": "claude-haiku-4-5-20251001",
......
......@@ -879,6 +879,7 @@ func TestOpenAIGatewayServiceRecordUsage_UsesRequestedModelAndUpstreamModelMetad
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "gpt-5.1", usageRepo.lastLog.Model)
require.Equal(t, "gpt-5.1", usageRepo.lastLog.RequestedModel)
require.NotNil(t, usageRepo.lastLog.UpstreamModel)
require.Equal(t, "gpt-5.1-codex", *usageRepo.lastLog.UpstreamModel)
require.NotNil(t, usageRepo.lastLog.ServiceTier)
......@@ -894,6 +895,40 @@ func TestOpenAIGatewayServiceRecordUsage_UsesRequestedModelAndUpstreamModelMetad
require.Equal(t, 1, userRepo.deductCalls)
}
func TestOpenAIGatewayServiceRecordUsage_BillsMappedRequestsUsingUpstreamModelFallback(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10}
expectedCost, err := svc.billingService.CalculateCost("gpt-5.1-codex", UsageTokens{
InputTokens: 20,
OutputTokens: 10,
}, 1.1)
require.NoError(t, err)
err = svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_upstream_model_billing_fallback",
Model: "gpt-5.1",
UpstreamModel: "gpt-5.1-codex",
Usage: usage,
Duration: time.Second,
},
APIKey: &APIKey{ID: 10},
User: &User{ID: 20},
Account: &Account{ID: 30},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "gpt-5.1", usageRepo.lastLog.Model)
require.Equal(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost)
require.Equal(t, expectedCost.TotalCost, usageRepo.lastLog.TotalCost)
require.Equal(t, expectedCost.ActualCost, userRepo.lastAmount)
}
func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFields(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
......
......@@ -4110,9 +4110,9 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier)
}
billingModel := result.Model
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
if result.BillingModel != "" {
billingModel = result.BillingModel
billingModel = strings.TrimSpace(result.BillingModel)
}
serviceTier := ""
if result.ServiceTier != nil {
......@@ -4140,6 +4140,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
AccountID: account.ID,
RequestID: requestID,
Model: result.Model,
RequestedModel: result.Model,
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
ServiceTier: result.ServiceTier,
ReasoningEffort: result.ReasoningEffort,
......
......@@ -68,3 +68,19 @@ func TestResolveOpenAIForwardModel(t *testing.T) {
})
}
}
func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt51(t *testing.T) {
account := &Account{
Credentials: map[string]any{},
}
withoutDefault := resolveOpenAIForwardModel(account, "claude-opus-4-6", "")
if got := normalizeCodexModel(withoutDefault); got != "gpt-5.1" {
t.Fatalf("normalizeCodexModel(%q) = %q, want %q", withoutDefault, got, "gpt-5.1")
}
withDefault := resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4")
if got := normalizeCodexModel(withDefault); got != "gpt-5.4" {
t.Fatalf("normalizeCodexModel(%q) = %q, want %q", withDefault, got, "gpt-5.4")
}
}
......@@ -2328,6 +2328,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
RequestID: responseID,
Usage: *usage,
Model: originalModel,
UpstreamModel: mappedModel,
ServiceTier: extractOpenAIServiceTier(reqBody),
ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel),
Stream: reqStream,
......@@ -2945,6 +2946,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
RequestID: responseID,
Usage: usage,
Model: originalModel,
UpstreamModel: mappedModel,
ServiceTier: extractOpenAIServiceTierFromBody(payload),
ReasoningEffort: extractOpenAIReasoningEffortFromBody(payload, originalModel),
Stream: reqStream,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment