Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
3718d6dc
Unverified
Commit
3718d6dc
authored
Mar 15, 2026
by
IanShaw
Committed by
GitHub
Mar 15, 2026
Browse files
Merge branch 'Wei-Shaw:main' into fix/open-issues-cleanup
parents
90b38381
8321e4a6
Changes
38
Hide whitespace changes
Inline
Side-by-side
backend/internal/handler/dto/mappers.go
View file @
3718d6dc
...
@@ -523,6 +523,8 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
...
@@ -523,6 +523,8 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
Model
:
l
.
Model
,
Model
:
l
.
Model
,
ServiceTier
:
l
.
ServiceTier
,
ServiceTier
:
l
.
ServiceTier
,
ReasoningEffort
:
l
.
ReasoningEffort
,
ReasoningEffort
:
l
.
ReasoningEffort
,
InboundEndpoint
:
l
.
InboundEndpoint
,
UpstreamEndpoint
:
l
.
UpstreamEndpoint
,
GroupID
:
l
.
GroupID
,
GroupID
:
l
.
GroupID
,
SubscriptionID
:
l
.
SubscriptionID
,
SubscriptionID
:
l
.
SubscriptionID
,
InputTokens
:
l
.
InputTokens
,
InputTokens
:
l
.
InputTokens
,
...
...
backend/internal/handler/dto/mappers_usage_test.go
View file @
3718d6dc
...
@@ -76,10 +76,14 @@ func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
...
@@ -76,10 +76,14 @@ func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
t
.
Parallel
()
t
.
Parallel
()
serviceTier
:=
"priority"
serviceTier
:=
"priority"
inboundEndpoint
:=
"/v1/chat/completions"
upstreamEndpoint
:=
"/v1/responses"
log
:=
&
service
.
UsageLog
{
log
:=
&
service
.
UsageLog
{
RequestID
:
"req_3"
,
RequestID
:
"req_3"
,
Model
:
"gpt-5.4"
,
Model
:
"gpt-5.4"
,
ServiceTier
:
&
serviceTier
,
ServiceTier
:
&
serviceTier
,
InboundEndpoint
:
&
inboundEndpoint
,
UpstreamEndpoint
:
&
upstreamEndpoint
,
AccountRateMultiplier
:
f64Ptr
(
1.5
),
AccountRateMultiplier
:
f64Ptr
(
1.5
),
}
}
...
@@ -88,8 +92,16 @@ func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
...
@@ -88,8 +92,16 @@ func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
require
.
NotNil
(
t
,
userDTO
.
ServiceTier
)
require
.
NotNil
(
t
,
userDTO
.
ServiceTier
)
require
.
Equal
(
t
,
serviceTier
,
*
userDTO
.
ServiceTier
)
require
.
Equal
(
t
,
serviceTier
,
*
userDTO
.
ServiceTier
)
require
.
NotNil
(
t
,
userDTO
.
InboundEndpoint
)
require
.
Equal
(
t
,
inboundEndpoint
,
*
userDTO
.
InboundEndpoint
)
require
.
NotNil
(
t
,
userDTO
.
UpstreamEndpoint
)
require
.
Equal
(
t
,
upstreamEndpoint
,
*
userDTO
.
UpstreamEndpoint
)
require
.
NotNil
(
t
,
adminDTO
.
ServiceTier
)
require
.
NotNil
(
t
,
adminDTO
.
ServiceTier
)
require
.
Equal
(
t
,
serviceTier
,
*
adminDTO
.
ServiceTier
)
require
.
Equal
(
t
,
serviceTier
,
*
adminDTO
.
ServiceTier
)
require
.
NotNil
(
t
,
adminDTO
.
InboundEndpoint
)
require
.
Equal
(
t
,
inboundEndpoint
,
*
adminDTO
.
InboundEndpoint
)
require
.
NotNil
(
t
,
adminDTO
.
UpstreamEndpoint
)
require
.
Equal
(
t
,
upstreamEndpoint
,
*
adminDTO
.
UpstreamEndpoint
)
require
.
NotNil
(
t
,
adminDTO
.
AccountRateMultiplier
)
require
.
NotNil
(
t
,
adminDTO
.
AccountRateMultiplier
)
require
.
InDelta
(
t
,
1.5
,
*
adminDTO
.
AccountRateMultiplier
,
1e-12
)
require
.
InDelta
(
t
,
1.5
,
*
adminDTO
.
AccountRateMultiplier
,
1e-12
)
}
}
...
...
backend/internal/handler/dto/types.go
View file @
3718d6dc
...
@@ -334,9 +334,13 @@ type UsageLog struct {
...
@@ -334,9 +334,13 @@ type UsageLog struct {
Model
string
`json:"model"`
Model
string
`json:"model"`
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
ServiceTier
*
string
`json:"service_tier,omitempty"`
ServiceTier
*
string
`json:"service_tier,omitempty"`
// ReasoningEffort is the request's reasoning effort level
(OpenAI Responses API)
.
// ReasoningEffort is the request's reasoning effort level.
//
nil means not provided / not applicable
.
//
OpenAI: "low"/"medium"/"high"/"xhigh"; Claude: "low"/"medium"/"high"/"max"
.
ReasoningEffort
*
string
`json:"reasoning_effort,omitempty"`
ReasoningEffort
*
string
`json:"reasoning_effort,omitempty"`
// InboundEndpoint is the client-facing API endpoint path, e.g. /v1/chat/completions.
InboundEndpoint
*
string
`json:"inbound_endpoint,omitempty"`
// UpstreamEndpoint is the normalized upstream endpoint path, e.g. /v1/responses.
UpstreamEndpoint
*
string
`json:"upstream_endpoint,omitempty"`
GroupID
*
int64
`json:"group_id"`
GroupID
*
int64
`json:"group_id"`
SubscriptionID
*
int64
`json:"subscription_id"`
SubscriptionID
*
int64
`json:"subscription_id"`
...
...
backend/internal/handler/gateway_handler.go
View file @
3718d6dc
...
@@ -391,6 +391,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
...
@@ -391,6 +391,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if
fs
.
SwitchCount
>
0
{
if
fs
.
SwitchCount
>
0
{
requestCtx
=
service
.
WithAccountSwitchCount
(
requestCtx
,
fs
.
SwitchCount
,
h
.
metadataBridgeEnabled
())
requestCtx
=
service
.
WithAccountSwitchCount
(
requestCtx
,
fs
.
SwitchCount
,
h
.
metadataBridgeEnabled
())
}
}
// 记录 Forward 前已写入字节数,Forward 后若增加则说明 SSE 内容已发,禁止 failover
writerSizeBeforeForward
:=
c
.
Writer
.
Size
()
if
account
.
Platform
==
service
.
PlatformAntigravity
{
if
account
.
Platform
==
service
.
PlatformAntigravity
{
result
,
err
=
h
.
antigravityGatewayService
.
ForwardGemini
(
requestCtx
,
c
,
account
,
reqModel
,
"generateContent"
,
reqStream
,
body
,
hasBoundSession
)
result
,
err
=
h
.
antigravityGatewayService
.
ForwardGemini
(
requestCtx
,
c
,
account
,
reqModel
,
"generateContent"
,
reqStream
,
body
,
hasBoundSession
)
}
else
{
}
else
{
...
@@ -402,6 +404,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
...
@@ -402,6 +404,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if
err
!=
nil
{
if
err
!=
nil
{
var
failoverErr
*
service
.
UpstreamFailoverError
var
failoverErr
*
service
.
UpstreamFailoverError
if
errors
.
As
(
err
,
&
failoverErr
)
{
if
errors
.
As
(
err
,
&
failoverErr
)
{
// 流式内容已写入客户端,无法撤销,禁止 failover 以防止流拼接腐化
if
c
.
Writer
.
Size
()
!=
writerSizeBeforeForward
{
h
.
handleFailoverExhausted
(
c
,
failoverErr
,
service
.
PlatformGemini
,
true
)
return
}
action
:=
fs
.
HandleFailoverError
(
c
.
Request
.
Context
(),
h
.
gatewayService
,
account
.
ID
,
account
.
Platform
,
failoverErr
)
action
:=
fs
.
HandleFailoverError
(
c
.
Request
.
Context
(),
h
.
gatewayService
,
account
.
ID
,
account
.
Platform
,
failoverErr
)
switch
action
{
switch
action
{
case
FailoverContinue
:
case
FailoverContinue
:
...
@@ -436,6 +443,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
...
@@ -436,6 +443,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
clientIP
:=
ip
.
GetClientIP
(
c
)
clientIP
:=
ip
.
GetClientIP
(
c
)
requestPayloadHash
:=
service
.
HashUsageRequestPayload
(
body
)
requestPayloadHash
:=
service
.
HashUsageRequestPayload
(
body
)
if
result
.
ReasoningEffort
==
nil
{
result
.
ReasoningEffort
=
service
.
NormalizeClaudeOutputEffort
(
parsedReq
.
OutputEffort
)
}
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h
.
submitUsageRecordTask
(
func
(
ctx
context
.
Context
)
{
h
.
submitUsageRecordTask
(
func
(
ctx
context
.
Context
)
{
if
err
:=
h
.
gatewayService
.
RecordUsage
(
ctx
,
&
service
.
RecordUsageInput
{
if
err
:=
h
.
gatewayService
.
RecordUsage
(
ctx
,
&
service
.
RecordUsageInput
{
...
@@ -637,6 +648,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
...
@@ -637,6 +648,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if
fs
.
SwitchCount
>
0
{
if
fs
.
SwitchCount
>
0
{
requestCtx
=
service
.
WithAccountSwitchCount
(
requestCtx
,
fs
.
SwitchCount
,
h
.
metadataBridgeEnabled
())
requestCtx
=
service
.
WithAccountSwitchCount
(
requestCtx
,
fs
.
SwitchCount
,
h
.
metadataBridgeEnabled
())
}
}
// 记录 Forward 前已写入字节数,Forward 后若增加则说明 SSE 内容已发,禁止 failover
writerSizeBeforeForward
:=
c
.
Writer
.
Size
()
if
account
.
Platform
==
service
.
PlatformAntigravity
&&
account
.
Type
!=
service
.
AccountTypeAPIKey
{
if
account
.
Platform
==
service
.
PlatformAntigravity
&&
account
.
Type
!=
service
.
AccountTypeAPIKey
{
result
,
err
=
h
.
antigravityGatewayService
.
Forward
(
requestCtx
,
c
,
account
,
body
,
hasBoundSession
)
result
,
err
=
h
.
antigravityGatewayService
.
Forward
(
requestCtx
,
c
,
account
,
body
,
hasBoundSession
)
}
else
{
}
else
{
...
@@ -706,6 +719,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
...
@@ -706,6 +719,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
}
var
failoverErr
*
service
.
UpstreamFailoverError
var
failoverErr
*
service
.
UpstreamFailoverError
if
errors
.
As
(
err
,
&
failoverErr
)
{
if
errors
.
As
(
err
,
&
failoverErr
)
{
// 流式内容已写入客户端,无法撤销,禁止 failover 以防止流拼接腐化
if
c
.
Writer
.
Size
()
!=
writerSizeBeforeForward
{
h
.
handleFailoverExhausted
(
c
,
failoverErr
,
account
.
Platform
,
true
)
return
}
action
:=
fs
.
HandleFailoverError
(
c
.
Request
.
Context
(),
h
.
gatewayService
,
account
.
ID
,
account
.
Platform
,
failoverErr
)
action
:=
fs
.
HandleFailoverError
(
c
.
Request
.
Context
(),
h
.
gatewayService
,
account
.
ID
,
account
.
Platform
,
failoverErr
)
switch
action
{
switch
action
{
case
FailoverContinue
:
case
FailoverContinue
:
...
@@ -740,6 +758,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
...
@@ -740,6 +758,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
clientIP
:=
ip
.
GetClientIP
(
c
)
clientIP
:=
ip
.
GetClientIP
(
c
)
requestPayloadHash
:=
service
.
HashUsageRequestPayload
(
body
)
requestPayloadHash
:=
service
.
HashUsageRequestPayload
(
body
)
if
result
.
ReasoningEffort
==
nil
{
result
.
ReasoningEffort
=
service
.
NormalizeClaudeOutputEffort
(
parsedReq
.
OutputEffort
)
}
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h
.
submitUsageRecordTask
(
func
(
ctx
context
.
Context
)
{
h
.
submitUsageRecordTask
(
func
(
ctx
context
.
Context
)
{
if
err
:=
h
.
gatewayService
.
RecordUsage
(
ctx
,
&
service
.
RecordUsageInput
{
if
err
:=
h
.
gatewayService
.
RecordUsage
(
ctx
,
&
service
.
RecordUsageInput
{
...
...
backend/internal/handler/gateway_handler_stream_failover_test.go
0 → 100644
View file @
3718d6dc
package
handler
import
(
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// partialMessageStartSSE 模拟 handleStreamingResponse 已写入的首批 SSE 事件。
const
partialMessageStartSSE
=
"event: message_start
\n
data: {
\"
type
\"
:
\"
message_start
\"
,
\"
message
\"
:{
\"
id
\"
:
\"
msg_01
\"
,
\"
type
\"
:
\"
message
\"
,
\"
role
\"
:
\"
assistant
\"
,
\"
content
\"
:[],
\"
model
\"
:
\"
claude-sonnet-4-5
\"
,
\"
stop_reason
\"
:null,
\"
stop_sequence
\"
:null,
\"
usage
\"
:{
\"
input_tokens
\"
:10,
\"
output_tokens
\"
:1}}}
\n\n
"
+
"event: content_block_start
\n
data: {
\"
type
\"
:
\"
content_block_start
\"
,
\"
index
\"
:0,
\"
content_block
\"
:{
\"
type
\"
:
\"
text
\"
,
\"
text
\"
:
\"\"
}}
\n\n
"
// TestStreamWrittenGuard_MessagesPath_AbortFailoverOnSSEContentWritten 验证:
// 当 Forward 在返回 UpstreamFailoverError 前已向客户端写入 SSE 内容时,
// 故障转移保护逻辑必须终止循环并发送 SSE 错误事件,而不是进行下一次 Forward。
// 具体验证:
// 1. c.Writer.Size() 检测条件正确触发(字节数已增加)
// 2. handleFailoverExhausted 以 streamStarted=true 调用后,响应体以 SSE 错误事件结尾
// 3. 响应体中只出现一个 message_start,不存在第二个(防止流拼接腐化)
func
TestStreamWrittenGuard_MessagesPath_AbortFailoverOnSSEContentWritten
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/v1/messages"
,
nil
)
// 步骤 1:记录 Forward 前的 writer size(模拟 writerSizeBeforeForward := c.Writer.Size())
sizeBeforeForward
:=
c
.
Writer
.
Size
()
require
.
Equal
(
t
,
-
1
,
sizeBeforeForward
,
"gin writer 初始 Size 应为 -1(未写入任何字节)"
)
// 步骤 2:模拟 Forward 已向客户端写入部分 SSE 内容(message_start + content_block_start)
_
,
err
:=
c
.
Writer
.
Write
([]
byte
(
partialMessageStartSSE
))
require
.
NoError
(
t
,
err
)
// 步骤 3:验证守卫条件成立(c.Writer.Size() != sizeBeforeForward)
require
.
NotEqual
(
t
,
sizeBeforeForward
,
c
.
Writer
.
Size
(),
"写入 SSE 内容后 writer size 必须增加,守卫条件应为 true"
)
// 步骤 4:模拟 UpstreamFailoverError(上游在流中途返回 403)
failoverErr
:=
&
service
.
UpstreamFailoverError
{
StatusCode
:
http
.
StatusForbidden
,
ResponseBody
:
[]
byte
(
`{"error":{"type":"permission_error","message":"forbidden"}}`
),
}
// 步骤 5:守卫触发 → 调用 handleFailoverExhausted,streamStarted=true
h
:=
&
GatewayHandler
{}
h
.
handleFailoverExhausted
(
c
,
failoverErr
,
service
.
PlatformAnthropic
,
true
)
body
:=
w
.
Body
.
String
()
// 断言 A:响应体中包含最初写入的 message_start SSE 事件行
require
.
Contains
(
t
,
body
,
"event: message_start"
,
"响应体应包含已写入的 message_start SSE 事件"
)
// 断言 B:响应体以 SSE 错误事件结尾(data: {"type":"error",...}\n\n)
require
.
True
(
t
,
strings
.
HasSuffix
(
strings
.
TrimRight
(
body
,
"
\n
"
),
"}"
),
"响应体应以 JSON 对象结尾(SSE error event 的 data 字段)"
)
require
.
Contains
(
t
,
body
,
`"type":"error"`
,
"响应体末尾必须包含 SSE 错误事件"
)
// 断言 C:SSE event 行 "event: message_start" 只出现一次(防止双 message_start 拼接腐化)
firstIdx
:=
strings
.
Index
(
body
,
"event: message_start"
)
lastIdx
:=
strings
.
LastIndex
(
body
,
"event: message_start"
)
assert
.
Equal
(
t
,
firstIdx
,
lastIdx
,
"响应体中 'event: message_start' 必须只出现一次,不得因 failover 拼接导致两次"
)
}
// TestStreamWrittenGuard_GeminiPath_AbortFailoverOnSSEContentWritten 与上述测试相同,
// 验证 Gemini 路径使用 service.PlatformGemini(而非 account.Platform)时行为一致。
func
TestStreamWrittenGuard_GeminiPath_AbortFailoverOnSSEContentWritten
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/v1beta/models/gemini-2.0-flash:streamGenerateContent"
,
nil
)
sizeBeforeForward
:=
c
.
Writer
.
Size
()
_
,
err
:=
c
.
Writer
.
Write
([]
byte
(
partialMessageStartSSE
))
require
.
NoError
(
t
,
err
)
require
.
NotEqual
(
t
,
sizeBeforeForward
,
c
.
Writer
.
Size
())
failoverErr
:=
&
service
.
UpstreamFailoverError
{
StatusCode
:
http
.
StatusForbidden
,
}
h
:=
&
GatewayHandler
{}
h
.
handleFailoverExhausted
(
c
,
failoverErr
,
service
.
PlatformGemini
,
true
)
body
:=
w
.
Body
.
String
()
require
.
Contains
(
t
,
body
,
"event: message_start"
)
require
.
Contains
(
t
,
body
,
`"type":"error"`
)
firstIdx
:=
strings
.
Index
(
body
,
"event: message_start"
)
lastIdx
:=
strings
.
LastIndex
(
body
,
"event: message_start"
)
assert
.
Equal
(
t
,
firstIdx
,
lastIdx
,
"Gemini 路径不得出现双 message_start"
)
}
// TestStreamWrittenGuard_NoByteWritten_GuardNotTriggered 验证反向场景:
// 当 Forward 返回 UpstreamFailoverError 时若未向客户端写入任何 SSE 内容,
// 守卫条件(c.Writer.Size() != sizeBeforeForward)为 false,不应中止 failover。
func
TestStreamWrittenGuard_NoByteWritten_GuardNotTriggered
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/v1/messages"
,
nil
)
// 模拟 writerSizeBeforeForward:初始为 -1
sizeBeforeForward
:=
c
.
Writer
.
Size
()
// Forward 未写入任何字节直接返回错误(例如 401 发生在连接建立前)
// c.Writer.Size() 仍为 -1
// 守卫条件:sizeBeforeForward == c.Writer.Size() → 不触发
guardTriggered
:=
c
.
Writer
.
Size
()
!=
sizeBeforeForward
require
.
False
(
t
,
guardTriggered
,
"未写入任何字节时,守卫条件必须为 false,应允许正常 failover 继续"
)
}
backend/internal/handler/openai_chat_completions.go
View file @
3718d6dc
...
@@ -256,14 +256,16 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
...
@@ -256,14 +256,16 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
h
.
submitUsageRecordTask
(
func
(
ctx
context
.
Context
)
{
h
.
submitUsageRecordTask
(
func
(
ctx
context
.
Context
)
{
if
err
:=
h
.
gatewayService
.
RecordUsage
(
ctx
,
&
service
.
OpenAIRecordUsageInput
{
if
err
:=
h
.
gatewayService
.
RecordUsage
(
ctx
,
&
service
.
OpenAIRecordUsageInput
{
Result
:
result
,
Result
:
result
,
APIKey
:
apiKey
,
APIKey
:
apiKey
,
User
:
apiKey
.
User
,
User
:
apiKey
.
User
,
Account
:
account
,
Account
:
account
,
Subscription
:
subscription
,
Subscription
:
subscription
,
UserAgent
:
userAgent
,
InboundEndpoint
:
normalizedOpenAIInboundEndpoint
(
c
,
openAIInboundEndpointChatCompletions
),
IPAddress
:
clientIP
,
UpstreamEndpoint
:
normalizedOpenAIUpstreamEndpoint
(
c
,
openAIUpstreamEndpointResponses
),
APIKeyService
:
h
.
apiKeyService
,
UserAgent
:
userAgent
,
IPAddress
:
clientIP
,
APIKeyService
:
h
.
apiKeyService
,
});
err
!=
nil
{
});
err
!=
nil
{
logger
.
L
()
.
With
(
logger
.
L
()
.
With
(
zap
.
String
(
"component"
,
"handler.openai_gateway.chat_completions"
),
zap
.
String
(
"component"
,
"handler.openai_gateway.chat_completions"
),
...
...
backend/internal/handler/openai_gateway_endpoint_normalization_test.go
0 → 100644
View file @
3718d6dc
package
handler
import
(
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func
TestNormalizedOpenAIUpstreamEndpoint
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
tests
:=
[]
struct
{
name
string
path
string
fallback
string
want
string
}{
{
name
:
"responses root maps to responses upstream"
,
path
:
"/v1/responses"
,
fallback
:
openAIUpstreamEndpointResponses
,
want
:
"/v1/responses"
,
},
{
name
:
"responses compact keeps compact suffix"
,
path
:
"/openai/v1/responses/compact"
,
fallback
:
openAIUpstreamEndpointResponses
,
want
:
"/v1/responses/compact"
,
},
{
name
:
"responses nested suffix preserved"
,
path
:
"/openai/v1/responses/compact/detail"
,
fallback
:
openAIUpstreamEndpointResponses
,
want
:
"/v1/responses/compact/detail"
,
},
{
name
:
"non responses path uses fallback"
,
path
:
"/v1/messages"
,
fallback
:
openAIUpstreamEndpointResponses
,
want
:
"/v1/responses"
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
tt
.
path
,
nil
)
got
:=
normalizedOpenAIUpstreamEndpoint
(
c
,
tt
.
fallback
)
require
.
Equal
(
t
,
tt
.
want
,
got
)
})
}
}
backend/internal/handler/openai_gateway_handler.go
View file @
3718d6dc
...
@@ -37,6 +37,13 @@ type OpenAIGatewayHandler struct {
...
@@ -37,6 +37,13 @@ type OpenAIGatewayHandler struct {
cfg
*
config
.
Config
cfg
*
config
.
Config
}
}
const
(
openAIInboundEndpointResponses
=
"/v1/responses"
openAIInboundEndpointMessages
=
"/v1/messages"
openAIInboundEndpointChatCompletions
=
"/v1/chat/completions"
openAIUpstreamEndpointResponses
=
"/v1/responses"
)
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
func
NewOpenAIGatewayHandler
(
func
NewOpenAIGatewayHandler
(
gatewayService
*
service
.
OpenAIGatewayService
,
gatewayService
*
service
.
OpenAIGatewayService
,
...
@@ -362,6 +369,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
...
@@ -362,6 +369,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
User
:
apiKey
.
User
,
User
:
apiKey
.
User
,
Account
:
account
,
Account
:
account
,
Subscription
:
subscription
,
Subscription
:
subscription
,
InboundEndpoint
:
normalizedOpenAIInboundEndpoint
(
c
,
openAIInboundEndpointResponses
),
UpstreamEndpoint
:
normalizedOpenAIUpstreamEndpoint
(
c
,
openAIUpstreamEndpointResponses
),
UserAgent
:
userAgent
,
UserAgent
:
userAgent
,
IPAddress
:
clientIP
,
IPAddress
:
clientIP
,
RequestPayloadHash
:
requestPayloadHash
,
RequestPayloadHash
:
requestPayloadHash
,
...
@@ -738,6 +747,8 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
...
@@ -738,6 +747,8 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
User
:
apiKey
.
User
,
User
:
apiKey
.
User
,
Account
:
account
,
Account
:
account
,
Subscription
:
subscription
,
Subscription
:
subscription
,
InboundEndpoint
:
normalizedOpenAIInboundEndpoint
(
c
,
openAIInboundEndpointMessages
),
UpstreamEndpoint
:
normalizedOpenAIUpstreamEndpoint
(
c
,
openAIUpstreamEndpointResponses
),
UserAgent
:
userAgent
,
UserAgent
:
userAgent
,
IPAddress
:
clientIP
,
IPAddress
:
clientIP
,
RequestPayloadHash
:
requestPayloadHash
,
RequestPayloadHash
:
requestPayloadHash
,
...
@@ -1235,6 +1246,8 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
...
@@ -1235,6 +1246,8 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
User
:
apiKey
.
User
,
User
:
apiKey
.
User
,
Account
:
account
,
Account
:
account
,
Subscription
:
subscription
,
Subscription
:
subscription
,
InboundEndpoint
:
normalizedOpenAIInboundEndpoint
(
c
,
openAIInboundEndpointResponses
),
UpstreamEndpoint
:
normalizedOpenAIUpstreamEndpoint
(
c
,
openAIUpstreamEndpointResponses
),
UserAgent
:
userAgent
,
UserAgent
:
userAgent
,
IPAddress
:
clientIP
,
IPAddress
:
clientIP
,
RequestPayloadHash
:
service
.
HashUsageRequestPayload
(
firstMessage
),
RequestPayloadHash
:
service
.
HashUsageRequestPayload
(
firstMessage
),
...
@@ -1530,6 +1543,62 @@ func openAIWSIngressFallbackSessionSeed(userID, apiKeyID int64, groupID *int64)
...
@@ -1530,6 +1543,62 @@ func openAIWSIngressFallbackSessionSeed(userID, apiKeyID int64, groupID *int64)
return
fmt
.
Sprintf
(
"openai_ws_ingress:%d:%d:%d"
,
gid
,
userID
,
apiKeyID
)
return
fmt
.
Sprintf
(
"openai_ws_ingress:%d:%d:%d"
,
gid
,
userID
,
apiKeyID
)
}
}
func
normalizedOpenAIInboundEndpoint
(
c
*
gin
.
Context
,
fallback
string
)
string
{
path
:=
strings
.
TrimSpace
(
fallback
)
if
c
!=
nil
{
if
fullPath
:=
strings
.
TrimSpace
(
c
.
FullPath
());
fullPath
!=
""
{
path
=
fullPath
}
else
if
c
.
Request
!=
nil
&&
c
.
Request
.
URL
!=
nil
{
if
requestPath
:=
strings
.
TrimSpace
(
c
.
Request
.
URL
.
Path
);
requestPath
!=
""
{
path
=
requestPath
}
}
}
switch
{
case
strings
.
Contains
(
path
,
openAIInboundEndpointChatCompletions
)
:
return
openAIInboundEndpointChatCompletions
case
strings
.
Contains
(
path
,
openAIInboundEndpointMessages
)
:
return
openAIInboundEndpointMessages
case
strings
.
Contains
(
path
,
openAIInboundEndpointResponses
)
:
return
openAIInboundEndpointResponses
default
:
return
path
}
}
func
normalizedOpenAIUpstreamEndpoint
(
c
*
gin
.
Context
,
fallback
string
)
string
{
base
:=
strings
.
TrimSpace
(
fallback
)
if
base
==
""
{
base
=
openAIUpstreamEndpointResponses
}
base
=
strings
.
TrimRight
(
base
,
"/"
)
if
c
==
nil
||
c
.
Request
==
nil
||
c
.
Request
.
URL
==
nil
{
return
base
}
path
:=
strings
.
TrimRight
(
strings
.
TrimSpace
(
c
.
Request
.
URL
.
Path
),
"/"
)
if
path
==
""
{
return
base
}
idx
:=
strings
.
LastIndex
(
path
,
"/responses"
)
if
idx
<
0
{
return
base
}
suffix
:=
strings
.
TrimSpace
(
path
[
idx
+
len
(
"/responses"
)
:
])
if
suffix
==
""
||
suffix
==
"/"
{
return
base
}
if
!
strings
.
HasPrefix
(
suffix
,
"/"
)
{
return
base
}
return
base
+
suffix
}
func
isOpenAIWSUpgradeRequest
(
r
*
http
.
Request
)
bool
{
func
isOpenAIWSUpgradeRequest
(
r
*
http
.
Request
)
bool
{
if
r
==
nil
{
if
r
==
nil
{
return
false
return
false
...
...
backend/internal/handler/sora_gateway_handler_test.go
View file @
3718d6dc
...
@@ -334,6 +334,14 @@ func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTi
...
@@ -334,6 +334,14 @@ func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTi
func
(
s
*
stubUsageLogRepo
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
requestType
*
int16
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
ModelStat
,
error
)
{
func
(
s
*
stubUsageLogRepo
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
requestType
*
int16
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
ModelStat
,
error
)
{
return
nil
,
nil
return
nil
,
nil
}
}
func
(
s
*
stubUsageLogRepo
)
GetEndpointStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
requestType
*
int16
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
EndpointStat
,
error
)
{
return
[]
usagestats
.
EndpointStat
{},
nil
}
func
(
s
*
stubUsageLogRepo
)
GetUpstreamEndpointStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
requestType
*
int16
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
EndpointStat
,
error
)
{
return
[]
usagestats
.
EndpointStat
{},
nil
}
func
(
s
*
stubUsageLogRepo
)
GetGroupStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
requestType
*
int16
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
GroupStat
,
error
)
{
func
(
s
*
stubUsageLogRepo
)
GetGroupStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
requestType
*
int16
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
GroupStat
,
error
)
{
return
nil
,
nil
return
nil
,
nil
}
}
...
...
backend/internal/pkg/usagestats/usage_log_types.go
View file @
3718d6dc
...
@@ -81,6 +81,15 @@ type ModelStat struct {
...
@@ -81,6 +81,15 @@ type ModelStat struct {
ActualCost
float64
`json:"actual_cost"`
// 实际扣除
ActualCost
float64
`json:"actual_cost"`
// 实际扣除
}
}
// EndpointStat represents usage statistics for a single request endpoint.
type
EndpointStat
struct
{
Endpoint
string
`json:"endpoint"`
Requests
int64
`json:"requests"`
TotalTokens
int64
`json:"total_tokens"`
Cost
float64
`json:"cost"`
// 标准计费
ActualCost
float64
`json:"actual_cost"`
// 实际扣除
}
// GroupStat represents usage statistics for a single group
// GroupStat represents usage statistics for a single group
type
GroupStat
struct
{
type
GroupStat
struct
{
GroupID
int64
`json:"group_id"`
GroupID
int64
`json:"group_id"`
...
@@ -179,15 +188,18 @@ type UsageLogFilters struct {
...
@@ -179,15 +188,18 @@ type UsageLogFilters struct {
// UsageStats represents usage statistics
// UsageStats represents usage statistics
type
UsageStats
struct
{
type
UsageStats
struct
{
TotalRequests
int64
`json:"total_requests"`
TotalRequests
int64
`json:"total_requests"`
TotalInputTokens
int64
`json:"total_input_tokens"`
TotalInputTokens
int64
`json:"total_input_tokens"`
TotalOutputTokens
int64
`json:"total_output_tokens"`
TotalOutputTokens
int64
`json:"total_output_tokens"`
TotalCacheTokens
int64
`json:"total_cache_tokens"`
TotalCacheTokens
int64
`json:"total_cache_tokens"`
TotalTokens
int64
`json:"total_tokens"`
TotalTokens
int64
`json:"total_tokens"`
TotalCost
float64
`json:"total_cost"`
TotalCost
float64
`json:"total_cost"`
TotalActualCost
float64
`json:"total_actual_cost"`
TotalActualCost
float64
`json:"total_actual_cost"`
TotalAccountCost
*
float64
`json:"total_account_cost,omitempty"`
TotalAccountCost
*
float64
`json:"total_account_cost,omitempty"`
AverageDurationMs
float64
`json:"average_duration_ms"`
AverageDurationMs
float64
`json:"average_duration_ms"`
Endpoints
[]
EndpointStat
`json:"endpoints,omitempty"`
UpstreamEndpoints
[]
EndpointStat
`json:"upstream_endpoints,omitempty"`
EndpointPaths
[]
EndpointStat
`json:"endpoint_paths,omitempty"`
}
}
// BatchUserUsageStats represents usage stats for a single user
// BatchUserUsageStats represents usage stats for a single user
...
@@ -254,7 +266,9 @@ type AccountUsageSummary struct {
...
@@ -254,7 +266,9 @@ type AccountUsageSummary struct {
// AccountUsageStatsResponse represents the full usage statistics response for an account
// AccountUsageStatsResponse represents the full usage statistics response for an account
type
AccountUsageStatsResponse
struct
{
type
AccountUsageStatsResponse
struct
{
History
[]
AccountUsageHistory
`json:"history"`
History
[]
AccountUsageHistory
`json:"history"`
Summary
AccountUsageSummary
`json:"summary"`
Summary
AccountUsageSummary
`json:"summary"`
Models
[]
ModelStat
`json:"models"`
Models
[]
ModelStat
`json:"models"`
Endpoints
[]
EndpointStat
`json:"endpoints"`
UpstreamEndpoints
[]
EndpointStat
`json:"upstream_endpoints"`
}
}
backend/internal/repository/usage_billing_repo.go
View file @
3718d6dc
...
@@ -132,7 +132,7 @@ func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, t
...
@@ -132,7 +132,7 @@ func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, t
}
}
}
}
if
cmd
.
AccountQuotaCost
>
0
&&
strings
.
EqualFold
(
cmd
.
AccountType
,
service
.
AccountTypeAPIKey
)
{
if
cmd
.
AccountQuotaCost
>
0
&&
(
strings
.
EqualFold
(
cmd
.
AccountType
,
service
.
AccountTypeAPIKey
)
||
strings
.
EqualFold
(
cmd
.
AccountType
,
service
.
AccountTypeBedrock
))
{
if
err
:=
incrementUsageBillingAccountQuota
(
ctx
,
tx
,
cmd
.
AccountID
,
cmd
.
AccountQuotaCost
);
err
!=
nil
{
if
err
:=
incrementUsageBillingAccountQuota
(
ctx
,
tx
,
cmd
.
AccountID
,
cmd
.
AccountQuotaCost
);
err
!=
nil
{
return
err
return
err
}
}
...
...
backend/internal/repository/usage_log_repo.go
View file @
3718d6dc
...
@@ -28,7 +28,7 @@ import (
...
@@ -28,7 +28,7 @@ import (
gocache
"github.com/patrickmn/go-cache"
gocache
"github.com/patrickmn/go-cache"
)
)
const
usageLogSelectColumns
=
"id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, cache_ttl_overridden, created_at"
const
usageLogSelectColumns
=
"id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort,
inbound_endpoint, upstream_endpoint,
cache_ttl_overridden, created_at"
var
usageLogInsertArgTypes
=
[
...
]
string
{
var
usageLogInsertArgTypes
=
[
...
]
string
{
"bigint"
,
"bigint"
,
...
@@ -65,6 +65,8 @@ var usageLogInsertArgTypes = [...]string{
...
@@ -65,6 +65,8 @@ var usageLogInsertArgTypes = [...]string{
"text"
,
"text"
,
"text"
,
"text"
,
"text"
,
"text"
,
"text"
,
"text"
,
"boolean"
,
"boolean"
,
"timestamptz"
,
"timestamptz"
,
}
}
...
@@ -304,6 +306,8 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
...
@@ -304,6 +306,8 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
media_type,
media_type,
service_tier,
service_tier,
reasoning_effort,
reasoning_effort,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
cache_ttl_overridden,
created_at
created_at
) VALUES (
) VALUES (
...
@@ -312,7 +316,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
...
@@ -312,7 +316,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
$8, $9, $10, $11,
$8, $9, $10, $11,
$12, $13,
$12, $13,
$14, $15, $16, $17, $18, $19,
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36
, $37, $38
)
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
RETURNING id, created_at
...
@@ -732,11 +736,13 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
...
@@ -732,11 +736,13 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
media_type,
media_type,
service_tier,
service_tier,
reasoning_effort,
reasoning_effort,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
cache_ttl_overridden,
created_at
created_at
) AS (VALUES `
)
) AS (VALUES `
)
args
:=
make
([]
any
,
0
,
len
(
keys
)
*
3
7
)
args
:=
make
([]
any
,
0
,
len
(
keys
)
*
3
8
)
argPos
:=
1
argPos
:=
1
for
idx
,
key
:=
range
keys
{
for
idx
,
key
:=
range
keys
{
if
idx
>
0
{
if
idx
>
0
{
...
@@ -799,6 +805,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
...
@@ -799,6 +805,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
media_type,
media_type,
service_tier,
service_tier,
reasoning_effort,
reasoning_effort,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
cache_ttl_overridden,
created_at
created_at
)
)
...
@@ -837,6 +845,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
...
@@ -837,6 +845,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
media_type,
media_type,
service_tier,
service_tier,
reasoning_effort,
reasoning_effort,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
cache_ttl_overridden,
created_at
created_at
FROM input
FROM input
...
@@ -915,11 +925,13 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
...
@@ -915,11 +925,13 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
media_type,
media_type,
service_tier,
service_tier,
reasoning_effort,
reasoning_effort,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
cache_ttl_overridden,
created_at
created_at
) AS (VALUES `
)
) AS (VALUES `
)
args
:=
make
([]
any
,
0
,
len
(
preparedList
)
*
3
6
)
args
:=
make
([]
any
,
0
,
len
(
preparedList
)
*
3
8
)
argPos
:=
1
argPos
:=
1
for
idx
,
prepared
:=
range
preparedList
{
for
idx
,
prepared
:=
range
preparedList
{
if
idx
>
0
{
if
idx
>
0
{
...
@@ -979,6 +991,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
...
@@ -979,6 +991,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
media_type,
media_type,
service_tier,
service_tier,
reasoning_effort,
reasoning_effort,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
cache_ttl_overridden,
created_at
created_at
)
)
...
@@ -1017,6 +1031,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
...
@@ -1017,6 +1031,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
media_type,
media_type,
service_tier,
service_tier,
reasoning_effort,
reasoning_effort,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
cache_ttl_overridden,
created_at
created_at
FROM input
FROM input
...
@@ -1063,6 +1079,8 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
...
@@ -1063,6 +1079,8 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
media_type,
media_type,
service_tier,
service_tier,
reasoning_effort,
reasoning_effort,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
cache_ttl_overridden,
created_at
created_at
) VALUES (
) VALUES (
...
@@ -1071,7 +1089,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
...
@@ -1071,7 +1089,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
$8, $9, $10, $11,
$8, $9, $10, $11,
$12, $13,
$12, $13,
$14, $15, $16, $17, $18, $19,
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36
, $37, $38
)
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
ON CONFLICT (request_id, api_key_id) DO NOTHING
`
,
prepared
.
args
...
)
`
,
prepared
.
args
...
)
...
@@ -1101,6 +1119,8 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
...
@@ -1101,6 +1119,8 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
mediaType
:=
nullString
(
log
.
MediaType
)
mediaType
:=
nullString
(
log
.
MediaType
)
serviceTier
:=
nullString
(
log
.
ServiceTier
)
serviceTier
:=
nullString
(
log
.
ServiceTier
)
reasoningEffort
:=
nullString
(
log
.
ReasoningEffort
)
reasoningEffort
:=
nullString
(
log
.
ReasoningEffort
)
inboundEndpoint
:=
nullString
(
log
.
InboundEndpoint
)
upstreamEndpoint
:=
nullString
(
log
.
UpstreamEndpoint
)
var
requestIDArg
any
var
requestIDArg
any
if
requestID
!=
""
{
if
requestID
!=
""
{
...
@@ -1147,6 +1167,8 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
...
@@ -1147,6 +1167,8 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
mediaType
,
mediaType
,
serviceTier
,
serviceTier
,
reasoningEffort
,
reasoningEffort
,
inboundEndpoint
,
upstreamEndpoint
,
log
.
CacheTTLOverridden
,
log
.
CacheTTLOverridden
,
createdAt
,
createdAt
,
},
},
...
@@ -2505,7 +2527,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
...
@@ -2505,7 +2527,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
args
=
append
(
args
,
*
filters
.
StartTime
)
args
=
append
(
args
,
*
filters
.
StartTime
)
}
}
if
filters
.
EndTime
!=
nil
{
if
filters
.
EndTime
!=
nil
{
conditions
=
append
(
conditions
,
fmt
.
Sprintf
(
"created_at <
=
$%d"
,
len
(
args
)
+
1
))
conditions
=
append
(
conditions
,
fmt
.
Sprintf
(
"created_at < $%d"
,
len
(
args
)
+
1
))
args
=
append
(
args
,
*
filters
.
EndTime
)
args
=
append
(
args
,
*
filters
.
EndTime
)
}
}
...
@@ -3040,7 +3062,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
...
@@ -3040,7 +3062,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
args
=
append
(
args
,
*
filters
.
StartTime
)
args
=
append
(
args
,
*
filters
.
StartTime
)
}
}
if
filters
.
EndTime
!=
nil
{
if
filters
.
EndTime
!=
nil
{
conditions
=
append
(
conditions
,
fmt
.
Sprintf
(
"created_at <
=
$%d"
,
len
(
args
)
+
1
))
conditions
=
append
(
conditions
,
fmt
.
Sprintf
(
"created_at < $%d"
,
len
(
args
)
+
1
))
args
=
append
(
args
,
*
filters
.
EndTime
)
args
=
append
(
args
,
*
filters
.
EndTime
)
}
}
...
@@ -3080,6 +3102,35 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
...
@@ -3080,6 +3102,35 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
stats
.
TotalAccountCost
=
&
totalAccountCost
stats
.
TotalAccountCost
=
&
totalAccountCost
}
}
stats
.
TotalTokens
=
stats
.
TotalInputTokens
+
stats
.
TotalOutputTokens
+
stats
.
TotalCacheTokens
stats
.
TotalTokens
=
stats
.
TotalInputTokens
+
stats
.
TotalOutputTokens
+
stats
.
TotalCacheTokens
start
:=
time
.
Unix
(
0
,
0
)
.
UTC
()
if
filters
.
StartTime
!=
nil
{
start
=
*
filters
.
StartTime
}
end
:=
time
.
Now
()
.
UTC
()
if
filters
.
EndTime
!=
nil
{
end
=
*
filters
.
EndTime
}
endpoints
,
endpointErr
:=
r
.
GetEndpointStatsWithFilters
(
ctx
,
start
,
end
,
filters
.
UserID
,
filters
.
APIKeyID
,
filters
.
AccountID
,
filters
.
GroupID
,
filters
.
Model
,
filters
.
RequestType
,
filters
.
Stream
,
filters
.
BillingType
)
if
endpointErr
!=
nil
{
logger
.
LegacyPrintf
(
"repository.usage_log"
,
"GetEndpointStatsWithFilters failed in GetStatsWithFilters: %v"
,
endpointErr
)
endpoints
=
[]
EndpointStat
{}
}
upstreamEndpoints
,
upstreamEndpointErr
:=
r
.
GetUpstreamEndpointStatsWithFilters
(
ctx
,
start
,
end
,
filters
.
UserID
,
filters
.
APIKeyID
,
filters
.
AccountID
,
filters
.
GroupID
,
filters
.
Model
,
filters
.
RequestType
,
filters
.
Stream
,
filters
.
BillingType
)
if
upstreamEndpointErr
!=
nil
{
logger
.
LegacyPrintf
(
"repository.usage_log"
,
"GetUpstreamEndpointStatsWithFilters failed in GetStatsWithFilters: %v"
,
upstreamEndpointErr
)
upstreamEndpoints
=
[]
EndpointStat
{}
}
endpointPaths
,
endpointPathErr
:=
r
.
getEndpointPathStatsWithFilters
(
ctx
,
start
,
end
,
filters
.
UserID
,
filters
.
APIKeyID
,
filters
.
AccountID
,
filters
.
GroupID
,
filters
.
Model
,
filters
.
RequestType
,
filters
.
Stream
,
filters
.
BillingType
)
if
endpointPathErr
!=
nil
{
logger
.
LegacyPrintf
(
"repository.usage_log"
,
"getEndpointPathStatsWithFilters failed in GetStatsWithFilters: %v"
,
endpointPathErr
)
endpointPaths
=
[]
EndpointStat
{}
}
stats
.
Endpoints
=
endpoints
stats
.
UpstreamEndpoints
=
upstreamEndpoints
stats
.
EndpointPaths
=
endpointPaths
return
stats
,
nil
return
stats
,
nil
}
}
...
@@ -3092,6 +3143,163 @@ type AccountUsageSummary = usagestats.AccountUsageSummary
...
@@ -3092,6 +3143,163 @@ type AccountUsageSummary = usagestats.AccountUsageSummary
// AccountUsageStatsResponse represents the full usage statistics response for an account
// AccountUsageStatsResponse represents the full usage statistics response for an account
type
AccountUsageStatsResponse
=
usagestats
.
AccountUsageStatsResponse
type
AccountUsageStatsResponse
=
usagestats
.
AccountUsageStatsResponse
// EndpointStat represents endpoint usage statistics row.
type
EndpointStat
=
usagestats
.
EndpointStat
func
(
r
*
usageLogRepository
)
getEndpointStatsByColumnWithFilters
(
ctx
context
.
Context
,
endpointColumn
string
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
requestType
*
int16
,
stream
*
bool
,
billingType
*
int8
)
(
results
[]
EndpointStat
,
err
error
)
{
actualCostExpr
:=
"COALESCE(SUM(actual_cost), 0) as actual_cost"
if
accountID
>
0
&&
userID
==
0
&&
apiKeyID
==
0
{
actualCostExpr
=
"COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
}
query
:=
fmt
.
Sprintf
(
`
SELECT
COALESCE(NULLIF(TRIM(%s), ''), 'unknown') AS endpoint,
COUNT(*) AS requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
%s
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
`
,
endpointColumn
,
actualCostExpr
)
args
:=
[]
any
{
startTime
,
endTime
}
if
userID
>
0
{
query
+=
fmt
.
Sprintf
(
" AND user_id = $%d"
,
len
(
args
)
+
1
)
args
=
append
(
args
,
userID
)
}
if
apiKeyID
>
0
{
query
+=
fmt
.
Sprintf
(
" AND api_key_id = $%d"
,
len
(
args
)
+
1
)
args
=
append
(
args
,
apiKeyID
)
}
if
accountID
>
0
{
query
+=
fmt
.
Sprintf
(
" AND account_id = $%d"
,
len
(
args
)
+
1
)
args
=
append
(
args
,
accountID
)
}
if
groupID
>
0
{
query
+=
fmt
.
Sprintf
(
" AND group_id = $%d"
,
len
(
args
)
+
1
)
args
=
append
(
args
,
groupID
)
}
if
model
!=
""
{
query
+=
fmt
.
Sprintf
(
" AND model = $%d"
,
len
(
args
)
+
1
)
args
=
append
(
args
,
model
)
}
query
,
args
=
appendRequestTypeOrStreamQueryFilter
(
query
,
args
,
requestType
,
stream
)
if
billingType
!=
nil
{
query
+=
fmt
.
Sprintf
(
" AND billing_type = $%d"
,
len
(
args
)
+
1
)
args
=
append
(
args
,
int16
(
*
billingType
))
}
query
+=
" GROUP BY endpoint ORDER BY requests DESC"
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
query
,
args
...
)
if
err
!=
nil
{
return
nil
,
err
}
defer
func
()
{
if
closeErr
:=
rows
.
Close
();
closeErr
!=
nil
&&
err
==
nil
{
err
=
closeErr
results
=
nil
}
}()
results
=
make
([]
EndpointStat
,
0
)
for
rows
.
Next
()
{
var
row
EndpointStat
if
err
:=
rows
.
Scan
(
&
row
.
Endpoint
,
&
row
.
Requests
,
&
row
.
TotalTokens
,
&
row
.
Cost
,
&
row
.
ActualCost
);
err
!=
nil
{
return
nil
,
err
}
results
=
append
(
results
,
row
)
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
nil
,
err
}
return
results
,
nil
}
func
(
r
*
usageLogRepository
)
getEndpointPathStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
requestType
*
int16
,
stream
*
bool
,
billingType
*
int8
)
(
results
[]
EndpointStat
,
err
error
)
{
actualCostExpr
:=
"COALESCE(SUM(actual_cost), 0) as actual_cost"
if
accountID
>
0
&&
userID
==
0
&&
apiKeyID
==
0
{
actualCostExpr
=
"COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
}
query
:=
fmt
.
Sprintf
(
`
SELECT
CONCAT(
COALESCE(NULLIF(TRIM(inbound_endpoint), ''), 'unknown'),
' -> ',
COALESCE(NULLIF(TRIM(upstream_endpoint), ''), 'unknown')
) AS endpoint,
COUNT(*) AS requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
%s
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
`
,
actualCostExpr
)
args
:=
[]
any
{
startTime
,
endTime
}
if
userID
>
0
{
query
+=
fmt
.
Sprintf
(
" AND user_id = $%d"
,
len
(
args
)
+
1
)
args
=
append
(
args
,
userID
)
}
if
apiKeyID
>
0
{
query
+=
fmt
.
Sprintf
(
" AND api_key_id = $%d"
,
len
(
args
)
+
1
)
args
=
append
(
args
,
apiKeyID
)
}
if
accountID
>
0
{
query
+=
fmt
.
Sprintf
(
" AND account_id = $%d"
,
len
(
args
)
+
1
)
args
=
append
(
args
,
accountID
)
}
if
groupID
>
0
{
query
+=
fmt
.
Sprintf
(
" AND group_id = $%d"
,
len
(
args
)
+
1
)
args
=
append
(
args
,
groupID
)
}
if
model
!=
""
{
query
+=
fmt
.
Sprintf
(
" AND model = $%d"
,
len
(
args
)
+
1
)
args
=
append
(
args
,
model
)
}
query
,
args
=
appendRequestTypeOrStreamQueryFilter
(
query
,
args
,
requestType
,
stream
)
if
billingType
!=
nil
{
query
+=
fmt
.
Sprintf
(
" AND billing_type = $%d"
,
len
(
args
)
+
1
)
args
=
append
(
args
,
int16
(
*
billingType
))
}
query
+=
" GROUP BY endpoint ORDER BY requests DESC"
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
query
,
args
...
)
if
err
!=
nil
{
return
nil
,
err
}
defer
func
()
{
if
closeErr
:=
rows
.
Close
();
closeErr
!=
nil
&&
err
==
nil
{
err
=
closeErr
results
=
nil
}
}()
results
=
make
([]
EndpointStat
,
0
)
for
rows
.
Next
()
{
var
row
EndpointStat
if
err
:=
rows
.
Scan
(
&
row
.
Endpoint
,
&
row
.
Requests
,
&
row
.
TotalTokens
,
&
row
.
Cost
,
&
row
.
ActualCost
);
err
!=
nil
{
return
nil
,
err
}
results
=
append
(
results
,
row
)
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
nil
,
err
}
return
results
,
nil
}
// GetEndpointStatsWithFilters returns inbound endpoint statistics with optional filters.
func
(
r
*
usageLogRepository
)
GetEndpointStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
requestType
*
int16
,
stream
*
bool
,
billingType
*
int8
)
([]
EndpointStat
,
error
)
{
return
r
.
getEndpointStatsByColumnWithFilters
(
ctx
,
"inbound_endpoint"
,
startTime
,
endTime
,
userID
,
apiKeyID
,
accountID
,
groupID
,
model
,
requestType
,
stream
,
billingType
)
}
// GetUpstreamEndpointStatsWithFilters returns upstream endpoint statistics with optional filters.
func
(
r
*
usageLogRepository
)
GetUpstreamEndpointStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
requestType
*
int16
,
stream
*
bool
,
billingType
*
int8
)
([]
EndpointStat
,
error
)
{
return
r
.
getEndpointStatsByColumnWithFilters
(
ctx
,
"upstream_endpoint"
,
startTime
,
endTime
,
userID
,
apiKeyID
,
accountID
,
groupID
,
model
,
requestType
,
stream
,
billingType
)
}
// GetAccountUsageStats returns comprehensive usage statistics for an account over a time range
// GetAccountUsageStats returns comprehensive usage statistics for an account over a time range
func
(
r
*
usageLogRepository
)
GetAccountUsageStats
(
ctx
context
.
Context
,
accountID
int64
,
startTime
,
endTime
time
.
Time
)
(
resp
*
AccountUsageStatsResponse
,
err
error
)
{
func
(
r
*
usageLogRepository
)
GetAccountUsageStats
(
ctx
context
.
Context
,
accountID
int64
,
startTime
,
endTime
time
.
Time
)
(
resp
*
AccountUsageStatsResponse
,
err
error
)
{
daysCount
:=
int
(
endTime
.
Sub
(
startTime
)
.
Hours
()
/
24
)
+
1
daysCount
:=
int
(
endTime
.
Sub
(
startTime
)
.
Hours
()
/
24
)
+
1
...
@@ -3254,11 +3462,23 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
...
@@ -3254,11 +3462,23 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
if
err
!=
nil
{
if
err
!=
nil
{
models
=
[]
ModelStat
{}
models
=
[]
ModelStat
{}
}
}
endpoints
,
endpointErr
:=
r
.
GetEndpointStatsWithFilters
(
ctx
,
startTime
,
endTime
,
0
,
0
,
accountID
,
0
,
""
,
nil
,
nil
,
nil
)
if
endpointErr
!=
nil
{
logger
.
LegacyPrintf
(
"repository.usage_log"
,
"GetEndpointStatsWithFilters failed in GetAccountUsageStats: %v"
,
endpointErr
)
endpoints
=
[]
EndpointStat
{}
}
upstreamEndpoints
,
upstreamEndpointErr
:=
r
.
GetUpstreamEndpointStatsWithFilters
(
ctx
,
startTime
,
endTime
,
0
,
0
,
accountID
,
0
,
""
,
nil
,
nil
,
nil
)
if
upstreamEndpointErr
!=
nil
{
logger
.
LegacyPrintf
(
"repository.usage_log"
,
"GetUpstreamEndpointStatsWithFilters failed in GetAccountUsageStats: %v"
,
upstreamEndpointErr
)
upstreamEndpoints
=
[]
EndpointStat
{}
}
resp
=
&
AccountUsageStatsResponse
{
resp
=
&
AccountUsageStatsResponse
{
History
:
history
,
History
:
history
,
Summary
:
summary
,
Summary
:
summary
,
Models
:
models
,
Models
:
models
,
Endpoints
:
endpoints
,
UpstreamEndpoints
:
upstreamEndpoints
,
}
}
return
resp
,
nil
return
resp
,
nil
}
}
...
@@ -3541,6 +3761,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
...
@@ -3541,6 +3761,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
mediaType
sql
.
NullString
mediaType
sql
.
NullString
serviceTier
sql
.
NullString
serviceTier
sql
.
NullString
reasoningEffort
sql
.
NullString
reasoningEffort
sql
.
NullString
inboundEndpoint
sql
.
NullString
upstreamEndpoint
sql
.
NullString
cacheTTLOverridden
bool
cacheTTLOverridden
bool
createdAt
time
.
Time
createdAt
time
.
Time
)
)
...
@@ -3581,6 +3803,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
...
@@ -3581,6 +3803,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&
mediaType
,
&
mediaType
,
&
serviceTier
,
&
serviceTier
,
&
reasoningEffort
,
&
reasoningEffort
,
&
inboundEndpoint
,
&
upstreamEndpoint
,
&
cacheTTLOverridden
,
&
cacheTTLOverridden
,
&
createdAt
,
&
createdAt
,
);
err
!=
nil
{
);
err
!=
nil
{
...
@@ -3656,6 +3880,12 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
...
@@ -3656,6 +3880,12 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if
reasoningEffort
.
Valid
{
if
reasoningEffort
.
Valid
{
log
.
ReasoningEffort
=
&
reasoningEffort
.
String
log
.
ReasoningEffort
=
&
reasoningEffort
.
String
}
}
if
inboundEndpoint
.
Valid
{
log
.
InboundEndpoint
=
&
inboundEndpoint
.
String
}
if
upstreamEndpoint
.
Valid
{
log
.
UpstreamEndpoint
=
&
upstreamEndpoint
.
String
}
return
log
,
nil
return
log
,
nil
}
}
...
...
backend/internal/repository/usage_log_repo_request_type_test.go
View file @
3718d6dc
...
@@ -73,6 +73,8 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
...
@@ -73,6 +73,8 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
sqlmock
.
AnyArg
(),
// media_type
sqlmock
.
AnyArg
(),
// media_type
sqlmock
.
AnyArg
(),
// service_tier
sqlmock
.
AnyArg
(),
// service_tier
sqlmock
.
AnyArg
(),
// reasoning_effort
sqlmock
.
AnyArg
(),
// reasoning_effort
sqlmock
.
AnyArg
(),
// inbound_endpoint
sqlmock
.
AnyArg
(),
// upstream_endpoint
log
.
CacheTTLOverridden
,
log
.
CacheTTLOverridden
,
createdAt
,
createdAt
,
)
.
)
.
...
@@ -141,6 +143,8 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
...
@@ -141,6 +143,8 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
sqlmock
.
AnyArg
(),
sqlmock
.
AnyArg
(),
serviceTier
,
serviceTier
,
sqlmock
.
AnyArg
(),
sqlmock
.
AnyArg
(),
sqlmock
.
AnyArg
(),
sqlmock
.
AnyArg
(),
log
.
CacheTTLOverridden
,
log
.
CacheTTLOverridden
,
createdAt
,
createdAt
,
)
.
)
.
...
@@ -376,6 +380,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
...
@@ -376,6 +380,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql
.
NullString
{},
sql
.
NullString
{},
sql
.
NullString
{
Valid
:
true
,
String
:
"priority"
},
sql
.
NullString
{
Valid
:
true
,
String
:
"priority"
},
sql
.
NullString
{},
sql
.
NullString
{},
sql
.
NullString
{},
sql
.
NullString
{},
false
,
false
,
now
,
now
,
}})
}})
...
@@ -415,6 +421,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
...
@@ -415,6 +421,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql
.
NullString
{},
sql
.
NullString
{},
sql
.
NullString
{
Valid
:
true
,
String
:
"flex"
},
sql
.
NullString
{
Valid
:
true
,
String
:
"flex"
},
sql
.
NullString
{},
sql
.
NullString
{},
sql
.
NullString
{},
sql
.
NullString
{},
false
,
false
,
now
,
now
,
}})
}})
...
@@ -454,6 +462,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
...
@@ -454,6 +462,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql
.
NullString
{},
sql
.
NullString
{},
sql
.
NullString
{
Valid
:
true
,
String
:
"priority"
},
sql
.
NullString
{
Valid
:
true
,
String
:
"priority"
},
sql
.
NullString
{},
sql
.
NullString
{},
sql
.
NullString
{},
sql
.
NullString
{},
false
,
false
,
now
,
now
,
}})
}})
...
...
backend/internal/server/api_contract_test.go
View file @
3718d6dc
...
@@ -1624,6 +1624,14 @@ func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTi
...
@@ -1624,6 +1624,14 @@ func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTi
return
nil
,
errors
.
New
(
"not implemented"
)
return
nil
,
errors
.
New
(
"not implemented"
)
}
}
func
(
r
*
stubUsageLogRepo
)
GetEndpointStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
requestType
*
int16
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
EndpointStat
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUsageLogRepo
)
GetUpstreamEndpointStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
requestType
*
int16
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
EndpointStat
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
}
func
(
r
*
stubUsageLogRepo
)
GetGroupStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
requestType
*
int16
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
GroupStat
,
error
)
{
func
(
r
*
stubUsageLogRepo
)
GetGroupStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
requestType
*
int16
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
GroupStat
,
error
)
{
return
nil
,
errors
.
New
(
"not implemented"
)
return
nil
,
errors
.
New
(
"not implemented"
)
}
}
...
...
backend/internal/service/account_usage_service.go
View file @
3718d6dc
...
@@ -45,6 +45,8 @@ type UsageLogRepository interface {
...
@@ -45,6 +45,8 @@ type UsageLogRepository interface {
GetDashboardStats
(
ctx
context
.
Context
)
(
*
usagestats
.
DashboardStats
,
error
)
GetDashboardStats
(
ctx
context
.
Context
)
(
*
usagestats
.
DashboardStats
,
error
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
requestType
*
int16
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
TrendDataPoint
,
error
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
requestType
*
int16
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
TrendDataPoint
,
error
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
requestType
*
int16
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
ModelStat
,
error
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
requestType
*
int16
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
ModelStat
,
error
)
GetEndpointStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
requestType
*
int16
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
EndpointStat
,
error
)
GetUpstreamEndpointStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
requestType
*
int16
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
EndpointStat
,
error
)
GetGroupStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
requestType
*
int16
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
GroupStat
,
error
)
GetGroupStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
requestType
*
int16
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
GroupStat
,
error
)
GetAPIKeyUsageTrend
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
limit
int
)
([]
usagestats
.
APIKeyUsageTrendPoint
,
error
)
GetAPIKeyUsageTrend
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
limit
int
)
([]
usagestats
.
APIKeyUsageTrendPoint
,
error
)
GetUserUsageTrend
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
limit
int
)
([]
usagestats
.
UserUsageTrendPoint
,
error
)
GetUserUsageTrend
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
limit
int
)
([]
usagestats
.
UserUsageTrendPoint
,
error
)
...
...
backend/internal/service/admin_service.go
View file @
3718d6dc
...
@@ -832,7 +832,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
...
@@ -832,7 +832,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
subscriptionType
=
SubscriptionTypeStandard
subscriptionType
=
SubscriptionTypeStandard
}
}
// 限额字段:
0 和 nil 都表示"无限制"
// 限额字段:
nil/负数 表示"无限制",0 表示"不允许用量",正数表示具体限额
dailyLimit
:=
normalizeLimit
(
input
.
DailyLimitUSD
)
dailyLimit
:=
normalizeLimit
(
input
.
DailyLimitUSD
)
weeklyLimit
:=
normalizeLimit
(
input
.
WeeklyLimitUSD
)
weeklyLimit
:=
normalizeLimit
(
input
.
WeeklyLimitUSD
)
monthlyLimit
:=
normalizeLimit
(
input
.
MonthlyLimitUSD
)
monthlyLimit
:=
normalizeLimit
(
input
.
MonthlyLimitUSD
)
...
@@ -944,9 +944,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
...
@@ -944,9 +944,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
return
group
,
nil
return
group
,
nil
}
}
// normalizeLimit 将
0 或
负数转换为 nil(表示无限制)
// normalizeLimit 将负数转换为 nil(表示无限制)
,0 保留(表示限额为零)
func
normalizeLimit
(
limit
*
float64
)
*
float64
{
func
normalizeLimit
(
limit
*
float64
)
*
float64
{
if
limit
==
nil
||
*
limit
<
=
0
{
if
limit
==
nil
||
*
limit
<
0
{
return
nil
return
nil
}
}
return
limit
return
limit
...
@@ -1058,16 +1058,11 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
...
@@ -1058,16 +1058,11 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if
input
.
SubscriptionType
!=
""
{
if
input
.
SubscriptionType
!=
""
{
group
.
SubscriptionType
=
input
.
SubscriptionType
group
.
SubscriptionType
=
input
.
SubscriptionType
}
}
// 限额字段:0 和 nil 都表示"无限制",正数表示具体限额
// 限额字段:nil/负数 表示"无限制",0 表示"不允许用量",正数表示具体限额
if
input
.
DailyLimitUSD
!=
nil
{
// 前端始终发送这三个字段,无需 nil 守卫
group
.
DailyLimitUSD
=
normalizeLimit
(
input
.
DailyLimitUSD
)
group
.
DailyLimitUSD
=
normalizeLimit
(
input
.
DailyLimitUSD
)
}
group
.
WeeklyLimitUSD
=
normalizeLimit
(
input
.
WeeklyLimitUSD
)
if
input
.
WeeklyLimitUSD
!=
nil
{
group
.
MonthlyLimitUSD
=
normalizeLimit
(
input
.
MonthlyLimitUSD
)
group
.
WeeklyLimitUSD
=
normalizeLimit
(
input
.
WeeklyLimitUSD
)
}
if
input
.
MonthlyLimitUSD
!=
nil
{
group
.
MonthlyLimitUSD
=
normalizeLimit
(
input
.
MonthlyLimitUSD
)
}
// 图片生成计费配置:负数表示清除(使用默认价格)
// 图片生成计费配置:负数表示清除(使用默认价格)
if
input
.
ImagePrice1K
!=
nil
{
if
input
.
ImagePrice1K
!=
nil
{
group
.
ImagePrice1K
=
normalizePrice
(
input
.
ImagePrice1K
)
group
.
ImagePrice1K
=
normalizePrice
(
input
.
ImagePrice1K
)
...
...
backend/internal/service/api_key.go
View file @
3718d6dc
...
@@ -22,8 +22,9 @@ const (
...
@@ -22,8 +22,9 @@ const (
)
)
// IsWindowExpired returns true if the window starting at windowStart has exceeded the given duration.
// IsWindowExpired returns true if the window starting at windowStart has exceeded the given duration.
// A nil windowStart is treated as expired — no initialized window means any accumulated usage is stale.
func
IsWindowExpired
(
windowStart
*
time
.
Time
,
duration
time
.
Duration
)
bool
{
func
IsWindowExpired
(
windowStart
*
time
.
Time
,
duration
time
.
Duration
)
bool
{
return
windowStart
!
=
nil
&&
time
.
Since
(
*
windowStart
)
>=
duration
return
windowStart
=
=
nil
||
time
.
Since
(
*
windowStart
)
>=
duration
}
}
type
APIKey
struct
{
type
APIKey
struct
{
...
...
backend/internal/service/api_key_rate_limit_test.go
View file @
3718d6dc
...
@@ -15,10 +15,10 @@ func TestIsWindowExpired(t *testing.T) {
...
@@ -15,10 +15,10 @@ func TestIsWindowExpired(t *testing.T) {
want
bool
want
bool
}{
}{
{
{
name
:
"nil window start"
,
name
:
"nil window start
(treated as expired)
"
,
start
:
nil
,
start
:
nil
,
duration
:
RateLimitWindow5h
,
duration
:
RateLimitWindow5h
,
want
:
fals
e
,
want
:
tru
e
,
},
},
{
{
name
:
"active window (started 1h ago, 5h window)"
,
name
:
"active window (started 1h ago, 5h window)"
,
...
@@ -113,7 +113,7 @@ func TestAPIKey_EffectiveUsage(t *testing.T) {
...
@@ -113,7 +113,7 @@ func TestAPIKey_EffectiveUsage(t *testing.T) {
want7d
:
0
,
want7d
:
0
,
},
},
{
{
name
:
"nil window starts return
raw usage
"
,
name
:
"nil window starts return
0 (stale usage reset)
"
,
key
:
APIKey
{
key
:
APIKey
{
Usage5h
:
5.0
,
Usage5h
:
5.0
,
Usage1d
:
10.0
,
Usage1d
:
10.0
,
...
@@ -122,9 +122,9 @@ func TestAPIKey_EffectiveUsage(t *testing.T) {
...
@@ -122,9 +122,9 @@ func TestAPIKey_EffectiveUsage(t *testing.T) {
Window1dStart
:
nil
,
Window1dStart
:
nil
,
Window7dStart
:
nil
,
Window7dStart
:
nil
,
},
},
want5h
:
5.
0
,
want5h
:
0
,
want1d
:
10.
0
,
want1d
:
0
,
want7d
:
50.
0
,
want7d
:
0
,
},
},
{
{
name
:
"mixed: 5h expired, 1d active, 7d nil"
,
name
:
"mixed: 5h expired, 1d active, 7d nil"
,
...
@@ -138,7 +138,7 @@ func TestAPIKey_EffectiveUsage(t *testing.T) {
...
@@ -138,7 +138,7 @@ func TestAPIKey_EffectiveUsage(t *testing.T) {
},
},
want5h
:
0
,
want5h
:
0
,
want1d
:
10.0
,
want1d
:
10.0
,
want7d
:
50.
0
,
want7d
:
0
,
},
},
{
{
name
:
"zero usage with active windows"
,
name
:
"zero usage with active windows"
,
...
@@ -210,7 +210,7 @@ func TestAPIKeyRateLimitData_EffectiveUsage(t *testing.T) {
...
@@ -210,7 +210,7 @@ func TestAPIKeyRateLimitData_EffectiveUsage(t *testing.T) {
want7d
:
0
,
want7d
:
0
,
},
},
{
{
name
:
"nil window starts return
raw usage
"
,
name
:
"nil window starts return
0 (stale usage reset)
"
,
data
:
APIKeyRateLimitData
{
data
:
APIKeyRateLimitData
{
Usage5h
:
3.0
,
Usage5h
:
3.0
,
Usage1d
:
8.0
,
Usage1d
:
8.0
,
...
@@ -219,9 +219,9 @@ func TestAPIKeyRateLimitData_EffectiveUsage(t *testing.T) {
...
@@ -219,9 +219,9 @@ func TestAPIKeyRateLimitData_EffectiveUsage(t *testing.T) {
Window1dStart
:
nil
,
Window1dStart
:
nil
,
Window7dStart
:
nil
,
Window7dStart
:
nil
,
},
},
want5h
:
3.
0
,
want5h
:
0
,
want1d
:
8.
0
,
want1d
:
0
,
want7d
:
40.
0
,
want7d
:
0
,
},
},
}
}
...
...
backend/internal/service/gateway_record_usage_test.go
View file @
3718d6dc
...
@@ -369,3 +369,54 @@ func TestGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T)
...
@@ -369,3 +369,54 @@ func TestGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T)
require
.
Equal
(
t
,
1
,
billingRepo
.
calls
)
require
.
Equal
(
t
,
1
,
billingRepo
.
calls
)
require
.
Equal
(
t
,
0
,
usageRepo
.
calls
)
require
.
Equal
(
t
,
0
,
usageRepo
.
calls
)
}
}
func
TestGatewayServiceRecordUsage_ReasoningEffortPersisted
(
t
*
testing
.
T
)
{
usageRepo
:=
&
openAIRecordUsageBestEffortLogRepoStub
{}
svc
:=
newGatewayRecordUsageServiceForTest
(
usageRepo
,
&
openAIRecordUsageUserRepoStub
{},
&
openAIRecordUsageSubRepoStub
{})
effort
:=
"max"
err
:=
svc
.
RecordUsage
(
context
.
Background
(),
&
RecordUsageInput
{
Result
:
&
ForwardResult
{
RequestID
:
"effort_test"
,
Usage
:
ClaudeUsage
{
InputTokens
:
10
,
OutputTokens
:
5
,
},
Model
:
"claude-opus-4-6"
,
Duration
:
time
.
Second
,
ReasoningEffort
:
&
effort
,
},
APIKey
:
&
APIKey
{
ID
:
1
},
User
:
&
User
{
ID
:
1
},
Account
:
&
Account
{
ID
:
1
},
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
usageRepo
.
lastLog
)
require
.
NotNil
(
t
,
usageRepo
.
lastLog
.
ReasoningEffort
)
require
.
Equal
(
t
,
"max"
,
*
usageRepo
.
lastLog
.
ReasoningEffort
)
}
func
TestGatewayServiceRecordUsage_ReasoningEffortNil
(
t
*
testing
.
T
)
{
usageRepo
:=
&
openAIRecordUsageBestEffortLogRepoStub
{}
svc
:=
newGatewayRecordUsageServiceForTest
(
usageRepo
,
&
openAIRecordUsageUserRepoStub
{},
&
openAIRecordUsageSubRepoStub
{})
err
:=
svc
.
RecordUsage
(
context
.
Background
(),
&
RecordUsageInput
{
Result
:
&
ForwardResult
{
RequestID
:
"no_effort_test"
,
Usage
:
ClaudeUsage
{
InputTokens
:
10
,
OutputTokens
:
5
,
},
Model
:
"claude-sonnet-4"
,
Duration
:
time
.
Second
,
},
APIKey
:
&
APIKey
{
ID
:
1
},
User
:
&
User
{
ID
:
1
},
Account
:
&
Account
{
ID
:
1
},
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
usageRepo
.
lastLog
)
require
.
Nil
(
t
,
usageRepo
.
lastLog
.
ReasoningEffort
)
}
backend/internal/service/gateway_request.go
View file @
3718d6dc
...
@@ -60,6 +60,7 @@ type ParsedRequest struct {
...
@@ -60,6 +60,7 @@ type ParsedRequest struct {
Messages
[]
any
// messages 数组
Messages
[]
any
// messages 数组
HasSystem
bool
// 是否包含 system 字段(包含 null 也视为显式传入)
HasSystem
bool
// 是否包含 system 字段(包含 null 也视为显式传入)
ThinkingEnabled
bool
// 是否开启 thinking(部分平台会影响最终模型名)
ThinkingEnabled
bool
// 是否开启 thinking(部分平台会影响最终模型名)
OutputEffort
string
// output_config.effort(Claude API 的推理强度控制)
MaxTokens
int
// max_tokens 值(用于探测请求拦截)
MaxTokens
int
// max_tokens 值(用于探测请求拦截)
SessionContext
*
SessionContext
// 可选:请求上下文区分因子(nil 时行为不变)
SessionContext
*
SessionContext
// 可选:请求上下文区分因子(nil 时行为不变)
...
@@ -116,6 +117,9 @@ func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) {
...
@@ -116,6 +117,9 @@ func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) {
parsed
.
ThinkingEnabled
=
true
parsed
.
ThinkingEnabled
=
true
}
}
// output_config.effort: Claude API 的推理强度控制参数
parsed
.
OutputEffort
=
strings
.
TrimSpace
(
gjson
.
Get
(
jsonStr
,
"output_config.effort"
)
.
String
())
// max_tokens: 仅接受整数值
// max_tokens: 仅接受整数值
maxTokensResult
:=
gjson
.
Get
(
jsonStr
,
"max_tokens"
)
maxTokensResult
:=
gjson
.
Get
(
jsonStr
,
"max_tokens"
)
if
maxTokensResult
.
Exists
()
&&
maxTokensResult
.
Type
==
gjson
.
Number
{
if
maxTokensResult
.
Exists
()
&&
maxTokensResult
.
Type
==
gjson
.
Number
{
...
@@ -747,6 +751,21 @@ func filterThinkingBlocksInternal(body []byte, _ bool) []byte {
...
@@ -747,6 +751,21 @@ func filterThinkingBlocksInternal(body []byte, _ bool) []byte {
return
newBody
return
newBody
}
}
// NormalizeClaudeOutputEffort normalizes Claude's output_config.effort value.
// Returns nil for empty or unrecognized values.
func
NormalizeClaudeOutputEffort
(
raw
string
)
*
string
{
value
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
raw
))
if
value
==
""
{
return
nil
}
switch
value
{
case
"low"
,
"medium"
,
"high"
,
"max"
:
return
&
value
default
:
return
nil
}
}
// =========================
// =========================
// Thinking Budget Rectifier
// Thinking Budget Rectifier
// =========================
// =========================
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment