Unverified Commit bf455811 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #1455 from touwaeriol/feat/channel-management

feat(channel): add channel management with multi-mode pricing and billing integration
parents b384570d e88b2890
...@@ -184,6 +184,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -184,6 +184,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
setOpsRequestContext(c, modelName, stream, body) setOpsRequestContext(c, modelName, stream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false))) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false)))
// 解析渠道级模型映射
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, modelName)
reqModel := modelName // 保存映射前的原始模型名
if channelMapping.Mapped {
modelName = channelMapping.MappedModel
}
// Get subscription (may be nil) // Get subscription (may be nil)
subscription, _ := middleware.GetSubscriptionFromContext(c) subscription, _ := middleware.GetSubscriptionFromContext(c)
...@@ -353,7 +360,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -353,7 +360,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
} }
for { for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, fs.FailedAccountIDs, "") // Gemini 不使用会话限制 selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, fs.FailedAccountIDs, "", int64(0)) // Gemini 不使用会话限制
if err != nil { if err != nil {
if len(fs.FailedAccountIDs) == 0 { if len(fs.FailedAccountIDs) == 0 {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
...@@ -523,6 +530,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -523,6 +530,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
LongContextMultiplier: 2.0, // 超出部分双倍计费 LongContextMultiplier: 2.0, // 超出部分双倍计费
ForceCacheBilling: fs.ForceCacheBilling, ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil { }); err != nil {
logger.L().With( logger.L().With(
zap.String("component", "handler.gemini_v1beta.models"), zap.String("component", "handler.gemini_v1beta.models"),
......
...@@ -30,6 +30,7 @@ type AdminHandlers struct { ...@@ -30,6 +30,7 @@ type AdminHandlers struct {
TLSFingerprintProfile *admin.TLSFingerprintProfileHandler TLSFingerprintProfile *admin.TLSFingerprintProfileHandler
APIKey *admin.AdminAPIKeyHandler APIKey *admin.AdminAPIKeyHandler
ScheduledTest *admin.ScheduledTestHandler ScheduledTest *admin.ScheduledTestHandler
Channel *admin.ChannelHandler
} }
// Handlers contains all HTTP handlers // Handlers contains all HTTP handlers
......
...@@ -79,6 +79,9 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { ...@@ -79,6 +79,9 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body) setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 解析渠道级模型映射
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
if h.errorPassthroughService != nil { if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService) service.BindErrorPassthroughService(c, h.errorPassthroughService)
} }
...@@ -183,7 +186,11 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { ...@@ -183,7 +186,11 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
forwardStart := time.Now() forwardStart := time.Now()
defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_chat_completions_fallback_model")) defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_chat_completions_fallback_model"))
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel) forwardBody := body
if channelMapping.Mapped {
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
}
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, defaultMappedModel)
forwardDurationMs := time.Since(forwardStart).Milliseconds() forwardDurationMs := time.Since(forwardStart).Milliseconds()
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
...@@ -267,6 +274,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { ...@@ -267,6 +274,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
UserAgent: userAgent, UserAgent: userAgent,
IPAddress: clientIP, IPAddress: clientIP,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil { }); err != nil {
logger.L().With( logger.L().With(
zap.String("component", "handler.openai_gateway.chat_completions"), zap.String("component", "handler.openai_gateway.chat_completions"),
......
...@@ -185,6 +185,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -185,6 +185,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body) setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 解析渠道级模型映射
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。 // 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
if !h.validateFunctionCallOutputRequest(c, body, reqLog) { if !h.validateFunctionCallOutputRequest(c, body, reqLog) {
return return
...@@ -284,7 +287,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -284,7 +287,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// Forward request // Forward request
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
forwardStart := time.Now() forwardStart := time.Now()
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body) // 应用渠道模型映射到请求体
forwardBody := body
if channelMapping.Mapped {
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
}
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, forwardBody)
forwardDurationMs := time.Since(forwardStart).Milliseconds() forwardDurationMs := time.Since(forwardStart).Milliseconds()
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
accountReleaseFunc() accountReleaseFunc()
...@@ -379,6 +387,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -379,6 +387,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
IPAddress: clientIP, IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash, RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil { }); err != nil {
logger.L().With( logger.L().With(
zap.String("component", "handler.openai_gateway.responses"), zap.String("component", "handler.openai_gateway.responses"),
...@@ -549,6 +558,9 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { ...@@ -549,6 +558,9 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body) setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 解析渠道级模型映射
channelMappingMsg, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。 // 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
if h.errorPassthroughService != nil { if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService) service.BindErrorPassthroughService(c, h.errorPassthroughService)
...@@ -673,7 +685,12 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { ...@@ -673,7 +685,12 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
// Forward 层需要始终拿到 group 默认映射模型,这样未命中账号级映射的 // Forward 层需要始终拿到 group 默认映射模型,这样未命中账号级映射的
// Claude 兼容模型才不会在后续 Codex 规范化中意外退化到 gpt-5.1。 // Claude 兼容模型才不会在后续 Codex 规范化中意外退化到 gpt-5.1。
defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_messages_fallback_model")) defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_messages_fallback_model"))
result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel) // 应用渠道模型映射到请求体
forwardBody := body
if channelMappingMsg.Mapped {
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMappingMsg.MappedModel)
}
result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, forwardBody, promptCacheKey, defaultMappedModel)
forwardDurationMs := time.Since(forwardStart).Milliseconds() forwardDurationMs := time.Since(forwardStart).Milliseconds()
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
...@@ -759,6 +776,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { ...@@ -759,6 +776,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
IPAddress: clientIP, IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash, RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
ChannelUsageFields: channelMappingMsg.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil { }); err != nil {
logger.L().With( logger.L().With(
zap.String("component", "handler.openai_gateway.messages"), zap.String("component", "handler.openai_gateway.messages"),
...@@ -1101,6 +1119,9 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { ...@@ -1101,6 +1119,9 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
setOpsRequestContext(c, reqModel, true, firstMessage) setOpsRequestContext(c, reqModel, true, firstMessage)
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2)) setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
// 解析渠道级模型映射
channelMappingWS, _ := h.gatewayService.ResolveChannelMappingAndRestrict(ctx, apiKey.GroupID, reqModel)
var currentUserRelease func() var currentUserRelease func()
var currentAccountRelease func() var currentAccountRelease func()
releaseTurnSlots := func() { releaseTurnSlots := func() {
...@@ -1259,6 +1280,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { ...@@ -1259,6 +1280,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
IPAddress: clientIP, IPAddress: clientIP,
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage), RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
ChannelUsageFields: channelMappingWS.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil { }); err != nil {
reqLog.Error("openai.websocket_record_usage_failed", reqLog.Error("openai.websocket_record_usage_failed",
zap.Int64("account_id", account.ID), zap.Int64("account_id", account.ID),
...@@ -1270,7 +1292,13 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { ...@@ -1270,7 +1292,13 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
}, },
} }
if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, firstMessage, hooks); err != nil { // 应用渠道模型映射到 WebSocket 首条消息
wsFirstMessage := firstMessage
if channelMappingWS.Mapped {
wsFirstMessage = h.gatewayService.ReplaceModelInBody(firstMessage, channelMappingWS.MappedModel)
}
if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, wsFirstMessage, hooks); err != nil {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
closeStatus, closeReason := summarizeWSCloseErrorForLog(err) closeStatus, closeReason := summarizeWSCloseErrorForLog(err)
reqLog.Warn("openai.websocket_proxy_failed", reqLog.Warn("openai.websocket_proxy_failed",
......
...@@ -2225,6 +2225,7 @@ func newMinimalGatewayService(accountRepo service.AccountRepository) *service.Ga ...@@ -2225,6 +2225,7 @@ func newMinimalGatewayService(accountRepo service.AccountRepository) *service.Ga
return service.NewGatewayService( return service.NewGatewayService(
accountRepo, nil, nil, nil, nil, nil, nil, nil, nil, accountRepo, nil, nil, nil, nil, nil, nil, nil, nil,
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
nil, nil,
) )
} }
......
...@@ -30,6 +30,8 @@ import ( ...@@ -30,6 +30,8 @@ import (
) )
// SoraGatewayHandler handles Sora chat completions requests // SoraGatewayHandler handles Sora chat completions requests
//
// NOTE: Sora 平台计划后续移除,不集成渠道(Channel)功能。
type SoraGatewayHandler struct { type SoraGatewayHandler struct {
gatewayService *service.GatewayService gatewayService *service.GatewayService
soraGatewayService *service.SoraGatewayService soraGatewayService *service.SoraGatewayService
...@@ -226,7 +228,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { ...@@ -226,7 +228,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
var lastFailoverHeaders http.Header var lastFailoverHeaders http.Header
for { for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "") selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "", int64(0))
if err != nil { if err != nil {
reqLog.Warn("sora.account_select_failed", reqLog.Warn("sora.account_select_failed",
zap.Error(err), zap.Error(err),
......
...@@ -465,6 +465,8 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) { ...@@ -465,6 +465,8 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
nil, // digestStore nil, // digestStore
nil, // settingService nil, // settingService
nil, // tlsFPProfileService nil, // tlsFPProfileService
nil, // channelService
nil, // resolver
) )
soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}} soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}}
......
...@@ -33,6 +33,7 @@ func ProvideAdminHandlers( ...@@ -33,6 +33,7 @@ func ProvideAdminHandlers(
tlsFingerprintProfileHandler *admin.TLSFingerprintProfileHandler, tlsFingerprintProfileHandler *admin.TLSFingerprintProfileHandler,
apiKeyHandler *admin.AdminAPIKeyHandler, apiKeyHandler *admin.AdminAPIKeyHandler,
scheduledTestHandler *admin.ScheduledTestHandler, scheduledTestHandler *admin.ScheduledTestHandler,
channelHandler *admin.ChannelHandler,
) *AdminHandlers { ) *AdminHandlers {
return &AdminHandlers{ return &AdminHandlers{
Dashboard: dashboardHandler, Dashboard: dashboardHandler,
...@@ -59,6 +60,7 @@ func ProvideAdminHandlers( ...@@ -59,6 +60,7 @@ func ProvideAdminHandlers(
TLSFingerprintProfile: tlsFingerprintProfileHandler, TLSFingerprintProfile: tlsFingerprintProfileHandler,
APIKey: apiKeyHandler, APIKey: apiKeyHandler,
ScheduledTest: scheduledTestHandler, ScheduledTest: scheduledTestHandler,
Channel: channelHandler,
} }
} }
...@@ -150,6 +152,7 @@ var ProviderSet = wire.NewSet( ...@@ -150,6 +152,7 @@ var ProviderSet = wire.NewSet(
admin.NewTLSFingerprintProfileHandler, admin.NewTLSFingerprintProfileHandler,
admin.NewAdminAPIKeyHandler, admin.NewAdminAPIKeyHandler,
admin.NewScheduledTestHandler, admin.NewScheduledTestHandler,
admin.NewChannelHandler,
// AdminHandlers and Handlers constructors // AdminHandlers and Handlers constructors
ProvideAdminHandlers, ProvideAdminHandlers,
......
...@@ -125,6 +125,7 @@ type ClaudeUsage struct { ...@@ -125,6 +125,7 @@ type ClaudeUsage struct {
OutputTokens int `json:"output_tokens"` OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"` CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"`
CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"` CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"`
ImageOutputTokens int `json:"image_output_tokens,omitempty"`
} }
// ClaudeError Claude 错误响应 // ClaudeError Claude 错误响应
......
...@@ -149,6 +149,12 @@ type GeminiCandidate struct { ...@@ -149,6 +149,12 @@ type GeminiCandidate struct {
GroundingMetadata *GeminiGroundingMetadata `json:"groundingMetadata,omitempty"` GroundingMetadata *GeminiGroundingMetadata `json:"groundingMetadata,omitempty"`
} }
// GeminiTokenDetail Gemini token 详情(按模态分类)
type GeminiTokenDetail struct {
Modality string `json:"modality"`
TokenCount int `json:"tokenCount"`
}
// GeminiUsageMetadata Gemini 用量元数据 // GeminiUsageMetadata Gemini 用量元数据
type GeminiUsageMetadata struct { type GeminiUsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount,omitempty"` PromptTokenCount int `json:"promptTokenCount,omitempty"`
...@@ -156,6 +162,18 @@ type GeminiUsageMetadata struct { ...@@ -156,6 +162,18 @@ type GeminiUsageMetadata struct {
CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"` CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"`
TotalTokenCount int `json:"totalTokenCount,omitempty"` TotalTokenCount int `json:"totalTokenCount,omitempty"`
ThoughtsTokenCount int `json:"thoughtsTokenCount,omitempty"` // thinking tokens(按输出价格计费) ThoughtsTokenCount int `json:"thoughtsTokenCount,omitempty"` // thinking tokens(按输出价格计费)
CandidatesTokensDetails []GeminiTokenDetail `json:"candidatesTokensDetails,omitempty"`
PromptTokensDetails []GeminiTokenDetail `json:"promptTokensDetails,omitempty"`
}
// ImageOutputTokens 从 CandidatesTokensDetails 中提取 IMAGE 模态的 token 数
func (m *GeminiUsageMetadata) ImageOutputTokens() int {
for _, d := range m.CandidatesTokensDetails {
if d.Modality == "IMAGE" {
return d.TokenCount
}
}
return 0
} }
// GeminiGroundingMetadata Gemini grounding 元数据(Web Search) // GeminiGroundingMetadata Gemini grounding 元数据(Web Search)
......
...@@ -284,6 +284,7 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon ...@@ -284,6 +284,7 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon
usage.InputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached usage.InputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount
usage.CacheReadInputTokens = cached usage.CacheReadInputTokens = cached
usage.ImageOutputTokens = geminiResp.UsageMetadata.ImageOutputTokens()
} }
// 生成响应 ID // 生成响应 ID
......
...@@ -35,6 +35,7 @@ type StreamingProcessor struct { ...@@ -35,6 +35,7 @@ type StreamingProcessor struct {
inputTokens int inputTokens int
outputTokens int outputTokens int
cacheReadTokens int cacheReadTokens int
imageOutputTokens int
} }
// NewStreamingProcessor 创建流式响应处理器 // NewStreamingProcessor 创建流式响应处理器
...@@ -87,6 +88,7 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte { ...@@ -87,6 +88,7 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
p.inputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached p.inputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount
p.cacheReadTokens = cached p.cacheReadTokens = cached
p.imageOutputTokens = geminiResp.UsageMetadata.ImageOutputTokens()
} }
// 处理 parts // 处理 parts
...@@ -127,6 +129,7 @@ func (p *StreamingProcessor) Finish() ([]byte, *ClaudeUsage) { ...@@ -127,6 +129,7 @@ func (p *StreamingProcessor) Finish() ([]byte, *ClaudeUsage) {
InputTokens: p.inputTokens, InputTokens: p.inputTokens,
OutputTokens: p.outputTokens, OutputTokens: p.outputTokens,
CacheReadInputTokens: p.cacheReadTokens, CacheReadInputTokens: p.cacheReadTokens,
ImageOutputTokens: p.imageOutputTokens,
} }
if !p.messageStartSent { if !p.messageStartSent {
...@@ -158,6 +161,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte ...@@ -158,6 +161,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
usage.InputTokens = v1Resp.Response.UsageMetadata.PromptTokenCount - cached usage.InputTokens = v1Resp.Response.UsageMetadata.PromptTokenCount - cached
usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount + v1Resp.Response.UsageMetadata.ThoughtsTokenCount usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount + v1Resp.Response.UsageMetadata.ThoughtsTokenCount
usage.CacheReadInputTokens = cached usage.CacheReadInputTokens = cached
usage.ImageOutputTokens = v1Resp.Response.UsageMetadata.ImageOutputTokens()
} }
responseID := v1Resp.ResponseID responseID := v1Resp.ResponseID
...@@ -485,6 +489,7 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte { ...@@ -485,6 +489,7 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
InputTokens: p.inputTokens, InputTokens: p.inputTokens,
OutputTokens: p.outputTokens, OutputTokens: p.outputTokens,
CacheReadInputTokens: p.cacheReadTokens, CacheReadInputTokens: p.cacheReadTokens,
ImageOutputTokens: p.imageOutputTokens,
} }
deltaEvent := map[string]any{ deltaEvent := map[string]any{
......
...@@ -175,6 +175,13 @@ type UserBreakdownDimension struct { ...@@ -175,6 +175,13 @@ type UserBreakdownDimension struct {
ModelType string // "requested", "upstream", or "mapping" ModelType string // "requested", "upstream", or "mapping"
Endpoint string // filter by endpoint value (non-empty to enable) Endpoint string // filter by endpoint value (non-empty to enable)
EndpointType string // "inbound", "upstream", or "path" EndpointType string // "inbound", "upstream", or "path"
// Additional filter conditions
UserID int64 // filter by user_id (>0 to enable)
APIKeyID int64 // filter by api_key_id (>0 to enable)
AccountID int64 // filter by account_id (>0 to enable)
RequestType *int16 // filter by request_type (non-nil to enable)
Stream *bool // filter by stream flag (non-nil to enable)
BillingType *int8 // filter by billing_type (non-nil to enable)
} }
// APIKeyUsageTrendPoint represents API key usage trend data point // APIKeyUsageTrendPoint represents API key usage trend data point
...@@ -230,6 +237,7 @@ type UsageLogFilters struct { ...@@ -230,6 +237,7 @@ type UsageLogFilters struct {
RequestType *int16 RequestType *int16
Stream *bool Stream *bool
BillingType *int8 BillingType *int8
BillingMode string
StartTime *time.Time StartTime *time.Time
EndTime *time.Time EndTime *time.Time
// ExactTotal requests exact COUNT(*) for pagination. Default false for fast large-table paging. // ExactTotal requests exact COUNT(*) for pagination. Default false for fast large-table paging.
......
package repository
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
)
type channelRepository struct {
db *sql.DB
}
// NewChannelRepository 创建渠道数据访问实例
func NewChannelRepository(db *sql.DB) service.ChannelRepository {
return &channelRepository{db: db}
}
// runInTx 在事务中执行 fn,成功 commit,失败 rollback。
func (r *channelRepository) runInTx(ctx context.Context, fn func(tx *sql.Tx) error) error {
tx, err := r.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("begin tx: %w", err)
}
defer func() { _ = tx.Rollback() }()
if err := fn(tx); err != nil {
return err
}
return tx.Commit()
}
func (r *channelRepository) Create(ctx context.Context, channel *service.Channel) error {
return r.runInTx(ctx, func(tx *sql.Tx) error {
modelMappingJSON, err := marshalModelMapping(channel.ModelMapping)
if err != nil {
return err
}
err = tx.QueryRowContext(ctx,
`INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models) VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id, created_at, updated_at`,
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels,
).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt)
if err != nil {
if isUniqueViolation(err) {
return service.ErrChannelExists
}
return fmt.Errorf("insert channel: %w", err)
}
// 设置分组关联
if len(channel.GroupIDs) > 0 {
if err := setGroupIDsTx(ctx, tx, channel.ID, channel.GroupIDs); err != nil {
return err
}
}
// 设置模型定价
if len(channel.ModelPricing) > 0 {
if err := replaceModelPricingTx(ctx, tx, channel.ID, channel.ModelPricing); err != nil {
return err
}
}
return nil
})
}
func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Channel, error) {
ch := &service.Channel{}
var modelMappingJSON []byte
err := r.db.QueryRowContext(ctx,
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, created_at, updated_at
FROM channels WHERE id = $1`, id,
).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt)
if err == sql.ErrNoRows {
return nil, service.ErrChannelNotFound
}
if err != nil {
return nil, fmt.Errorf("get channel: %w", err)
}
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
groupIDs, err := r.GetGroupIDs(ctx, id)
if err != nil {
return nil, err
}
ch.GroupIDs = groupIDs
pricing, err := r.ListModelPricing(ctx, id)
if err != nil {
return nil, err
}
ch.ModelPricing = pricing
return ch, nil
}
func (r *channelRepository) Update(ctx context.Context, channel *service.Channel) error {
return r.runInTx(ctx, func(tx *sql.Tx) error {
modelMappingJSON, err := marshalModelMapping(channel.ModelMapping)
if err != nil {
return err
}
result, err := tx.ExecContext(ctx,
`UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, updated_at = NOW()
WHERE id = $7`,
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.ID,
)
if err != nil {
if isUniqueViolation(err) {
return service.ErrChannelExists
}
return fmt.Errorf("update channel: %w", err)
}
rows, _ := result.RowsAffected()
if rows == 0 {
return service.ErrChannelNotFound
}
// 更新分组关联
if channel.GroupIDs != nil {
if err := setGroupIDsTx(ctx, tx, channel.ID, channel.GroupIDs); err != nil {
return err
}
}
// 更新模型定价
if channel.ModelPricing != nil {
if err := replaceModelPricingTx(ctx, tx, channel.ID, channel.ModelPricing); err != nil {
return err
}
}
return nil
})
}
func (r *channelRepository) Delete(ctx context.Context, id int64) error {
result, err := r.db.ExecContext(ctx, `DELETE FROM channels WHERE id = $1`, id)
if err != nil {
return fmt.Errorf("delete channel: %w", err)
}
rows, _ := result.RowsAffected()
if rows == 0 {
return service.ErrChannelNotFound
}
return nil
}
func (r *channelRepository) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]service.Channel, *pagination.PaginationResult, error) {
where := []string{"1=1"}
args := []any{}
argIdx := 1
if status != "" {
where = append(where, fmt.Sprintf("c.status = $%d", argIdx))
args = append(args, status)
argIdx++
}
if search != "" {
where = append(where, fmt.Sprintf("(c.name ILIKE $%d OR c.description ILIKE $%d)", argIdx, argIdx))
args = append(args, "%"+escapeLike(search)+"%")
argIdx++
}
whereClause := strings.Join(where, " AND ")
// 计数
var total int64
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM channels c WHERE %s", whereClause)
if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
return nil, nil, fmt.Errorf("count channels: %w", err)
}
pageSize := params.Limit() // 约束在 [1, 100]
page := params.Page
if page < 1 {
page = 1
}
offset := (page - 1) * pageSize
// 查询 channel 列表
dataQuery := fmt.Sprintf(
`SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.created_at, c.updated_at
FROM channels c WHERE %s ORDER BY c.id ASC LIMIT $%d OFFSET $%d`,
whereClause, argIdx, argIdx+1,
)
args = append(args, pageSize, offset)
rows, err := r.db.QueryContext(ctx, dataQuery, args...)
if err != nil {
return nil, nil, fmt.Errorf("query channels: %w", err)
}
defer func() { _ = rows.Close() }()
var channels []service.Channel
var channelIDs []int64
for rows.Next() {
var ch service.Channel
var modelMappingJSON []byte
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
return nil, nil, fmt.Errorf("scan channel: %w", err)
}
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
channels = append(channels, ch)
channelIDs = append(channelIDs, ch.ID)
}
if err := rows.Err(); err != nil {
return nil, nil, fmt.Errorf("iterate channels: %w", err)
}
// 批量加载分组 ID 和模型定价(避免 N+1)
if len(channelIDs) > 0 {
groupMap, err := r.batchLoadGroupIDs(ctx, channelIDs)
if err != nil {
return nil, nil, err
}
pricingMap, err := r.batchLoadModelPricing(ctx, channelIDs)
if err != nil {
return nil, nil, err
}
for i := range channels {
channels[i].GroupIDs = groupMap[channels[i].ID]
channels[i].ModelPricing = pricingMap[channels[i].ID]
}
}
pages := 0
if total > 0 {
pages = int((total + int64(pageSize) - 1) / int64(pageSize))
}
paginationResult := &pagination.PaginationResult{
Total: total,
Page: page,
PageSize: pageSize,
Pages: pages,
}
return channels, paginationResult, nil
}
func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) {
rows, err := r.db.QueryContext(ctx,
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, created_at, updated_at FROM channels ORDER BY id`,
)
if err != nil {
return nil, fmt.Errorf("query all channels: %w", err)
}
defer func() { _ = rows.Close() }()
var channels []service.Channel
var channelIDs []int64
for rows.Next() {
var ch service.Channel
var modelMappingJSON []byte
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
return nil, fmt.Errorf("scan channel: %w", err)
}
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
channels = append(channels, ch)
channelIDs = append(channelIDs, ch.ID)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate channels: %w", err)
}
if len(channelIDs) == 0 {
return channels, nil
}
// 批量加载分组 ID
groupMap, err := r.batchLoadGroupIDs(ctx, channelIDs)
if err != nil {
return nil, err
}
// 批量加载模型定价
pricingMap, err := r.batchLoadModelPricing(ctx, channelIDs)
if err != nil {
return nil, err
}
for i := range channels {
channels[i].GroupIDs = groupMap[channels[i].ID]
channels[i].ModelPricing = pricingMap[channels[i].ID]
}
return channels, nil
}
// --- 批量加载辅助方法 ---
// batchLoadGroupIDs 批量加载多个渠道的分组 ID
func (r *channelRepository) batchLoadGroupIDs(ctx context.Context, channelIDs []int64) (map[int64][]int64, error) {
rows, err := r.db.QueryContext(ctx,
`SELECT channel_id, group_id FROM channel_groups
WHERE channel_id = ANY($1) ORDER BY channel_id, group_id`,
pq.Array(channelIDs),
)
if err != nil {
return nil, fmt.Errorf("batch load group ids: %w", err)
}
defer func() { _ = rows.Close() }()
groupMap := make(map[int64][]int64, len(channelIDs))
for rows.Next() {
var channelID, groupID int64
if err := rows.Scan(&channelID, &groupID); err != nil {
return nil, fmt.Errorf("scan group id: %w", err)
}
groupMap[channelID] = append(groupMap[channelID], groupID)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate group ids: %w", err)
}
return groupMap, nil
}
func (r *channelRepository) ExistsByName(ctx context.Context, name string) (bool, error) {
var exists bool
err := r.db.QueryRowContext(ctx,
`SELECT EXISTS(SELECT 1 FROM channels WHERE name = $1)`, name,
).Scan(&exists)
return exists, err
}
func (r *channelRepository) ExistsByNameExcluding(ctx context.Context, name string, excludeID int64) (bool, error) {
var exists bool
err := r.db.QueryRowContext(ctx,
`SELECT EXISTS(SELECT 1 FROM channels WHERE name = $1 AND id != $2)`, name, excludeID,
).Scan(&exists)
return exists, err
}
// --- 分组关联 ---
func (r *channelRepository) GetGroupIDs(ctx context.Context, channelID int64) ([]int64, error) {
rows, err := r.db.QueryContext(ctx,
`SELECT group_id FROM channel_groups WHERE channel_id = $1 ORDER BY group_id`, channelID,
)
if err != nil {
return nil, fmt.Errorf("get group ids: %w", err)
}
defer func() { _ = rows.Close() }()
var ids []int64
for rows.Next() {
var id int64
if err := rows.Scan(&id); err != nil {
return nil, fmt.Errorf("scan group id: %w", err)
}
ids = append(ids, id)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate group ids: %w", err)
}
return ids, nil
}
func (r *channelRepository) SetGroupIDs(ctx context.Context, channelID int64, groupIDs []int64) error {
return setGroupIDsTx(ctx, r.db, channelID, groupIDs)
}
func (r *channelRepository) GetChannelIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
var channelID int64
err := r.db.QueryRowContext(ctx,
`SELECT channel_id FROM channel_groups WHERE group_id = $1`, groupID,
).Scan(&channelID)
if err == sql.ErrNoRows {
return 0, nil
}
return channelID, err
}
func (r *channelRepository) GetGroupsInOtherChannels(ctx context.Context, channelID int64, groupIDs []int64) ([]int64, error) {
if len(groupIDs) == 0 {
return nil, nil
}
rows, err := r.db.QueryContext(ctx,
`SELECT group_id FROM channel_groups WHERE group_id = ANY($1) AND channel_id != $2`,
pq.Array(groupIDs), channelID,
)
if err != nil {
return nil, fmt.Errorf("get groups in other channels: %w", err)
}
defer func() { _ = rows.Close() }()
var conflicting []int64
for rows.Next() {
var id int64
if err := rows.Scan(&id); err != nil {
return nil, fmt.Errorf("scan conflicting group id: %w", err)
}
conflicting = append(conflicting, id)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate conflicting group ids: %w", err)
}
return conflicting, nil
}
// marshalModelMapping 将 model mapping 序列化为嵌套 JSON 字节
// 格式:{"platform": {"src": "dst"}, ...}
func marshalModelMapping(m map[string]map[string]string) ([]byte, error) {
if len(m) == 0 {
return []byte("{}"), nil
}
data, err := json.Marshal(m)
if err != nil {
return nil, fmt.Errorf("marshal model_mapping: %w", err)
}
return data, nil
}
// unmarshalModelMapping 将 JSON 字节反序列化为嵌套 model mapping
func unmarshalModelMapping(data []byte) map[string]map[string]string {
if len(data) == 0 {
return nil
}
var m map[string]map[string]string
if err := json.Unmarshal(data, &m); err != nil {
return nil
}
return m
}
// GetGroupPlatforms 批量查询分组 ID 对应的平台
func (r *channelRepository) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) {
if len(groupIDs) == 0 {
return make(map[int64]string), nil
}
rows, err := r.db.QueryContext(ctx,
`SELECT id, platform FROM groups WHERE id = ANY($1)`,
pq.Array(groupIDs),
)
if err != nil {
return nil, fmt.Errorf("get group platforms: %w", err)
}
defer rows.Close() //nolint:errcheck
result := make(map[int64]string, len(groupIDs))
for rows.Next() {
var id int64
var platform string
if err := rows.Scan(&id, &platform); err != nil {
return nil, fmt.Errorf("scan group platform: %w", err)
}
result[id] = platform
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate group platforms: %w", err)
}
return result, nil
}
package repository
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"strings"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
)
// --- 模型定价 ---
func (r *channelRepository) ListModelPricing(ctx context.Context, channelID int64) ([]service.ChannelModelPricing, error) {
rows, err := r.db.QueryContext(ctx,
`SELECT id, channel_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at
FROM channel_model_pricing WHERE channel_id = $1 ORDER BY id`, channelID,
)
if err != nil {
return nil, fmt.Errorf("list model pricing: %w", err)
}
defer func() { _ = rows.Close() }()
result, pricingIDs, err := scanModelPricingRows(rows)
if err != nil {
return nil, err
}
if len(pricingIDs) > 0 {
intervalMap, err := r.batchLoadIntervals(ctx, pricingIDs)
if err != nil {
return nil, err
}
for i := range result {
result[i].Intervals = intervalMap[result[i].ID]
}
}
return result, nil
}
func (r *channelRepository) CreateModelPricing(ctx context.Context, pricing *service.ChannelModelPricing) error {
return createModelPricingExec(ctx, r.db, pricing)
}
func (r *channelRepository) UpdateModelPricing(ctx context.Context, pricing *service.ChannelModelPricing) error {
modelsJSON, err := json.Marshal(pricing.Models)
if err != nil {
return fmt.Errorf("marshal models: %w", err)
}
billingMode := pricing.BillingMode
if billingMode == "" {
billingMode = service.BillingModeToken
}
result, err := r.db.ExecContext(ctx,
`UPDATE channel_model_pricing
SET models = $1, billing_mode = $2, input_price = $3, output_price = $4, cache_write_price = $5, cache_read_price = $6, image_output_price = $7, per_request_price = $8, platform = $9, updated_at = NOW()
WHERE id = $10`,
modelsJSON, billingMode, pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice,
pricing.ImageOutputPrice, pricing.PerRequestPrice, pricing.Platform, pricing.ID,
)
if err != nil {
return fmt.Errorf("update model pricing: %w", err)
}
rows, _ := result.RowsAffected()
if rows == 0 {
return fmt.Errorf("pricing entry not found: %d", pricing.ID)
}
return nil
}
func (r *channelRepository) DeleteModelPricing(ctx context.Context, id int64) error {
_, err := r.db.ExecContext(ctx, `DELETE FROM channel_model_pricing WHERE id = $1`, id)
if err != nil {
return fmt.Errorf("delete model pricing: %w", err)
}
return nil
}
func (r *channelRepository) ReplaceModelPricing(ctx context.Context, channelID int64, pricingList []service.ChannelModelPricing) error {
return r.runInTx(ctx, func(tx *sql.Tx) error {
return replaceModelPricingTx(ctx, tx, channelID, pricingList)
})
}
// --- 批量加载辅助方法 ---
// batchLoadModelPricing 批量加载多个渠道的模型定价(含区间)
func (r *channelRepository) batchLoadModelPricing(ctx context.Context, channelIDs []int64) (map[int64][]service.ChannelModelPricing, error) {
rows, err := r.db.QueryContext(ctx,
`SELECT id, channel_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at
FROM channel_model_pricing WHERE channel_id = ANY($1) ORDER BY channel_id, id`,
pq.Array(channelIDs),
)
if err != nil {
return nil, fmt.Errorf("batch load model pricing: %w", err)
}
defer func() { _ = rows.Close() }()
allPricing, allPricingIDs, err := scanModelPricingRows(rows)
if err != nil {
return nil, err
}
// 按 channelID 分组
pricingMap := make(map[int64][]service.ChannelModelPricing, len(channelIDs))
for _, p := range allPricing {
pricingMap[p.ChannelID] = append(pricingMap[p.ChannelID], p)
}
// 批量加载所有区间
if len(allPricingIDs) > 0 {
intervalMap, err := r.batchLoadIntervals(ctx, allPricingIDs)
if err != nil {
return nil, err
}
for chID := range pricingMap {
for i := range pricingMap[chID] {
pricingMap[chID][i].Intervals = intervalMap[pricingMap[chID][i].ID]
}
}
}
return pricingMap, nil
}
// batchLoadIntervals 批量加载多个定价条目的区间
func (r *channelRepository) batchLoadIntervals(ctx context.Context, pricingIDs []int64) (map[int64][]service.PricingInterval, error) {
rows, err := r.db.QueryContext(ctx,
`SELECT id, pricing_id, min_tokens, max_tokens, tier_label,
input_price, output_price, cache_write_price, cache_read_price,
per_request_price, sort_order, created_at, updated_at
FROM channel_pricing_intervals
WHERE pricing_id = ANY($1) ORDER BY pricing_id, sort_order, id`,
pq.Array(pricingIDs),
)
if err != nil {
return nil, fmt.Errorf("batch load intervals: %w", err)
}
defer func() { _ = rows.Close() }()
intervalMap := make(map[int64][]service.PricingInterval, len(pricingIDs))
for rows.Next() {
var iv service.PricingInterval
if err := rows.Scan(
&iv.ID, &iv.PricingID, &iv.MinTokens, &iv.MaxTokens, &iv.TierLabel,
&iv.InputPrice, &iv.OutputPrice, &iv.CacheWritePrice, &iv.CacheReadPrice,
&iv.PerRequestPrice, &iv.SortOrder, &iv.CreatedAt, &iv.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("scan interval: %w", err)
}
intervalMap[iv.PricingID] = append(intervalMap[iv.PricingID], iv)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate intervals: %w", err)
}
return intervalMap, nil
}
// --- 共享 scan 辅助 ---
// scanModelPricingRows 扫描 model pricing 行,返回结果列表和 ID 列表
func scanModelPricingRows(rows *sql.Rows) ([]service.ChannelModelPricing, []int64, error) {
var result []service.ChannelModelPricing
var pricingIDs []int64
for rows.Next() {
var p service.ChannelModelPricing
var modelsJSON []byte
if err := rows.Scan(
&p.ID, &p.ChannelID, &p.Platform, &modelsJSON, &p.BillingMode,
&p.InputPrice, &p.OutputPrice, &p.CacheWritePrice, &p.CacheReadPrice,
&p.ImageOutputPrice, &p.PerRequestPrice, &p.CreatedAt, &p.UpdatedAt,
); err != nil {
return nil, nil, fmt.Errorf("scan model pricing: %w", err)
}
if err := json.Unmarshal(modelsJSON, &p.Models); err != nil {
p.Models = []string{}
}
pricingIDs = append(pricingIDs, p.ID)
result = append(result, p)
}
if err := rows.Err(); err != nil {
return nil, nil, fmt.Errorf("iterate model pricing: %w", err)
}
return result, pricingIDs, nil
}
// --- 事务内辅助方法 ---
// dbExec 是 *sql.DB 和 *sql.Tx 共享的最小 SQL 执行接口
type dbExec interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
}
func setGroupIDsTx(ctx context.Context, exec dbExec, channelID int64, groupIDs []int64) error {
if _, err := exec.ExecContext(ctx, `DELETE FROM channel_groups WHERE channel_id = $1`, channelID); err != nil {
return fmt.Errorf("delete old group associations: %w", err)
}
if len(groupIDs) == 0 {
return nil
}
_, err := exec.ExecContext(ctx,
`INSERT INTO channel_groups (channel_id, group_id)
SELECT $1, unnest($2::bigint[])`,
channelID, pq.Array(groupIDs),
)
if err != nil {
return fmt.Errorf("insert group associations: %w", err)
}
return nil
}
func createModelPricingExec(ctx context.Context, exec dbExec, pricing *service.ChannelModelPricing) error {
modelsJSON, err := json.Marshal(pricing.Models)
if err != nil {
return fmt.Errorf("marshal models: %w", err)
}
billingMode := pricing.BillingMode
if billingMode == "" {
billingMode = service.BillingModeToken
}
platform := pricing.Platform
if platform == "" {
platform = "anthropic"
}
err = exec.QueryRowContext(ctx,
`INSERT INTO channel_model_pricing (channel_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`,
pricing.ChannelID, platform, modelsJSON, billingMode,
pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice,
pricing.ImageOutputPrice, pricing.PerRequestPrice,
).Scan(&pricing.ID, &pricing.CreatedAt, &pricing.UpdatedAt)
if err != nil {
return fmt.Errorf("insert model pricing: %w", err)
}
for i := range pricing.Intervals {
pricing.Intervals[i].PricingID = pricing.ID
if err := createIntervalExec(ctx, exec, &pricing.Intervals[i]); err != nil {
return err
}
}
return nil
}
func createIntervalExec(ctx context.Context, exec dbExec, iv *service.PricingInterval) error {
return exec.QueryRowContext(ctx,
`INSERT INTO channel_pricing_intervals
(pricing_id, min_tokens, max_tokens, tier_label, input_price, output_price, cache_write_price, cache_read_price, per_request_price, sort_order)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`,
iv.PricingID, iv.MinTokens, iv.MaxTokens, iv.TierLabel,
iv.InputPrice, iv.OutputPrice, iv.CacheWritePrice, iv.CacheReadPrice,
iv.PerRequestPrice, iv.SortOrder,
).Scan(&iv.ID, &iv.CreatedAt, &iv.UpdatedAt)
}
func replaceModelPricingTx(ctx context.Context, exec dbExec, channelID int64, pricingList []service.ChannelModelPricing) error {
if _, err := exec.ExecContext(ctx, `DELETE FROM channel_model_pricing WHERE channel_id = $1`, channelID); err != nil {
return fmt.Errorf("delete old model pricing: %w", err)
}
for i := range pricingList {
pricingList[i].ChannelID = channelID
if err := createModelPricingExec(ctx, exec, &pricingList[i]); err != nil {
return fmt.Errorf("insert model pricing: %w", err)
}
}
return nil
}
// isUniqueViolation 检查 pq 唯一约束违反错误
func isUniqueViolation(err error) bool {
var pqErr *pq.Error
if errors.As(err, &pqErr) && pqErr != nil {
return pqErr.Code == "23505"
}
return false
}
// escapeLike 转义 LIKE/ILIKE 模式中的特殊字符
func escapeLike(s string) string {
s = strings.ReplaceAll(s, `\`, `\\`)
s = strings.ReplaceAll(s, `%`, `\%`)
s = strings.ReplaceAll(s, `_`, `\_`)
return s
}
//go:build unit
package repository
import (
"encoding/json"
"errors"
"fmt"
"testing"
"github.com/lib/pq"
"github.com/stretchr/testify/require"
)
// --- marshalModelMapping ---
func TestMarshalModelMapping(t *testing.T) {
tests := []struct {
name string
input map[string]map[string]string
wantJSON string // expected JSON output (exact match)
}{
{
name: "empty map",
input: map[string]map[string]string{},
wantJSON: "{}",
},
{
name: "nil map",
input: nil,
wantJSON: "{}",
},
{
name: "populated map",
input: map[string]map[string]string{
"openai": {"gpt-4": "gpt-4-turbo"},
},
},
{
name: "nested values",
input: map[string]map[string]string{
"openai": {"*": "gpt-5.4"},
"anthropic": {"claude-old": "claude-new"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := marshalModelMapping(tt.input)
require.NoError(t, err)
if tt.wantJSON != "" {
require.Equal(t, []byte(tt.wantJSON), result)
} else {
// round-trip: unmarshal and compare with input
var parsed map[string]map[string]string
require.NoError(t, json.Unmarshal(result, &parsed))
require.Equal(t, tt.input, parsed)
}
})
}
}
// --- unmarshalModelMapping ---
func TestUnmarshalModelMapping(t *testing.T) {
tests := []struct {
name string
input []byte
wantNil bool
want map[string]map[string]string
}{
{
name: "nil data",
input: nil,
wantNil: true,
},
{
name: "empty data",
input: []byte{},
wantNil: true,
},
{
name: "invalid JSON",
input: []byte("not-json"),
wantNil: true,
},
{
name: "type error - number",
input: []byte("42"),
wantNil: true,
},
{
name: "type error - array",
input: []byte("[1,2,3]"),
wantNil: true,
},
{
name: "valid JSON",
input: []byte(`{"openai":{"gpt-4":"gpt-4-turbo"},"anthropic":{"old":"new"}}`),
want: map[string]map[string]string{
"openai": {"gpt-4": "gpt-4-turbo"},
"anthropic": {"old": "new"},
},
},
{
name: "empty object",
input: []byte("{}"),
want: map[string]map[string]string{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := unmarshalModelMapping(tt.input)
if tt.wantNil {
require.Nil(t, result)
} else {
require.NotNil(t, result)
require.Equal(t, tt.want, result)
}
})
}
}
// --- escapeLike ---
func TestEscapeLike(t *testing.T) {
tests := []struct {
name string
input string
want string
}{
{
name: "no special chars",
input: "hello",
want: "hello",
},
{
name: "backslash",
input: `a\b`,
want: `a\\b`,
},
{
name: "percent",
input: "50%",
want: `50\%`,
},
{
name: "underscore",
input: "a_b",
want: `a\_b`,
},
{
name: "all special chars",
input: `a\b%c_d`,
want: `a\\b\%c\_d`,
},
{
name: "empty string",
input: "",
want: "",
},
{
name: "consecutive special chars",
input: "%_%",
want: `\%\_\%`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, escapeLike(tt.input))
})
}
}
// --- isUniqueViolation ---
func TestIsUniqueViolation(t *testing.T) {
tests := []struct {
name string
err error
want bool
}{
{
name: "unique violation code 23505",
err: &pq.Error{Code: "23505"},
want: true,
},
{
name: "different pq error code",
err: &pq.Error{Code: "23503"},
want: false,
},
{
name: "non-pq error",
err: errors.New("some generic error"),
want: false,
},
{
name: "typed nil pq.Error",
err: func() error {
var pqErr *pq.Error
return pqErr
}(),
want: false,
},
{
name: "bare nil",
err: nil,
want: false,
},
{
name: "wrapped pq error with 23505",
err: fmt.Errorf("wrapped: %w", &pq.Error{Code: "23505"}),
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, isUniqueViolation(tt.err))
})
}
}
...@@ -28,7 +28,7 @@ import ( ...@@ -28,7 +28,7 @@ import (
gocache "github.com/patrickmn/go-cache" gocache "github.com/patrickmn/go-cache"
) )
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at" const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, created_at"
// usageLogInsertArgTypes must stay in the same order as: // usageLogInsertArgTypes must stay in the same order as:
// 1. prepareUsageLogInsert().args // 1. prepareUsageLogInsert().args
...@@ -53,6 +53,8 @@ var usageLogInsertArgTypes = [...]string{ ...@@ -53,6 +53,8 @@ var usageLogInsertArgTypes = [...]string{
"integer", // cache_read_tokens "integer", // cache_read_tokens
"integer", // cache_creation_5m_tokens "integer", // cache_creation_5m_tokens
"integer", // cache_creation_1h_tokens "integer", // cache_creation_1h_tokens
"integer", // image_output_tokens
"numeric", // image_output_cost
"numeric", // input_cost "numeric", // input_cost
"numeric", // output_cost "numeric", // output_cost
"numeric", // cache_creation_cost "numeric", // cache_creation_cost
...@@ -77,6 +79,10 @@ var usageLogInsertArgTypes = [...]string{ ...@@ -77,6 +79,10 @@ var usageLogInsertArgTypes = [...]string{
"text", // inbound_endpoint "text", // inbound_endpoint
"text", // upstream_endpoint "text", // upstream_endpoint
"boolean", // cache_ttl_overridden "boolean", // cache_ttl_overridden
"bigint", // channel_id
"text", // model_mapping_chain
"text", // billing_tier
"text", // billing_mode
"timestamptz", // created_at "timestamptz", // created_at
} }
...@@ -326,6 +332,8 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, ...@@ -326,6 +332,8 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
cache_read_tokens, cache_read_tokens,
cache_creation_5m_tokens, cache_creation_5m_tokens,
cache_creation_1h_tokens, cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost, input_cost,
output_cost, output_cost,
cache_creation_cost, cache_creation_cost,
...@@ -350,14 +358,18 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, ...@@ -350,14 +358,18 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
inbound_endpoint, inbound_endpoint,
upstream_endpoint, upstream_endpoint,
cache_ttl_overridden, cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
billing_mode,
created_at created_at
) VALUES ( ) VALUES (
$1, $2, $3, $4, $5, $6, $7, $1, $2, $3, $4, $5, $6, $7,
$8, $9, $8, $9,
$10, $11, $12, $13, $10, $11, $12, $13,
$14, $15, $14, $15, $16, $17,
$16, $17, $18, $19, $20, $21, $18, $19, $20, $21, $22, $23,
$22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40 $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
) )
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at RETURNING id, created_at
...@@ -758,6 +770,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage ...@@ -758,6 +770,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
cache_read_tokens, cache_read_tokens,
cache_creation_5m_tokens, cache_creation_5m_tokens,
cache_creation_1h_tokens, cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost, input_cost,
output_cost, output_cost,
cache_creation_cost, cache_creation_cost,
...@@ -782,10 +796,14 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage ...@@ -782,10 +796,14 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
inbound_endpoint, inbound_endpoint,
upstream_endpoint, upstream_endpoint,
cache_ttl_overridden, cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
billing_mode,
created_at created_at
) AS (VALUES `) ) AS (VALUES `)
args := make([]any, 0, len(keys)*39) args := make([]any, 0, len(keys)*47)
argPos := 1 argPos := 1
for idx, key := range keys { for idx, key := range keys {
if idx > 0 { if idx > 0 {
...@@ -829,6 +847,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage ...@@ -829,6 +847,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
cache_read_tokens, cache_read_tokens,
cache_creation_5m_tokens, cache_creation_5m_tokens,
cache_creation_1h_tokens, cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost, input_cost,
output_cost, output_cost,
cache_creation_cost, cache_creation_cost,
...@@ -853,6 +873,10 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage ...@@ -853,6 +873,10 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
inbound_endpoint, inbound_endpoint,
upstream_endpoint, upstream_endpoint,
cache_ttl_overridden, cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
billing_mode,
created_at created_at
) )
SELECT SELECT
...@@ -871,6 +895,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage ...@@ -871,6 +895,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
cache_read_tokens, cache_read_tokens,
cache_creation_5m_tokens, cache_creation_5m_tokens,
cache_creation_1h_tokens, cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost, input_cost,
output_cost, output_cost,
cache_creation_cost, cache_creation_cost,
...@@ -895,6 +921,10 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage ...@@ -895,6 +921,10 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
inbound_endpoint, inbound_endpoint,
upstream_endpoint, upstream_endpoint,
cache_ttl_overridden, cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
billing_mode,
created_at created_at
FROM input FROM input
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
...@@ -953,6 +983,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ...@@ -953,6 +983,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
cache_read_tokens, cache_read_tokens,
cache_creation_5m_tokens, cache_creation_5m_tokens,
cache_creation_1h_tokens, cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost, input_cost,
output_cost, output_cost,
cache_creation_cost, cache_creation_cost,
...@@ -977,10 +1009,14 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ...@@ -977,10 +1009,14 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
inbound_endpoint, inbound_endpoint,
upstream_endpoint, upstream_endpoint,
cache_ttl_overridden, cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
billing_mode,
created_at created_at
) AS (VALUES `) ) AS (VALUES `)
args := make([]any, 0, len(preparedList)*40) args := make([]any, 0, len(preparedList)*46)
argPos := 1 argPos := 1
for idx, prepared := range preparedList { for idx, prepared := range preparedList {
if idx > 0 { if idx > 0 {
...@@ -1021,6 +1057,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ...@@ -1021,6 +1057,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
cache_read_tokens, cache_read_tokens,
cache_creation_5m_tokens, cache_creation_5m_tokens,
cache_creation_1h_tokens, cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost, input_cost,
output_cost, output_cost,
cache_creation_cost, cache_creation_cost,
...@@ -1045,6 +1083,10 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ...@@ -1045,6 +1083,10 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
inbound_endpoint, inbound_endpoint,
upstream_endpoint, upstream_endpoint,
cache_ttl_overridden, cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
billing_mode,
created_at created_at
) )
SELECT SELECT
...@@ -1063,6 +1105,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ...@@ -1063,6 +1105,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
cache_read_tokens, cache_read_tokens,
cache_creation_5m_tokens, cache_creation_5m_tokens,
cache_creation_1h_tokens, cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost, input_cost,
output_cost, output_cost,
cache_creation_cost, cache_creation_cost,
...@@ -1087,6 +1131,10 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ...@@ -1087,6 +1131,10 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
inbound_endpoint, inbound_endpoint,
upstream_endpoint, upstream_endpoint,
cache_ttl_overridden, cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
billing_mode,
created_at created_at
FROM input FROM input
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
...@@ -1113,6 +1161,8 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared ...@@ -1113,6 +1161,8 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
cache_read_tokens, cache_read_tokens,
cache_creation_5m_tokens, cache_creation_5m_tokens,
cache_creation_1h_tokens, cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost, input_cost,
output_cost, output_cost,
cache_creation_cost, cache_creation_cost,
...@@ -1137,14 +1187,18 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared ...@@ -1137,14 +1187,18 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
inbound_endpoint, inbound_endpoint,
upstream_endpoint, upstream_endpoint,
cache_ttl_overridden, cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
billing_mode,
created_at created_at
) VALUES ( ) VALUES (
$1, $2, $3, $4, $5, $6, $7, $1, $2, $3, $4, $5, $6, $7,
$8, $9, $8, $9,
$10, $11, $12, $13, $10, $11, $12, $13,
$14, $15, $14, $15, $16, $17,
$16, $17, $18, $19, $20, $21, $18, $19, $20, $21, $22, $23,
$22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40 $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
) )
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
`, prepared.args...) `, prepared.args...)
...@@ -1176,6 +1230,10 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { ...@@ -1176,6 +1230,10 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
reasoningEffort := nullString(log.ReasoningEffort) reasoningEffort := nullString(log.ReasoningEffort)
inboundEndpoint := nullString(log.InboundEndpoint) inboundEndpoint := nullString(log.InboundEndpoint)
upstreamEndpoint := nullString(log.UpstreamEndpoint) upstreamEndpoint := nullString(log.UpstreamEndpoint)
channelID := nullInt64(log.ChannelID)
modelMappingChain := nullString(log.ModelMappingChain)
billingTier := nullString(log.BillingTier)
billingMode := nullString(log.BillingMode)
requestedModel := strings.TrimSpace(log.RequestedModel) requestedModel := strings.TrimSpace(log.RequestedModel)
if requestedModel == "" { if requestedModel == "" {
requestedModel = strings.TrimSpace(log.Model) requestedModel = strings.TrimSpace(log.Model)
...@@ -1208,6 +1266,8 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { ...@@ -1208,6 +1266,8 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
log.CacheReadTokens, log.CacheReadTokens,
log.CacheCreation5mTokens, log.CacheCreation5mTokens,
log.CacheCreation1hTokens, log.CacheCreation1hTokens,
log.ImageOutputTokens,
log.ImageOutputCost,
log.InputCost, log.InputCost,
log.OutputCost, log.OutputCost,
log.CacheCreationCost, log.CacheCreationCost,
...@@ -1232,6 +1292,10 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { ...@@ -1232,6 +1292,10 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
inboundEndpoint, inboundEndpoint,
upstreamEndpoint, upstreamEndpoint,
log.CacheTTLOverridden, log.CacheTTLOverridden,
channelID,
modelMappingChain,
billingTier,
billingMode,
createdAt, createdAt,
}, },
} }
...@@ -2564,8 +2628,8 @@ type UsageLogFilters = usagestats.UsageLogFilters ...@@ -2564,8 +2628,8 @@ type UsageLogFilters = usagestats.UsageLogFilters
// ListWithFilters lists usage logs with optional filters (for admin) // ListWithFilters lists usage logs with optional filters (for admin)
func (r *usageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
conditions := make([]string, 0, 8) conditions := make([]string, 0, 9)
args := make([]any, 0, 8) args := make([]any, 0, 9)
if filters.UserID > 0 { if filters.UserID > 0 {
conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1)) conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1))
...@@ -2589,6 +2653,10 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat ...@@ -2589,6 +2653,10 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1)) conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
args = append(args, int16(*filters.BillingType)) args = append(args, int16(*filters.BillingType))
} }
if filters.BillingMode != "" {
conditions = append(conditions, fmt.Sprintf("billing_mode = $%d", len(args)+1))
args = append(args, filters.BillingMode)
}
if filters.StartTime != nil { if filters.StartTime != nil {
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1)) conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1))
args = append(args, *filters.StartTime) args = append(args, *filters.StartTime)
...@@ -3096,6 +3164,30 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim ...@@ -3096,6 +3164,30 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim
query += fmt.Sprintf(" AND %s = $%d", col, len(args)+1) query += fmt.Sprintf(" AND %s = $%d", col, len(args)+1)
args = append(args, dim.Endpoint) args = append(args, dim.Endpoint)
} }
if dim.UserID > 0 {
query += fmt.Sprintf(" AND ul.user_id = $%d", len(args)+1)
args = append(args, dim.UserID)
}
if dim.APIKeyID > 0 {
query += fmt.Sprintf(" AND ul.api_key_id = $%d", len(args)+1)
args = append(args, dim.APIKeyID)
}
if dim.AccountID > 0 {
query += fmt.Sprintf(" AND ul.account_id = $%d", len(args)+1)
args = append(args, dim.AccountID)
}
if dim.RequestType != nil {
query += fmt.Sprintf(" AND ul.request_type = $%d", len(args)+1)
args = append(args, *dim.RequestType)
}
if dim.Stream != nil {
query += fmt.Sprintf(" AND ul.stream = $%d", len(args)+1)
args = append(args, *dim.Stream)
}
if dim.BillingType != nil {
query += fmt.Sprintf(" AND ul.billing_type = $%d", len(args)+1)
args = append(args, *dim.BillingType)
}
query += " GROUP BY ul.user_id, u.email ORDER BY actual_cost DESC" query += " GROUP BY ul.user_id, u.email ORDER BY actual_cost DESC"
if limit > 0 { if limit > 0 {
...@@ -3256,6 +3348,10 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us ...@@ -3256,6 +3348,10 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1)) conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
args = append(args, int16(*filters.BillingType)) args = append(args, int16(*filters.BillingType))
} }
if filters.BillingMode != "" {
conditions = append(conditions, fmt.Sprintf("billing_mode = $%d", len(args)+1))
args = append(args, filters.BillingMode)
}
if filters.StartTime != nil { if filters.StartTime != nil {
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1)) conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1))
args = append(args, *filters.StartTime) args = append(args, *filters.StartTime)
...@@ -3935,6 +4031,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -3935,6 +4031,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
cacheReadTokens int cacheReadTokens int
cacheCreation5m int cacheCreation5m int
cacheCreation1h int cacheCreation1h int
imageOutputTokens int
imageOutputCost float64
inputCost float64 inputCost float64
outputCost float64 outputCost float64
cacheCreationCost float64 cacheCreationCost float64
...@@ -3959,6 +4057,10 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -3959,6 +4057,10 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
inboundEndpoint sql.NullString inboundEndpoint sql.NullString
upstreamEndpoint sql.NullString upstreamEndpoint sql.NullString
cacheTTLOverridden bool cacheTTLOverridden bool
channelID sql.NullInt64
modelMappingChain sql.NullString
billingTier sql.NullString
billingMode sql.NullString
createdAt time.Time createdAt time.Time
) )
...@@ -3979,6 +4081,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -3979,6 +4081,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&cacheReadTokens, &cacheReadTokens,
&cacheCreation5m, &cacheCreation5m,
&cacheCreation1h, &cacheCreation1h,
&imageOutputTokens,
&imageOutputCost,
&inputCost, &inputCost,
&outputCost, &outputCost,
&cacheCreationCost, &cacheCreationCost,
...@@ -4003,6 +4107,10 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -4003,6 +4107,10 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&inboundEndpoint, &inboundEndpoint,
&upstreamEndpoint, &upstreamEndpoint,
&cacheTTLOverridden, &cacheTTLOverridden,
&channelID,
&modelMappingChain,
&billingTier,
&billingMode,
&createdAt, &createdAt,
); err != nil { ); err != nil {
return nil, err return nil, err
...@@ -4021,6 +4129,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -4021,6 +4129,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
CacheReadTokens: cacheReadTokens, CacheReadTokens: cacheReadTokens,
CacheCreation5mTokens: cacheCreation5m, CacheCreation5mTokens: cacheCreation5m,
CacheCreation1hTokens: cacheCreation1h, CacheCreation1hTokens: cacheCreation1h,
ImageOutputTokens: imageOutputTokens,
ImageOutputCost: imageOutputCost,
InputCost: inputCost, InputCost: inputCost,
OutputCost: outputCost, OutputCost: outputCost,
CacheCreationCost: cacheCreationCost, CacheCreationCost: cacheCreationCost,
...@@ -4087,6 +4197,19 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -4087,6 +4197,19 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if upstreamModel.Valid { if upstreamModel.Valid {
log.UpstreamModel = &upstreamModel.String log.UpstreamModel = &upstreamModel.String
} }
if channelID.Valid {
value := channelID.Int64
log.ChannelID = &value
}
if modelMappingChain.Valid {
log.ModelMappingChain = &modelMappingChain.String
}
if billingTier.Valid {
log.BillingTier = &billingTier.String
}
if billingMode.Valid {
log.BillingMode = &billingMode.String
}
return log, nil return log, nil
} }
......
...@@ -56,6 +56,8 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) { ...@@ -56,6 +56,8 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
log.CacheReadTokens, log.CacheReadTokens,
log.CacheCreation5mTokens, log.CacheCreation5mTokens,
log.CacheCreation1hTokens, log.CacheCreation1hTokens,
log.ImageOutputTokens,
log.ImageOutputCost,
log.InputCost, log.InputCost,
log.OutputCost, log.OutputCost,
log.CacheCreationCost, log.CacheCreationCost,
...@@ -80,6 +82,10 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) { ...@@ -80,6 +82,10 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
sqlmock.AnyArg(), // inbound_endpoint sqlmock.AnyArg(), // inbound_endpoint
sqlmock.AnyArg(), // upstream_endpoint sqlmock.AnyArg(), // upstream_endpoint
log.CacheTTLOverridden, log.CacheTTLOverridden,
sqlmock.AnyArg(), // channel_id
sqlmock.AnyArg(), // model_mapping_chain
sqlmock.AnyArg(), // billing_tier
sqlmock.AnyArg(), // billing_mode
createdAt, createdAt,
). ).
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(99), createdAt)) WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(99), createdAt))
...@@ -129,6 +135,8 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) { ...@@ -129,6 +135,8 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
log.CacheReadTokens, log.CacheReadTokens,
log.CacheCreation5mTokens, log.CacheCreation5mTokens,
log.CacheCreation1hTokens, log.CacheCreation1hTokens,
log.ImageOutputTokens,
log.ImageOutputCost,
log.InputCost, log.InputCost,
log.OutputCost, log.OutputCost,
log.CacheCreationCost, log.CacheCreationCost,
...@@ -153,6 +161,10 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) { ...@@ -153,6 +161,10 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
sqlmock.AnyArg(), sqlmock.AnyArg(),
sqlmock.AnyArg(), sqlmock.AnyArg(),
log.CacheTTLOverridden, log.CacheTTLOverridden,
sqlmock.AnyArg(), // channel_id
sqlmock.AnyArg(), // model_mapping_chain
sqlmock.AnyArg(), // billing_tier
sqlmock.AnyArg(), // billing_mode
createdAt, createdAt,
). ).
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt)) WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt))
...@@ -439,6 +451,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { ...@@ -439,6 +451,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
4, // cache_read_tokens 4, // cache_read_tokens
5, // cache_creation_5m_tokens 5, // cache_creation_5m_tokens
6, // cache_creation_1h_tokens 6, // cache_creation_1h_tokens
0, // image_output_tokens
0.0, // image_output_cost
0.1, // input_cost 0.1, // input_cost
0.2, // output_cost 0.2, // output_cost
0.3, // cache_creation_cost 0.3, // cache_creation_cost
...@@ -463,6 +477,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { ...@@ -463,6 +477,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{}, sql.NullString{},
sql.NullString{}, sql.NullString{},
false, false,
sql.NullInt64{}, // channel_id
sql.NullString{}, // model_mapping_chain
sql.NullString{}, // billing_tier
sql.NullString{}, // billing_mode
now, now,
}}) }})
require.NoError(t, err) require.NoError(t, err)
...@@ -487,6 +505,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { ...@@ -487,6 +505,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullInt64{}, sql.NullInt64{},
sql.NullInt64{}, sql.NullInt64{},
1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6,
0, 0.0, // image_output_tokens, image_output_cost
0.1, 0.2, 0.3, 0.4, 1.0, 0.9, 0.1, 0.2, 0.3, 0.4, 1.0, 0.9,
1.0, 1.0,
sql.NullFloat64{}, sql.NullFloat64{},
...@@ -506,6 +525,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { ...@@ -506,6 +525,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{}, sql.NullString{},
sql.NullString{}, sql.NullString{},
false, false,
sql.NullInt64{}, // channel_id
sql.NullString{}, // model_mapping_chain
sql.NullString{}, // billing_tier
sql.NullString{}, // billing_mode
now, now,
}}) }})
require.NoError(t, err) require.NoError(t, err)
...@@ -530,6 +553,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { ...@@ -530,6 +553,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullInt64{}, sql.NullInt64{},
sql.NullInt64{}, sql.NullInt64{},
1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6,
0, 0.0, // image_output_tokens, image_output_cost
0.1, 0.2, 0.3, 0.4, 1.0, 0.9, 0.1, 0.2, 0.3, 0.4, 1.0, 0.9,
1.0, 1.0,
sql.NullFloat64{}, sql.NullFloat64{},
...@@ -549,6 +573,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { ...@@ -549,6 +573,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{}, sql.NullString{},
sql.NullString{}, sql.NullString{},
false, false,
sql.NullInt64{}, // channel_id
sql.NullString{}, // model_mapping_chain
sql.NullString{}, // billing_tier
sql.NullString{}, // billing_mode
now, now,
}}) }})
require.NoError(t, err) require.NoError(t, err)
......
...@@ -74,6 +74,7 @@ var ProviderSet = wire.NewSet( ...@@ -74,6 +74,7 @@ var ProviderSet = wire.NewSet(
NewUserGroupRateRepository, NewUserGroupRateRepository,
NewErrorPassthroughRepository, NewErrorPassthroughRepository,
NewTLSFingerprintProfileRepository, NewTLSFingerprintProfileRepository,
NewChannelRepository,
// Cache implementations // Cache implementations
NewGatewayCache, NewGatewayCache,
......
...@@ -87,6 +87,9 @@ func RegisterAdminRoutes( ...@@ -87,6 +87,9 @@ func RegisterAdminRoutes(
// 定时测试计划 // 定时测试计划
registerScheduledTestRoutes(admin, h) registerScheduledTestRoutes(admin, h)
// 渠道管理
registerChannelRoutes(admin, h)
} }
} }
...@@ -567,3 +570,15 @@ func registerTLSFingerprintProfileRoutes(admin *gin.RouterGroup, h *handler.Hand ...@@ -567,3 +570,15 @@ func registerTLSFingerprintProfileRoutes(admin *gin.RouterGroup, h *handler.Hand
profiles.DELETE("/:id", h.Admin.TLSFingerprintProfile.Delete) profiles.DELETE("/:id", h.Admin.TLSFingerprintProfile.Delete)
} }
} }
func registerChannelRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
channels := admin.Group("/channels")
{
channels.GET("", h.Admin.Channel.List)
channels.GET("/model-pricing", h.Admin.Channel.GetModelDefaultPricing)
channels.GET("/:id", h.Admin.Channel.GetByID)
channels.POST("", h.Admin.Channel.Create)
channels.PUT("/:id", h.Admin.Channel.Update)
channels.DELETE("/:id", h.Admin.Channel.Delete)
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment