Unverified Commit 3084330d authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #1019 from Ethan0x0000/feat/usage-endpoint-distribution

feat: add endpoint metadata and usage endpoint distribution insights
parents b566649e cf924775
...@@ -523,6 +523,8 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog { ...@@ -523,6 +523,8 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
Model: l.Model, Model: l.Model,
ServiceTier: l.ServiceTier, ServiceTier: l.ServiceTier,
ReasoningEffort: l.ReasoningEffort, ReasoningEffort: l.ReasoningEffort,
InboundEndpoint: l.InboundEndpoint,
UpstreamEndpoint: l.UpstreamEndpoint,
GroupID: l.GroupID, GroupID: l.GroupID,
SubscriptionID: l.SubscriptionID, SubscriptionID: l.SubscriptionID,
InputTokens: l.InputTokens, InputTokens: l.InputTokens,
......
...@@ -76,10 +76,14 @@ func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) { ...@@ -76,10 +76,14 @@ func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
t.Parallel() t.Parallel()
serviceTier := "priority" serviceTier := "priority"
inboundEndpoint := "/v1/chat/completions"
upstreamEndpoint := "/v1/responses"
log := &service.UsageLog{ log := &service.UsageLog{
RequestID: "req_3", RequestID: "req_3",
Model: "gpt-5.4", Model: "gpt-5.4",
ServiceTier: &serviceTier, ServiceTier: &serviceTier,
InboundEndpoint: &inboundEndpoint,
UpstreamEndpoint: &upstreamEndpoint,
AccountRateMultiplier: f64Ptr(1.5), AccountRateMultiplier: f64Ptr(1.5),
} }
...@@ -88,8 +92,16 @@ func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) { ...@@ -88,8 +92,16 @@ func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
require.NotNil(t, userDTO.ServiceTier) require.NotNil(t, userDTO.ServiceTier)
require.Equal(t, serviceTier, *userDTO.ServiceTier) require.Equal(t, serviceTier, *userDTO.ServiceTier)
require.NotNil(t, userDTO.InboundEndpoint)
require.Equal(t, inboundEndpoint, *userDTO.InboundEndpoint)
require.NotNil(t, userDTO.UpstreamEndpoint)
require.Equal(t, upstreamEndpoint, *userDTO.UpstreamEndpoint)
require.NotNil(t, adminDTO.ServiceTier) require.NotNil(t, adminDTO.ServiceTier)
require.Equal(t, serviceTier, *adminDTO.ServiceTier) require.Equal(t, serviceTier, *adminDTO.ServiceTier)
require.NotNil(t, adminDTO.InboundEndpoint)
require.Equal(t, inboundEndpoint, *adminDTO.InboundEndpoint)
require.NotNil(t, adminDTO.UpstreamEndpoint)
require.Equal(t, upstreamEndpoint, *adminDTO.UpstreamEndpoint)
require.NotNil(t, adminDTO.AccountRateMultiplier) require.NotNil(t, adminDTO.AccountRateMultiplier)
require.InDelta(t, 1.5, *adminDTO.AccountRateMultiplier, 1e-12) require.InDelta(t, 1.5, *adminDTO.AccountRateMultiplier, 1e-12)
} }
......
...@@ -337,6 +337,10 @@ type UsageLog struct { ...@@ -337,6 +337,10 @@ type UsageLog struct {
// ReasoningEffort is the request's reasoning effort level (OpenAI Responses API). // ReasoningEffort is the request's reasoning effort level (OpenAI Responses API).
// nil means not provided / not applicable. // nil means not provided / not applicable.
ReasoningEffort *string `json:"reasoning_effort,omitempty"` ReasoningEffort *string `json:"reasoning_effort,omitempty"`
// InboundEndpoint is the client-facing API endpoint path, e.g. /v1/chat/completions.
InboundEndpoint *string `json:"inbound_endpoint,omitempty"`
// UpstreamEndpoint is the normalized upstream endpoint path, e.g. /v1/responses.
UpstreamEndpoint *string `json:"upstream_endpoint,omitempty"`
GroupID *int64 `json:"group_id"` GroupID *int64 `json:"group_id"`
SubscriptionID *int64 `json:"subscription_id"` SubscriptionID *int64 `json:"subscription_id"`
......
...@@ -261,6 +261,8 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { ...@@ -261,6 +261,8 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
User: apiKey.User, User: apiKey.User,
Account: account, Account: account,
Subscription: subscription, Subscription: subscription,
InboundEndpoint: normalizedOpenAIInboundEndpoint(c, openAIInboundEndpointChatCompletions),
UpstreamEndpoint: normalizedOpenAIUpstreamEndpoint(c, openAIUpstreamEndpointResponses),
UserAgent: userAgent, UserAgent: userAgent,
IPAddress: clientIP, IPAddress: clientIP,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
......
package handler
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestNormalizedOpenAIUpstreamEndpoint(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
path string
fallback string
want string
}{
{
name: "responses root maps to responses upstream",
path: "/v1/responses",
fallback: openAIUpstreamEndpointResponses,
want: "/v1/responses",
},
{
name: "responses compact keeps compact suffix",
path: "/openai/v1/responses/compact",
fallback: openAIUpstreamEndpointResponses,
want: "/v1/responses/compact",
},
{
name: "responses nested suffix preserved",
path: "/openai/v1/responses/compact/detail",
fallback: openAIUpstreamEndpointResponses,
want: "/v1/responses/compact/detail",
},
{
name: "non responses path uses fallback",
path: "/v1/messages",
fallback: openAIUpstreamEndpointResponses,
want: "/v1/responses",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, tt.path, nil)
got := normalizedOpenAIUpstreamEndpoint(c, tt.fallback)
require.Equal(t, tt.want, got)
})
}
}
...@@ -37,6 +37,13 @@ type OpenAIGatewayHandler struct { ...@@ -37,6 +37,13 @@ type OpenAIGatewayHandler struct {
cfg *config.Config cfg *config.Config
} }
const (
openAIInboundEndpointResponses = "/v1/responses"
openAIInboundEndpointMessages = "/v1/messages"
openAIInboundEndpointChatCompletions = "/v1/chat/completions"
openAIUpstreamEndpointResponses = "/v1/responses"
)
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler // NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
func NewOpenAIGatewayHandler( func NewOpenAIGatewayHandler(
gatewayService *service.OpenAIGatewayService, gatewayService *service.OpenAIGatewayService,
...@@ -362,6 +369,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -362,6 +369,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
User: apiKey.User, User: apiKey.User,
Account: account, Account: account,
Subscription: subscription, Subscription: subscription,
InboundEndpoint: normalizedOpenAIInboundEndpoint(c, openAIInboundEndpointResponses),
UpstreamEndpoint: normalizedOpenAIUpstreamEndpoint(c, openAIUpstreamEndpointResponses),
UserAgent: userAgent, UserAgent: userAgent,
IPAddress: clientIP, IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash, RequestPayloadHash: requestPayloadHash,
...@@ -738,6 +747,8 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { ...@@ -738,6 +747,8 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
User: apiKey.User, User: apiKey.User,
Account: account, Account: account,
Subscription: subscription, Subscription: subscription,
InboundEndpoint: normalizedOpenAIInboundEndpoint(c, openAIInboundEndpointMessages),
UpstreamEndpoint: normalizedOpenAIUpstreamEndpoint(c, openAIUpstreamEndpointResponses),
UserAgent: userAgent, UserAgent: userAgent,
IPAddress: clientIP, IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash, RequestPayloadHash: requestPayloadHash,
...@@ -1235,6 +1246,8 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { ...@@ -1235,6 +1246,8 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
User: apiKey.User, User: apiKey.User,
Account: account, Account: account,
Subscription: subscription, Subscription: subscription,
InboundEndpoint: normalizedOpenAIInboundEndpoint(c, openAIInboundEndpointResponses),
UpstreamEndpoint: normalizedOpenAIUpstreamEndpoint(c, openAIUpstreamEndpointResponses),
UserAgent: userAgent, UserAgent: userAgent,
IPAddress: clientIP, IPAddress: clientIP,
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage), RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
...@@ -1530,6 +1543,62 @@ func openAIWSIngressFallbackSessionSeed(userID, apiKeyID int64, groupID *int64) ...@@ -1530,6 +1543,62 @@ func openAIWSIngressFallbackSessionSeed(userID, apiKeyID int64, groupID *int64)
return fmt.Sprintf("openai_ws_ingress:%d:%d:%d", gid, userID, apiKeyID) return fmt.Sprintf("openai_ws_ingress:%d:%d:%d", gid, userID, apiKeyID)
} }
func normalizedOpenAIInboundEndpoint(c *gin.Context, fallback string) string {
path := strings.TrimSpace(fallback)
if c != nil {
if fullPath := strings.TrimSpace(c.FullPath()); fullPath != "" {
path = fullPath
} else if c.Request != nil && c.Request.URL != nil {
if requestPath := strings.TrimSpace(c.Request.URL.Path); requestPath != "" {
path = requestPath
}
}
}
switch {
case strings.Contains(path, openAIInboundEndpointChatCompletions):
return openAIInboundEndpointChatCompletions
case strings.Contains(path, openAIInboundEndpointMessages):
return openAIInboundEndpointMessages
case strings.Contains(path, openAIInboundEndpointResponses):
return openAIInboundEndpointResponses
default:
return path
}
}
func normalizedOpenAIUpstreamEndpoint(c *gin.Context, fallback string) string {
base := strings.TrimSpace(fallback)
if base == "" {
base = openAIUpstreamEndpointResponses
}
base = strings.TrimRight(base, "/")
if c == nil || c.Request == nil || c.Request.URL == nil {
return base
}
path := strings.TrimRight(strings.TrimSpace(c.Request.URL.Path), "/")
if path == "" {
return base
}
idx := strings.LastIndex(path, "/responses")
if idx < 0 {
return base
}
suffix := strings.TrimSpace(path[idx+len("/responses"):])
if suffix == "" || suffix == "/" {
return base
}
if !strings.HasPrefix(suffix, "/") {
return base
}
return base + suffix
}
func isOpenAIWSUpgradeRequest(r *http.Request) bool { func isOpenAIWSUpgradeRequest(r *http.Request) bool {
if r == nil { if r == nil {
return false return false
......
...@@ -334,6 +334,14 @@ func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTi ...@@ -334,6 +334,14 @@ func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTi
func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) { func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
return nil, nil return nil, nil
} }
func (s *stubUsageLogRepo) GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
return []usagestats.EndpointStat{}, nil
}
func (s *stubUsageLogRepo) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
return []usagestats.EndpointStat{}, nil
}
func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) { func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
return nil, nil return nil, nil
} }
......
...@@ -81,6 +81,15 @@ type ModelStat struct { ...@@ -81,6 +81,15 @@ type ModelStat struct {
ActualCost float64 `json:"actual_cost"` // 实际扣除 ActualCost float64 `json:"actual_cost"` // 实际扣除
} }
// EndpointStat represents usage statistics for a single request endpoint.
type EndpointStat struct {
Endpoint string `json:"endpoint"`
Requests int64 `json:"requests"`
TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// GroupStat represents usage statistics for a single group // GroupStat represents usage statistics for a single group
type GroupStat struct { type GroupStat struct {
GroupID int64 `json:"group_id"` GroupID int64 `json:"group_id"`
...@@ -188,6 +197,9 @@ type UsageStats struct { ...@@ -188,6 +197,9 @@ type UsageStats struct {
TotalActualCost float64 `json:"total_actual_cost"` TotalActualCost float64 `json:"total_actual_cost"`
TotalAccountCost *float64 `json:"total_account_cost,omitempty"` TotalAccountCost *float64 `json:"total_account_cost,omitempty"`
AverageDurationMs float64 `json:"average_duration_ms"` AverageDurationMs float64 `json:"average_duration_ms"`
Endpoints []EndpointStat `json:"endpoints,omitempty"`
UpstreamEndpoints []EndpointStat `json:"upstream_endpoints,omitempty"`
EndpointPaths []EndpointStat `json:"endpoint_paths,omitempty"`
} }
// BatchUserUsageStats represents usage stats for a single user // BatchUserUsageStats represents usage stats for a single user
...@@ -257,4 +269,6 @@ type AccountUsageStatsResponse struct { ...@@ -257,4 +269,6 @@ type AccountUsageStatsResponse struct {
History []AccountUsageHistory `json:"history"` History []AccountUsageHistory `json:"history"`
Summary AccountUsageSummary `json:"summary"` Summary AccountUsageSummary `json:"summary"`
Models []ModelStat `json:"models"` Models []ModelStat `json:"models"`
Endpoints []EndpointStat `json:"endpoints"`
UpstreamEndpoints []EndpointStat `json:"upstream_endpoints"`
} }
...@@ -28,7 +28,7 @@ import ( ...@@ -28,7 +28,7 @@ import (
gocache "github.com/patrickmn/go-cache" gocache "github.com/patrickmn/go-cache"
) )
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, cache_ttl_overridden, created_at" const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at"
var usageLogInsertArgTypes = [...]string{ var usageLogInsertArgTypes = [...]string{
"bigint", "bigint",
...@@ -65,6 +65,8 @@ var usageLogInsertArgTypes = [...]string{ ...@@ -65,6 +65,8 @@ var usageLogInsertArgTypes = [...]string{
"text", "text",
"text", "text",
"text", "text",
"text",
"text",
"boolean", "boolean",
"timestamptz", "timestamptz",
} }
...@@ -304,6 +306,8 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, ...@@ -304,6 +306,8 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
media_type, media_type,
service_tier, service_tier,
reasoning_effort, reasoning_effort,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden, cache_ttl_overridden,
created_at created_at
) VALUES ( ) VALUES (
...@@ -312,7 +316,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, ...@@ -312,7 +316,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
$8, $9, $10, $11, $8, $9, $10, $11,
$12, $13, $12, $13,
$14, $15, $16, $17, $18, $19, $14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36 $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38
) )
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at RETURNING id, created_at
...@@ -732,11 +736,13 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage ...@@ -732,11 +736,13 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
media_type, media_type,
service_tier, service_tier,
reasoning_effort, reasoning_effort,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden, cache_ttl_overridden,
created_at created_at
) AS (VALUES `) ) AS (VALUES `)
args := make([]any, 0, len(keys)*37) args := make([]any, 0, len(keys)*38)
argPos := 1 argPos := 1
for idx, key := range keys { for idx, key := range keys {
if idx > 0 { if idx > 0 {
...@@ -799,6 +805,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage ...@@ -799,6 +805,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
media_type, media_type,
service_tier, service_tier,
reasoning_effort, reasoning_effort,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden, cache_ttl_overridden,
created_at created_at
) )
...@@ -837,6 +845,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage ...@@ -837,6 +845,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
media_type, media_type,
service_tier, service_tier,
reasoning_effort, reasoning_effort,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden, cache_ttl_overridden,
created_at created_at
FROM input FROM input
...@@ -915,11 +925,13 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ...@@ -915,11 +925,13 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
media_type, media_type,
service_tier, service_tier,
reasoning_effort, reasoning_effort,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden, cache_ttl_overridden,
created_at created_at
) AS (VALUES `) ) AS (VALUES `)
args := make([]any, 0, len(preparedList)*36) args := make([]any, 0, len(preparedList)*38)
argPos := 1 argPos := 1
for idx, prepared := range preparedList { for idx, prepared := range preparedList {
if idx > 0 { if idx > 0 {
...@@ -979,6 +991,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ...@@ -979,6 +991,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
media_type, media_type,
service_tier, service_tier,
reasoning_effort, reasoning_effort,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden, cache_ttl_overridden,
created_at created_at
) )
...@@ -1017,6 +1031,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ...@@ -1017,6 +1031,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
media_type, media_type,
service_tier, service_tier,
reasoning_effort, reasoning_effort,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden, cache_ttl_overridden,
created_at created_at
FROM input FROM input
...@@ -1063,6 +1079,8 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared ...@@ -1063,6 +1079,8 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
media_type, media_type,
service_tier, service_tier,
reasoning_effort, reasoning_effort,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden, cache_ttl_overridden,
created_at created_at
) VALUES ( ) VALUES (
...@@ -1071,7 +1089,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared ...@@ -1071,7 +1089,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
$8, $9, $10, $11, $8, $9, $10, $11,
$12, $13, $12, $13,
$14, $15, $16, $17, $18, $19, $14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36 $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38
) )
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
`, prepared.args...) `, prepared.args...)
...@@ -1101,6 +1119,8 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { ...@@ -1101,6 +1119,8 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
mediaType := nullString(log.MediaType) mediaType := nullString(log.MediaType)
serviceTier := nullString(log.ServiceTier) serviceTier := nullString(log.ServiceTier)
reasoningEffort := nullString(log.ReasoningEffort) reasoningEffort := nullString(log.ReasoningEffort)
inboundEndpoint := nullString(log.InboundEndpoint)
upstreamEndpoint := nullString(log.UpstreamEndpoint)
var requestIDArg any var requestIDArg any
if requestID != "" { if requestID != "" {
...@@ -1147,6 +1167,8 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { ...@@ -1147,6 +1167,8 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
mediaType, mediaType,
serviceTier, serviceTier,
reasoningEffort, reasoningEffort,
inboundEndpoint,
upstreamEndpoint,
log.CacheTTLOverridden, log.CacheTTLOverridden,
createdAt, createdAt,
}, },
...@@ -2505,7 +2527,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat ...@@ -2505,7 +2527,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
args = append(args, *filters.StartTime) args = append(args, *filters.StartTime)
} }
if filters.EndTime != nil { if filters.EndTime != nil {
conditions = append(conditions, fmt.Sprintf("created_at <= $%d", len(args)+1)) conditions = append(conditions, fmt.Sprintf("created_at < $%d", len(args)+1))
args = append(args, *filters.EndTime) args = append(args, *filters.EndTime)
} }
...@@ -3040,7 +3062,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us ...@@ -3040,7 +3062,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
args = append(args, *filters.StartTime) args = append(args, *filters.StartTime)
} }
if filters.EndTime != nil { if filters.EndTime != nil {
conditions = append(conditions, fmt.Sprintf("created_at <= $%d", len(args)+1)) conditions = append(conditions, fmt.Sprintf("created_at < $%d", len(args)+1))
args = append(args, *filters.EndTime) args = append(args, *filters.EndTime)
} }
...@@ -3080,6 +3102,35 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us ...@@ -3080,6 +3102,35 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
stats.TotalAccountCost = &totalAccountCost stats.TotalAccountCost = &totalAccountCost
} }
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
start := time.Unix(0, 0).UTC()
if filters.StartTime != nil {
start = *filters.StartTime
}
end := time.Now().UTC()
if filters.EndTime != nil {
end = *filters.EndTime
}
endpoints, endpointErr := r.GetEndpointStatsWithFilters(ctx, start, end, filters.UserID, filters.APIKeyID, filters.AccountID, filters.GroupID, filters.Model, filters.RequestType, filters.Stream, filters.BillingType)
if endpointErr != nil {
logger.LegacyPrintf("repository.usage_log", "GetEndpointStatsWithFilters failed in GetStatsWithFilters: %v", endpointErr)
endpoints = []EndpointStat{}
}
upstreamEndpoints, upstreamEndpointErr := r.GetUpstreamEndpointStatsWithFilters(ctx, start, end, filters.UserID, filters.APIKeyID, filters.AccountID, filters.GroupID, filters.Model, filters.RequestType, filters.Stream, filters.BillingType)
if upstreamEndpointErr != nil {
logger.LegacyPrintf("repository.usage_log", "GetUpstreamEndpointStatsWithFilters failed in GetStatsWithFilters: %v", upstreamEndpointErr)
upstreamEndpoints = []EndpointStat{}
}
endpointPaths, endpointPathErr := r.getEndpointPathStatsWithFilters(ctx, start, end, filters.UserID, filters.APIKeyID, filters.AccountID, filters.GroupID, filters.Model, filters.RequestType, filters.Stream, filters.BillingType)
if endpointPathErr != nil {
logger.LegacyPrintf("repository.usage_log", "getEndpointPathStatsWithFilters failed in GetStatsWithFilters: %v", endpointPathErr)
endpointPaths = []EndpointStat{}
}
stats.Endpoints = endpoints
stats.UpstreamEndpoints = upstreamEndpoints
stats.EndpointPaths = endpointPaths
return stats, nil return stats, nil
} }
...@@ -3092,6 +3143,163 @@ type AccountUsageSummary = usagestats.AccountUsageSummary ...@@ -3092,6 +3143,163 @@ type AccountUsageSummary = usagestats.AccountUsageSummary
// AccountUsageStatsResponse represents the full usage statistics response for an account // AccountUsageStatsResponse represents the full usage statistics response for an account
type AccountUsageStatsResponse = usagestats.AccountUsageStatsResponse type AccountUsageStatsResponse = usagestats.AccountUsageStatsResponse
// EndpointStat represents endpoint usage statistics row.
type EndpointStat = usagestats.EndpointStat
func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Context, endpointColumn string, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) {
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
if accountID > 0 && userID == 0 && apiKeyID == 0 {
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
}
query := fmt.Sprintf(`
SELECT
COALESCE(NULLIF(TRIM(%s), ''), 'unknown') AS endpoint,
COUNT(*) AS requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
%s
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
`, endpointColumn, actualCostExpr)
args := []any{startTime, endTime}
if userID > 0 {
query += fmt.Sprintf(" AND user_id = $%d", len(args)+1)
args = append(args, userID)
}
if apiKeyID > 0 {
query += fmt.Sprintf(" AND api_key_id = $%d", len(args)+1)
args = append(args, apiKeyID)
}
if accountID > 0 {
query += fmt.Sprintf(" AND account_id = $%d", len(args)+1)
args = append(args, accountID)
}
if groupID > 0 {
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
args = append(args, groupID)
}
if model != "" {
query += fmt.Sprintf(" AND model = $%d", len(args)+1)
args = append(args, model)
}
query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream)
if billingType != nil {
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
args = append(args, int16(*billingType))
}
query += " GROUP BY endpoint ORDER BY requests DESC"
rows, err := r.sql.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer func() {
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
results = nil
}
}()
results = make([]EndpointStat, 0)
for rows.Next() {
var row EndpointStat
if err := rows.Scan(&row.Endpoint, &row.Requests, &row.TotalTokens, &row.Cost, &row.ActualCost); err != nil {
return nil, err
}
results = append(results, row)
}
if err := rows.Err(); err != nil {
return nil, err
}
return results, nil
}
func (r *usageLogRepository) getEndpointPathStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) {
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
if accountID > 0 && userID == 0 && apiKeyID == 0 {
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
}
query := fmt.Sprintf(`
SELECT
CONCAT(
COALESCE(NULLIF(TRIM(inbound_endpoint), ''), 'unknown'),
' -> ',
COALESCE(NULLIF(TRIM(upstream_endpoint), ''), 'unknown')
) AS endpoint,
COUNT(*) AS requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
%s
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
`, actualCostExpr)
args := []any{startTime, endTime}
if userID > 0 {
query += fmt.Sprintf(" AND user_id = $%d", len(args)+1)
args = append(args, userID)
}
if apiKeyID > 0 {
query += fmt.Sprintf(" AND api_key_id = $%d", len(args)+1)
args = append(args, apiKeyID)
}
if accountID > 0 {
query += fmt.Sprintf(" AND account_id = $%d", len(args)+1)
args = append(args, accountID)
}
if groupID > 0 {
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
args = append(args, groupID)
}
if model != "" {
query += fmt.Sprintf(" AND model = $%d", len(args)+1)
args = append(args, model)
}
query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream)
if billingType != nil {
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
args = append(args, int16(*billingType))
}
query += " GROUP BY endpoint ORDER BY requests DESC"
rows, err := r.sql.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer func() {
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
results = nil
}
}()
results = make([]EndpointStat, 0)
for rows.Next() {
var row EndpointStat
if err := rows.Scan(&row.Endpoint, &row.Requests, &row.TotalTokens, &row.Cost, &row.ActualCost); err != nil {
return nil, err
}
results = append(results, row)
}
if err := rows.Err(); err != nil {
return nil, err
}
return results, nil
}
// GetEndpointStatsWithFilters returns inbound endpoint statistics with optional filters.
func (r *usageLogRepository) GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]EndpointStat, error) {
return r.getEndpointStatsByColumnWithFilters(ctx, "inbound_endpoint", startTime, endTime, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
}
// GetUpstreamEndpointStatsWithFilters returns upstream endpoint statistics with optional filters.
func (r *usageLogRepository) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]EndpointStat, error) {
return r.getEndpointStatsByColumnWithFilters(ctx, "upstream_endpoint", startTime, endTime, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
}
// GetAccountUsageStats returns comprehensive usage statistics for an account over a time range // GetAccountUsageStats returns comprehensive usage statistics for an account over a time range
func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (resp *AccountUsageStatsResponse, err error) { func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (resp *AccountUsageStatsResponse, err error) {
daysCount := int(endTime.Sub(startTime).Hours()/24) + 1 daysCount := int(endTime.Sub(startTime).Hours()/24) + 1
...@@ -3254,11 +3462,23 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID ...@@ -3254,11 +3462,23 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
if err != nil { if err != nil {
models = []ModelStat{} models = []ModelStat{}
} }
endpoints, endpointErr := r.GetEndpointStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, "", nil, nil, nil)
if endpointErr != nil {
logger.LegacyPrintf("repository.usage_log", "GetEndpointStatsWithFilters failed in GetAccountUsageStats: %v", endpointErr)
endpoints = []EndpointStat{}
}
upstreamEndpoints, upstreamEndpointErr := r.GetUpstreamEndpointStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, "", nil, nil, nil)
if upstreamEndpointErr != nil {
logger.LegacyPrintf("repository.usage_log", "GetUpstreamEndpointStatsWithFilters failed in GetAccountUsageStats: %v", upstreamEndpointErr)
upstreamEndpoints = []EndpointStat{}
}
resp = &AccountUsageStatsResponse{ resp = &AccountUsageStatsResponse{
History: history, History: history,
Summary: summary, Summary: summary,
Models: models, Models: models,
Endpoints: endpoints,
UpstreamEndpoints: upstreamEndpoints,
} }
return resp, nil return resp, nil
} }
...@@ -3541,6 +3761,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -3541,6 +3761,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
mediaType sql.NullString mediaType sql.NullString
serviceTier sql.NullString serviceTier sql.NullString
reasoningEffort sql.NullString reasoningEffort sql.NullString
inboundEndpoint sql.NullString
upstreamEndpoint sql.NullString
cacheTTLOverridden bool cacheTTLOverridden bool
createdAt time.Time createdAt time.Time
) )
...@@ -3581,6 +3803,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -3581,6 +3803,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&mediaType, &mediaType,
&serviceTier, &serviceTier,
&reasoningEffort, &reasoningEffort,
&inboundEndpoint,
&upstreamEndpoint,
&cacheTTLOverridden, &cacheTTLOverridden,
&createdAt, &createdAt,
); err != nil { ); err != nil {
...@@ -3656,6 +3880,12 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -3656,6 +3880,12 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if reasoningEffort.Valid { if reasoningEffort.Valid {
log.ReasoningEffort = &reasoningEffort.String log.ReasoningEffort = &reasoningEffort.String
} }
if inboundEndpoint.Valid {
log.InboundEndpoint = &inboundEndpoint.String
}
if upstreamEndpoint.Valid {
log.UpstreamEndpoint = &upstreamEndpoint.String
}
return log, nil return log, nil
} }
......
...@@ -73,6 +73,8 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) { ...@@ -73,6 +73,8 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
sqlmock.AnyArg(), // media_type sqlmock.AnyArg(), // media_type
sqlmock.AnyArg(), // service_tier sqlmock.AnyArg(), // service_tier
sqlmock.AnyArg(), // reasoning_effort sqlmock.AnyArg(), // reasoning_effort
sqlmock.AnyArg(), // inbound_endpoint
sqlmock.AnyArg(), // upstream_endpoint
log.CacheTTLOverridden, log.CacheTTLOverridden,
createdAt, createdAt,
). ).
...@@ -141,6 +143,8 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) { ...@@ -141,6 +143,8 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
sqlmock.AnyArg(), sqlmock.AnyArg(),
serviceTier, serviceTier,
sqlmock.AnyArg(), sqlmock.AnyArg(),
sqlmock.AnyArg(),
sqlmock.AnyArg(),
log.CacheTTLOverridden, log.CacheTTLOverridden,
createdAt, createdAt,
). ).
...@@ -376,6 +380,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { ...@@ -376,6 +380,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{}, sql.NullString{},
sql.NullString{Valid: true, String: "priority"}, sql.NullString{Valid: true, String: "priority"},
sql.NullString{}, sql.NullString{},
sql.NullString{},
sql.NullString{},
false, false,
now, now,
}}) }})
...@@ -415,6 +421,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { ...@@ -415,6 +421,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{}, sql.NullString{},
sql.NullString{Valid: true, String: "flex"}, sql.NullString{Valid: true, String: "flex"},
sql.NullString{}, sql.NullString{},
sql.NullString{},
sql.NullString{},
false, false,
now, now,
}}) }})
...@@ -454,6 +462,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { ...@@ -454,6 +462,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{}, sql.NullString{},
sql.NullString{Valid: true, String: "priority"}, sql.NullString{Valid: true, String: "priority"},
sql.NullString{}, sql.NullString{},
sql.NullString{},
sql.NullString{},
false, false,
now, now,
}}) }})
......
...@@ -1624,6 +1624,14 @@ func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTi ...@@ -1624,6 +1624,14 @@ func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTi
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubUsageLogRepo) GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) { func (r *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
......
...@@ -45,6 +45,8 @@ type UsageLogRepository interface { ...@@ -45,6 +45,8 @@ type UsageLogRepository interface {
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error)
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error)
GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error)
GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error)
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
......
...@@ -226,6 +226,41 @@ func TestOpenAIGatewayServiceRecordUsage_UsesUserSpecificGroupRate(t *testing.T) ...@@ -226,6 +226,41 @@ func TestOpenAIGatewayServiceRecordUsage_UsesUserSpecificGroupRate(t *testing.T)
require.Equal(t, 1, userRepo.deductCalls) require.Equal(t, 1, userRepo.deductCalls)
} }
func TestOpenAIGatewayServiceRecordUsage_IncludesEndpointMetadata(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
rateRepo := &openAIUserGroupRateRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, rateRepo)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_endpoint_metadata",
Usage: OpenAIUsage{
InputTokens: 8,
OutputTokens: 2,
},
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 1002,
Group: &Group{RateMultiplier: 1},
},
User: &User{ID: 2002},
Account: &Account{ID: 3002},
InboundEndpoint: " /v1/chat/completions ",
UpstreamEndpoint: " /v1/responses ",
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.NotNil(t, usageRepo.lastLog.InboundEndpoint)
require.Equal(t, "/v1/chat/completions", *usageRepo.lastLog.InboundEndpoint)
require.NotNil(t, usageRepo.lastLog.UpstreamEndpoint)
require.Equal(t, "/v1/responses", *usageRepo.lastLog.UpstreamEndpoint)
}
func TestOpenAIGatewayServiceRecordUsage_FallsBackToGroupDefaultRateOnResolverError(t *testing.T) { func TestOpenAIGatewayServiceRecordUsage_FallsBackToGroupDefaultRateOnResolverError(t *testing.T) {
groupID := int64(12) groupID := int64(12)
groupRate := 1.6 groupRate := 1.6
......
...@@ -4028,6 +4028,8 @@ type OpenAIRecordUsageInput struct { ...@@ -4028,6 +4028,8 @@ type OpenAIRecordUsageInput struct {
User *User User *User
Account *Account Account *Account
Subscription *UserSubscription Subscription *UserSubscription
InboundEndpoint string
UpstreamEndpoint string
UserAgent string // 请求的 User-Agent UserAgent string // 请求的 User-Agent
IPAddress string // 请求的客户端 IP 地址 IPAddress string // 请求的客户端 IP 地址
RequestPayloadHash string RequestPayloadHash string
...@@ -4106,6 +4108,8 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec ...@@ -4106,6 +4108,8 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
Model: billingModel, Model: billingModel,
ServiceTier: result.ServiceTier, ServiceTier: result.ServiceTier,
ReasoningEffort: result.ReasoningEffort, ReasoningEffort: result.ReasoningEffort,
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint),
InputTokens: actualInputTokens, InputTokens: actualInputTokens,
OutputTokens: result.Usage.OutputTokens, OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens, CacheCreationTokens: result.Usage.CacheCreationInputTokens,
...@@ -4125,7 +4129,6 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec ...@@ -4125,7 +4129,6 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
FirstTokenMs: result.FirstTokenMs, FirstTokenMs: result.FirstTokenMs,
CreatedAt: time.Now(), CreatedAt: time.Now(),
} }
// 添加 UserAgent // 添加 UserAgent
if input.UserAgent != "" { if input.UserAgent != "" {
usageLog.UserAgent = &input.UserAgent usageLog.UserAgent = &input.UserAgent
...@@ -4668,3 +4671,11 @@ func normalizeOpenAIReasoningEffort(raw string) string { ...@@ -4668,3 +4671,11 @@ func normalizeOpenAIReasoningEffort(raw string) string {
return "" return ""
} }
} }
func optionalTrimmedStringPtr(raw string) *string {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return nil
}
return &trimmed
}
...@@ -103,6 +103,10 @@ type UsageLog struct { ...@@ -103,6 +103,10 @@ type UsageLog struct {
// ReasoningEffort is the request's reasoning effort level (OpenAI Responses API), // ReasoningEffort is the request's reasoning effort level (OpenAI Responses API),
// e.g. "low" / "medium" / "high" / "xhigh". Nil means not provided / not applicable. // e.g. "low" / "medium" / "high" / "xhigh". Nil means not provided / not applicable.
ReasoningEffort *string ReasoningEffort *string
// InboundEndpoint is the client-facing API endpoint path, e.g. /v1/chat/completions.
InboundEndpoint *string
// UpstreamEndpoint is the normalized upstream endpoint path, e.g. /v1/responses.
UpstreamEndpoint *string
GroupID *int64 GroupID *int64
SubscriptionID *int64 SubscriptionID *int64
......
-- Add endpoint tracking fields to usage_logs.
-- inbound_endpoint: client-facing API route (e.g. /v1/chat/completions, /v1/messages, /v1/responses)
-- upstream_endpoint: normalized upstream route (e.g. /v1/responses)
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS inbound_endpoint VARCHAR(128);
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS upstream_endpoint VARCHAR(128);
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import { apiClient } from '../client' import { apiClient } from '../client'
import type { AdminUsageLog, UsageQueryParams, PaginatedResponse, UsageRequestType } from '@/types' import type { AdminUsageLog, UsageQueryParams, PaginatedResponse, UsageRequestType } from '@/types'
import type { EndpointStat } from '@/types'
// ==================== Types ==================== // ==================== Types ====================
...@@ -18,6 +19,9 @@ export interface AdminUsageStatsResponse { ...@@ -18,6 +19,9 @@ export interface AdminUsageStatsResponse {
total_actual_cost: number total_actual_cost: number
total_account_cost?: number total_account_cost?: number
average_duration_ms: number average_duration_ms: number
endpoints?: EndpointStat[]
upstream_endpoints?: EndpointStat[]
endpoint_paths?: EndpointStat[]
} }
export interface SimpleUser { export interface SimpleUser {
......
...@@ -446,6 +446,18 @@ ...@@ -446,6 +446,18 @@
<!-- Model Distribution --> <!-- Model Distribution -->
<ModelDistributionChart :model-stats="stats.models" :loading="false" /> <ModelDistributionChart :model-stats="stats.models" :loading="false" />
<EndpointDistributionChart
:endpoint-stats="stats.endpoints || []"
:loading="false"
:title="t('usage.inboundEndpoint')"
/>
<EndpointDistributionChart
:endpoint-stats="stats.upstream_endpoints || []"
:loading="false"
:title="t('usage.upstreamEndpoint')"
/>
</template> </template>
<!-- No Data State --> <!-- No Data State -->
...@@ -489,6 +501,7 @@ import { Line } from 'vue-chartjs' ...@@ -489,6 +501,7 @@ import { Line } from 'vue-chartjs'
import BaseDialog from '@/components/common/BaseDialog.vue' import BaseDialog from '@/components/common/BaseDialog.vue'
import LoadingSpinner from '@/components/common/LoadingSpinner.vue' import LoadingSpinner from '@/components/common/LoadingSpinner.vue'
import ModelDistributionChart from '@/components/charts/ModelDistributionChart.vue' import ModelDistributionChart from '@/components/charts/ModelDistributionChart.vue'
import EndpointDistributionChart from '@/components/charts/EndpointDistributionChart.vue'
import Icon from '@/components/icons/Icon.vue' import Icon from '@/components/icons/Icon.vue'
import { adminAPI } from '@/api/admin' import { adminAPI } from '@/api/admin'
import type { Account, AccountUsageStatsResponse } from '@/types' import type { Account, AccountUsageStatsResponse } from '@/types'
......
...@@ -410,6 +410,18 @@ ...@@ -410,6 +410,18 @@
<!-- Model Distribution --> <!-- Model Distribution -->
<ModelDistributionChart :model-stats="stats.models" :loading="false" /> <ModelDistributionChart :model-stats="stats.models" :loading="false" />
<EndpointDistributionChart
:endpoint-stats="stats.endpoints || []"
:loading="false"
:title="t('usage.inboundEndpoint')"
/>
<EndpointDistributionChart
:endpoint-stats="stats.upstream_endpoints || []"
:loading="false"
:title="t('usage.upstreamEndpoint')"
/>
</template> </template>
<!-- No Data State --> <!-- No Data State -->
...@@ -453,6 +465,7 @@ import { Line } from 'vue-chartjs' ...@@ -453,6 +465,7 @@ import { Line } from 'vue-chartjs'
import BaseDialog from '@/components/common/BaseDialog.vue' import BaseDialog from '@/components/common/BaseDialog.vue'
import LoadingSpinner from '@/components/common/LoadingSpinner.vue' import LoadingSpinner from '@/components/common/LoadingSpinner.vue'
import ModelDistributionChart from '@/components/charts/ModelDistributionChart.vue' import ModelDistributionChart from '@/components/charts/ModelDistributionChart.vue'
import EndpointDistributionChart from '@/components/charts/EndpointDistributionChart.vue'
import Icon from '@/components/icons/Icon.vue' import Icon from '@/components/icons/Icon.vue'
import { adminAPI } from '@/api/admin' import { adminAPI } from '@/api/admin'
import type { Account, AccountUsageStatsResponse } from '@/types' import type { Account, AccountUsageStatsResponse } from '@/types'
......
...@@ -35,6 +35,19 @@ ...@@ -35,6 +35,19 @@
</span> </span>
</template> </template>
<template #cell-endpoint="{ row }">
<div class="max-w-[320px] space-y-1 text-xs">
<div class="break-all text-gray-700 dark:text-gray-300">
<span class="font-medium text-gray-500 dark:text-gray-400">{{ t('usage.inbound') }}:</span>
<span class="ml-1">{{ row.inbound_endpoint?.trim() || '-' }}</span>
</div>
<div class="break-all text-gray-700 dark:text-gray-300">
<span class="font-medium text-gray-500 dark:text-gray-400">{{ t('usage.upstream') }}:</span>
<span class="ml-1">{{ row.upstream_endpoint?.trim() || '-' }}</span>
</div>
</div>
</template>
<template #cell-group="{ row }"> <template #cell-group="{ row }">
<span v-if="row.group" class="inline-flex items-center rounded px-2 py-0.5 text-xs font-medium bg-indigo-100 text-indigo-800 dark:bg-indigo-900 dark:text-indigo-200"> <span v-if="row.group" class="inline-flex items-center rounded px-2 py-0.5 text-xs font-medium bg-indigo-100 text-indigo-800 dark:bg-indigo-900 dark:text-indigo-200">
{{ row.group.name }} {{ row.group.name }}
...@@ -328,6 +341,7 @@ const getRequestTypeBadgeClass = (row: AdminUsageLog): string => { ...@@ -328,6 +341,7 @@ const getRequestTypeBadgeClass = (row: AdminUsageLog): string => {
if (requestType === 'sync') return 'bg-gray-100 text-gray-800 dark:bg-gray-700 dark:text-gray-200' if (requestType === 'sync') return 'bg-gray-100 text-gray-800 dark:bg-gray-700 dark:text-gray-200'
return 'bg-amber-100 text-amber-800 dark:bg-amber-900 dark:text-amber-200' return 'bg-amber-100 text-amber-800 dark:bg-amber-900 dark:text-amber-200'
} }
const formatCacheTokens = (tokens: number): string => { const formatCacheTokens = (tokens: number): string => {
if (tokens >= 1000000) return `${(tokens / 1000000).toFixed(1)}M` if (tokens >= 1000000) return `${(tokens / 1000000).toFixed(1)}M`
if (tokens >= 1000) return `${(tokens / 1000).toFixed(1)}K` if (tokens >= 1000) return `${(tokens / 1000).toFixed(1)}K`
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment