Unverified Commit 241023f3 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #1097 from Ethan0x0000/pr/upstream-model-tracking

feat(usage): 新增 upstream_model 追踪,支持按模型来源统计与展示
parents 1292c44b cfaac12a
...@@ -5,6 +5,7 @@ package repository ...@@ -5,6 +5,7 @@ package repository
import ( import (
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
...@@ -27,3 +28,23 @@ func TestResolveEndpointColumn(t *testing.T) { ...@@ -27,3 +28,23 @@ func TestResolveEndpointColumn(t *testing.T) {
}) })
} }
} }
func TestResolveModelDimensionExpression(t *testing.T) {
tests := []struct {
modelType string
want string
}{
{usagestats.ModelSourceRequested, "model"},
{usagestats.ModelSourceUpstream, "COALESCE(NULLIF(TRIM(upstream_model), ''), model)"},
{usagestats.ModelSourceMapping, "(model || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), model))"},
{"", "model"},
{"invalid", "model"},
}
for _, tc := range tests {
t.Run(tc.modelType, func(t *testing.T) {
got := resolveModelDimensionExpression(tc.modelType)
require.Equal(t, tc.want, got)
})
}
}
...@@ -44,6 +44,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) { ...@@ -44,6 +44,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
log.AccountID, log.AccountID,
log.RequestID, log.RequestID,
log.Model, log.Model,
sqlmock.AnyArg(), // upstream_model
sqlmock.AnyArg(), // group_id sqlmock.AnyArg(), // group_id
sqlmock.AnyArg(), // subscription_id sqlmock.AnyArg(), // subscription_id
log.InputTokens, log.InputTokens,
...@@ -116,6 +117,7 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) { ...@@ -116,6 +117,7 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
log.Model, log.Model,
sqlmock.AnyArg(), sqlmock.AnyArg(),
sqlmock.AnyArg(), sqlmock.AnyArg(),
sqlmock.AnyArg(),
log.InputTokens, log.InputTokens,
log.OutputTokens, log.OutputTokens,
log.CacheCreationTokens, log.CacheCreationTokens,
...@@ -353,6 +355,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { ...@@ -353,6 +355,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
int64(30), // account_id int64(30), // account_id
sql.NullString{Valid: true, String: "req-1"}, sql.NullString{Valid: true, String: "req-1"},
"gpt-5", // model "gpt-5", // model
sql.NullString{}, // upstream_model
sql.NullInt64{}, // group_id sql.NullInt64{}, // group_id
sql.NullInt64{}, // subscription_id sql.NullInt64{}, // subscription_id
1, // input_tokens 1, // input_tokens
...@@ -404,6 +407,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { ...@@ -404,6 +407,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
int64(31), int64(31),
sql.NullString{Valid: true, String: "req-2"}, sql.NullString{Valid: true, String: "req-2"},
"gpt-5", "gpt-5",
sql.NullString{},
sql.NullInt64{}, sql.NullInt64{},
sql.NullInt64{}, sql.NullInt64{},
1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6,
...@@ -445,6 +449,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { ...@@ -445,6 +449,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
int64(32), int64(32),
sql.NullString{Valid: true, String: "req-3"}, sql.NullString{Valid: true, String: "req-3"},
"gpt-5.4", "gpt-5.4",
sql.NullString{},
sql.NullInt64{}, sql.NullInt64{},
sql.NullInt64{}, sql.NullInt64{},
1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6,
......
...@@ -140,6 +140,27 @@ func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTi ...@@ -140,6 +140,27 @@ func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTi
return stats, nil return stats, nil
} }
func (s *DashboardService) GetModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, modelSource string) ([]usagestats.ModelStat, error) {
normalizedSource := usagestats.NormalizeModelSource(modelSource)
if normalizedSource == usagestats.ModelSourceRequested {
return s.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
}
type modelStatsBySourceRepo interface {
GetModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, source string) ([]usagestats.ModelStat, error)
}
if sourceRepo, ok := s.usageRepo.(modelStatsBySourceRepo); ok {
stats, err := sourceRepo.GetModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, normalizedSource)
if err != nil {
return nil, fmt.Errorf("get model stats with filters by source: %w", err)
}
return stats, nil
}
return s.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
}
func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) { func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
stats, err := s.usageRepo.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) stats, err := s.usageRepo.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
if err != nil { if err != nil {
......
...@@ -788,7 +788,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_NonStreamingSuc ...@@ -788,7 +788,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_NonStreamingSuc
rateLimitService: &RateLimitService{}, rateLimitService: &RateLimitService{},
} }
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), body, "claude-3-5-sonnet-latest", false, time.Now()) result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), body, "claude-3-5-sonnet-latest", "claude-3-5-sonnet-latest", false, time.Now())
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.Equal(t, 12, result.Usage.InputTokens) require.Equal(t, 12, result.Usage.InputTokens)
...@@ -815,7 +815,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_InvalidTokenTyp ...@@ -815,7 +815,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_InvalidTokenTyp
} }
svc := &GatewayService{} svc := &GatewayService{}
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{}`), "claude-3-5-sonnet-latest", false, time.Now()) result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{}`), "claude-3-5-sonnet-latest", "claude-3-5-sonnet-latest", false, time.Now())
require.Nil(t, result) require.Nil(t, result)
require.Error(t, err) require.Error(t, err)
require.Contains(t, err.Error(), "requires apikey token") require.Contains(t, err.Error(), "requires apikey token")
...@@ -840,7 +840,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_UpstreamRequest ...@@ -840,7 +840,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_UpstreamRequest
} }
account := newAnthropicAPIKeyAccountForTest() account := newAnthropicAPIKeyAccountForTest()
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{"model":"x"}`), "x", false, time.Now()) result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{"model":"x"}`), "x", "x", false, time.Now())
require.Nil(t, result) require.Nil(t, result)
require.Error(t, err) require.Error(t, err)
require.Contains(t, err.Error(), "upstream request failed") require.Contains(t, err.Error(), "upstream request failed")
...@@ -873,7 +873,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_EmptyResponseBo ...@@ -873,7 +873,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_EmptyResponseBo
httpUpstream: upstream, httpUpstream: upstream,
} }
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), []byte(`{"model":"x"}`), "x", false, time.Now()) result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), []byte(`{"model":"x"}`), "x", "x", false, time.Now())
require.Nil(t, result) require.Nil(t, result)
require.Error(t, err) require.Error(t, err)
require.Contains(t, err.Error(), "empty response") require.Contains(t, err.Error(), "empty response")
......
...@@ -490,6 +490,7 @@ type ForwardResult struct { ...@@ -490,6 +490,7 @@ type ForwardResult struct {
RequestID string RequestID string
Usage ClaudeUsage Usage ClaudeUsage
Model string Model string
UpstreamModel string // Actual upstream model after mapping (empty = no mapping)
Stream bool Stream bool
Duration time.Duration Duration time.Duration
FirstTokenMs *int // 首字时间(流式请求) FirstTokenMs *int // 首字时间(流式请求)
...@@ -3988,7 +3989,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -3988,7 +3989,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
passthroughModel = mappedModel passthroughModel = mappedModel
} }
} }
return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody, passthroughModel, parsed.Stream, startTime) return s.forwardAnthropicAPIKeyPassthroughWithInput(ctx, c, account, anthropicPassthroughForwardInput{
Body: passthroughBody,
RequestModel: passthroughModel,
OriginalModel: parsed.Model,
RequestStream: parsed.Stream,
StartTime: startTime,
})
} }
if account != nil && account.IsBedrock() { if account != nil && account.IsBedrock() {
...@@ -4512,6 +4519,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -4512,6 +4519,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
RequestID: resp.Header.Get("x-request-id"), RequestID: resp.Header.Get("x-request-id"),
Usage: *usage, Usage: *usage,
Model: originalModel, // 使用原始模型用于计费和日志 Model: originalModel, // 使用原始模型用于计费和日志
UpstreamModel: mappedModel,
Stream: reqStream, Stream: reqStream,
Duration: time.Since(startTime), Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs, FirstTokenMs: firstTokenMs,
...@@ -4519,14 +4527,38 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -4519,14 +4527,38 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}, nil }, nil
} }
type anthropicPassthroughForwardInput struct {
Body []byte
RequestModel string
OriginalModel string
RequestStream bool
StartTime time.Time
}
func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( func (s *GatewayService) forwardAnthropicAPIKeyPassthrough(
ctx context.Context, ctx context.Context,
c *gin.Context, c *gin.Context,
account *Account, account *Account,
body []byte, body []byte,
reqModel string, reqModel string,
originalModel string,
reqStream bool, reqStream bool,
startTime time.Time, startTime time.Time,
) (*ForwardResult, error) {
return s.forwardAnthropicAPIKeyPassthroughWithInput(ctx, c, account, anthropicPassthroughForwardInput{
Body: body,
RequestModel: reqModel,
OriginalModel: originalModel,
RequestStream: reqStream,
StartTime: startTime,
})
}
func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput(
ctx context.Context,
c *gin.Context,
account *Account,
input anthropicPassthroughForwardInput,
) (*ForwardResult, error) { ) (*ForwardResult, error) {
token, tokenType, err := s.GetAccessToken(ctx, account) token, tokenType, err := s.GetAccessToken(ctx, account)
if err != nil { if err != nil {
...@@ -4542,19 +4574,19 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( ...@@ -4542,19 +4574,19 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough(
} }
logger.LegacyPrintf("service.gateway", "[Anthropic 自动透传] 命中 API Key 透传分支: account=%d name=%s model=%s stream=%v", logger.LegacyPrintf("service.gateway", "[Anthropic 自动透传] 命中 API Key 透传分支: account=%d name=%s model=%s stream=%v",
account.ID, account.Name, reqModel, reqStream) account.ID, account.Name, input.RequestModel, input.RequestStream)
if c != nil { if c != nil {
c.Set("anthropic_passthrough", true) c.Set("anthropic_passthrough", true)
} }
// 重试间复用同一请求体,避免每次 string(body) 产生额外分配。 // 重试间复用同一请求体,避免每次 string(body) 产生额外分配。
setOpsUpstreamRequestBody(c, body) setOpsUpstreamRequestBody(c, input.Body)
var resp *http.Response var resp *http.Response
retryStart := time.Now() retryStart := time.Now()
for attempt := 1; attempt <= maxRetryAttempts; attempt++ { for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, input.RequestStream)
upstreamReq, err := s.buildUpstreamRequestAnthropicAPIKeyPassthrough(upstreamCtx, c, account, body, token) upstreamReq, err := s.buildUpstreamRequestAnthropicAPIKeyPassthrough(upstreamCtx, c, account, input.Body, token)
releaseUpstreamCtx() releaseUpstreamCtx()
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -4712,8 +4744,8 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( ...@@ -4712,8 +4744,8 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough(
var usage *ClaudeUsage var usage *ClaudeUsage
var firstTokenMs *int var firstTokenMs *int
var clientDisconnect bool var clientDisconnect bool
if reqStream { if input.RequestStream {
streamResult, err := s.handleStreamingResponseAnthropicAPIKeyPassthrough(ctx, resp, c, account, startTime, reqModel) streamResult, err := s.handleStreamingResponseAnthropicAPIKeyPassthrough(ctx, resp, c, account, input.StartTime, input.RequestModel)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -4733,9 +4765,10 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( ...@@ -4733,9 +4765,10 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough(
return &ForwardResult{ return &ForwardResult{
RequestID: resp.Header.Get("x-request-id"), RequestID: resp.Header.Get("x-request-id"),
Usage: *usage, Usage: *usage,
Model: reqModel, Model: input.OriginalModel,
Stream: reqStream, UpstreamModel: input.RequestModel,
Duration: time.Since(startTime), Stream: input.RequestStream,
Duration: time.Since(input.StartTime),
FirstTokenMs: firstTokenMs, FirstTokenMs: firstTokenMs,
ClientDisconnect: clientDisconnect, ClientDisconnect: clientDisconnect,
}, nil }, nil
...@@ -5240,6 +5273,7 @@ func (s *GatewayService) forwardBedrock( ...@@ -5240,6 +5273,7 @@ func (s *GatewayService) forwardBedrock(
RequestID: resp.Header.Get("x-amzn-requestid"), RequestID: resp.Header.Get("x-amzn-requestid"),
Usage: *usage, Usage: *usage,
Model: reqModel, Model: reqModel,
UpstreamModel: mappedModel,
Stream: reqStream, Stream: reqStream,
Duration: time.Since(startTime), Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs, FirstTokenMs: firstTokenMs,
...@@ -7530,6 +7564,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -7530,6 +7564,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
AccountID: account.ID, AccountID: account.ID,
RequestID: requestID, RequestID: requestID,
Model: result.Model, Model: result.Model,
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
ReasoningEffort: result.ReasoningEffort, ReasoningEffort: result.ReasoningEffort,
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint), UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint),
...@@ -7711,6 +7746,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * ...@@ -7711,6 +7746,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
AccountID: account.ID, AccountID: account.ID,
RequestID: requestID, RequestID: requestID,
Model: result.Model, Model: result.Model,
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
ReasoningEffort: result.ReasoningEffort, ReasoningEffort: result.ReasoningEffort,
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint), UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint),
......
...@@ -281,6 +281,7 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse( ...@@ -281,6 +281,7 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse(
Usage: usage, Usage: usage,
Model: originalModel, Model: originalModel,
BillingModel: mappedModel, BillingModel: mappedModel,
UpstreamModel: mappedModel,
Stream: false, Stream: false,
Duration: time.Since(startTime), Duration: time.Since(startTime),
}, nil }, nil
...@@ -328,6 +329,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( ...@@ -328,6 +329,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
Usage: usage, Usage: usage,
Model: originalModel, Model: originalModel,
BillingModel: mappedModel, BillingModel: mappedModel,
UpstreamModel: mappedModel,
Stream: true, Stream: true,
Duration: time.Since(startTime), Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs, FirstTokenMs: firstTokenMs,
......
...@@ -303,6 +303,7 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse( ...@@ -303,6 +303,7 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
Usage: usage, Usage: usage,
Model: originalModel, Model: originalModel,
BillingModel: mappedModel, BillingModel: mappedModel,
UpstreamModel: mappedModel,
Stream: false, Stream: false,
Duration: time.Since(startTime), Duration: time.Since(startTime),
}, nil }, nil
...@@ -351,6 +352,7 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( ...@@ -351,6 +352,7 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
Usage: usage, Usage: usage,
Model: originalModel, Model: originalModel,
BillingModel: mappedModel, BillingModel: mappedModel,
UpstreamModel: mappedModel,
Stream: true, Stream: true,
Duration: time.Since(startTime), Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs, FirstTokenMs: firstTokenMs,
......
...@@ -846,7 +846,7 @@ func TestExtractOpenAIServiceTierFromBody(t *testing.T) { ...@@ -846,7 +846,7 @@ func TestExtractOpenAIServiceTierFromBody(t *testing.T) {
require.Nil(t, extractOpenAIServiceTierFromBody(nil)) require.Nil(t, extractOpenAIServiceTierFromBody(nil))
} }
func TestOpenAIGatewayServiceRecordUsage_UsesBillingModelAndMetadataFields(t *testing.T) { func TestOpenAIGatewayServiceRecordUsage_UsesRequestedModelAndUpstreamModelMetadataFields(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{} userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{} subRepo := &openAIRecordUsageSubRepoStub{}
...@@ -859,6 +859,7 @@ func TestOpenAIGatewayServiceRecordUsage_UsesBillingModelAndMetadataFields(t *te ...@@ -859,6 +859,7 @@ func TestOpenAIGatewayServiceRecordUsage_UsesBillingModelAndMetadataFields(t *te
RequestID: "resp_billing_model_override", RequestID: "resp_billing_model_override",
BillingModel: "gpt-5.1-codex", BillingModel: "gpt-5.1-codex",
Model: "gpt-5.1", Model: "gpt-5.1",
UpstreamModel: "gpt-5.1-codex",
ServiceTier: &serviceTier, ServiceTier: &serviceTier,
ReasoningEffort: &reasoning, ReasoningEffort: &reasoning,
Usage: OpenAIUsage{ Usage: OpenAIUsage{
...@@ -877,7 +878,9 @@ func TestOpenAIGatewayServiceRecordUsage_UsesBillingModelAndMetadataFields(t *te ...@@ -877,7 +878,9 @@ func TestOpenAIGatewayServiceRecordUsage_UsesBillingModelAndMetadataFields(t *te
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog) require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "gpt-5.1-codex", usageRepo.lastLog.Model) require.Equal(t, "gpt-5.1", usageRepo.lastLog.Model)
require.NotNil(t, usageRepo.lastLog.UpstreamModel)
require.Equal(t, "gpt-5.1-codex", *usageRepo.lastLog.UpstreamModel)
require.NotNil(t, usageRepo.lastLog.ServiceTier) require.NotNil(t, usageRepo.lastLog.ServiceTier)
require.Equal(t, serviceTier, *usageRepo.lastLog.ServiceTier) require.Equal(t, serviceTier, *usageRepo.lastLog.ServiceTier)
require.NotNil(t, usageRepo.lastLog.ReasoningEffort) require.NotNil(t, usageRepo.lastLog.ReasoningEffort)
......
...@@ -216,6 +216,9 @@ type OpenAIForwardResult struct { ...@@ -216,6 +216,9 @@ type OpenAIForwardResult struct {
// This is set by the Anthropic Messages conversion path where // This is set by the Anthropic Messages conversion path where
// the mapped upstream model differs from the client-facing model. // the mapped upstream model differs from the client-facing model.
BillingModel string BillingModel string
// UpstreamModel is the actual model sent to the upstream provider after mapping.
// Empty when no mapping was applied (requested model was used as-is).
UpstreamModel string
// ServiceTier records the OpenAI Responses API service tier, e.g. "priority" / "flex". // ServiceTier records the OpenAI Responses API service tier, e.g. "priority" / "flex".
// Nil means the request did not specify a recognized tier. // Nil means the request did not specify a recognized tier.
ServiceTier *string ServiceTier *string
...@@ -2128,6 +2131,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco ...@@ -2128,6 +2131,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
firstTokenMs, firstTokenMs,
wsAttempts, wsAttempts,
) )
wsResult.UpstreamModel = mappedModel
return wsResult, nil return wsResult, nil
} }
s.writeOpenAIWSFallbackErrorResponse(c, account, wsErr) s.writeOpenAIWSFallbackErrorResponse(c, account, wsErr)
...@@ -2263,6 +2267,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco ...@@ -2263,6 +2267,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
RequestID: resp.Header.Get("x-request-id"), RequestID: resp.Header.Get("x-request-id"),
Usage: *usage, Usage: *usage,
Model: originalModel, Model: originalModel,
UpstreamModel: mappedModel,
ServiceTier: serviceTier, ServiceTier: serviceTier,
ReasoningEffort: reasoningEffort, ReasoningEffort: reasoningEffort,
Stream: reqStream, Stream: reqStream,
...@@ -4134,7 +4139,8 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec ...@@ -4134,7 +4139,8 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
APIKeyID: apiKey.ID, APIKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
RequestID: requestID, RequestID: requestID,
Model: billingModel, Model: result.Model,
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
ServiceTier: result.ServiceTier, ServiceTier: result.ServiceTier,
ReasoningEffort: result.ReasoningEffort, ReasoningEffort: result.ReasoningEffort,
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
...@@ -4700,11 +4706,3 @@ func normalizeOpenAIReasoningEffort(raw string) string { ...@@ -4700,11 +4706,3 @@ func normalizeOpenAIReasoningEffort(raw string) string {
return "" return ""
} }
} }
func optionalTrimmedStringPtr(raw string) *string {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return nil
}
return &trimmed
}
...@@ -98,6 +98,9 @@ type UsageLog struct { ...@@ -98,6 +98,9 @@ type UsageLog struct {
AccountID int64 AccountID int64
RequestID string RequestID string
Model string Model string
// UpstreamModel is the actual model sent to the upstream provider after mapping.
// Nil means no mapping was applied (requested model was used as-is).
UpstreamModel *string
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex". // ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
ServiceTier *string ServiceTier *string
// ReasoningEffort is the request's reasoning effort level. // ReasoningEffort is the request's reasoning effort level.
......
package service
import "strings"
func optionalTrimmedStringPtr(raw string) *string {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return nil
}
return &trimmed
}
// optionalNonEqualStringPtr returns a pointer to value if it is non-empty and
// differs from compare; otherwise nil. Used to store upstream_model only when
// it differs from the requested model.
func optionalNonEqualStringPtr(value, compare string) *string {
if value == "" || value == compare {
return nil
}
return &value
}
-- Add upstream_model field to usage_logs.
-- Stores the actual upstream model name when it differs from the requested model
-- (i.e., when model mapping is applied). NULL means no mapping was applied.
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS upstream_model VARCHAR(100);
-- Support upstream_model / mapping model distribution aggregations with time-range filters.
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_usage_logs_created_model_upstream_model
ON usage_logs (created_at, model, upstream_model);
...@@ -34,18 +34,18 @@ Example: `017_add_gemini_tier_id.sql` ...@@ -34,18 +34,18 @@ Example: `017_add_gemini_tier_id.sql`
## Migration File Structure ## Migration File Structure
This project uses a custom migration runner (`internal/repository/migrations_runner.go`) that executes the full SQL file content as-is.
- Regular migrations (`*.sql`): executed in a transaction.
- Non-transactional migrations (`*_notx.sql`): split by statement and executed without transaction (for `CONCURRENTLY`).
```sql ```sql
-- +goose Up -- Forward-only migration (recommended)
-- +goose StatementBegin ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS example_column VARCHAR(100);
-- Your forward migration SQL here
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
-- Your rollback migration SQL here
-- +goose StatementEnd
``` ```
> ⚠️ Do **not** place executable "Down" SQL in the same file. The runner does not parse goose Up/Down sections and will execute all SQL statements in the file.
## Important Rules ## Important Rules
### ⚠️ Immutability Principle ### ⚠️ Immutability Principle
...@@ -66,9 +66,9 @@ Why? ...@@ -66,9 +66,9 @@ Why?
touch migrations/018_your_change.sql touch migrations/018_your_change.sql
``` ```
2. **Write Up and Down migrations** 2. **Write forward-only migration SQL**
- Up: Apply the change - Put only the intended schema change in the file
- Down: Revert the change (should be symmetric with Up) - If rollback is needed, create a new migration file to revert
3. **Test locally** 3. **Test locally**
```bash ```bash
...@@ -144,8 +144,6 @@ touch migrations/018_your_new_change.sql ...@@ -144,8 +144,6 @@ touch migrations/018_your_new_change.sql
## Example Migration ## Example Migration
```sql ```sql
-- +goose Up
-- +goose StatementBegin
-- Add tier_id field to Gemini OAuth accounts for quota tracking -- Add tier_id field to Gemini OAuth accounts for quota tracking
UPDATE accounts UPDATE accounts
SET credentials = jsonb_set( SET credentials = jsonb_set(
...@@ -157,17 +155,6 @@ SET credentials = jsonb_set( ...@@ -157,17 +155,6 @@ SET credentials = jsonb_set(
WHERE platform = 'gemini' WHERE platform = 'gemini'
AND type = 'oauth' AND type = 'oauth'
AND credentials->>'tier_id' IS NULL; AND credentials->>'tier_id' IS NULL;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
-- Remove tier_id field
UPDATE accounts
SET credentials = credentials - 'tier_id'
WHERE platform = 'gemini'
AND type = 'oauth'
AND credentials->>'tier_id' = 'LEGACY';
-- +goose StatementEnd
``` ```
## Troubleshooting ## Troubleshooting
...@@ -194,5 +181,4 @@ VALUES ('NNN_migration.sql', 'calculated_checksum', NOW()); ...@@ -194,5 +181,4 @@ VALUES ('NNN_migration.sql', 'calculated_checksum', NOW());
## References ## References
- Migration runner: `internal/repository/migrations_runner.go` - Migration runner: `internal/repository/migrations_runner.go`
- Goose syntax: https://github.com/pressly/goose
- PostgreSQL docs: https://www.postgresql.org/docs/ - PostgreSQL docs: https://www.postgresql.org/docs/
...@@ -81,6 +81,7 @@ export interface ModelStatsParams { ...@@ -81,6 +81,7 @@ export interface ModelStatsParams {
user_id?: number user_id?: number
api_key_id?: number api_key_id?: number
model?: string model?: string
model_source?: 'requested' | 'upstream' | 'mapping'
account_id?: number account_id?: number
group_id?: number group_id?: number
request_type?: UsageRequestType request_type?: UsageRequestType
...@@ -162,6 +163,7 @@ export interface UserBreakdownParams { ...@@ -162,6 +163,7 @@ export interface UserBreakdownParams {
end_date?: string end_date?: string
group_id?: number group_id?: number
model?: string model?: string
model_source?: 'requested' | 'upstream' | 'mapping'
endpoint?: string endpoint?: string
endpoint_type?: 'inbound' | 'upstream' | 'path' endpoint_type?: 'inbound' | 'upstream' | 'path'
limit?: number limit?: number
......
...@@ -25,8 +25,16 @@ ...@@ -25,8 +25,16 @@
<span class="text-sm text-gray-900 dark:text-white">{{ row.account?.name || '-' }}</span> <span class="text-sm text-gray-900 dark:text-white">{{ row.account?.name || '-' }}</span>
</template> </template>
<template #cell-model="{ value }"> <template #cell-model="{ row }">
<span class="font-medium text-gray-900 dark:text-white">{{ value }}</span> <div v-if="row.upstream_model && row.upstream_model !== row.model" class="space-y-0.5 text-xs">
<div class="break-all font-medium text-gray-900 dark:text-white">
{{ row.model }}
</div>
<div class="break-all text-gray-500 dark:text-gray-400">
<span class="mr-0.5"></span>{{ row.upstream_model }}
</div>
</div>
<span v-else class="font-medium text-gray-900 dark:text-white">{{ row.model }}</span>
</template> </template>
<template #cell-reasoning_effort="{ row }"> <template #cell-reasoning_effort="{ row }">
......
<template> <template>
<div class="card p-4"> <div class="card p-4">
<div class="mb-4 flex items-start justify-between gap-3"> <div class="mb-4 flex items-center justify-between gap-3">
<h3 class="text-sm font-semibold text-gray-900 dark:text-white"> <h3 class="text-sm font-semibold text-gray-900 dark:text-white">
{{ title || t('usage.endpointDistribution') }} {{ title || t('usage.endpointDistribution') }}
</h3> </h3>
<div class="flex flex-col items-end gap-2"> <div class="flex flex-wrap items-center justify-end gap-2">
<div <div
v-if="showSourceToggle" v-if="showSourceToggle"
class="inline-flex rounded-lg border border-gray-200 bg-gray-50 p-0.5 dark:border-gray-700 dark:bg-dark-800" class="inline-flex rounded-lg border border-gray-200 bg-gray-50 p-0.5 dark:border-gray-700 dark:bg-dark-800"
......
...@@ -6,7 +6,42 @@ ...@@ -6,7 +6,42 @@
? t('admin.dashboard.modelDistribution') ? t('admin.dashboard.modelDistribution')
: t('admin.dashboard.spendingRankingTitle') }} : t('admin.dashboard.spendingRankingTitle') }}
</h3> </h3>
<div class="flex items-center gap-2"> <div class="flex flex-wrap items-center justify-end gap-2">
<div
v-if="showSourceToggle"
class="inline-flex rounded-lg border border-gray-200 bg-gray-50 p-0.5 dark:border-gray-700 dark:bg-dark-800"
>
<button
type="button"
class="rounded-md px-2.5 py-1 text-xs font-medium transition-colors"
:class="source === 'requested'
? 'bg-white text-gray-900 shadow-sm dark:bg-dark-700 dark:text-white'
: 'text-gray-500 hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-200'"
@click="emit('update:source', 'requested')"
>
{{ t('usage.requestedModel') }}
</button>
<button
type="button"
class="rounded-md px-2.5 py-1 text-xs font-medium transition-colors"
:class="source === 'upstream'
? 'bg-white text-gray-900 shadow-sm dark:bg-dark-700 dark:text-white'
: 'text-gray-500 hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-200'"
@click="emit('update:source', 'upstream')"
>
{{ t('usage.upstreamModel') }}
</button>
<button
type="button"
class="rounded-md px-2.5 py-1 text-xs font-medium transition-colors"
:class="source === 'mapping'
? 'bg-white text-gray-900 shadow-sm dark:bg-dark-700 dark:text-white'
: 'text-gray-500 hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-200'"
@click="emit('update:source', 'mapping')"
>
{{ t('usage.mapping') }}
</button>
</div>
<div <div
v-if="showMetricToggle" v-if="showMetricToggle"
class="inline-flex rounded-lg border border-gray-200 bg-gray-50 p-0.5 dark:border-gray-700 dark:bg-dark-800" class="inline-flex rounded-lg border border-gray-200 bg-gray-50 p-0.5 dark:border-gray-700 dark:bg-dark-800"
...@@ -215,9 +250,13 @@ ChartJS.register(ArcElement, Tooltip, Legend) ...@@ -215,9 +250,13 @@ ChartJS.register(ArcElement, Tooltip, Legend)
const { t } = useI18n() const { t } = useI18n()
type DistributionMetric = 'tokens' | 'actual_cost' type DistributionMetric = 'tokens' | 'actual_cost'
type ModelSource = 'requested' | 'upstream' | 'mapping'
type RankingDisplayItem = UserSpendingRankingItem & { isOther?: boolean } type RankingDisplayItem = UserSpendingRankingItem & { isOther?: boolean }
const props = withDefaults(defineProps<{ const props = withDefaults(defineProps<{
modelStats: ModelStat[] modelStats: ModelStat[]
upstreamModelStats?: ModelStat[]
mappingModelStats?: ModelStat[]
source?: ModelSource
enableRankingView?: boolean enableRankingView?: boolean
rankingItems?: UserSpendingRankingItem[] rankingItems?: UserSpendingRankingItem[]
rankingTotalActualCost?: number rankingTotalActualCost?: number
...@@ -225,12 +264,16 @@ const props = withDefaults(defineProps<{ ...@@ -225,12 +264,16 @@ const props = withDefaults(defineProps<{
rankingTotalTokens?: number rankingTotalTokens?: number
loading?: boolean loading?: boolean
metric?: DistributionMetric metric?: DistributionMetric
showSourceToggle?: boolean
showMetricToggle?: boolean showMetricToggle?: boolean
rankingLoading?: boolean rankingLoading?: boolean
rankingError?: boolean rankingError?: boolean
startDate?: string startDate?: string
endDate?: string endDate?: string
}>(), { }>(), {
upstreamModelStats: () => [],
mappingModelStats: () => [],
source: 'requested',
enableRankingView: false, enableRankingView: false,
rankingItems: () => [], rankingItems: () => [],
rankingTotalActualCost: 0, rankingTotalActualCost: 0,
...@@ -238,6 +281,7 @@ const props = withDefaults(defineProps<{ ...@@ -238,6 +281,7 @@ const props = withDefaults(defineProps<{
rankingTotalTokens: 0, rankingTotalTokens: 0,
loading: false, loading: false,
metric: 'tokens', metric: 'tokens',
showSourceToggle: false,
showMetricToggle: false, showMetricToggle: false,
rankingLoading: false, rankingLoading: false,
rankingError: false rankingError: false
...@@ -261,6 +305,7 @@ const toggleBreakdown = async (type: string, id: string) => { ...@@ -261,6 +305,7 @@ const toggleBreakdown = async (type: string, id: string) => {
start_date: props.startDate, start_date: props.startDate,
end_date: props.endDate, end_date: props.endDate,
model: id, model: id,
model_source: props.source,
}) })
breakdownItems.value = res.users || [] breakdownItems.value = res.users || []
} catch { } catch {
...@@ -272,6 +317,7 @@ const toggleBreakdown = async (type: string, id: string) => { ...@@ -272,6 +317,7 @@ const toggleBreakdown = async (type: string, id: string) => {
const emit = defineEmits<{ const emit = defineEmits<{
'update:metric': [value: DistributionMetric] 'update:metric': [value: DistributionMetric]
'update:source': [value: ModelSource]
'ranking-click': [item: UserSpendingRankingItem] 'ranking-click': [item: UserSpendingRankingItem]
}>() }>()
...@@ -294,14 +340,19 @@ const chartColors = [ ...@@ -294,14 +340,19 @@ const chartColors = [
] ]
const displayModelStats = computed(() => { const displayModelStats = computed(() => {
if (!props.modelStats?.length) return [] const sourceStats = props.source === 'upstream'
? props.upstreamModelStats
: props.source === 'mapping'
? props.mappingModelStats
: props.modelStats
if (!sourceStats?.length) return []
const metricKey = props.metric === 'actual_cost' ? 'actual_cost' : 'total_tokens' const metricKey = props.metric === 'actual_cost' ? 'actual_cost' : 'total_tokens'
return [...props.modelStats].sort((a, b) => b[metricKey] - a[metricKey]) return [...sourceStats].sort((a, b) => b[metricKey] - a[metricKey])
}) })
const chartData = computed(() => { const chartData = computed(() => {
if (!props.modelStats?.length) return null if (!displayModelStats.value.length) return null
return { return {
labels: displayModelStats.value.map((m) => m.model), labels: displayModelStats.value.map((m) => m.model),
......
...@@ -718,11 +718,14 @@ export default { ...@@ -718,11 +718,14 @@ export default {
exporting: 'Exporting...', exporting: 'Exporting...',
preparingExport: 'Preparing export...', preparingExport: 'Preparing export...',
model: 'Model', model: 'Model',
requestedModel: 'Requested',
upstreamModel: 'Upstream',
reasoningEffort: 'Reasoning Effort', reasoningEffort: 'Reasoning Effort',
endpoint: 'Endpoint', endpoint: 'Endpoint',
endpointDistribution: 'Endpoint Distribution', endpointDistribution: 'Endpoint Distribution',
inbound: 'Inbound', inbound: 'Inbound',
upstream: 'Upstream', upstream: 'Upstream',
mapping: 'Mapping',
path: 'Path', path: 'Path',
inboundEndpoint: 'Inbound Endpoint', inboundEndpoint: 'Inbound Endpoint',
upstreamEndpoint: 'Upstream Endpoint', upstreamEndpoint: 'Upstream Endpoint',
......
...@@ -723,11 +723,14 @@ export default { ...@@ -723,11 +723,14 @@ export default {
exporting: '导出中...', exporting: '导出中...',
preparingExport: '正在准备导出...', preparingExport: '正在准备导出...',
model: '模型', model: '模型',
requestedModel: '请求',
upstreamModel: '上游',
reasoningEffort: '推理强度', reasoningEffort: '推理强度',
endpoint: '端点', endpoint: '端点',
endpointDistribution: '端点分布', endpointDistribution: '端点分布',
inbound: '入站', inbound: '入站',
upstream: '上游', upstream: '上游',
mapping: '映射',
path: '路径', path: '路径',
inboundEndpoint: '入站端点', inboundEndpoint: '入站端点',
upstreamEndpoint: '上游端点', upstreamEndpoint: '上游端点',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment