Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
7b83d6e7
Commit
7b83d6e7
authored
Apr 05, 2026
by
陈曦
Browse files
Merge remote-tracking branch 'upstream/main'
parents
daa2e6df
dbb248df
Changes
106
Hide whitespace changes
Inline
Side-by-side
backend/internal/handler/gateway_handler.go
View file @
7b83d6e7
...
...
@@ -158,6 +158,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
reqStream
:=
parsedReq
.
Stream
reqLog
=
reqLog
.
With
(
zap
.
String
(
"model"
,
reqModel
),
zap
.
Bool
(
"stream"
,
reqStream
))
// 解析渠道级模型映射
channelMapping
,
_
:=
h
.
gatewayService
.
ResolveChannelMappingAndRestrict
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
reqModel
)
// 设置 max_tokens=1 + haiku 探测请求标识到 context 中
// 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断
if
isMaxTokensOneHaikuRequest
(
reqModel
,
parsedReq
.
MaxTokens
,
reqStream
)
{
...
...
@@ -292,7 +295,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
for
{
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
reqModel
,
fs
.
FailedAccountIDs
,
""
)
// Gemini 不使用会话限制
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
reqModel
,
fs
.
FailedAccountIDs
,
""
,
int64
(
0
)
)
// Gemini 不使用会话限制
if
err
!=
nil
{
if
len
(
fs
.
FailedAccountIDs
)
==
0
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts: "
+
err
.
Error
(),
streamStarted
)
...
...
@@ -478,6 +481,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
RequestPayloadHash
:
requestPayloadHash
,
ForceCacheBilling
:
fs
.
ForceCacheBilling
,
APIKeyService
:
h
.
apiKeyService
,
ChannelUsageFields
:
channelMapping
.
ToUsageFields
(
reqModel
,
result
.
UpstreamModel
),
});
err
!=
nil
{
logger
.
L
()
.
With
(
zap
.
String
(
"component"
,
"handler.gateway.messages"
),
...
...
@@ -514,7 +518,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
for
{
// 选择支持该模型的账号
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
currentAPIKey
.
GroupID
,
sessionKey
,
reqModel
,
fs
.
FailedAccountIDs
,
parsedReq
.
MetadataUserID
)
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
currentAPIKey
.
GroupID
,
sessionKey
,
reqModel
,
fs
.
FailedAccountIDs
,
parsedReq
.
MetadataUserID
,
int64
(
0
)
)
if
err
!=
nil
{
if
len
(
fs
.
FailedAccountIDs
)
==
0
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts: "
+
err
.
Error
(),
streamStarted
)
...
...
@@ -660,6 +664,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
parsedReq
.
OnUpstreamAccepted
=
queueRelease
// ===== 用户消息串行队列 END =====
// 应用渠道模型映射到请求
if
channelMapping
.
Mapped
{
parsedReq
.
Model
=
channelMapping
.
MappedModel
parsedReq
.
Body
=
h
.
gatewayService
.
ReplaceModelInBody
(
parsedReq
.
Body
,
channelMapping
.
MappedModel
)
body
=
h
.
gatewayService
.
ReplaceModelInBody
(
body
,
channelMapping
.
MappedModel
)
}
// 转发请求 - 根据账号平台分流
var
result
*
service
.
ForwardResult
requestCtx
:=
c
.
Request
.
Context
()
...
...
@@ -810,6 +821,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
RequestPayloadHash
:
requestPayloadHash
,
ForceCacheBilling
:
fs
.
ForceCacheBilling
,
APIKeyService
:
h
.
apiKeyService
,
ChannelUsageFields
:
channelMapping
.
ToUsageFields
(
reqModel
,
result
.
UpstreamModel
),
});
err
!=
nil
{
logger
.
L
()
.
With
(
zap
.
String
(
"component"
,
"handler.gateway.messages"
),
...
...
backend/internal/handler/gateway_handler_chat_completions.go
View file @
7b83d6e7
...
...
@@ -80,6 +80,9 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
setOpsRequestContext
(
c
,
reqModel
,
reqStream
,
body
)
setOpsEndpointContext
(
c
,
""
,
int16
(
service
.
RequestTypeFromLegacy
(
reqStream
,
false
)))
// 解析渠道级模型映射
channelMapping
,
_
:=
h
.
gatewayService
.
ResolveChannelMappingAndRestrict
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
reqModel
)
// Claude Code only restriction
if
apiKey
.
Group
!=
nil
&&
apiKey
.
Group
.
ClaudeCodeOnly
{
h
.
chatCompletionsErrorResponse
(
c
,
http
.
StatusForbidden
,
"permission_error"
,
...
...
@@ -154,7 +157,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
fs
:=
NewFailoverState
(
h
.
maxAccountSwitches
,
false
)
for
{
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionHash
,
reqModel
,
fs
.
FailedAccountIDs
,
""
)
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionHash
,
reqModel
,
fs
.
FailedAccountIDs
,
""
,
int64
(
0
)
)
if
err
!=
nil
{
if
len
(
fs
.
FailedAccountIDs
)
==
0
{
h
.
chatCompletionsErrorResponse
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts: "
+
err
.
Error
())
...
...
@@ -203,7 +206,11 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
// 5. Forward request
writerSizeBeforeForward
:=
c
.
Writer
.
Size
()
result
,
err
:=
h
.
gatewayService
.
ForwardAsChatCompletions
(
c
.
Request
.
Context
(),
c
,
account
,
body
,
parsedReq
)
forwardBody
:=
body
if
channelMapping
.
Mapped
{
forwardBody
=
h
.
gatewayService
.
ReplaceModelInBody
(
body
,
channelMapping
.
MappedModel
)
}
result
,
err
:=
h
.
gatewayService
.
ForwardAsChatCompletions
(
c
.
Request
.
Context
(),
c
,
account
,
forwardBody
,
parsedReq
)
if
accountReleaseFunc
!=
nil
{
accountReleaseFunc
()
...
...
@@ -255,6 +262,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
IPAddress
:
clientIP
,
RequestPayloadHash
:
requestPayloadHash
,
APIKeyService
:
h
.
apiKeyService
,
ChannelUsageFields
:
channelMapping
.
ToUsageFields
(
reqModel
,
result
.
UpstreamModel
),
});
err
!=
nil
{
reqLog
.
Error
(
"gateway.cc.record_usage_failed"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
...
...
backend/internal/handler/gateway_handler_responses.go
View file @
7b83d6e7
...
...
@@ -80,6 +80,9 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
setOpsRequestContext
(
c
,
reqModel
,
reqStream
,
body
)
setOpsEndpointContext
(
c
,
""
,
int16
(
service
.
RequestTypeFromLegacy
(
reqStream
,
false
)))
// 解析渠道级模型映射
channelMapping
,
_
:=
h
.
gatewayService
.
ResolveChannelMappingAndRestrict
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
reqModel
)
// Claude Code only restriction:
// /v1/responses is never a Claude Code endpoint.
// When claude_code_only is enabled, this endpoint is rejected.
...
...
@@ -159,7 +162,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
fs
:=
NewFailoverState
(
h
.
maxAccountSwitches
,
false
)
for
{
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionHash
,
reqModel
,
fs
.
FailedAccountIDs
,
""
)
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionHash
,
reqModel
,
fs
.
FailedAccountIDs
,
""
,
int64
(
0
)
)
if
err
!=
nil
{
if
len
(
fs
.
FailedAccountIDs
)
==
0
{
h
.
responsesErrorResponse
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts: "
+
err
.
Error
())
...
...
@@ -208,7 +211,11 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
// 5. Forward request
writerSizeBeforeForward
:=
c
.
Writer
.
Size
()
result
,
err
:=
h
.
gatewayService
.
ForwardAsResponses
(
c
.
Request
.
Context
(),
c
,
account
,
body
,
parsedReq
)
forwardBody
:=
body
if
channelMapping
.
Mapped
{
forwardBody
=
h
.
gatewayService
.
ReplaceModelInBody
(
body
,
channelMapping
.
MappedModel
)
}
result
,
err
:=
h
.
gatewayService
.
ForwardAsResponses
(
c
.
Request
.
Context
(),
c
,
account
,
forwardBody
,
parsedReq
)
if
accountReleaseFunc
!=
nil
{
accountReleaseFunc
()
...
...
@@ -261,6 +268,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
IPAddress
:
clientIP
,
RequestPayloadHash
:
requestPayloadHash
,
APIKeyService
:
h
.
apiKeyService
,
ChannelUsageFields
:
channelMapping
.
ToUsageFields
(
reqModel
,
result
.
UpstreamModel
),
});
err
!=
nil
{
reqLog
.
Error
(
"gateway.responses.record_usage_failed"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
...
...
backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go
View file @
7b83d6e7
...
...
@@ -161,6 +161,8 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
nil
,
// digestStore
nil
,
// settingService
nil
,
// tlsFPProfileService
nil
,
// channelService
nil
,
// resolver
)
// RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。
...
...
backend/internal/handler/gemini_v1beta_handler.go
View file @
7b83d6e7
...
...
@@ -184,6 +184,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
setOpsRequestContext
(
c
,
modelName
,
stream
,
body
)
setOpsEndpointContext
(
c
,
""
,
int16
(
service
.
RequestTypeFromLegacy
(
stream
,
false
)))
// 解析渠道级模型映射
channelMapping
,
_
:=
h
.
gatewayService
.
ResolveChannelMappingAndRestrict
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
modelName
)
reqModel
:=
modelName
// 保存映射前的原始模型名
if
channelMapping
.
Mapped
{
modelName
=
channelMapping
.
MappedModel
}
// Get subscription (may be nil)
subscription
,
_
:=
middleware
.
GetSubscriptionFromContext
(
c
)
...
...
@@ -353,7 +360,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
for
{
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
modelName
,
fs
.
FailedAccountIDs
,
""
)
// Gemini 不使用会话限制
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
modelName
,
fs
.
FailedAccountIDs
,
""
,
int64
(
0
)
)
// Gemini 不使用会话限制
if
err
!=
nil
{
if
len
(
fs
.
FailedAccountIDs
)
==
0
{
googleError
(
c
,
http
.
StatusServiceUnavailable
,
"No available Gemini accounts: "
+
err
.
Error
())
...
...
@@ -523,6 +530,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
LongContextMultiplier
:
2.0
,
// 超出部分双倍计费
ForceCacheBilling
:
fs
.
ForceCacheBilling
,
APIKeyService
:
h
.
apiKeyService
,
ChannelUsageFields
:
channelMapping
.
ToUsageFields
(
reqModel
,
result
.
UpstreamModel
),
});
err
!=
nil
{
logger
.
L
()
.
With
(
zap
.
String
(
"component"
,
"handler.gemini_v1beta.models"
),
...
...
backend/internal/handler/handler.go
View file @
7b83d6e7
...
...
@@ -30,6 +30,7 @@ type AdminHandlers struct {
TLSFingerprintProfile
*
admin
.
TLSFingerprintProfileHandler
APIKey
*
admin
.
AdminAPIKeyHandler
ScheduledTest
*
admin
.
ScheduledTestHandler
Channel
*
admin
.
ChannelHandler
}
// Handlers contains all HTTP handlers
...
...
backend/internal/handler/openai_chat_completions.go
View file @
7b83d6e7
...
...
@@ -79,6 +79,9 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
setOpsRequestContext
(
c
,
reqModel
,
reqStream
,
body
)
setOpsEndpointContext
(
c
,
""
,
int16
(
service
.
RequestTypeFromLegacy
(
reqStream
,
false
)))
// 解析渠道级模型映射
channelMapping
,
_
:=
h
.
gatewayService
.
ResolveChannelMappingAndRestrict
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
reqModel
)
if
h
.
errorPassthroughService
!=
nil
{
service
.
BindErrorPassthroughService
(
c
,
h
.
errorPassthroughService
)
}
...
...
@@ -183,7 +186,11 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
forwardStart
:=
time
.
Now
()
defaultMappedModel
:=
resolveOpenAIForwardDefaultMappedModel
(
apiKey
,
c
.
GetString
(
"openai_chat_completions_fallback_model"
))
result
,
err
:=
h
.
gatewayService
.
ForwardAsChatCompletions
(
c
.
Request
.
Context
(),
c
,
account
,
body
,
promptCacheKey
,
defaultMappedModel
)
forwardBody
:=
body
if
channelMapping
.
Mapped
{
forwardBody
=
h
.
gatewayService
.
ReplaceModelInBody
(
body
,
channelMapping
.
MappedModel
)
}
result
,
err
:=
h
.
gatewayService
.
ForwardAsChatCompletions
(
c
.
Request
.
Context
(),
c
,
account
,
forwardBody
,
promptCacheKey
,
defaultMappedModel
)
forwardDurationMs
:=
time
.
Since
(
forwardStart
)
.
Milliseconds
()
if
accountReleaseFunc
!=
nil
{
...
...
@@ -257,16 +264,17 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
h
.
submitUsageRecordTask
(
func
(
ctx
context
.
Context
)
{
if
err
:=
h
.
gatewayService
.
RecordUsage
(
ctx
,
&
service
.
OpenAIRecordUsageInput
{
Result
:
result
,
APIKey
:
apiKey
,
User
:
apiKey
.
User
,
Account
:
account
,
Subscription
:
subscription
,
InboundEndpoint
:
GetInboundEndpoint
(
c
),
UpstreamEndpoint
:
GetUpstreamEndpoint
(
c
,
account
.
Platform
),
UserAgent
:
userAgent
,
IPAddress
:
clientIP
,
APIKeyService
:
h
.
apiKeyService
,
Result
:
result
,
APIKey
:
apiKey
,
User
:
apiKey
.
User
,
Account
:
account
,
Subscription
:
subscription
,
InboundEndpoint
:
GetInboundEndpoint
(
c
),
UpstreamEndpoint
:
GetUpstreamEndpoint
(
c
,
account
.
Platform
),
UserAgent
:
userAgent
,
IPAddress
:
clientIP
,
APIKeyService
:
h
.
apiKeyService
,
ChannelUsageFields
:
channelMapping
.
ToUsageFields
(
reqModel
,
result
.
UpstreamModel
),
});
err
!=
nil
{
logger
.
L
()
.
With
(
zap
.
String
(
"component"
,
"handler.openai_gateway.chat_completions"
),
...
...
backend/internal/handler/openai_gateway_handler.go
View file @
7b83d6e7
...
...
@@ -185,6 +185,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
setOpsRequestContext
(
c
,
reqModel
,
reqStream
,
body
)
setOpsEndpointContext
(
c
,
""
,
int16
(
service
.
RequestTypeFromLegacy
(
reqStream
,
false
)))
// 解析渠道级模型映射
channelMapping
,
_
:=
h
.
gatewayService
.
ResolveChannelMappingAndRestrict
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
reqModel
)
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
if
!
h
.
validateFunctionCallOutputRequest
(
c
,
body
,
reqLog
)
{
return
...
...
@@ -284,7 +287,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// Forward request
service
.
SetOpsLatencyMs
(
c
,
service
.
OpsRoutingLatencyMsKey
,
time
.
Since
(
routingStart
)
.
Milliseconds
())
forwardStart
:=
time
.
Now
()
result
,
err
:=
h
.
gatewayService
.
Forward
(
c
.
Request
.
Context
(),
c
,
account
,
body
)
// 应用渠道模型映射到请求体
forwardBody
:=
body
if
channelMapping
.
Mapped
{
forwardBody
=
h
.
gatewayService
.
ReplaceModelInBody
(
body
,
channelMapping
.
MappedModel
)
}
result
,
err
:=
h
.
gatewayService
.
Forward
(
c
.
Request
.
Context
(),
c
,
account
,
forwardBody
)
forwardDurationMs
:=
time
.
Since
(
forwardStart
)
.
Milliseconds
()
if
accountReleaseFunc
!=
nil
{
accountReleaseFunc
()
...
...
@@ -379,6 +387,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
IPAddress
:
clientIP
,
RequestPayloadHash
:
requestPayloadHash
,
APIKeyService
:
h
.
apiKeyService
,
ChannelUsageFields
:
channelMapping
.
ToUsageFields
(
reqModel
,
result
.
UpstreamModel
),
});
err
!=
nil
{
logger
.
L
()
.
With
(
zap
.
String
(
"component"
,
"handler.openai_gateway.responses"
),
...
...
@@ -549,6 +558,9 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
setOpsRequestContext
(
c
,
reqModel
,
reqStream
,
body
)
setOpsEndpointContext
(
c
,
""
,
int16
(
service
.
RequestTypeFromLegacy
(
reqStream
,
false
)))
// 解析渠道级模型映射
channelMappingMsg
,
_
:=
h
.
gatewayService
.
ResolveChannelMappingAndRestrict
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
reqModel
)
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
if
h
.
errorPassthroughService
!=
nil
{
service
.
BindErrorPassthroughService
(
c
,
h
.
errorPassthroughService
)
...
...
@@ -673,7 +685,12 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
// Forward 层需要始终拿到 group 默认映射模型,这样未命中账号级映射的
// Claude 兼容模型才不会在后续 Codex 规范化中意外退化到 gpt-5.1。
defaultMappedModel
:=
resolveOpenAIForwardDefaultMappedModel
(
apiKey
,
c
.
GetString
(
"openai_messages_fallback_model"
))
result
,
err
:=
h
.
gatewayService
.
ForwardAsAnthropic
(
c
.
Request
.
Context
(),
c
,
account
,
body
,
promptCacheKey
,
defaultMappedModel
)
// 应用渠道模型映射到请求体
forwardBody
:=
body
if
channelMappingMsg
.
Mapped
{
forwardBody
=
h
.
gatewayService
.
ReplaceModelInBody
(
body
,
channelMappingMsg
.
MappedModel
)
}
result
,
err
:=
h
.
gatewayService
.
ForwardAsAnthropic
(
c
.
Request
.
Context
(),
c
,
account
,
forwardBody
,
promptCacheKey
,
defaultMappedModel
)
forwardDurationMs
:=
time
.
Since
(
forwardStart
)
.
Milliseconds
()
if
accountReleaseFunc
!=
nil
{
...
...
@@ -759,6 +776,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
IPAddress
:
clientIP
,
RequestPayloadHash
:
requestPayloadHash
,
APIKeyService
:
h
.
apiKeyService
,
ChannelUsageFields
:
channelMappingMsg
.
ToUsageFields
(
reqModel
,
result
.
UpstreamModel
),
});
err
!=
nil
{
logger
.
L
()
.
With
(
zap
.
String
(
"component"
,
"handler.openai_gateway.messages"
),
...
...
@@ -1101,6 +1119,9 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
setOpsRequestContext
(
c
,
reqModel
,
true
,
firstMessage
)
setOpsEndpointContext
(
c
,
""
,
int16
(
service
.
RequestTypeWSV2
))
// 解析渠道级模型映射
channelMappingWS
,
_
:=
h
.
gatewayService
.
ResolveChannelMappingAndRestrict
(
ctx
,
apiKey
.
GroupID
,
reqModel
)
var
currentUserRelease
func
()
var
currentAccountRelease
func
()
releaseTurnSlots
:=
func
()
{
...
...
@@ -1259,6 +1280,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
IPAddress
:
clientIP
,
RequestPayloadHash
:
service
.
HashUsageRequestPayload
(
firstMessage
),
APIKeyService
:
h
.
apiKeyService
,
ChannelUsageFields
:
channelMappingWS
.
ToUsageFields
(
reqModel
,
result
.
UpstreamModel
),
});
err
!=
nil
{
reqLog
.
Error
(
"openai.websocket_record_usage_failed"
,
zap
.
Int64
(
"account_id"
,
account
.
ID
),
...
...
@@ -1270,7 +1292,13 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
},
}
if
err
:=
h
.
gatewayService
.
ProxyResponsesWebSocketFromClient
(
ctx
,
c
,
wsConn
,
account
,
token
,
firstMessage
,
hooks
);
err
!=
nil
{
// 应用渠道模型映射到 WebSocket 首条消息
wsFirstMessage
:=
firstMessage
if
channelMappingWS
.
Mapped
{
wsFirstMessage
=
h
.
gatewayService
.
ReplaceModelInBody
(
firstMessage
,
channelMappingWS
.
MappedModel
)
}
if
err
:=
h
.
gatewayService
.
ProxyResponsesWebSocketFromClient
(
ctx
,
c
,
wsConn
,
account
,
token
,
wsFirstMessage
,
hooks
);
err
!=
nil
{
h
.
gatewayService
.
ReportOpenAIAccountScheduleResult
(
account
.
ID
,
false
,
nil
)
closeStatus
,
closeReason
:=
summarizeWSCloseErrorForLog
(
err
)
reqLog
.
Warn
(
"openai.websocket_proxy_failed"
,
...
...
backend/internal/handler/sora_client_handler_test.go
View file @
7b83d6e7
...
...
@@ -2225,6 +2225,7 @@ func newMinimalGatewayService(accountRepo service.AccountRepository) *service.Ga
return
service
.
NewGatewayService
(
accountRepo
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
)
}
...
...
backend/internal/handler/sora_gateway_handler.go
View file @
7b83d6e7
...
...
@@ -30,6 +30,8 @@ import (
)
// SoraGatewayHandler handles Sora chat completions requests
//
// NOTE: Sora 平台计划后续移除,不集成渠道(Channel)功能。
type
SoraGatewayHandler
struct
{
gatewayService
*
service
.
GatewayService
soraGatewayService
*
service
.
SoraGatewayService
...
...
@@ -226,7 +228,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
var
lastFailoverHeaders
http
.
Header
for
{
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionHash
,
reqModel
,
failedAccountIDs
,
""
)
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionHash
,
reqModel
,
failedAccountIDs
,
""
,
int64
(
0
)
)
if
err
!=
nil
{
reqLog
.
Warn
(
"sora.account_select_failed"
,
zap
.
Error
(
err
),
...
...
backend/internal/handler/sora_gateway_handler_test.go
View file @
7b83d6e7
...
...
@@ -465,6 +465,8 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
nil
,
// digestStore
nil
,
// settingService
nil
,
// tlsFPProfileService
nil
,
// channelService
nil
,
// resolver
)
soraClient
:=
&
stubSoraClient
{
imageURLs
:
[]
string
{
"https://example.com/a.png"
}}
...
...
backend/internal/handler/wire.go
View file @
7b83d6e7
...
...
@@ -33,6 +33,7 @@ func ProvideAdminHandlers(
tlsFingerprintProfileHandler
*
admin
.
TLSFingerprintProfileHandler
,
apiKeyHandler
*
admin
.
AdminAPIKeyHandler
,
scheduledTestHandler
*
admin
.
ScheduledTestHandler
,
channelHandler
*
admin
.
ChannelHandler
,
)
*
AdminHandlers
{
return
&
AdminHandlers
{
Dashboard
:
dashboardHandler
,
...
...
@@ -59,6 +60,7 @@ func ProvideAdminHandlers(
TLSFingerprintProfile
:
tlsFingerprintProfileHandler
,
APIKey
:
apiKeyHandler
,
ScheduledTest
:
scheduledTestHandler
,
Channel
:
channelHandler
,
}
}
...
...
@@ -150,6 +152,7 @@ var ProviderSet = wire.NewSet(
admin
.
NewTLSFingerprintProfileHandler
,
admin
.
NewAdminAPIKeyHandler
,
admin
.
NewScheduledTestHandler
,
admin
.
NewChannelHandler
,
// AdminHandlers and Handlers constructors
ProvideAdminHandlers
,
...
...
backend/internal/pkg/antigravity/claude_types.go
View file @
7b83d6e7
...
...
@@ -125,6 +125,7 @@ type ClaudeUsage struct {
OutputTokens
int
`json:"output_tokens"`
CacheCreationInputTokens
int
`json:"cache_creation_input_tokens,omitempty"`
CacheReadInputTokens
int
`json:"cache_read_input_tokens,omitempty"`
ImageOutputTokens
int
`json:"image_output_tokens,omitempty"`
}
// ClaudeError Claude 错误响应
...
...
backend/internal/pkg/antigravity/gemini_types.go
View file @
7b83d6e7
...
...
@@ -149,13 +149,31 @@ type GeminiCandidate struct {
GroundingMetadata
*
GeminiGroundingMetadata
`json:"groundingMetadata,omitempty"`
}
// GeminiTokenDetail Gemini token 详情(按模态分类)
type
GeminiTokenDetail
struct
{
Modality
string
`json:"modality"`
TokenCount
int
`json:"tokenCount"`
}
// GeminiUsageMetadata Gemini 用量元数据
type
GeminiUsageMetadata
struct
{
PromptTokenCount
int
`json:"promptTokenCount,omitempty"`
CandidatesTokenCount
int
`json:"candidatesTokenCount,omitempty"`
CachedContentTokenCount
int
`json:"cachedContentTokenCount,omitempty"`
TotalTokenCount
int
`json:"totalTokenCount,omitempty"`
ThoughtsTokenCount
int
`json:"thoughtsTokenCount,omitempty"`
// thinking tokens(按输出价格计费)
PromptTokenCount
int
`json:"promptTokenCount,omitempty"`
CandidatesTokenCount
int
`json:"candidatesTokenCount,omitempty"`
CachedContentTokenCount
int
`json:"cachedContentTokenCount,omitempty"`
TotalTokenCount
int
`json:"totalTokenCount,omitempty"`
ThoughtsTokenCount
int
`json:"thoughtsTokenCount,omitempty"`
// thinking tokens(按输出价格计费)
CandidatesTokensDetails
[]
GeminiTokenDetail
`json:"candidatesTokensDetails,omitempty"`
PromptTokensDetails
[]
GeminiTokenDetail
`json:"promptTokensDetails,omitempty"`
}
// ImageOutputTokens 从 CandidatesTokensDetails 中提取 IMAGE 模态的 token 数
func
(
m
*
GeminiUsageMetadata
)
ImageOutputTokens
()
int
{
for
_
,
d
:=
range
m
.
CandidatesTokensDetails
{
if
d
.
Modality
==
"IMAGE"
{
return
d
.
TokenCount
}
}
return
0
}
// GeminiGroundingMetadata Gemini grounding 元数据(Web Search)
...
...
backend/internal/pkg/antigravity/response_transformer.go
View file @
7b83d6e7
...
...
@@ -284,6 +284,7 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon
usage
.
InputTokens
=
geminiResp
.
UsageMetadata
.
PromptTokenCount
-
cached
usage
.
OutputTokens
=
geminiResp
.
UsageMetadata
.
CandidatesTokenCount
+
geminiResp
.
UsageMetadata
.
ThoughtsTokenCount
usage
.
CacheReadInputTokens
=
cached
usage
.
ImageOutputTokens
=
geminiResp
.
UsageMetadata
.
ImageOutputTokens
()
}
// 生成响应 ID
...
...
backend/internal/pkg/antigravity/stream_transformer.go
View file @
7b83d6e7
...
...
@@ -32,9 +32,10 @@ type StreamingProcessor struct {
groundingChunks
[]
GeminiGroundingChunk
// 累计 usage
inputTokens
int
outputTokens
int
cacheReadTokens
int
inputTokens
int
outputTokens
int
cacheReadTokens
int
imageOutputTokens
int
}
// NewStreamingProcessor 创建流式响应处理器
...
...
@@ -87,6 +88,7 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
p
.
inputTokens
=
geminiResp
.
UsageMetadata
.
PromptTokenCount
-
cached
p
.
outputTokens
=
geminiResp
.
UsageMetadata
.
CandidatesTokenCount
+
geminiResp
.
UsageMetadata
.
ThoughtsTokenCount
p
.
cacheReadTokens
=
cached
p
.
imageOutputTokens
=
geminiResp
.
UsageMetadata
.
ImageOutputTokens
()
}
// 处理 parts
...
...
@@ -127,6 +129,7 @@ func (p *StreamingProcessor) Finish() ([]byte, *ClaudeUsage) {
InputTokens
:
p
.
inputTokens
,
OutputTokens
:
p
.
outputTokens
,
CacheReadInputTokens
:
p
.
cacheReadTokens
,
ImageOutputTokens
:
p
.
imageOutputTokens
,
}
if
!
p
.
messageStartSent
{
...
...
@@ -158,6 +161,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
usage
.
InputTokens
=
v1Resp
.
Response
.
UsageMetadata
.
PromptTokenCount
-
cached
usage
.
OutputTokens
=
v1Resp
.
Response
.
UsageMetadata
.
CandidatesTokenCount
+
v1Resp
.
Response
.
UsageMetadata
.
ThoughtsTokenCount
usage
.
CacheReadInputTokens
=
cached
usage
.
ImageOutputTokens
=
v1Resp
.
Response
.
UsageMetadata
.
ImageOutputTokens
()
}
responseID
:=
v1Resp
.
ResponseID
...
...
@@ -485,6 +489,7 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
InputTokens
:
p
.
inputTokens
,
OutputTokens
:
p
.
outputTokens
,
CacheReadInputTokens
:
p
.
cacheReadTokens
,
ImageOutputTokens
:
p
.
imageOutputTokens
,
}
deltaEvent
:=
map
[
string
]
any
{
...
...
backend/internal/pkg/usagestats/usage_log_types.go
View file @
7b83d6e7
...
...
@@ -175,6 +175,13 @@ type UserBreakdownDimension struct {
ModelType
string
// "requested", "upstream", or "mapping"
Endpoint
string
// filter by endpoint value (non-empty to enable)
EndpointType
string
// "inbound", "upstream", or "path"
// Additional filter conditions
UserID
int64
// filter by user_id (>0 to enable)
APIKeyID
int64
// filter by api_key_id (>0 to enable)
AccountID
int64
// filter by account_id (>0 to enable)
RequestType
*
int16
// filter by request_type (non-nil to enable)
Stream
*
bool
// filter by stream flag (non-nil to enable)
BillingType
*
int8
// filter by billing_type (non-nil to enable)
}
// APIKeyUsageTrendPoint represents API key usage trend data point
...
...
@@ -230,6 +237,7 @@ type UsageLogFilters struct {
RequestType
*
int16
Stream
*
bool
BillingType
*
int8
BillingMode
string
StartTime
*
time
.
Time
EndTime
*
time
.
Time
// ExactTotal requests exact COUNT(*) for pagination. Default false for fast large-table paging.
...
...
backend/internal/repository/channel_repo.go
0 → 100644
View file @
7b83d6e7
package
repository
import
(
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
)
type
channelRepository
struct
{
db
*
sql
.
DB
}
// NewChannelRepository 创建渠道数据访问实例
func
NewChannelRepository
(
db
*
sql
.
DB
)
service
.
ChannelRepository
{
return
&
channelRepository
{
db
:
db
}
}
// runInTx 在事务中执行 fn,成功 commit,失败 rollback。
func
(
r
*
channelRepository
)
runInTx
(
ctx
context
.
Context
,
fn
func
(
tx
*
sql
.
Tx
)
error
)
error
{
tx
,
err
:=
r
.
db
.
BeginTx
(
ctx
,
nil
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"begin tx: %w"
,
err
)
}
defer
func
()
{
_
=
tx
.
Rollback
()
}()
if
err
:=
fn
(
tx
);
err
!=
nil
{
return
err
}
return
tx
.
Commit
()
}
func
(
r
*
channelRepository
)
Create
(
ctx
context
.
Context
,
channel
*
service
.
Channel
)
error
{
return
r
.
runInTx
(
ctx
,
func
(
tx
*
sql
.
Tx
)
error
{
modelMappingJSON
,
err
:=
marshalModelMapping
(
channel
.
ModelMapping
)
if
err
!=
nil
{
return
err
}
err
=
tx
.
QueryRowContext
(
ctx
,
`INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models) VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id, created_at, updated_at`
,
channel
.
Name
,
channel
.
Description
,
channel
.
Status
,
modelMappingJSON
,
channel
.
BillingModelSource
,
channel
.
RestrictModels
,
)
.
Scan
(
&
channel
.
ID
,
&
channel
.
CreatedAt
,
&
channel
.
UpdatedAt
)
if
err
!=
nil
{
if
isUniqueViolation
(
err
)
{
return
service
.
ErrChannelExists
}
return
fmt
.
Errorf
(
"insert channel: %w"
,
err
)
}
// 设置分组关联
if
len
(
channel
.
GroupIDs
)
>
0
{
if
err
:=
setGroupIDsTx
(
ctx
,
tx
,
channel
.
ID
,
channel
.
GroupIDs
);
err
!=
nil
{
return
err
}
}
// 设置模型定价
if
len
(
channel
.
ModelPricing
)
>
0
{
if
err
:=
replaceModelPricingTx
(
ctx
,
tx
,
channel
.
ID
,
channel
.
ModelPricing
);
err
!=
nil
{
return
err
}
}
return
nil
})
}
func
(
r
*
channelRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
Channel
,
error
)
{
ch
:=
&
service
.
Channel
{}
var
modelMappingJSON
[]
byte
err
:=
r
.
db
.
QueryRowContext
(
ctx
,
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, created_at, updated_at
FROM channels WHERE id = $1`
,
id
,
)
.
Scan
(
&
ch
.
ID
,
&
ch
.
Name
,
&
ch
.
Description
,
&
ch
.
Status
,
&
modelMappingJSON
,
&
ch
.
BillingModelSource
,
&
ch
.
RestrictModels
,
&
ch
.
CreatedAt
,
&
ch
.
UpdatedAt
)
if
err
==
sql
.
ErrNoRows
{
return
nil
,
service
.
ErrChannelNotFound
}
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get channel: %w"
,
err
)
}
ch
.
ModelMapping
=
unmarshalModelMapping
(
modelMappingJSON
)
groupIDs
,
err
:=
r
.
GetGroupIDs
(
ctx
,
id
)
if
err
!=
nil
{
return
nil
,
err
}
ch
.
GroupIDs
=
groupIDs
pricing
,
err
:=
r
.
ListModelPricing
(
ctx
,
id
)
if
err
!=
nil
{
return
nil
,
err
}
ch
.
ModelPricing
=
pricing
return
ch
,
nil
}
func
(
r
*
channelRepository
)
Update
(
ctx
context
.
Context
,
channel
*
service
.
Channel
)
error
{
return
r
.
runInTx
(
ctx
,
func
(
tx
*
sql
.
Tx
)
error
{
modelMappingJSON
,
err
:=
marshalModelMapping
(
channel
.
ModelMapping
)
if
err
!=
nil
{
return
err
}
result
,
err
:=
tx
.
ExecContext
(
ctx
,
`UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, updated_at = NOW()
WHERE id = $7`
,
channel
.
Name
,
channel
.
Description
,
channel
.
Status
,
modelMappingJSON
,
channel
.
BillingModelSource
,
channel
.
RestrictModels
,
channel
.
ID
,
)
if
err
!=
nil
{
if
isUniqueViolation
(
err
)
{
return
service
.
ErrChannelExists
}
return
fmt
.
Errorf
(
"update channel: %w"
,
err
)
}
rows
,
_
:=
result
.
RowsAffected
()
if
rows
==
0
{
return
service
.
ErrChannelNotFound
}
// 更新分组关联
if
channel
.
GroupIDs
!=
nil
{
if
err
:=
setGroupIDsTx
(
ctx
,
tx
,
channel
.
ID
,
channel
.
GroupIDs
);
err
!=
nil
{
return
err
}
}
// 更新模型定价
if
channel
.
ModelPricing
!=
nil
{
if
err
:=
replaceModelPricingTx
(
ctx
,
tx
,
channel
.
ID
,
channel
.
ModelPricing
);
err
!=
nil
{
return
err
}
}
return
nil
})
}
func
(
r
*
channelRepository
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
result
,
err
:=
r
.
db
.
ExecContext
(
ctx
,
`DELETE FROM channels WHERE id = $1`
,
id
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"delete channel: %w"
,
err
)
}
rows
,
_
:=
result
.
RowsAffected
()
if
rows
==
0
{
return
service
.
ErrChannelNotFound
}
return
nil
}
func
(
r
*
channelRepository
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
status
,
search
string
)
([]
service
.
Channel
,
*
pagination
.
PaginationResult
,
error
)
{
where
:=
[]
string
{
"1=1"
}
args
:=
[]
any
{}
argIdx
:=
1
if
status
!=
""
{
where
=
append
(
where
,
fmt
.
Sprintf
(
"c.status = $%d"
,
argIdx
))
args
=
append
(
args
,
status
)
argIdx
++
}
if
search
!=
""
{
where
=
append
(
where
,
fmt
.
Sprintf
(
"(c.name ILIKE $%d OR c.description ILIKE $%d)"
,
argIdx
,
argIdx
))
args
=
append
(
args
,
"%"
+
escapeLike
(
search
)
+
"%"
)
argIdx
++
}
whereClause
:=
strings
.
Join
(
where
,
" AND "
)
// 计数
var
total
int64
countQuery
:=
fmt
.
Sprintf
(
"SELECT COUNT(*) FROM channels c WHERE %s"
,
whereClause
)
if
err
:=
r
.
db
.
QueryRowContext
(
ctx
,
countQuery
,
args
...
)
.
Scan
(
&
total
);
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"count channels: %w"
,
err
)
}
pageSize
:=
params
.
Limit
()
// 约束在 [1, 100]
page
:=
params
.
Page
if
page
<
1
{
page
=
1
}
offset
:=
(
page
-
1
)
*
pageSize
// 查询 channel 列表
dataQuery
:=
fmt
.
Sprintf
(
`SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.created_at, c.updated_at
FROM channels c WHERE %s ORDER BY c.id ASC LIMIT $%d OFFSET $%d`
,
whereClause
,
argIdx
,
argIdx
+
1
,
)
args
=
append
(
args
,
pageSize
,
offset
)
rows
,
err
:=
r
.
db
.
QueryContext
(
ctx
,
dataQuery
,
args
...
)
if
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"query channels: %w"
,
err
)
}
defer
func
()
{
_
=
rows
.
Close
()
}()
var
channels
[]
service
.
Channel
var
channelIDs
[]
int64
for
rows
.
Next
()
{
var
ch
service
.
Channel
var
modelMappingJSON
[]
byte
if
err
:=
rows
.
Scan
(
&
ch
.
ID
,
&
ch
.
Name
,
&
ch
.
Description
,
&
ch
.
Status
,
&
modelMappingJSON
,
&
ch
.
BillingModelSource
,
&
ch
.
RestrictModels
,
&
ch
.
CreatedAt
,
&
ch
.
UpdatedAt
);
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"scan channel: %w"
,
err
)
}
ch
.
ModelMapping
=
unmarshalModelMapping
(
modelMappingJSON
)
channels
=
append
(
channels
,
ch
)
channelIDs
=
append
(
channelIDs
,
ch
.
ID
)
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"iterate channels: %w"
,
err
)
}
// 批量加载分组 ID 和模型定价(避免 N+1)
if
len
(
channelIDs
)
>
0
{
groupMap
,
err
:=
r
.
batchLoadGroupIDs
(
ctx
,
channelIDs
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
pricingMap
,
err
:=
r
.
batchLoadModelPricing
(
ctx
,
channelIDs
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
for
i
:=
range
channels
{
channels
[
i
]
.
GroupIDs
=
groupMap
[
channels
[
i
]
.
ID
]
channels
[
i
]
.
ModelPricing
=
pricingMap
[
channels
[
i
]
.
ID
]
}
}
pages
:=
0
if
total
>
0
{
pages
=
int
((
total
+
int64
(
pageSize
)
-
1
)
/
int64
(
pageSize
))
}
paginationResult
:=
&
pagination
.
PaginationResult
{
Total
:
total
,
Page
:
page
,
PageSize
:
pageSize
,
Pages
:
pages
,
}
return
channels
,
paginationResult
,
nil
}
func
(
r
*
channelRepository
)
ListAll
(
ctx
context
.
Context
)
([]
service
.
Channel
,
error
)
{
rows
,
err
:=
r
.
db
.
QueryContext
(
ctx
,
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, created_at, updated_at FROM channels ORDER BY id`
,
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query all channels: %w"
,
err
)
}
defer
func
()
{
_
=
rows
.
Close
()
}()
var
channels
[]
service
.
Channel
var
channelIDs
[]
int64
for
rows
.
Next
()
{
var
ch
service
.
Channel
var
modelMappingJSON
[]
byte
if
err
:=
rows
.
Scan
(
&
ch
.
ID
,
&
ch
.
Name
,
&
ch
.
Description
,
&
ch
.
Status
,
&
modelMappingJSON
,
&
ch
.
BillingModelSource
,
&
ch
.
RestrictModels
,
&
ch
.
CreatedAt
,
&
ch
.
UpdatedAt
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"scan channel: %w"
,
err
)
}
ch
.
ModelMapping
=
unmarshalModelMapping
(
modelMappingJSON
)
channels
=
append
(
channels
,
ch
)
channelIDs
=
append
(
channelIDs
,
ch
.
ID
)
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"iterate channels: %w"
,
err
)
}
if
len
(
channelIDs
)
==
0
{
return
channels
,
nil
}
// 批量加载分组 ID
groupMap
,
err
:=
r
.
batchLoadGroupIDs
(
ctx
,
channelIDs
)
if
err
!=
nil
{
return
nil
,
err
}
// 批量加载模型定价
pricingMap
,
err
:=
r
.
batchLoadModelPricing
(
ctx
,
channelIDs
)
if
err
!=
nil
{
return
nil
,
err
}
for
i
:=
range
channels
{
channels
[
i
]
.
GroupIDs
=
groupMap
[
channels
[
i
]
.
ID
]
channels
[
i
]
.
ModelPricing
=
pricingMap
[
channels
[
i
]
.
ID
]
}
return
channels
,
nil
}
// --- 批量加载辅助方法 ---
// batchLoadGroupIDs 批量加载多个渠道的分组 ID
func
(
r
*
channelRepository
)
batchLoadGroupIDs
(
ctx
context
.
Context
,
channelIDs
[]
int64
)
(
map
[
int64
][]
int64
,
error
)
{
rows
,
err
:=
r
.
db
.
QueryContext
(
ctx
,
`SELECT channel_id, group_id FROM channel_groups
WHERE channel_id = ANY($1) ORDER BY channel_id, group_id`
,
pq
.
Array
(
channelIDs
),
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"batch load group ids: %w"
,
err
)
}
defer
func
()
{
_
=
rows
.
Close
()
}()
groupMap
:=
make
(
map
[
int64
][]
int64
,
len
(
channelIDs
))
for
rows
.
Next
()
{
var
channelID
,
groupID
int64
if
err
:=
rows
.
Scan
(
&
channelID
,
&
groupID
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"scan group id: %w"
,
err
)
}
groupMap
[
channelID
]
=
append
(
groupMap
[
channelID
],
groupID
)
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"iterate group ids: %w"
,
err
)
}
return
groupMap
,
nil
}
func
(
r
*
channelRepository
)
ExistsByName
(
ctx
context
.
Context
,
name
string
)
(
bool
,
error
)
{
var
exists
bool
err
:=
r
.
db
.
QueryRowContext
(
ctx
,
`SELECT EXISTS(SELECT 1 FROM channels WHERE name = $1)`
,
name
,
)
.
Scan
(
&
exists
)
return
exists
,
err
}
func
(
r
*
channelRepository
)
ExistsByNameExcluding
(
ctx
context
.
Context
,
name
string
,
excludeID
int64
)
(
bool
,
error
)
{
var
exists
bool
err
:=
r
.
db
.
QueryRowContext
(
ctx
,
`SELECT EXISTS(SELECT 1 FROM channels WHERE name = $1 AND id != $2)`
,
name
,
excludeID
,
)
.
Scan
(
&
exists
)
return
exists
,
err
}
// --- 分组关联 ---
func
(
r
*
channelRepository
)
GetGroupIDs
(
ctx
context
.
Context
,
channelID
int64
)
([]
int64
,
error
)
{
rows
,
err
:=
r
.
db
.
QueryContext
(
ctx
,
`SELECT group_id FROM channel_groups WHERE channel_id = $1 ORDER BY group_id`
,
channelID
,
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get group ids: %w"
,
err
)
}
defer
func
()
{
_
=
rows
.
Close
()
}()
var
ids
[]
int64
for
rows
.
Next
()
{
var
id
int64
if
err
:=
rows
.
Scan
(
&
id
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"scan group id: %w"
,
err
)
}
ids
=
append
(
ids
,
id
)
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"iterate group ids: %w"
,
err
)
}
return
ids
,
nil
}
func
(
r
*
channelRepository
)
SetGroupIDs
(
ctx
context
.
Context
,
channelID
int64
,
groupIDs
[]
int64
)
error
{
return
setGroupIDsTx
(
ctx
,
r
.
db
,
channelID
,
groupIDs
)
}
func
(
r
*
channelRepository
)
GetChannelIDByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
var
channelID
int64
err
:=
r
.
db
.
QueryRowContext
(
ctx
,
`SELECT channel_id FROM channel_groups WHERE group_id = $1`
,
groupID
,
)
.
Scan
(
&
channelID
)
if
err
==
sql
.
ErrNoRows
{
return
0
,
nil
}
return
channelID
,
err
}
func
(
r
*
channelRepository
)
GetGroupsInOtherChannels
(
ctx
context
.
Context
,
channelID
int64
,
groupIDs
[]
int64
)
([]
int64
,
error
)
{
if
len
(
groupIDs
)
==
0
{
return
nil
,
nil
}
rows
,
err
:=
r
.
db
.
QueryContext
(
ctx
,
`SELECT group_id FROM channel_groups WHERE group_id = ANY($1) AND channel_id != $2`
,
pq
.
Array
(
groupIDs
),
channelID
,
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get groups in other channels: %w"
,
err
)
}
defer
func
()
{
_
=
rows
.
Close
()
}()
var
conflicting
[]
int64
for
rows
.
Next
()
{
var
id
int64
if
err
:=
rows
.
Scan
(
&
id
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"scan conflicting group id: %w"
,
err
)
}
conflicting
=
append
(
conflicting
,
id
)
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"iterate conflicting group ids: %w"
,
err
)
}
return
conflicting
,
nil
}
// marshalModelMapping 将 model mapping 序列化为嵌套 JSON 字节
// 格式:{"platform": {"src": "dst"}, ...}
func
marshalModelMapping
(
m
map
[
string
]
map
[
string
]
string
)
([]
byte
,
error
)
{
if
len
(
m
)
==
0
{
return
[]
byte
(
"{}"
),
nil
}
data
,
err
:=
json
.
Marshal
(
m
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"marshal model_mapping: %w"
,
err
)
}
return
data
,
nil
}
// unmarshalModelMapping 将 JSON 字节反序列化为嵌套 model mapping
func
unmarshalModelMapping
(
data
[]
byte
)
map
[
string
]
map
[
string
]
string
{
if
len
(
data
)
==
0
{
return
nil
}
var
m
map
[
string
]
map
[
string
]
string
if
err
:=
json
.
Unmarshal
(
data
,
&
m
);
err
!=
nil
{
return
nil
}
return
m
}
// GetGroupPlatforms 批量查询分组 ID 对应的平台
func
(
r
*
channelRepository
)
GetGroupPlatforms
(
ctx
context
.
Context
,
groupIDs
[]
int64
)
(
map
[
int64
]
string
,
error
)
{
if
len
(
groupIDs
)
==
0
{
return
make
(
map
[
int64
]
string
),
nil
}
rows
,
err
:=
r
.
db
.
QueryContext
(
ctx
,
`SELECT id, platform FROM groups WHERE id = ANY($1)`
,
pq
.
Array
(
groupIDs
),
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get group platforms: %w"
,
err
)
}
defer
rows
.
Close
()
//nolint:errcheck
result
:=
make
(
map
[
int64
]
string
,
len
(
groupIDs
))
for
rows
.
Next
()
{
var
id
int64
var
platform
string
if
err
:=
rows
.
Scan
(
&
id
,
&
platform
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"scan group platform: %w"
,
err
)
}
result
[
id
]
=
platform
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"iterate group platforms: %w"
,
err
)
}
return
result
,
nil
}
backend/internal/repository/channel_repo_pricing.go
0 → 100644
View file @
7b83d6e7
package
repository
import
(
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"strings"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
)
// --- 模型定价 ---
func
(
r
*
channelRepository
)
ListModelPricing
(
ctx
context
.
Context
,
channelID
int64
)
([]
service
.
ChannelModelPricing
,
error
)
{
rows
,
err
:=
r
.
db
.
QueryContext
(
ctx
,
`SELECT id, channel_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at
FROM channel_model_pricing WHERE channel_id = $1 ORDER BY id`
,
channelID
,
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"list model pricing: %w"
,
err
)
}
defer
func
()
{
_
=
rows
.
Close
()
}()
result
,
pricingIDs
,
err
:=
scanModelPricingRows
(
rows
)
if
err
!=
nil
{
return
nil
,
err
}
if
len
(
pricingIDs
)
>
0
{
intervalMap
,
err
:=
r
.
batchLoadIntervals
(
ctx
,
pricingIDs
)
if
err
!=
nil
{
return
nil
,
err
}
for
i
:=
range
result
{
result
[
i
]
.
Intervals
=
intervalMap
[
result
[
i
]
.
ID
]
}
}
return
result
,
nil
}
func
(
r
*
channelRepository
)
CreateModelPricing
(
ctx
context
.
Context
,
pricing
*
service
.
ChannelModelPricing
)
error
{
return
createModelPricingExec
(
ctx
,
r
.
db
,
pricing
)
}
func
(
r
*
channelRepository
)
UpdateModelPricing
(
ctx
context
.
Context
,
pricing
*
service
.
ChannelModelPricing
)
error
{
modelsJSON
,
err
:=
json
.
Marshal
(
pricing
.
Models
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"marshal models: %w"
,
err
)
}
billingMode
:=
pricing
.
BillingMode
if
billingMode
==
""
{
billingMode
=
service
.
BillingModeToken
}
result
,
err
:=
r
.
db
.
ExecContext
(
ctx
,
`UPDATE channel_model_pricing
SET models = $1, billing_mode = $2, input_price = $3, output_price = $4, cache_write_price = $5, cache_read_price = $6, image_output_price = $7, per_request_price = $8, platform = $9, updated_at = NOW()
WHERE id = $10`
,
modelsJSON
,
billingMode
,
pricing
.
InputPrice
,
pricing
.
OutputPrice
,
pricing
.
CacheWritePrice
,
pricing
.
CacheReadPrice
,
pricing
.
ImageOutputPrice
,
pricing
.
PerRequestPrice
,
pricing
.
Platform
,
pricing
.
ID
,
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"update model pricing: %w"
,
err
)
}
rows
,
_
:=
result
.
RowsAffected
()
if
rows
==
0
{
return
fmt
.
Errorf
(
"pricing entry not found: %d"
,
pricing
.
ID
)
}
return
nil
}
func
(
r
*
channelRepository
)
DeleteModelPricing
(
ctx
context
.
Context
,
id
int64
)
error
{
_
,
err
:=
r
.
db
.
ExecContext
(
ctx
,
`DELETE FROM channel_model_pricing WHERE id = $1`
,
id
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"delete model pricing: %w"
,
err
)
}
return
nil
}
func
(
r
*
channelRepository
)
ReplaceModelPricing
(
ctx
context
.
Context
,
channelID
int64
,
pricingList
[]
service
.
ChannelModelPricing
)
error
{
return
r
.
runInTx
(
ctx
,
func
(
tx
*
sql
.
Tx
)
error
{
return
replaceModelPricingTx
(
ctx
,
tx
,
channelID
,
pricingList
)
})
}
// --- 批量加载辅助方法 ---
// batchLoadModelPricing 批量加载多个渠道的模型定价(含区间)
func
(
r
*
channelRepository
)
batchLoadModelPricing
(
ctx
context
.
Context
,
channelIDs
[]
int64
)
(
map
[
int64
][]
service
.
ChannelModelPricing
,
error
)
{
rows
,
err
:=
r
.
db
.
QueryContext
(
ctx
,
`SELECT id, channel_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at
FROM channel_model_pricing WHERE channel_id = ANY($1) ORDER BY channel_id, id`
,
pq
.
Array
(
channelIDs
),
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"batch load model pricing: %w"
,
err
)
}
defer
func
()
{
_
=
rows
.
Close
()
}()
allPricing
,
allPricingIDs
,
err
:=
scanModelPricingRows
(
rows
)
if
err
!=
nil
{
return
nil
,
err
}
// 按 channelID 分组
pricingMap
:=
make
(
map
[
int64
][]
service
.
ChannelModelPricing
,
len
(
channelIDs
))
for
_
,
p
:=
range
allPricing
{
pricingMap
[
p
.
ChannelID
]
=
append
(
pricingMap
[
p
.
ChannelID
],
p
)
}
// 批量加载所有区间
if
len
(
allPricingIDs
)
>
0
{
intervalMap
,
err
:=
r
.
batchLoadIntervals
(
ctx
,
allPricingIDs
)
if
err
!=
nil
{
return
nil
,
err
}
for
chID
:=
range
pricingMap
{
for
i
:=
range
pricingMap
[
chID
]
{
pricingMap
[
chID
][
i
]
.
Intervals
=
intervalMap
[
pricingMap
[
chID
][
i
]
.
ID
]
}
}
}
return
pricingMap
,
nil
}
// batchLoadIntervals 批量加载多个定价条目的区间
func
(
r
*
channelRepository
)
batchLoadIntervals
(
ctx
context
.
Context
,
pricingIDs
[]
int64
)
(
map
[
int64
][]
service
.
PricingInterval
,
error
)
{
rows
,
err
:=
r
.
db
.
QueryContext
(
ctx
,
`SELECT id, pricing_id, min_tokens, max_tokens, tier_label,
input_price, output_price, cache_write_price, cache_read_price,
per_request_price, sort_order, created_at, updated_at
FROM channel_pricing_intervals
WHERE pricing_id = ANY($1) ORDER BY pricing_id, sort_order, id`
,
pq
.
Array
(
pricingIDs
),
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"batch load intervals: %w"
,
err
)
}
defer
func
()
{
_
=
rows
.
Close
()
}()
intervalMap
:=
make
(
map
[
int64
][]
service
.
PricingInterval
,
len
(
pricingIDs
))
for
rows
.
Next
()
{
var
iv
service
.
PricingInterval
if
err
:=
rows
.
Scan
(
&
iv
.
ID
,
&
iv
.
PricingID
,
&
iv
.
MinTokens
,
&
iv
.
MaxTokens
,
&
iv
.
TierLabel
,
&
iv
.
InputPrice
,
&
iv
.
OutputPrice
,
&
iv
.
CacheWritePrice
,
&
iv
.
CacheReadPrice
,
&
iv
.
PerRequestPrice
,
&
iv
.
SortOrder
,
&
iv
.
CreatedAt
,
&
iv
.
UpdatedAt
,
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"scan interval: %w"
,
err
)
}
intervalMap
[
iv
.
PricingID
]
=
append
(
intervalMap
[
iv
.
PricingID
],
iv
)
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"iterate intervals: %w"
,
err
)
}
return
intervalMap
,
nil
}
// --- 共享 scan 辅助 ---
// scanModelPricingRows 扫描 model pricing 行,返回结果列表和 ID 列表
func
scanModelPricingRows
(
rows
*
sql
.
Rows
)
([]
service
.
ChannelModelPricing
,
[]
int64
,
error
)
{
var
result
[]
service
.
ChannelModelPricing
var
pricingIDs
[]
int64
for
rows
.
Next
()
{
var
p
service
.
ChannelModelPricing
var
modelsJSON
[]
byte
if
err
:=
rows
.
Scan
(
&
p
.
ID
,
&
p
.
ChannelID
,
&
p
.
Platform
,
&
modelsJSON
,
&
p
.
BillingMode
,
&
p
.
InputPrice
,
&
p
.
OutputPrice
,
&
p
.
CacheWritePrice
,
&
p
.
CacheReadPrice
,
&
p
.
ImageOutputPrice
,
&
p
.
PerRequestPrice
,
&
p
.
CreatedAt
,
&
p
.
UpdatedAt
,
);
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"scan model pricing: %w"
,
err
)
}
if
err
:=
json
.
Unmarshal
(
modelsJSON
,
&
p
.
Models
);
err
!=
nil
{
p
.
Models
=
[]
string
{}
}
pricingIDs
=
append
(
pricingIDs
,
p
.
ID
)
result
=
append
(
result
,
p
)
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"iterate model pricing: %w"
,
err
)
}
return
result
,
pricingIDs
,
nil
}
// --- 事务内辅助方法 ---
// dbExec 是 *sql.DB 和 *sql.Tx 共享的最小 SQL 执行接口
type
dbExec
interface
{
ExecContext
(
ctx
context
.
Context
,
query
string
,
args
...
any
)
(
sql
.
Result
,
error
)
QueryContext
(
ctx
context
.
Context
,
query
string
,
args
...
any
)
(
*
sql
.
Rows
,
error
)
QueryRowContext
(
ctx
context
.
Context
,
query
string
,
args
...
any
)
*
sql
.
Row
}
func
setGroupIDsTx
(
ctx
context
.
Context
,
exec
dbExec
,
channelID
int64
,
groupIDs
[]
int64
)
error
{
if
_
,
err
:=
exec
.
ExecContext
(
ctx
,
`DELETE FROM channel_groups WHERE channel_id = $1`
,
channelID
);
err
!=
nil
{
return
fmt
.
Errorf
(
"delete old group associations: %w"
,
err
)
}
if
len
(
groupIDs
)
==
0
{
return
nil
}
_
,
err
:=
exec
.
ExecContext
(
ctx
,
`INSERT INTO channel_groups (channel_id, group_id)
SELECT $1, unnest($2::bigint[])`
,
channelID
,
pq
.
Array
(
groupIDs
),
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"insert group associations: %w"
,
err
)
}
return
nil
}
func
createModelPricingExec
(
ctx
context
.
Context
,
exec
dbExec
,
pricing
*
service
.
ChannelModelPricing
)
error
{
modelsJSON
,
err
:=
json
.
Marshal
(
pricing
.
Models
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"marshal models: %w"
,
err
)
}
billingMode
:=
pricing
.
BillingMode
if
billingMode
==
""
{
billingMode
=
service
.
BillingModeToken
}
platform
:=
pricing
.
Platform
if
platform
==
""
{
platform
=
"anthropic"
}
err
=
exec
.
QueryRowContext
(
ctx
,
`INSERT INTO channel_model_pricing (channel_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`
,
pricing
.
ChannelID
,
platform
,
modelsJSON
,
billingMode
,
pricing
.
InputPrice
,
pricing
.
OutputPrice
,
pricing
.
CacheWritePrice
,
pricing
.
CacheReadPrice
,
pricing
.
ImageOutputPrice
,
pricing
.
PerRequestPrice
,
)
.
Scan
(
&
pricing
.
ID
,
&
pricing
.
CreatedAt
,
&
pricing
.
UpdatedAt
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"insert model pricing: %w"
,
err
)
}
for
i
:=
range
pricing
.
Intervals
{
pricing
.
Intervals
[
i
]
.
PricingID
=
pricing
.
ID
if
err
:=
createIntervalExec
(
ctx
,
exec
,
&
pricing
.
Intervals
[
i
]);
err
!=
nil
{
return
err
}
}
return
nil
}
func
createIntervalExec
(
ctx
context
.
Context
,
exec
dbExec
,
iv
*
service
.
PricingInterval
)
error
{
return
exec
.
QueryRowContext
(
ctx
,
`INSERT INTO channel_pricing_intervals
(pricing_id, min_tokens, max_tokens, tier_label, input_price, output_price, cache_write_price, cache_read_price, per_request_price, sort_order)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`
,
iv
.
PricingID
,
iv
.
MinTokens
,
iv
.
MaxTokens
,
iv
.
TierLabel
,
iv
.
InputPrice
,
iv
.
OutputPrice
,
iv
.
CacheWritePrice
,
iv
.
CacheReadPrice
,
iv
.
PerRequestPrice
,
iv
.
SortOrder
,
)
.
Scan
(
&
iv
.
ID
,
&
iv
.
CreatedAt
,
&
iv
.
UpdatedAt
)
}
func
replaceModelPricingTx
(
ctx
context
.
Context
,
exec
dbExec
,
channelID
int64
,
pricingList
[]
service
.
ChannelModelPricing
)
error
{
if
_
,
err
:=
exec
.
ExecContext
(
ctx
,
`DELETE FROM channel_model_pricing WHERE channel_id = $1`
,
channelID
);
err
!=
nil
{
return
fmt
.
Errorf
(
"delete old model pricing: %w"
,
err
)
}
for
i
:=
range
pricingList
{
pricingList
[
i
]
.
ChannelID
=
channelID
if
err
:=
createModelPricingExec
(
ctx
,
exec
,
&
pricingList
[
i
]);
err
!=
nil
{
return
fmt
.
Errorf
(
"insert model pricing: %w"
,
err
)
}
}
return
nil
}
// isUniqueViolation 检查 pq 唯一约束违反错误
func
isUniqueViolation
(
err
error
)
bool
{
var
pqErr
*
pq
.
Error
if
errors
.
As
(
err
,
&
pqErr
)
&&
pqErr
!=
nil
{
return
pqErr
.
Code
==
"23505"
}
return
false
}
// escapeLike 转义 LIKE/ILIKE 模式中的特殊字符
func
escapeLike
(
s
string
)
string
{
s
=
strings
.
ReplaceAll
(
s
,
`\`
,
`\\`
)
s
=
strings
.
ReplaceAll
(
s
,
`%`
,
`\%`
)
s
=
strings
.
ReplaceAll
(
s
,
`_`
,
`\_`
)
return
s
}
backend/internal/repository/channel_repo_test.go
0 → 100644
View file @
7b83d6e7
//go:build unit
package
repository
import
(
"encoding/json"
"errors"
"fmt"
"testing"
"github.com/lib/pq"
"github.com/stretchr/testify/require"
)
// --- marshalModelMapping ---
func
TestMarshalModelMapping
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
input
map
[
string
]
map
[
string
]
string
wantJSON
string
// expected JSON output (exact match)
}{
{
name
:
"empty map"
,
input
:
map
[
string
]
map
[
string
]
string
{},
wantJSON
:
"{}"
,
},
{
name
:
"nil map"
,
input
:
nil
,
wantJSON
:
"{}"
,
},
{
name
:
"populated map"
,
input
:
map
[
string
]
map
[
string
]
string
{
"openai"
:
{
"gpt-4"
:
"gpt-4-turbo"
},
},
},
{
name
:
"nested values"
,
input
:
map
[
string
]
map
[
string
]
string
{
"openai"
:
{
"*"
:
"gpt-5.4"
},
"anthropic"
:
{
"claude-old"
:
"claude-new"
},
},
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
result
,
err
:=
marshalModelMapping
(
tt
.
input
)
require
.
NoError
(
t
,
err
)
if
tt
.
wantJSON
!=
""
{
require
.
Equal
(
t
,
[]
byte
(
tt
.
wantJSON
),
result
)
}
else
{
// round-trip: unmarshal and compare with input
var
parsed
map
[
string
]
map
[
string
]
string
require
.
NoError
(
t
,
json
.
Unmarshal
(
result
,
&
parsed
))
require
.
Equal
(
t
,
tt
.
input
,
parsed
)
}
})
}
}
// --- unmarshalModelMapping ---
func
TestUnmarshalModelMapping
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
input
[]
byte
wantNil
bool
want
map
[
string
]
map
[
string
]
string
}{
{
name
:
"nil data"
,
input
:
nil
,
wantNil
:
true
,
},
{
name
:
"empty data"
,
input
:
[]
byte
{},
wantNil
:
true
,
},
{
name
:
"invalid JSON"
,
input
:
[]
byte
(
"not-json"
),
wantNil
:
true
,
},
{
name
:
"type error - number"
,
input
:
[]
byte
(
"42"
),
wantNil
:
true
,
},
{
name
:
"type error - array"
,
input
:
[]
byte
(
"[1,2,3]"
),
wantNil
:
true
,
},
{
name
:
"valid JSON"
,
input
:
[]
byte
(
`{"openai":{"gpt-4":"gpt-4-turbo"},"anthropic":{"old":"new"}}`
),
want
:
map
[
string
]
map
[
string
]
string
{
"openai"
:
{
"gpt-4"
:
"gpt-4-turbo"
},
"anthropic"
:
{
"old"
:
"new"
},
},
},
{
name
:
"empty object"
,
input
:
[]
byte
(
"{}"
),
want
:
map
[
string
]
map
[
string
]
string
{},
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
result
:=
unmarshalModelMapping
(
tt
.
input
)
if
tt
.
wantNil
{
require
.
Nil
(
t
,
result
)
}
else
{
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
tt
.
want
,
result
)
}
})
}
}
// --- escapeLike ---
func
TestEscapeLike
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
input
string
want
string
}{
{
name
:
"no special chars"
,
input
:
"hello"
,
want
:
"hello"
,
},
{
name
:
"backslash"
,
input
:
`a\b`
,
want
:
`a\\b`
,
},
{
name
:
"percent"
,
input
:
"50%"
,
want
:
`50\%`
,
},
{
name
:
"underscore"
,
input
:
"a_b"
,
want
:
`a\_b`
,
},
{
name
:
"all special chars"
,
input
:
`a\b%c_d`
,
want
:
`a\\b\%c\_d`
,
},
{
name
:
"empty string"
,
input
:
""
,
want
:
""
,
},
{
name
:
"consecutive special chars"
,
input
:
"%_%"
,
want
:
`\%\_\%`
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
tt
.
want
,
escapeLike
(
tt
.
input
))
})
}
}
// --- isUniqueViolation ---
func
TestIsUniqueViolation
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
err
error
want
bool
}{
{
name
:
"unique violation code 23505"
,
err
:
&
pq
.
Error
{
Code
:
"23505"
},
want
:
true
,
},
{
name
:
"different pq error code"
,
err
:
&
pq
.
Error
{
Code
:
"23503"
},
want
:
false
,
},
{
name
:
"non-pq error"
,
err
:
errors
.
New
(
"some generic error"
),
want
:
false
,
},
{
name
:
"typed nil pq.Error"
,
err
:
func
()
error
{
var
pqErr
*
pq
.
Error
return
pqErr
}(),
want
:
false
,
},
{
name
:
"bare nil"
,
err
:
nil
,
want
:
false
,
},
{
name
:
"wrapped pq error with 23505"
,
err
:
fmt
.
Errorf
(
"wrapped: %w"
,
&
pq
.
Error
{
Code
:
"23505"
}),
want
:
true
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
tt
.
want
,
isUniqueViolation
(
tt
.
err
))
})
}
}
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment