Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
6d90fb0b
"vscode:/vscode.git/clone" did not exist on "093a5a260edcd8e1c1562e0dfe73963e76e8a3e0"
Commit
6d90fb0b
authored
Feb 09, 2026
by
erio
Browse files
feat: detect client disconnect during streaming and continue draining upstream for billing
parent
b889d501
Changes
2
Hide whitespace changes
Inline
Side-by-side
backend/internal/service/antigravity_gateway_service.go
View file @
6d90fb0b
...
@@ -1305,6 +1305,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
...
@@ -1305,6 +1305,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
var
usage
*
ClaudeUsage
var
usage
*
ClaudeUsage
var
firstTokenMs
*
int
var
firstTokenMs
*
int
var
clientDisconnect
bool
if
claudeReq
.
Stream
{
if
claudeReq
.
Stream
{
// 客户端要求流式,直接透传转换
// 客户端要求流式,直接透传转换
streamRes
,
err
:=
s
.
handleClaudeStreamingResponse
(
c
,
resp
,
startTime
,
originalModel
)
streamRes
,
err
:=
s
.
handleClaudeStreamingResponse
(
c
,
resp
,
startTime
,
originalModel
)
...
@@ -1314,6 +1315,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
...
@@ -1314,6 +1315,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
}
}
usage
=
streamRes
.
usage
usage
=
streamRes
.
usage
firstTokenMs
=
streamRes
.
firstTokenMs
firstTokenMs
=
streamRes
.
firstTokenMs
clientDisconnect
=
streamRes
.
clientDisconnect
}
else
{
}
else
{
// 客户端要求非流式,收集流式响应后转换返回
// 客户端要求非流式,收集流式响应后转换返回
streamRes
,
err
:=
s
.
handleClaudeStreamToNonStreaming
(
c
,
resp
,
startTime
,
originalModel
)
streamRes
,
err
:=
s
.
handleClaudeStreamToNonStreaming
(
c
,
resp
,
startTime
,
originalModel
)
...
@@ -1326,12 +1328,13 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
...
@@ -1326,12 +1328,13 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
}
}
return
&
ForwardResult
{
return
&
ForwardResult
{
RequestID
:
requestID
,
RequestID
:
requestID
,
Usage
:
*
usage
,
Usage
:
*
usage
,
Model
:
originalModel
,
// 使用原始模型用于计费和日志
Model
:
originalModel
,
// 使用原始模型用于计费和日志
Stream
:
claudeReq
.
Stream
,
Stream
:
claudeReq
.
Stream
,
Duration
:
time
.
Since
(
startTime
),
Duration
:
time
.
Since
(
startTime
),
FirstTokenMs
:
firstTokenMs
,
FirstTokenMs
:
firstTokenMs
,
ClientDisconnect
:
clientDisconnect
,
},
nil
},
nil
}
}
...
@@ -1860,6 +1863,7 @@ handleSuccess:
...
@@ -1860,6 +1863,7 @@ handleSuccess:
var
usage
*
ClaudeUsage
var
usage
*
ClaudeUsage
var
firstTokenMs
*
int
var
firstTokenMs
*
int
var
clientDisconnect
bool
if
stream
{
if
stream
{
// 客户端要求流式,直接透传
// 客户端要求流式,直接透传
...
@@ -1870,6 +1874,7 @@ handleSuccess:
...
@@ -1870,6 +1874,7 @@ handleSuccess:
}
}
usage
=
streamRes
.
usage
usage
=
streamRes
.
usage
firstTokenMs
=
streamRes
.
firstTokenMs
firstTokenMs
=
streamRes
.
firstTokenMs
clientDisconnect
=
streamRes
.
clientDisconnect
}
else
{
}
else
{
// 客户端要求非流式,收集流式响应后返回
// 客户端要求非流式,收集流式响应后返回
streamRes
,
err
:=
s
.
handleGeminiStreamToNonStreaming
(
c
,
resp
,
startTime
)
streamRes
,
err
:=
s
.
handleGeminiStreamToNonStreaming
(
c
,
resp
,
startTime
)
...
@@ -1893,14 +1898,15 @@ handleSuccess:
...
@@ -1893,14 +1898,15 @@ handleSuccess:
}
}
return
&
ForwardResult
{
return
&
ForwardResult
{
RequestID
:
requestID
,
RequestID
:
requestID
,
Usage
:
*
usage
,
Usage
:
*
usage
,
Model
:
originalModel
,
Model
:
originalModel
,
Stream
:
stream
,
Stream
:
stream
,
Duration
:
time
.
Since
(
startTime
),
Duration
:
time
.
Since
(
startTime
),
FirstTokenMs
:
firstTokenMs
,
FirstTokenMs
:
firstTokenMs
,
ImageCount
:
imageCount
,
ClientDisconnect
:
clientDisconnect
,
ImageSize
:
imageSize
,
ImageCount
:
imageCount
,
ImageSize
:
imageSize
,
},
nil
},
nil
}
}
...
@@ -2319,8 +2325,69 @@ func (s *AntigravityGatewayService) handleUpstreamError(
...
@@ -2319,8 +2325,69 @@ func (s *AntigravityGatewayService) handleUpstreamError(
}
}
type
antigravityStreamResult
struct
{
type
antigravityStreamResult
struct
{
usage
*
ClaudeUsage
usage
*
ClaudeUsage
firstTokenMs
*
int
firstTokenMs
*
int
clientDisconnect
bool
// 客户端是否在流式传输过程中断开
}
// antigravityClientWriter 封装流式响应的客户端写入,自动检测断开并标记。
// 断开后所有写入操作变为 no-op,调用方通过 Disconnected() 判断是否继续 drain 上游。
type
antigravityClientWriter
struct
{
w
gin
.
ResponseWriter
flusher
http
.
Flusher
disconnected
bool
prefix
string
// 日志前缀,标识来源方法
}
func
newAntigravityClientWriter
(
w
gin
.
ResponseWriter
,
flusher
http
.
Flusher
,
prefix
string
)
*
antigravityClientWriter
{
return
&
antigravityClientWriter
{
w
:
w
,
flusher
:
flusher
,
prefix
:
prefix
}
}
// Write 写入数据到客户端,写入失败时标记断开并返回 false
func
(
cw
*
antigravityClientWriter
)
Write
(
p
[]
byte
)
bool
{
if
cw
.
disconnected
{
return
false
}
if
_
,
err
:=
cw
.
w
.
Write
(
p
);
err
!=
nil
{
cw
.
markDisconnected
()
return
false
}
cw
.
flusher
.
Flush
()
return
true
}
// Fprintf 格式化写入数据到客户端,写入失败时标记断开并返回 false
func
(
cw
*
antigravityClientWriter
)
Fprintf
(
format
string
,
args
...
any
)
bool
{
if
cw
.
disconnected
{
return
false
}
if
_
,
err
:=
fmt
.
Fprintf
(
cw
.
w
,
format
,
args
...
);
err
!=
nil
{
cw
.
markDisconnected
()
return
false
}
cw
.
flusher
.
Flush
()
return
true
}
func
(
cw
*
antigravityClientWriter
)
Disconnected
()
bool
{
return
cw
.
disconnected
}
func
(
cw
*
antigravityClientWriter
)
markDisconnected
()
{
cw
.
disconnected
=
true
log
.
Printf
(
"Client disconnected during streaming (%s), continuing to drain upstream for billing"
,
cw
.
prefix
)
}
// handleStreamReadError 处理上游读取错误的通用逻辑。
// 返回 (clientDisconnect, handled):handled=true 表示错误已处理,调用方应返回已收集的 usage。
func
handleStreamReadError
(
err
error
,
clientDisconnected
bool
,
prefix
string
)
(
disconnect
bool
,
handled
bool
)
{
if
errors
.
Is
(
err
,
context
.
Canceled
)
||
errors
.
Is
(
err
,
context
.
DeadlineExceeded
)
{
log
.
Printf
(
"Context canceled during streaming (%s), returning collected usage"
,
prefix
)
return
true
,
true
}
if
clientDisconnected
{
log
.
Printf
(
"Upstream read error after client disconnect (%s): %v, returning collected usage"
,
prefix
,
err
)
return
true
,
true
}
return
false
,
false
}
}
func
(
s
*
AntigravityGatewayService
)
handleGeminiStreamingResponse
(
c
*
gin
.
Context
,
resp
*
http
.
Response
,
startTime
time
.
Time
)
(
*
antigravityStreamResult
,
error
)
{
func
(
s
*
AntigravityGatewayService
)
handleGeminiStreamingResponse
(
c
*
gin
.
Context
,
resp
*
http
.
Response
,
startTime
time
.
Time
)
(
*
antigravityStreamResult
,
error
)
{
...
@@ -2396,10 +2463,12 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
...
@@ -2396,10 +2463,12 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
intervalCh
=
intervalTicker
.
C
intervalCh
=
intervalTicker
.
C
}
}
cw
:=
newAntigravityClientWriter
(
c
.
Writer
,
flusher
,
"antigravity gemini"
)
// 仅发送一次错误事件,避免多次写入导致协议混乱
// 仅发送一次错误事件,避免多次写入导致协议混乱
errorEventSent
:=
false
errorEventSent
:=
false
sendErrorEvent
:=
func
(
reason
string
)
{
sendErrorEvent
:=
func
(
reason
string
)
{
if
errorEventSent
{
if
errorEventSent
||
cw
.
Disconnected
()
{
return
return
}
}
errorEventSent
=
true
errorEventSent
=
true
...
@@ -2411,9 +2480,12 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
...
@@ -2411,9 +2480,12 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
select
{
select
{
case
ev
,
ok
:=
<-
events
:
case
ev
,
ok
:=
<-
events
:
if
!
ok
{
if
!
ok
{
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
cw
.
Disconnected
()
},
nil
}
}
if
ev
.
err
!=
nil
{
if
ev
.
err
!=
nil
{
if
disconnect
,
handled
:=
handleStreamReadError
(
ev
.
err
,
cw
.
Disconnected
(),
"antigravity gemini"
);
handled
{
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
disconnect
},
nil
}
if
errors
.
Is
(
ev
.
err
,
bufio
.
ErrTooLong
)
{
if
errors
.
Is
(
ev
.
err
,
bufio
.
ErrTooLong
)
{
log
.
Printf
(
"SSE line too long (antigravity): max_size=%d error=%v"
,
maxLineSize
,
ev
.
err
)
log
.
Printf
(
"SSE line too long (antigravity): max_size=%d error=%v"
,
maxLineSize
,
ev
.
err
)
sendErrorEvent
(
"response_too_large"
)
sendErrorEvent
(
"response_too_large"
)
...
@@ -2428,11 +2500,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
...
@@ -2428,11 +2500,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
if
strings
.
HasPrefix
(
trimmed
,
"data:"
)
{
if
strings
.
HasPrefix
(
trimmed
,
"data:"
)
{
payload
:=
strings
.
TrimSpace
(
strings
.
TrimPrefix
(
trimmed
,
"data:"
))
payload
:=
strings
.
TrimSpace
(
strings
.
TrimPrefix
(
trimmed
,
"data:"
))
if
payload
==
""
||
payload
==
"[DONE]"
{
if
payload
==
""
||
payload
==
"[DONE]"
{
if
_
,
err
:=
fmt
.
Fprintf
(
c
.
Writer
,
"%s
\n
"
,
line
);
err
!=
nil
{
cw
.
Fprintf
(
"%s
\n
"
,
line
)
sendErrorEvent
(
"write_failed"
)
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
}
flusher
.
Flush
()
continue
continue
}
}
...
@@ -2468,27 +2536,22 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
...
@@ -2468,27 +2536,22 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
firstTokenMs
=
&
ms
firstTokenMs
=
&
ms
}
}
if
_
,
err
:=
fmt
.
Fprintf
(
c
.
Writer
,
"data: %s
\n\n
"
,
payload
);
err
!=
nil
{
cw
.
Fprintf
(
"data: %s
\n\n
"
,
payload
)
sendErrorEvent
(
"write_failed"
)
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
}
flusher
.
Flush
()
continue
continue
}
}
if
_
,
err
:=
fmt
.
Fprintf
(
c
.
Writer
,
"%s
\n
"
,
line
);
err
!=
nil
{
cw
.
Fprintf
(
"%s
\n
"
,
line
)
sendErrorEvent
(
"write_failed"
)
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
}
flusher
.
Flush
()
case
<-
intervalCh
:
case
<-
intervalCh
:
lastRead
:=
time
.
Unix
(
0
,
atomic
.
LoadInt64
(
&
lastReadAt
))
lastRead
:=
time
.
Unix
(
0
,
atomic
.
LoadInt64
(
&
lastReadAt
))
if
time
.
Since
(
lastRead
)
<
streamInterval
{
if
time
.
Since
(
lastRead
)
<
streamInterval
{
continue
continue
}
}
if
cw
.
Disconnected
()
{
log
.
Printf
(
"Upstream timeout after client disconnect (antigravity gemini), returning collected usage"
)
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
true
},
nil
}
log
.
Printf
(
"Stream data interval timeout (antigravity)"
)
log
.
Printf
(
"Stream data interval timeout (antigravity)"
)
// 注意:此函数没有 account 上下文,无法调用 HandleStreamTimeout
sendErrorEvent
(
"stream_timeout"
)
sendErrorEvent
(
"stream_timeout"
)
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream data interval timeout"
)
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream data interval timeout"
)
}
}
...
@@ -3186,10 +3249,12 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
...
@@ -3186,10 +3249,12 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
intervalCh
=
intervalTicker
.
C
intervalCh
=
intervalTicker
.
C
}
}
cw
:=
newAntigravityClientWriter
(
c
.
Writer
,
flusher
,
"antigravity claude"
)
// 仅发送一次错误事件,避免多次写入导致协议混乱
// 仅发送一次错误事件,避免多次写入导致协议混乱
errorEventSent
:=
false
errorEventSent
:=
false
sendErrorEvent
:=
func
(
reason
string
)
{
sendErrorEvent
:=
func
(
reason
string
)
{
if
errorEventSent
{
if
errorEventSent
||
cw
.
Disconnected
()
{
return
return
}
}
errorEventSent
=
true
errorEventSent
=
true
...
@@ -3197,19 +3262,27 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
...
@@ -3197,19 +3262,27 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
flusher
.
Flush
()
flusher
.
Flush
()
}
}
// finishUsage 是获取 processor 最终 usage 的辅助函数
finishUsage
:=
func
()
*
ClaudeUsage
{
_
,
agUsage
:=
processor
.
Finish
()
return
convertUsage
(
agUsage
)
}
for
{
for
{
select
{
select
{
case
ev
,
ok
:=
<-
events
:
case
ev
,
ok
:=
<-
events
:
if
!
ok
{
if
!
ok
{
// 发送结束事件
//
上游完成,
发送结束事件
finalEvents
,
agUsage
:=
processor
.
Finish
()
finalEvents
,
agUsage
:=
processor
.
Finish
()
if
len
(
finalEvents
)
>
0
{
if
len
(
finalEvents
)
>
0
{
_
,
_
=
c
.
Writer
.
Write
(
finalEvents
)
cw
.
Write
(
finalEvents
)
flusher
.
Flush
()
}
}
return
&
antigravityStreamResult
{
usage
:
convertUsage
(
agUsage
),
firstTokenMs
:
firstTokenMs
},
nil
return
&
antigravityStreamResult
{
usage
:
convertUsage
(
agUsage
),
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
cw
.
Disconnected
()
},
nil
}
}
if
ev
.
err
!=
nil
{
if
ev
.
err
!=
nil
{
if
disconnect
,
handled
:=
handleStreamReadError
(
ev
.
err
,
cw
.
Disconnected
(),
"antigravity claude"
);
handled
{
return
&
antigravityStreamResult
{
usage
:
finishUsage
(),
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
disconnect
},
nil
}
if
errors
.
Is
(
ev
.
err
,
bufio
.
ErrTooLong
)
{
if
errors
.
Is
(
ev
.
err
,
bufio
.
ErrTooLong
)
{
log
.
Printf
(
"SSE line too long (antigravity): max_size=%d error=%v"
,
maxLineSize
,
ev
.
err
)
log
.
Printf
(
"SSE line too long (antigravity): max_size=%d error=%v"
,
maxLineSize
,
ev
.
err
)
sendErrorEvent
(
"response_too_large"
)
sendErrorEvent
(
"response_too_large"
)
...
@@ -3219,25 +3292,14 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
...
@@ -3219,25 +3292,14 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
return
nil
,
fmt
.
Errorf
(
"stream read error: %w"
,
ev
.
err
)
return
nil
,
fmt
.
Errorf
(
"stream read error: %w"
,
ev
.
err
)
}
}
line
:=
ev
.
line
// 处理 SSE 行,转换为 Claude 格式
// 处理 SSE 行,转换为 Claude 格式
claudeEvents
:=
processor
.
ProcessLine
(
strings
.
TrimRight
(
line
,
"
\r\n
"
))
claudeEvents
:=
processor
.
ProcessLine
(
strings
.
TrimRight
(
ev
.
line
,
"
\r\n
"
))
if
len
(
claudeEvents
)
>
0
{
if
len
(
claudeEvents
)
>
0
{
if
firstTokenMs
==
nil
{
if
firstTokenMs
==
nil
{
ms
:=
int
(
time
.
Since
(
startTime
)
.
Milliseconds
())
ms
:=
int
(
time
.
Since
(
startTime
)
.
Milliseconds
())
firstTokenMs
=
&
ms
firstTokenMs
=
&
ms
}
}
cw
.
Write
(
claudeEvents
)
if
_
,
writeErr
:=
c
.
Writer
.
Write
(
claudeEvents
);
writeErr
!=
nil
{
finalEvents
,
agUsage
:=
processor
.
Finish
()
if
len
(
finalEvents
)
>
0
{
_
,
_
=
c
.
Writer
.
Write
(
finalEvents
)
}
sendErrorEvent
(
"write_failed"
)
return
&
antigravityStreamResult
{
usage
:
convertUsage
(
agUsage
),
firstTokenMs
:
firstTokenMs
},
writeErr
}
flusher
.
Flush
()
}
}
case
<-
intervalCh
:
case
<-
intervalCh
:
...
@@ -3245,13 +3307,15 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
...
@@ -3245,13 +3307,15 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
if
time
.
Since
(
lastRead
)
<
streamInterval
{
if
time
.
Since
(
lastRead
)
<
streamInterval
{
continue
continue
}
}
if
cw
.
Disconnected
()
{
log
.
Printf
(
"Upstream timeout after client disconnect (antigravity claude), returning collected usage"
)
return
&
antigravityStreamResult
{
usage
:
finishUsage
(),
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
true
},
nil
}
log
.
Printf
(
"Stream data interval timeout (antigravity)"
)
log
.
Printf
(
"Stream data interval timeout (antigravity)"
)
// 注意:此函数没有 account 上下文,无法调用 HandleStreamTimeout
sendErrorEvent
(
"stream_timeout"
)
sendErrorEvent
(
"stream_timeout"
)
return
&
antigravityStreamResult
{
usage
:
convertUsage
(
nil
),
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream data interval timeout"
)
return
&
antigravityStreamResult
{
usage
:
convertUsage
(
nil
),
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream data interval timeout"
)
}
}
}
}
}
}
// extractImageSize 从 Gemini 请求中提取 image_size 参数
// extractImageSize 从 Gemini 请求中提取 image_size 参数
...
@@ -3390,3 +3454,289 @@ func filterEmptyPartsFromGeminiRequest(body []byte) ([]byte, error) {
...
@@ -3390,3 +3454,289 @@ func filterEmptyPartsFromGeminiRequest(body []byte) ([]byte, error) {
payload
[
"contents"
]
=
filtered
payload
[
"contents"
]
=
filtered
return
json
.
Marshal
(
payload
)
return
json
.
Marshal
(
payload
)
}
}
// ForwardUpstream 使用 base_url + /v1/messages + 双 header 认证透传上游 Claude 请求
func
(
s
*
AntigravityGatewayService
)
ForwardUpstream
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
body
[]
byte
)
(
*
ForwardResult
,
error
)
{
startTime
:=
time
.
Now
()
sessionID
:=
getSessionID
(
c
)
prefix
:=
logPrefix
(
sessionID
,
account
.
Name
)
// 获取上游配置
baseURL
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"base_url"
))
apiKey
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"api_key"
))
if
baseURL
==
""
||
apiKey
==
""
{
return
nil
,
fmt
.
Errorf
(
"upstream account missing base_url or api_key"
)
}
baseURL
=
strings
.
TrimSuffix
(
baseURL
,
"/"
)
// 解析请求获取模型信息
var
claudeReq
antigravity
.
ClaudeRequest
if
err
:=
json
.
Unmarshal
(
body
,
&
claudeReq
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"parse claude request: %w"
,
err
)
}
if
strings
.
TrimSpace
(
claudeReq
.
Model
)
==
""
{
return
nil
,
fmt
.
Errorf
(
"missing model"
)
}
originalModel
:=
claudeReq
.
Model
billingModel
:=
originalModel
// 构建上游请求 URL
upstreamURL
:=
baseURL
+
"/v1/messages"
// 创建请求
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodPost
,
upstreamURL
,
bytes
.
NewReader
(
body
))
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"create upstream request: %w"
,
err
)
}
// 设置请求头
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
apiKey
)
req
.
Header
.
Set
(
"x-api-key"
,
apiKey
)
// Claude API 兼容
// 透传 Claude 相关 headers
if
v
:=
c
.
GetHeader
(
"anthropic-version"
);
v
!=
""
{
req
.
Header
.
Set
(
"anthropic-version"
,
v
)
}
if
v
:=
c
.
GetHeader
(
"anthropic-beta"
);
v
!=
""
{
req
.
Header
.
Set
(
"anthropic-beta"
,
v
)
}
// 代理 URL
proxyURL
:=
""
if
account
.
ProxyID
!=
nil
&&
account
.
Proxy
!=
nil
{
proxyURL
=
account
.
Proxy
.
URL
()
}
// 发送请求
resp
,
err
:=
s
.
httpUpstream
.
Do
(
req
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
err
!=
nil
{
log
.
Printf
(
"%s upstream request failed: %v"
,
prefix
,
err
)
return
nil
,
fmt
.
Errorf
(
"upstream request failed: %w"
,
err
)
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
// 处理错误响应
if
resp
.
StatusCode
>=
400
{
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
// 429 错误时标记账号限流
if
resp
.
StatusCode
==
http
.
StatusTooManyRequests
{
quotaScope
,
_
:=
resolveAntigravityQuotaScope
(
originalModel
)
s
.
handleUpstreamError
(
ctx
,
prefix
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
,
quotaScope
,
0
,
""
,
false
)
}
// 透传上游错误
c
.
Header
(
"Content-Type"
,
resp
.
Header
.
Get
(
"Content-Type"
))
c
.
Status
(
resp
.
StatusCode
)
_
,
_
=
c
.
Writer
.
Write
(
respBody
)
return
&
ForwardResult
{
Model
:
billingModel
,
},
nil
}
// 处理成功响应(流式/非流式)
var
usage
*
ClaudeUsage
var
firstTokenMs
*
int
var
clientDisconnect
bool
if
claudeReq
.
Stream
{
// 流式响应:透传
c
.
Header
(
"Content-Type"
,
"text/event-stream"
)
c
.
Header
(
"Cache-Control"
,
"no-cache"
)
c
.
Header
(
"Connection"
,
"keep-alive"
)
c
.
Header
(
"X-Accel-Buffering"
,
"no"
)
c
.
Status
(
http
.
StatusOK
)
streamRes
:=
s
.
streamUpstreamResponse
(
c
,
resp
,
startTime
)
usage
=
streamRes
.
usage
firstTokenMs
=
streamRes
.
firstTokenMs
clientDisconnect
=
streamRes
.
clientDisconnect
}
else
{
// 非流式响应:直接透传
respBody
,
err
:=
io
.
ReadAll
(
resp
.
Body
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"read upstream response: %w"
,
err
)
}
// 提取 usage
usage
=
s
.
extractClaudeUsage
(
respBody
)
c
.
Header
(
"Content-Type"
,
resp
.
Header
.
Get
(
"Content-Type"
))
c
.
Status
(
http
.
StatusOK
)
_
,
_
=
c
.
Writer
.
Write
(
respBody
)
}
// 构建计费结果
duration
:=
time
.
Since
(
startTime
)
log
.
Printf
(
"%s status=success duration_ms=%d"
,
prefix
,
duration
.
Milliseconds
())
return
&
ForwardResult
{
Model
:
billingModel
,
Stream
:
claudeReq
.
Stream
,
Duration
:
duration
,
FirstTokenMs
:
firstTokenMs
,
ClientDisconnect
:
clientDisconnect
,
Usage
:
ClaudeUsage
{
InputTokens
:
usage
.
InputTokens
,
OutputTokens
:
usage
.
OutputTokens
,
CacheReadInputTokens
:
usage
.
CacheReadInputTokens
,
CacheCreationInputTokens
:
usage
.
CacheCreationInputTokens
,
},
},
nil
}
// streamUpstreamResponse 透传上游 SSE 流并提取 Claude usage
func
(
s
*
AntigravityGatewayService
)
streamUpstreamResponse
(
c
*
gin
.
Context
,
resp
*
http
.
Response
,
startTime
time
.
Time
)
*
antigravityStreamResult
{
usage
:=
&
ClaudeUsage
{}
var
firstTokenMs
*
int
scanner
:=
bufio
.
NewScanner
(
resp
.
Body
)
maxLineSize
:=
defaultMaxLineSize
if
s
.
settingService
.
cfg
!=
nil
&&
s
.
settingService
.
cfg
.
Gateway
.
MaxLineSize
>
0
{
maxLineSize
=
s
.
settingService
.
cfg
.
Gateway
.
MaxLineSize
}
scanner
.
Buffer
(
make
([]
byte
,
64
*
1024
),
maxLineSize
)
type
scanEvent
struct
{
line
string
err
error
}
events
:=
make
(
chan
scanEvent
,
16
)
done
:=
make
(
chan
struct
{})
sendEvent
:=
func
(
ev
scanEvent
)
bool
{
select
{
case
events
<-
ev
:
return
true
case
<-
done
:
return
false
}
}
var
lastReadAt
int64
atomic
.
StoreInt64
(
&
lastReadAt
,
time
.
Now
()
.
UnixNano
())
go
func
()
{
defer
close
(
events
)
for
scanner
.
Scan
()
{
atomic
.
StoreInt64
(
&
lastReadAt
,
time
.
Now
()
.
UnixNano
())
if
!
sendEvent
(
scanEvent
{
line
:
scanner
.
Text
()})
{
return
}
}
if
err
:=
scanner
.
Err
();
err
!=
nil
{
_
=
sendEvent
(
scanEvent
{
err
:
err
})
}
}()
defer
close
(
done
)
streamInterval
:=
time
.
Duration
(
0
)
if
s
.
settingService
.
cfg
!=
nil
&&
s
.
settingService
.
cfg
.
Gateway
.
StreamDataIntervalTimeout
>
0
{
streamInterval
=
time
.
Duration
(
s
.
settingService
.
cfg
.
Gateway
.
StreamDataIntervalTimeout
)
*
time
.
Second
}
var
intervalTicker
*
time
.
Ticker
if
streamInterval
>
0
{
intervalTicker
=
time
.
NewTicker
(
streamInterval
)
defer
intervalTicker
.
Stop
()
}
var
intervalCh
<-
chan
time
.
Time
if
intervalTicker
!=
nil
{
intervalCh
=
intervalTicker
.
C
}
flusher
,
_
:=
c
.
Writer
.
(
http
.
Flusher
)
cw
:=
newAntigravityClientWriter
(
c
.
Writer
,
flusher
,
"antigravity upstream"
)
for
{
select
{
case
ev
,
ok
:=
<-
events
:
if
!
ok
{
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
cw
.
Disconnected
()}
}
if
ev
.
err
!=
nil
{
if
disconnect
,
handled
:=
handleStreamReadError
(
ev
.
err
,
cw
.
Disconnected
(),
"antigravity upstream"
);
handled
{
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
disconnect
}
}
log
.
Printf
(
"Stream read error (antigravity upstream): %v"
,
ev
.
err
)
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
}
}
line
:=
ev
.
line
// 记录首 token 时间
if
firstTokenMs
==
nil
&&
len
(
line
)
>
0
{
ms
:=
int
(
time
.
Since
(
startTime
)
.
Milliseconds
())
firstTokenMs
=
&
ms
}
// 尝试从 message_delta 或 message_stop 事件提取 usage
s
.
extractSSEUsage
(
line
,
usage
)
// 透传行
cw
.
Fprintf
(
"%s
\n
"
,
line
)
case
<-
intervalCh
:
lastRead
:=
time
.
Unix
(
0
,
atomic
.
LoadInt64
(
&
lastReadAt
))
if
time
.
Since
(
lastRead
)
<
streamInterval
{
continue
}
if
cw
.
Disconnected
()
{
log
.
Printf
(
"Upstream timeout after client disconnect (antigravity upstream), returning collected usage"
)
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
,
clientDisconnect
:
true
}
}
log
.
Printf
(
"Stream data interval timeout (antigravity upstream)"
)
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
}
}
}
}
// extractSSEUsage 从 SSE data 行中提取 Claude usage(用于流式透传场景)
func
(
s
*
AntigravityGatewayService
)
extractSSEUsage
(
line
string
,
usage
*
ClaudeUsage
)
{
if
!
strings
.
HasPrefix
(
line
,
"data: "
)
{
return
}
dataStr
:=
strings
.
TrimPrefix
(
line
,
"data: "
)
var
event
map
[
string
]
any
if
json
.
Unmarshal
([]
byte
(
dataStr
),
&
event
)
!=
nil
{
return
}
u
,
ok
:=
event
[
"usage"
]
.
(
map
[
string
]
any
)
if
!
ok
{
return
}
if
v
,
ok
:=
u
[
"input_tokens"
]
.
(
float64
);
ok
&&
int
(
v
)
>
0
{
usage
.
InputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"output_tokens"
]
.
(
float64
);
ok
&&
int
(
v
)
>
0
{
usage
.
OutputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"cache_read_input_tokens"
]
.
(
float64
);
ok
&&
int
(
v
)
>
0
{
usage
.
CacheReadInputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"cache_creation_input_tokens"
]
.
(
float64
);
ok
&&
int
(
v
)
>
0
{
usage
.
CacheCreationInputTokens
=
int
(
v
)
}
}
// extractClaudeUsage 从非流式 Claude 响应提取 usage
func
(
s
*
AntigravityGatewayService
)
extractClaudeUsage
(
body
[]
byte
)
*
ClaudeUsage
{
usage
:=
&
ClaudeUsage
{}
var
resp
map
[
string
]
any
if
json
.
Unmarshal
(
body
,
&
resp
)
!=
nil
{
return
usage
}
if
u
,
ok
:=
resp
[
"usage"
]
.
(
map
[
string
]
any
);
ok
{
if
v
,
ok
:=
u
[
"input_tokens"
]
.
(
float64
);
ok
{
usage
.
InputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"output_tokens"
]
.
(
float64
);
ok
{
usage
.
OutputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"cache_read_input_tokens"
]
.
(
float64
);
ok
{
usage
.
CacheReadInputTokens
=
int
(
v
)
}
if
v
,
ok
:=
u
[
"cache_creation_input_tokens"
]
.
(
float64
);
ok
{
usage
.
CacheCreationInputTokens
=
int
(
v
)
}
}
return
usage
}
backend/internal/service/antigravity_gateway_service_test.go
View file @
6d90fb0b
...
@@ -4,17 +4,42 @@ import (
...
@@ -4,17 +4,42 @@ import (
"bytes"
"bytes"
"context"
"context"
"encoding/json"
"encoding/json"
"errors"
"fmt"
"io"
"io"
"net/http"
"net/http"
"net/http/httptest"
"net/http/httptest"
"testing"
"testing"
"time"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/require"
)
)
// antigravityFailingWriter 模拟客户端断开连接的 gin.ResponseWriter
type
antigravityFailingWriter
struct
{
gin
.
ResponseWriter
failAfter
int
// 允许成功写入的次数,之后所有写入返回错误
writes
int
}
func
(
w
*
antigravityFailingWriter
)
Write
(
p
[]
byte
)
(
int
,
error
)
{
if
w
.
writes
>=
w
.
failAfter
{
return
0
,
errors
.
New
(
"write failed: client disconnected"
)
}
w
.
writes
++
return
w
.
ResponseWriter
.
Write
(
p
)
}
// newAntigravityTestService 创建用于流式测试的 AntigravityGatewayService
func
newAntigravityTestService
(
cfg
*
config
.
Config
)
*
AntigravityGatewayService
{
return
&
AntigravityGatewayService
{
settingService
:
&
SettingService
{
cfg
:
cfg
},
}
}
func
TestStripSignatureSensitiveBlocksFromClaudeRequest
(
t
*
testing
.
T
)
{
func
TestStripSignatureSensitiveBlocksFromClaudeRequest
(
t
*
testing
.
T
)
{
req
:=
&
antigravity
.
ClaudeRequest
{
req
:=
&
antigravity
.
ClaudeRequest
{
Model
:
"claude-sonnet-4-5"
,
Model
:
"claude-sonnet-4-5"
,
...
@@ -337,8 +362,8 @@ func TestAntigravityGatewayService_Forward_StickySessionForceCacheBilling(t *tes
...
@@ -337,8 +362,8 @@ func TestAntigravityGatewayService_Forward_StickySessionForceCacheBilling(t *tes
require
.
True
(
t
,
failoverErr
.
ForceCacheBilling
,
"ForceCacheBilling should be true for sticky session switch"
)
require
.
True
(
t
,
failoverErr
.
ForceCacheBilling
,
"ForceCacheBilling should be true for sticky session switch"
)
}
}
// TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling
// TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling
verifies
//
验证:
ForwardGemini
粘性会话切换时,UpstreamFailoverError.ForceCacheBilling 应为 true
//
that
ForwardGemini
sets ForceCacheBilling=true for sticky session switch.
func
TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling
(
t
*
testing
.
T
)
{
func
TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
gin
.
SetMode
(
gin
.
TestMode
)
writer
:=
httptest
.
NewRecorder
()
writer
:=
httptest
.
NewRecorder
()
...
@@ -391,3 +416,438 @@ func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(
...
@@ -391,3 +416,438 @@ func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(
require
.
Equal
(
t
,
http
.
StatusServiceUnavailable
,
failoverErr
.
StatusCode
)
require
.
Equal
(
t
,
http
.
StatusServiceUnavailable
,
failoverErr
.
StatusCode
)
require
.
True
(
t
,
failoverErr
.
ForceCacheBilling
,
"ForceCacheBilling should be true for sticky session switch"
)
require
.
True
(
t
,
failoverErr
.
ForceCacheBilling
,
"ForceCacheBilling should be true for sticky session switch"
)
}
}
// --- 流式 happy path 测试 ---
// TestStreamUpstreamResponse_NormalComplete
// 验证:正常流式转发完成时,数据正确透传、usage 正确收集、clientDisconnect=false
func
TestStreamUpstreamResponse_NormalComplete
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
svc
:=
newAntigravityTestService
(
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
MaxLineSize
:
defaultMaxLineSize
},
})
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
pr
,
pw
:=
io
.
Pipe
()
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
pr
,
Header
:
http
.
Header
{}}
go
func
()
{
defer
func
()
{
_
=
pw
.
Close
()
}()
fmt
.
Fprintln
(
pw
,
`event: message_start`
)
fmt
.
Fprintln
(
pw
,
`data: {"type":"message_start","message":{"usage":{"input_tokens":10}}}`
)
fmt
.
Fprintln
(
pw
,
""
)
fmt
.
Fprintln
(
pw
,
`event: content_block_delta`
)
fmt
.
Fprintln
(
pw
,
`data: {"type":"content_block_delta","delta":{"text":"hello"}}`
)
fmt
.
Fprintln
(
pw
,
""
)
fmt
.
Fprintln
(
pw
,
`event: message_delta`
)
fmt
.
Fprintln
(
pw
,
`data: {"type":"message_delta","usage":{"output_tokens":5}}`
)
fmt
.
Fprintln
(
pw
,
""
)
}()
result
:=
svc
.
streamUpstreamResponse
(
c
,
resp
,
time
.
Now
())
_
=
pr
.
Close
()
require
.
NotNil
(
t
,
result
)
require
.
False
(
t
,
result
.
clientDisconnect
,
"normal completion should not set clientDisconnect"
)
require
.
NotNil
(
t
,
result
.
usage
)
require
.
Equal
(
t
,
5
,
result
.
usage
.
OutputTokens
,
"should collect output_tokens from message_delta"
)
require
.
NotNil
(
t
,
result
.
firstTokenMs
,
"should record first token time"
)
// 验证数据被透传到客户端
body
:=
rec
.
Body
.
String
()
require
.
Contains
(
t
,
body
,
"event: message_start"
)
require
.
Contains
(
t
,
body
,
"content_block_delta"
)
require
.
Contains
(
t
,
body
,
"message_delta"
)
}
// TestHandleGeminiStreamingResponse_NormalComplete
// 验证:正常 Gemini 流式转发,数据正确透传、usage 正确收集
func
TestHandleGeminiStreamingResponse_NormalComplete
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
svc
:=
newAntigravityTestService
(
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
MaxLineSize
:
defaultMaxLineSize
},
})
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
pr
,
pw
:=
io
.
Pipe
()
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
pr
,
Header
:
http
.
Header
{}}
go
func
()
{
defer
func
()
{
_
=
pw
.
Close
()
}()
// 第一个 chunk(部分内容)
fmt
.
Fprintln
(
pw
,
`data: {"candidates":[{"content":{"parts":[{"text":"Hello"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":3}}`
)
fmt
.
Fprintln
(
pw
,
""
)
// 第二个 chunk(最终内容+完整 usage)
fmt
.
Fprintln
(
pw
,
`data: {"candidates":[{"content":{"parts":[{"text":" world"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":8,"cachedContentTokenCount":2}}`
)
fmt
.
Fprintln
(
pw
,
""
)
}()
result
,
err
:=
svc
.
handleGeminiStreamingResponse
(
c
,
resp
,
time
.
Now
())
_
=
pr
.
Close
()
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
False
(
t
,
result
.
clientDisconnect
,
"normal completion should not set clientDisconnect"
)
require
.
NotNil
(
t
,
result
.
usage
)
// Gemini usage: promptTokenCount=10, candidatesTokenCount=8, cachedContentTokenCount=2
// → InputTokens=10-2=8, OutputTokens=8, CacheReadInputTokens=2
require
.
Equal
(
t
,
8
,
result
.
usage
.
InputTokens
)
require
.
Equal
(
t
,
8
,
result
.
usage
.
OutputTokens
)
require
.
Equal
(
t
,
2
,
result
.
usage
.
CacheReadInputTokens
)
require
.
NotNil
(
t
,
result
.
firstTokenMs
,
"should record first token time"
)
// 验证数据被透传到客户端
body
:=
rec
.
Body
.
String
()
require
.
Contains
(
t
,
body
,
"Hello"
)
require
.
Contains
(
t
,
body
,
"world"
)
// 不应包含错误事件
require
.
NotContains
(
t
,
body
,
"event: error"
)
}
// TestHandleClaudeStreamingResponse_NormalComplete
// 验证:正常 Claude 流式转发(Gemini→Claude 转换),数据正确转换并输出
func
TestHandleClaudeStreamingResponse_NormalComplete
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
svc
:=
newAntigravityTestService
(
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
MaxLineSize
:
defaultMaxLineSize
},
})
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
pr
,
pw
:=
io
.
Pipe
()
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
pr
,
Header
:
http
.
Header
{}}
go
func
()
{
defer
func
()
{
_
=
pw
.
Close
()
}()
// v1internal 包装格式:Gemini 数据嵌套在 "response" 字段下
// ProcessLine 先尝试反序列化为 V1InternalResponse,裸格式会导致 Response.UsageMetadata 为空
fmt
.
Fprintln
(
pw
,
`data: {"response":{"candidates":[{"content":{"parts":[{"text":"Hi there"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":3}}}`
)
fmt
.
Fprintln
(
pw
,
""
)
}()
result
,
err
:=
svc
.
handleClaudeStreamingResponse
(
c
,
resp
,
time
.
Now
(),
"claude-sonnet-4-5"
)
_
=
pr
.
Close
()
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
False
(
t
,
result
.
clientDisconnect
,
"normal completion should not set clientDisconnect"
)
require
.
NotNil
(
t
,
result
.
usage
)
// Gemini→Claude 转换的 usage:promptTokenCount=5→InputTokens=5, candidatesTokenCount=3→OutputTokens=3
require
.
Equal
(
t
,
5
,
result
.
usage
.
InputTokens
)
require
.
Equal
(
t
,
3
,
result
.
usage
.
OutputTokens
)
require
.
NotNil
(
t
,
result
.
firstTokenMs
,
"should record first token time"
)
// 验证输出是 Claude SSE 格式(processor 会转换)
body
:=
rec
.
Body
.
String
()
require
.
Contains
(
t
,
body
,
"event: message_start"
,
"should contain Claude message_start event"
)
require
.
Contains
(
t
,
body
,
"event: message_stop"
,
"should contain Claude message_stop event"
)
// 不应包含错误事件
require
.
NotContains
(
t
,
body
,
"event: error"
)
}
// --- 流式客户端断开检测测试 ---
// TestStreamUpstreamResponse_ClientDisconnectDrainsUsage
// 验证:客户端写入失败后,streamUpstreamResponse 继续读取上游以收集 usage
func
TestStreamUpstreamResponse_ClientDisconnectDrainsUsage
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
svc
:=
newAntigravityTestService
(
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
MaxLineSize
:
defaultMaxLineSize
},
})
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
c
.
Writer
=
&
antigravityFailingWriter
{
ResponseWriter
:
c
.
Writer
,
failAfter
:
0
}
pr
,
pw
:=
io
.
Pipe
()
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
pr
,
Header
:
http
.
Header
{}}
go
func
()
{
defer
func
()
{
_
=
pw
.
Close
()
}()
fmt
.
Fprintln
(
pw
,
`event: message_start`
)
fmt
.
Fprintln
(
pw
,
`data: {"type":"message_start","message":{"usage":{"input_tokens":10}}}`
)
fmt
.
Fprintln
(
pw
,
""
)
fmt
.
Fprintln
(
pw
,
`event: message_delta`
)
fmt
.
Fprintln
(
pw
,
`data: {"type":"message_delta","usage":{"output_tokens":20}}`
)
fmt
.
Fprintln
(
pw
,
""
)
}()
result
:=
svc
.
streamUpstreamResponse
(
c
,
resp
,
time
.
Now
())
_
=
pr
.
Close
()
require
.
NotNil
(
t
,
result
)
require
.
True
(
t
,
result
.
clientDisconnect
)
require
.
NotNil
(
t
,
result
.
usage
)
require
.
Equal
(
t
,
20
,
result
.
usage
.
OutputTokens
)
}
// TestStreamUpstreamResponse_ContextCanceled
// 验证:context 取消时返回 usage 且标记 clientDisconnect
func
TestStreamUpstreamResponse_ContextCanceled
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
svc
:=
newAntigravityTestService
(
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
MaxLineSize
:
defaultMaxLineSize
},
})
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
ctx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
cancel
()
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
.
WithContext
(
ctx
)
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
cancelReadCloser
{},
Header
:
http
.
Header
{}}
result
:=
svc
.
streamUpstreamResponse
(
c
,
resp
,
time
.
Now
())
require
.
NotNil
(
t
,
result
)
require
.
True
(
t
,
result
.
clientDisconnect
)
require
.
NotContains
(
t
,
rec
.
Body
.
String
(),
"event: error"
)
}
// TestStreamUpstreamResponse_Timeout
// 验证:上游超时时返回已收集的 usage
func
TestStreamUpstreamResponse_Timeout
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
svc
:=
newAntigravityTestService
(
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
StreamDataIntervalTimeout
:
1
,
MaxLineSize
:
defaultMaxLineSize
},
})
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
pr
,
pw
:=
io
.
Pipe
()
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
pr
,
Header
:
http
.
Header
{}}
result
:=
svc
.
streamUpstreamResponse
(
c
,
resp
,
time
.
Now
())
_
=
pw
.
Close
()
_
=
pr
.
Close
()
require
.
NotNil
(
t
,
result
)
require
.
False
(
t
,
result
.
clientDisconnect
)
}
// TestStreamUpstreamResponse_TimeoutAfterClientDisconnect
// 验证:客户端断开后上游超时,返回 usage 并标记 clientDisconnect
func
TestStreamUpstreamResponse_TimeoutAfterClientDisconnect
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
svc
:=
newAntigravityTestService
(
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
StreamDataIntervalTimeout
:
1
,
MaxLineSize
:
defaultMaxLineSize
},
})
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
c
.
Writer
=
&
antigravityFailingWriter
{
ResponseWriter
:
c
.
Writer
,
failAfter
:
0
}
pr
,
pw
:=
io
.
Pipe
()
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
pr
,
Header
:
http
.
Header
{}}
go
func
()
{
fmt
.
Fprintln
(
pw
,
`data: {"type":"message_start","message":{"usage":{"input_tokens":5}}}`
)
fmt
.
Fprintln
(
pw
,
""
)
// 不关闭 pw → 等待超时
}()
result
:=
svc
.
streamUpstreamResponse
(
c
,
resp
,
time
.
Now
())
_
=
pw
.
Close
()
_
=
pr
.
Close
()
require
.
NotNil
(
t
,
result
)
require
.
True
(
t
,
result
.
clientDisconnect
)
}
// TestHandleGeminiStreamingResponse_ClientDisconnect
// 验证:Gemini 流式转发中客户端断开后继续 drain 上游
func
TestHandleGeminiStreamingResponse_ClientDisconnect
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
svc
:=
newAntigravityTestService
(
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
MaxLineSize
:
defaultMaxLineSize
},
})
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
c
.
Writer
=
&
antigravityFailingWriter
{
ResponseWriter
:
c
.
Writer
,
failAfter
:
0
}
pr
,
pw
:=
io
.
Pipe
()
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
pr
,
Header
:
http
.
Header
{}}
go
func
()
{
defer
func
()
{
_
=
pw
.
Close
()
}()
fmt
.
Fprintln
(
pw
,
`data: {"candidates":[{"content":{"parts":[{"text":"hi"}]}}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":10}}`
)
fmt
.
Fprintln
(
pw
,
""
)
}()
result
,
err
:=
svc
.
handleGeminiStreamingResponse
(
c
,
resp
,
time
.
Now
())
_
=
pr
.
Close
()
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
True
(
t
,
result
.
clientDisconnect
)
require
.
NotContains
(
t
,
rec
.
Body
.
String
(),
"write_failed"
)
}
// TestHandleGeminiStreamingResponse_ContextCanceled
// 验证:context 取消时不注入错误事件
func
TestHandleGeminiStreamingResponse_ContextCanceled
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
svc
:=
newAntigravityTestService
(
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
MaxLineSize
:
defaultMaxLineSize
},
})
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
ctx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
cancel
()
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
.
WithContext
(
ctx
)
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
cancelReadCloser
{},
Header
:
http
.
Header
{}}
result
,
err
:=
svc
.
handleGeminiStreamingResponse
(
c
,
resp
,
time
.
Now
())
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
True
(
t
,
result
.
clientDisconnect
)
require
.
NotContains
(
t
,
rec
.
Body
.
String
(),
"event: error"
)
}
// TestHandleClaudeStreamingResponse_ClientDisconnect
// 验证:Claude 流式转发中客户端断开后继续 drain 上游
func
TestHandleClaudeStreamingResponse_ClientDisconnect
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
svc
:=
newAntigravityTestService
(
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
MaxLineSize
:
defaultMaxLineSize
},
})
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
c
.
Writer
=
&
antigravityFailingWriter
{
ResponseWriter
:
c
.
Writer
,
failAfter
:
0
}
pr
,
pw
:=
io
.
Pipe
()
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
pr
,
Header
:
http
.
Header
{}}
go
func
()
{
defer
func
()
{
_
=
pw
.
Close
()
}()
// v1internal 包装格式
fmt
.
Fprintln
(
pw
,
`data: {"response":{"candidates":[{"content":{"parts":[{"text":"hello"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":8,"candidatesTokenCount":15}}}`
)
fmt
.
Fprintln
(
pw
,
""
)
}()
result
,
err
:=
svc
.
handleClaudeStreamingResponse
(
c
,
resp
,
time
.
Now
(),
"claude-sonnet-4-5"
)
_
=
pr
.
Close
()
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
True
(
t
,
result
.
clientDisconnect
)
}
// TestHandleClaudeStreamingResponse_ContextCanceled
// 验证:context 取消时不注入错误事件
func
TestHandleClaudeStreamingResponse_ContextCanceled
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
svc
:=
newAntigravityTestService
(
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
MaxLineSize
:
defaultMaxLineSize
},
})
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
ctx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
cancel
()
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
.
WithContext
(
ctx
)
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
cancelReadCloser
{},
Header
:
http
.
Header
{}}
result
,
err
:=
svc
.
handleClaudeStreamingResponse
(
c
,
resp
,
time
.
Now
(),
"claude-sonnet-4-5"
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
True
(
t
,
result
.
clientDisconnect
)
require
.
NotContains
(
t
,
rec
.
Body
.
String
(),
"event: error"
)
}
// TestExtractSSEUsage 验证 extractSSEUsage 从 SSE data 行正确提取 usage
func
TestExtractSSEUsage
(
t
*
testing
.
T
)
{
svc
:=
&
AntigravityGatewayService
{}
tests
:=
[]
struct
{
name
string
line
string
expected
ClaudeUsage
}{
{
name
:
"message_delta with output_tokens"
,
line
:
`data: {"type":"message_delta","usage":{"output_tokens":42}}`
,
expected
:
ClaudeUsage
{
OutputTokens
:
42
},
},
{
name
:
"non-data line ignored"
,
line
:
`event: message_start`
,
expected
:
ClaudeUsage
{},
},
{
name
:
"top-level usage with all fields"
,
line
:
`data: {"usage":{"input_tokens":10,"output_tokens":20,"cache_read_input_tokens":5,"cache_creation_input_tokens":3}}`
,
expected
:
ClaudeUsage
{
InputTokens
:
10
,
OutputTokens
:
20
,
CacheReadInputTokens
:
5
,
CacheCreationInputTokens
:
3
},
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
usage
:=
&
ClaudeUsage
{}
svc
.
extractSSEUsage
(
tt
.
line
,
usage
)
require
.
Equal
(
t
,
tt
.
expected
,
*
usage
)
})
}
}
// TestAntigravityClientWriter 验证 antigravityClientWriter 的断开检测
func
TestAntigravityClientWriter
(
t
*
testing
.
T
)
{
t
.
Run
(
"normal write succeeds"
,
func
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
flusher
,
_
:=
c
.
Writer
.
(
http
.
Flusher
)
cw
:=
newAntigravityClientWriter
(
c
.
Writer
,
flusher
,
"test"
)
ok
:=
cw
.
Write
([]
byte
(
"hello"
))
require
.
True
(
t
,
ok
)
require
.
False
(
t
,
cw
.
Disconnected
())
require
.
Contains
(
t
,
rec
.
Body
.
String
(),
"hello"
)
})
t
.
Run
(
"write failure marks disconnected"
,
func
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
fw
:=
&
antigravityFailingWriter
{
ResponseWriter
:
c
.
Writer
,
failAfter
:
0
}
flusher
,
_
:=
c
.
Writer
.
(
http
.
Flusher
)
cw
:=
newAntigravityClientWriter
(
fw
,
flusher
,
"test"
)
ok
:=
cw
.
Write
([]
byte
(
"hello"
))
require
.
False
(
t
,
ok
)
require
.
True
(
t
,
cw
.
Disconnected
())
})
t
.
Run
(
"subsequent writes are no-op"
,
func
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
fw
:=
&
antigravityFailingWriter
{
ResponseWriter
:
c
.
Writer
,
failAfter
:
0
}
flusher
,
_
:=
c
.
Writer
.
(
http
.
Flusher
)
cw
:=
newAntigravityClientWriter
(
fw
,
flusher
,
"test"
)
cw
.
Write
([]
byte
(
"first"
))
ok
:=
cw
.
Fprintf
(
"second %d"
,
2
)
require
.
False
(
t
,
ok
)
require
.
True
(
t
,
cw
.
Disconnected
())
})
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment