Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
195e227c
Commit
195e227c
authored
Jan 06, 2026
by
song
Browse files
merge: 合并 upstream/main 并保留本地图片计费功能
parents
6fa704d6
752882a0
Changes
187
Hide whitespace changes
Inline
Side-by-side
backend/internal/service/admin_service.go
View file @
195e227c
...
...
@@ -123,6 +123,7 @@ type UpdateGroupInput struct {
type
CreateAccountInput
struct
{
Name
string
Notes
*
string
Platform
string
Type
string
Credentials
map
[
string
]
any
...
...
@@ -138,6 +139,7 @@ type CreateAccountInput struct {
type
UpdateAccountInput
struct
{
Name
string
Notes
*
string
Type
string
// Account type: oauth, setup-token, apikey
Credentials
map
[
string
]
any
Extra
map
[
string
]
any
...
...
@@ -687,6 +689,7 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
account
:=
&
Account
{
Name
:
input
.
Name
,
Notes
:
normalizeAccountNotes
(
input
.
Notes
),
Platform
:
input
.
Platform
,
Type
:
input
.
Type
,
Credentials
:
input
.
Credentials
,
...
...
@@ -723,6 +726,9 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
if
input
.
Type
!=
""
{
account
.
Type
=
input
.
Type
}
if
input
.
Notes
!=
nil
{
account
.
Notes
=
normalizeAccountNotes
(
input
.
Notes
)
}
if
len
(
input
.
Credentials
)
>
0
{
account
.
Credentials
=
input
.
Credentials
}
...
...
@@ -730,7 +736,12 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
account
.
Extra
=
input
.
Extra
}
if
input
.
ProxyID
!=
nil
{
account
.
ProxyID
=
input
.
ProxyID
// 0 表示清除代理(前端发送 0 而不是 null 来表达清除意图)
if
*
input
.
ProxyID
==
0
{
account
.
ProxyID
=
nil
}
else
{
account
.
ProxyID
=
input
.
ProxyID
}
account
.
Proxy
=
nil
// 清除关联对象,防止 GORM Save 时根据 Proxy.ID 覆盖 ProxyID
}
// 只在指针非 nil 时更新 Concurrency(支持设置为 0)
...
...
backend/internal/service/antigravity_gateway_service.go
View file @
195e227c
...
...
@@ -9,8 +9,10 @@ import (
"fmt"
"io"
"log"
mathrand
"math/rand"
"net/http"
"strings"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
...
...
@@ -255,6 +257,16 @@ func (s *AntigravityGatewayService) buildClaudeTestRequest(projectID, mappedMode
return
antigravity
.
TransformClaudeToGemini
(
claudeReq
,
projectID
,
mappedModel
)
}
func
(
s
*
AntigravityGatewayService
)
getClaudeTransformOptions
(
ctx
context
.
Context
)
antigravity
.
TransformOptions
{
opts
:=
antigravity
.
DefaultTransformOptions
()
if
s
.
settingService
==
nil
{
return
opts
}
opts
.
EnableIdentityPatch
=
s
.
settingService
.
IsIdentityPatchEnabled
(
ctx
)
opts
.
IdentityPatch
=
s
.
settingService
.
GetIdentityPatchPrompt
(
ctx
)
return
opts
}
// extractGeminiResponseText 从 Gemini 响应中提取文本
func
extractGeminiResponseText
(
respBody
[]
byte
)
string
{
var
resp
map
[
string
]
any
...
...
@@ -380,7 +392,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
}
// 转换 Claude 请求为 Gemini 格式
geminiBody
,
err
:=
antigravity
.
TransformClaudeToGemini
(
&
claudeReq
,
projectID
,
mappedModel
)
geminiBody
,
err
:=
antigravity
.
TransformClaudeToGemini
WithOptions
(
&
claudeReq
,
projectID
,
mappedModel
,
s
.
getClaudeTransformOptions
(
ctx
)
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"transform request: %w"
,
err
)
}
...
...
@@ -394,6 +406,14 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
// 重试循环
var
resp
*
http
.
Response
for
attempt
:=
1
;
attempt
<=
antigravityMaxRetries
;
attempt
++
{
// 检查 context 是否已取消(客户端断开连接)
select
{
case
<-
ctx
.
Done
()
:
log
.
Printf
(
"%s status=context_canceled error=%v"
,
prefix
,
ctx
.
Err
())
return
nil
,
ctx
.
Err
()
default
:
}
upstreamReq
,
err
:=
antigravity
.
NewAPIRequest
(
ctx
,
action
,
accessToken
,
geminiBody
)
if
err
!=
nil
{
return
nil
,
err
...
...
@@ -403,7 +423,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
if
err
!=
nil
{
if
attempt
<
antigravityMaxRetries
{
log
.
Printf
(
"%s status=request_failed retry=%d/%d error=%v"
,
prefix
,
attempt
,
antigravityMaxRetries
,
err
)
sleepAntigravityBackoff
(
attempt
)
if
!
sleepAntigravityBackoffWithContext
(
ctx
,
attempt
)
{
log
.
Printf
(
"%s status=context_canceled_during_backoff"
,
prefix
)
return
nil
,
ctx
.
Err
()
}
continue
}
log
.
Printf
(
"%s status=request_failed retries_exhausted error=%v"
,
prefix
,
err
)
...
...
@@ -416,7 +439,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
if
attempt
<
antigravityMaxRetries
{
log
.
Printf
(
"%s status=%d retry=%d/%d"
,
prefix
,
resp
.
StatusCode
,
attempt
,
antigravityMaxRetries
)
sleepAntigravityBackoff
(
attempt
)
if
!
sleepAntigravityBackoffWithContext
(
ctx
,
attempt
)
{
log
.
Printf
(
"%s status=context_canceled_during_backoff"
,
prefix
)
return
nil
,
ctx
.
Err
()
}
continue
}
// 所有重试都失败,标记限流状态
...
...
@@ -443,35 +469,70 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
// Antigravity /v1internal 链路在部分场景会对 thought/thinking signature 做严格校验,
// 当历史消息携带的 signature 不合法时会直接 400;去除 thinking 后可继续完成请求。
if
resp
.
StatusCode
==
http
.
StatusBadRequest
&&
isSignatureRelatedError
(
respBody
)
{
retryClaudeReq
:=
claudeReq
retryClaudeReq
.
Messages
=
append
([]
antigravity
.
ClaudeMessage
(
nil
),
claudeReq
.
Messages
...
)
stripped
,
stripErr
:=
stripThinkingFromClaudeRequest
(
&
retryClaudeReq
)
if
stripErr
==
nil
&&
stripped
{
log
.
Printf
(
"Antigravity account %d: detected signature-related 400, retrying once without thinking blocks"
,
account
.
ID
)
retryGeminiBody
,
txErr
:=
antigravity
.
TransformClaudeToGemini
(
&
retryClaudeReq
,
projectID
,
mappedModel
)
if
txErr
==
nil
{
retryReq
,
buildErr
:=
antigravity
.
NewAPIRequest
(
ctx
,
action
,
accessToken
,
retryGeminiBody
)
if
buildErr
==
nil
{
retryResp
,
retryErr
:=
s
.
httpUpstream
.
Do
(
retryReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
retryErr
==
nil
{
// Retry success: continue normal success flow with the new response.
if
retryResp
.
StatusCode
<
400
{
_
=
resp
.
Body
.
Close
()
resp
=
retryResp
respBody
=
nil
}
else
{
// Retry still errored: replace error context with retry response.
retryBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
retryResp
.
Body
,
2
<<
20
))
_
=
retryResp
.
Body
.
Close
()
respBody
=
retryBody
resp
=
retryResp
}
}
else
{
log
.
Printf
(
"Antigravity account %d: signature retry request failed: %v"
,
account
.
ID
,
retryErr
)
}
// Conservative two-stage fallback:
// 1) Disable top-level thinking + thinking->text
// 2) Only if still signature-related 400: also downgrade tool_use/tool_result to text.
retryStages
:=
[]
struct
{
name
string
strip
func
(
*
antigravity
.
ClaudeRequest
)
(
bool
,
error
)
}{
{
name
:
"thinking-only"
,
strip
:
stripThinkingFromClaudeRequest
},
{
name
:
"thinking+tools"
,
strip
:
stripSignatureSensitiveBlocksFromClaudeRequest
},
}
for
_
,
stage
:=
range
retryStages
{
retryClaudeReq
:=
claudeReq
retryClaudeReq
.
Messages
=
append
([]
antigravity
.
ClaudeMessage
(
nil
),
claudeReq
.
Messages
...
)
stripped
,
stripErr
:=
stage
.
strip
(
&
retryClaudeReq
)
if
stripErr
!=
nil
||
!
stripped
{
continue
}
log
.
Printf
(
"Antigravity account %d: detected signature-related 400, retrying once (%s)"
,
account
.
ID
,
stage
.
name
)
retryGeminiBody
,
txErr
:=
antigravity
.
TransformClaudeToGeminiWithOptions
(
&
retryClaudeReq
,
projectID
,
mappedModel
,
s
.
getClaudeTransformOptions
(
ctx
))
if
txErr
!=
nil
{
continue
}
retryReq
,
buildErr
:=
antigravity
.
NewAPIRequest
(
ctx
,
action
,
accessToken
,
retryGeminiBody
)
if
buildErr
!=
nil
{
continue
}
retryResp
,
retryErr
:=
s
.
httpUpstream
.
Do
(
retryReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
retryErr
!=
nil
{
log
.
Printf
(
"Antigravity account %d: signature retry request failed (%s): %v"
,
account
.
ID
,
stage
.
name
,
retryErr
)
continue
}
if
retryResp
.
StatusCode
<
400
{
_
=
resp
.
Body
.
Close
()
resp
=
retryResp
respBody
=
nil
break
}
retryBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
retryResp
.
Body
,
2
<<
20
))
_
=
retryResp
.
Body
.
Close
()
// If this stage fixed the signature issue, we stop; otherwise we may try the next stage.
if
retryResp
.
StatusCode
!=
http
.
StatusBadRequest
||
!
isSignatureRelatedError
(
retryBody
)
{
respBody
=
retryBody
resp
=
&
http
.
Response
{
StatusCode
:
retryResp
.
StatusCode
,
Header
:
retryResp
.
Header
.
Clone
(),
Body
:
io
.
NopCloser
(
bytes
.
NewReader
(
retryBody
)),
}
break
}
// Still signature-related; capture context and allow next stage.
respBody
=
retryBody
resp
=
&
http
.
Response
{
StatusCode
:
retryResp
.
StatusCode
,
Header
:
retryResp
.
Header
.
Clone
(),
Body
:
io
.
NopCloser
(
bytes
.
NewReader
(
retryBody
)),
}
}
}
...
...
@@ -528,7 +589,17 @@ func isSignatureRelatedError(respBody []byte) bool {
}
// Keep this intentionally broad: different upstreams may use "signature" or "thought_signature".
return
strings
.
Contains
(
msg
,
"thought_signature"
)
||
strings
.
Contains
(
msg
,
"signature"
)
if
strings
.
Contains
(
msg
,
"thought_signature"
)
||
strings
.
Contains
(
msg
,
"signature"
)
{
return
true
}
// Also detect thinking block structural errors:
// "Expected `thinking` or `redacted_thinking`, but found `text`"
if
strings
.
Contains
(
msg
,
"expected"
)
&&
(
strings
.
Contains
(
msg
,
"thinking"
)
||
strings
.
Contains
(
msg
,
"redacted_thinking"
))
{
return
true
}
return
false
}
func
extractAntigravityErrorMessage
(
body
[]
byte
)
string
{
...
...
@@ -555,7 +626,7 @@ func extractAntigravityErrorMessage(body []byte) string {
// stripThinkingFromClaudeRequest converts thinking blocks to text blocks in a Claude Messages request.
// This preserves the thinking content while avoiding signature validation errors.
// Note: redacted_thinking blocks are removed because they cannot be converted to text.
// It also disables top-level `thinking` to
prevent dummy-thought injection during retry
.
// It also disables top-level `thinking` to
avoid upstream structural constraints for thinking mode
.
func
stripThinkingFromClaudeRequest
(
req
*
antigravity
.
ClaudeRequest
)
(
bool
,
error
)
{
if
req
==
nil
{
return
false
,
nil
...
...
@@ -585,6 +656,92 @@ func stripThinkingFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error
continue
}
filtered
:=
make
([]
map
[
string
]
any
,
0
,
len
(
blocks
))
modifiedAny
:=
false
for
_
,
block
:=
range
blocks
{
t
,
_
:=
block
[
"type"
]
.
(
string
)
switch
t
{
case
"thinking"
:
thinkingText
,
_
:=
block
[
"thinking"
]
.
(
string
)
if
thinkingText
!=
""
{
filtered
=
append
(
filtered
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
thinkingText
,
})
}
modifiedAny
=
true
case
"redacted_thinking"
:
modifiedAny
=
true
case
""
:
if
thinkingText
,
hasThinking
:=
block
[
"thinking"
]
.
(
string
);
hasThinking
{
if
thinkingText
!=
""
{
filtered
=
append
(
filtered
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
thinkingText
,
})
}
modifiedAny
=
true
}
else
{
filtered
=
append
(
filtered
,
block
)
}
default
:
filtered
=
append
(
filtered
,
block
)
}
}
if
!
modifiedAny
{
continue
}
if
len
(
filtered
)
==
0
{
filtered
=
append
(
filtered
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
"(content removed)"
,
})
}
newRaw
,
err
:=
json
.
Marshal
(
filtered
)
if
err
!=
nil
{
return
changed
,
err
}
req
.
Messages
[
i
]
.
Content
=
newRaw
changed
=
true
}
return
changed
,
nil
}
// stripSignatureSensitiveBlocksFromClaudeRequest is a stronger retry degradation that additionally converts
// tool blocks to plain text. Use this only after a thinking-only retry still fails with signature errors.
func
stripSignatureSensitiveBlocksFromClaudeRequest
(
req
*
antigravity
.
ClaudeRequest
)
(
bool
,
error
)
{
if
req
==
nil
{
return
false
,
nil
}
changed
:=
false
if
req
.
Thinking
!=
nil
{
req
.
Thinking
=
nil
changed
=
true
}
for
i
:=
range
req
.
Messages
{
raw
:=
req
.
Messages
[
i
]
.
Content
if
len
(
raw
)
==
0
{
continue
}
// If content is a string, nothing to strip.
var
str
string
if
json
.
Unmarshal
(
raw
,
&
str
)
==
nil
{
continue
}
// Otherwise treat as an array of blocks and convert signature-sensitive blocks to text.
var
blocks
[]
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
raw
,
&
blocks
);
err
!=
nil
{
continue
}
filtered
:=
make
([]
map
[
string
]
any
,
0
,
len
(
blocks
))
modifiedAny
:=
false
for
_
,
block
:=
range
blocks
{
...
...
@@ -603,6 +760,49 @@ func stripThinkingFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error
case
"redacted_thinking"
:
// Remove redacted_thinking (cannot convert encrypted content)
modifiedAny
=
true
case
"tool_use"
:
// Convert tool_use to text to avoid upstream signature/thought_signature validation errors.
// This is a retry-only degradation path, so we prioritise request validity over tool semantics.
name
,
_
:=
block
[
"name"
]
.
(
string
)
id
,
_
:=
block
[
"id"
]
.
(
string
)
input
:=
block
[
"input"
]
inputJSON
,
_
:=
json
.
Marshal
(
input
)
text
:=
"(tool_use)"
if
name
!=
""
{
text
+=
" name="
+
name
}
if
id
!=
""
{
text
+=
" id="
+
id
}
if
len
(
inputJSON
)
>
0
&&
string
(
inputJSON
)
!=
"null"
{
text
+=
" input="
+
string
(
inputJSON
)
}
filtered
=
append
(
filtered
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
text
,
})
modifiedAny
=
true
case
"tool_result"
:
// Convert tool_result to text so it stays consistent when tool_use is downgraded.
toolUseID
,
_
:=
block
[
"tool_use_id"
]
.
(
string
)
isError
,
_
:=
block
[
"is_error"
]
.
(
bool
)
content
:=
block
[
"content"
]
contentJSON
,
_
:=
json
.
Marshal
(
content
)
text
:=
"(tool_result)"
if
toolUseID
!=
""
{
text
+=
" tool_use_id="
+
toolUseID
}
if
isError
{
text
+=
" is_error=true"
}
if
len
(
contentJSON
)
>
0
&&
string
(
contentJSON
)
!=
"null"
{
text
+=
"
\n
"
+
string
(
contentJSON
)
}
filtered
=
append
(
filtered
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
text
,
})
modifiedAny
=
true
case
""
:
// Handle untyped block with "thinking" field
if
thinkingText
,
hasThinking
:=
block
[
"thinking"
]
.
(
string
);
hasThinking
{
...
...
@@ -625,6 +825,14 @@ func stripThinkingFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error
continue
}
if
len
(
filtered
)
==
0
{
// Keep request valid: upstream rejects empty content arrays.
filtered
=
append
(
filtered
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
"(content removed)"
,
})
}
newRaw
,
err
:=
json
.
Marshal
(
filtered
)
if
err
!=
nil
{
return
changed
,
err
...
...
@@ -711,6 +919,14 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
// 重试循环
var
resp
*
http
.
Response
for
attempt
:=
1
;
attempt
<=
antigravityMaxRetries
;
attempt
++
{
// 检查 context 是否已取消(客户端断开连接)
select
{
case
<-
ctx
.
Done
()
:
log
.
Printf
(
"%s status=context_canceled error=%v"
,
prefix
,
ctx
.
Err
())
return
nil
,
ctx
.
Err
()
default
:
}
upstreamReq
,
err
:=
antigravity
.
NewAPIRequest
(
ctx
,
upstreamAction
,
accessToken
,
wrappedBody
)
if
err
!=
nil
{
return
nil
,
err
...
...
@@ -720,7 +936,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
if
err
!=
nil
{
if
attempt
<
antigravityMaxRetries
{
log
.
Printf
(
"%s status=request_failed retry=%d/%d error=%v"
,
prefix
,
attempt
,
antigravityMaxRetries
,
err
)
sleepAntigravityBackoff
(
attempt
)
if
!
sleepAntigravityBackoffWithContext
(
ctx
,
attempt
)
{
log
.
Printf
(
"%s status=context_canceled_during_backoff"
,
prefix
)
return
nil
,
ctx
.
Err
()
}
continue
}
log
.
Printf
(
"%s status=request_failed retries_exhausted error=%v"
,
prefix
,
err
)
...
...
@@ -733,7 +952,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
if
attempt
<
antigravityMaxRetries
{
log
.
Printf
(
"%s status=%d retry=%d/%d"
,
prefix
,
resp
.
StatusCode
,
attempt
,
antigravityMaxRetries
)
sleepAntigravityBackoff
(
attempt
)
if
!
sleepAntigravityBackoffWithContext
(
ctx
,
attempt
)
{
log
.
Printf
(
"%s status=context_canceled_during_backoff"
,
prefix
)
return
nil
,
ctx
.
Err
()
}
continue
}
// 所有重试都失败,标记限流状态
...
...
@@ -750,11 +972,18 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
break
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
defer
func
()
{
if
resp
!=
nil
&&
resp
.
Body
!=
nil
{
_
=
resp
.
Body
.
Close
()
}
}()
// 处理错误响应
if
resp
.
StatusCode
>=
400
{
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
// 尽早关闭原始响应体,释放连接;后续逻辑仍可能需要读取 body,因此用内存副本重新包装。
_
=
resp
.
Body
.
Close
()
resp
.
Body
=
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
))
// 模型兜底:模型不存在且开启 fallback 时,自动用 fallback 模型重试一次
if
s
.
settingService
!=
nil
&&
s
.
settingService
.
IsModelFallbackEnabled
(
ctx
)
&&
...
...
@@ -763,15 +992,13 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
if
fallbackModel
!=
""
&&
fallbackModel
!=
mappedModel
{
log
.
Printf
(
"[Antigravity] Model not found (%s), retrying with fallback model %s (account: %s)"
,
mappedModel
,
fallbackModel
,
account
.
Name
)
// 关闭原始响应,释放连接(respBody 已读取到内存)
_
=
resp
.
Body
.
Close
()
fallbackWrapped
,
err
:=
s
.
wrapV1InternalRequest
(
projectID
,
fallbackModel
,
body
)
if
err
==
nil
{
fallbackReq
,
err
:=
antigravity
.
NewAPIRequest
(
ctx
,
upstreamAction
,
accessToken
,
fallbackWrapped
)
if
err
==
nil
{
fallbackResp
,
err
:=
s
.
httpUpstream
.
Do
(
fallbackReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
err
==
nil
&&
fallbackResp
.
StatusCode
<
400
{
_
=
resp
.
Body
.
Close
()
resp
=
fallbackResp
}
else
if
fallbackResp
!=
nil
{
_
=
fallbackResp
.
Body
.
Close
()
...
...
@@ -872,8 +1099,28 @@ func (s *AntigravityGatewayService) shouldFailoverUpstreamError(statusCode int)
}
}
func
sleepAntigravityBackoff
(
attempt
int
)
{
sleepGeminiBackoff
(
attempt
)
// 复用 Gemini 的退避逻辑
// sleepAntigravityBackoffWithContext 带 context 取消检查的退避等待
// 返回 true 表示正常完成等待,false 表示 context 已取消
func
sleepAntigravityBackoffWithContext
(
ctx
context
.
Context
,
attempt
int
)
bool
{
delay
:=
geminiRetryBaseDelay
*
time
.
Duration
(
1
<<
uint
(
attempt
-
1
))
if
delay
>
geminiRetryMaxDelay
{
delay
=
geminiRetryMaxDelay
}
// +/- 20% jitter
r
:=
mathrand
.
New
(
mathrand
.
NewSource
(
time
.
Now
()
.
UnixNano
()))
jitter
:=
time
.
Duration
(
float64
(
delay
)
*
0.2
*
(
r
.
Float64
()
*
2
-
1
))
sleepFor
:=
delay
+
jitter
if
sleepFor
<
0
{
sleepFor
=
0
}
select
{
case
<-
ctx
.
Done
()
:
return
false
case
<-
time
.
After
(
sleepFor
)
:
return
true
}
}
func
(
s
*
AntigravityGatewayService
)
handleUpstreamError
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
)
{
...
...
@@ -928,57 +1175,145 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
return
nil
,
errors
.
New
(
"streaming not supported"
)
}
reader
:=
bufio
.
NewReader
(
resp
.
Body
)
// 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
scanner
:=
bufio
.
NewScanner
(
resp
.
Body
)
maxLineSize
:=
defaultMaxLineSize
if
s
.
settingService
.
cfg
!=
nil
&&
s
.
settingService
.
cfg
.
Gateway
.
MaxLineSize
>
0
{
maxLineSize
=
s
.
settingService
.
cfg
.
Gateway
.
MaxLineSize
}
scanner
.
Buffer
(
make
([]
byte
,
64
*
1024
),
maxLineSize
)
usage
:=
&
ClaudeUsage
{}
var
firstTokenMs
*
int
type
scanEvent
struct
{
line
string
err
error
}
// 独立 goroutine 读取上游,避免读取阻塞影响超时处理
events
:=
make
(
chan
scanEvent
,
16
)
done
:=
make
(
chan
struct
{})
sendEvent
:=
func
(
ev
scanEvent
)
bool
{
select
{
case
events
<-
ev
:
return
true
case
<-
done
:
return
false
}
}
var
lastReadAt
int64
atomic
.
StoreInt64
(
&
lastReadAt
,
time
.
Now
()
.
UnixNano
())
go
func
()
{
defer
close
(
events
)
for
scanner
.
Scan
()
{
atomic
.
StoreInt64
(
&
lastReadAt
,
time
.
Now
()
.
UnixNano
())
if
!
sendEvent
(
scanEvent
{
line
:
scanner
.
Text
()})
{
return
}
}
if
err
:=
scanner
.
Err
();
err
!=
nil
{
_
=
sendEvent
(
scanEvent
{
err
:
err
})
}
}()
defer
close
(
done
)
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
streamInterval
:=
time
.
Duration
(
0
)
if
s
.
settingService
.
cfg
!=
nil
&&
s
.
settingService
.
cfg
.
Gateway
.
StreamDataIntervalTimeout
>
0
{
streamInterval
=
time
.
Duration
(
s
.
settingService
.
cfg
.
Gateway
.
StreamDataIntervalTimeout
)
*
time
.
Second
}
var
intervalTicker
*
time
.
Ticker
if
streamInterval
>
0
{
intervalTicker
=
time
.
NewTicker
(
streamInterval
)
defer
intervalTicker
.
Stop
()
}
var
intervalCh
<-
chan
time
.
Time
if
intervalTicker
!=
nil
{
intervalCh
=
intervalTicker
.
C
}
// 仅发送一次错误事件,避免多次写入导致协议混乱
errorEventSent
:=
false
sendErrorEvent
:=
func
(
reason
string
)
{
if
errorEventSent
{
return
}
errorEventSent
=
true
_
,
_
=
fmt
.
Fprintf
(
c
.
Writer
,
"event: error
\n
data: {
\"
error
\"
:
\"
%s
\"
}
\n\n
"
,
reason
)
flusher
.
Flush
()
}
for
{
line
,
err
:=
reader
.
ReadString
(
'\n'
)
if
len
(
line
)
>
0
{
select
{
case
ev
,
ok
:=
<-
events
:
if
!
ok
{
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
}
if
ev
.
err
!=
nil
{
if
errors
.
Is
(
ev
.
err
,
bufio
.
ErrTooLong
)
{
log
.
Printf
(
"SSE line too long (antigravity): max_size=%d error=%v"
,
maxLineSize
,
ev
.
err
)
sendErrorEvent
(
"response_too_large"
)
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
ev
.
err
}
sendErrorEvent
(
"stream_read_error"
)
return
nil
,
ev
.
err
}
line
:=
ev
.
line
trimmed
:=
strings
.
TrimRight
(
line
,
"
\r\n
"
)
if
strings
.
HasPrefix
(
trimmed
,
"data:"
)
{
payload
:=
strings
.
TrimSpace
(
strings
.
TrimPrefix
(
trimmed
,
"data:"
))
if
payload
==
""
||
payload
==
"[DONE]"
{
_
,
_
=
io
.
WriteString
(
c
.
Writer
,
line
)
flusher
.
Flush
()
}
else
{
// 解包 v1internal 响应
inner
,
parseErr
:=
s
.
unwrapV1InternalResponse
([]
byte
(
payload
))
if
parseErr
==
nil
&&
inner
!=
nil
{
payload
=
string
(
inner
)
if
_
,
err
:=
fmt
.
Fprintf
(
c
.
Writer
,
"%s
\n
"
,
line
);
err
!=
nil
{
sendErrorEvent
(
"write_failed"
)
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
}
flusher
.
Flush
()
continue
}
// 解析 usage
var
parsed
map
[
string
]
any
if
json
.
Unmarshal
(
inner
,
&
parsed
)
==
nil
{
if
u
:=
extractGeminiUsage
(
parsed
);
u
!=
nil
{
usage
=
u
}
}
// 解包 v1internal 响应
inner
,
parseErr
:=
s
.
unwrapV1InternalResponse
([]
byte
(
payload
))
if
parseErr
==
nil
&&
inner
!=
nil
{
payload
=
string
(
inner
)
}
if
firstTokenMs
==
nil
{
ms
:=
int
(
time
.
Since
(
startTime
)
.
Milliseconds
())
firstTokenMs
=
&
ms
// 解析 usage
var
parsed
map
[
string
]
any
if
json
.
Unmarshal
(
inner
,
&
parsed
)
==
nil
{
if
u
:=
extractGeminiUsage
(
parsed
);
u
!=
nil
{
usage
=
u
}
}
_
,
_
=
fmt
.
Fprintf
(
c
.
Writer
,
"data: %s
\n\n
"
,
payload
)
flusher
.
Flush
()
if
firstTokenMs
==
nil
{
ms
:=
int
(
time
.
Since
(
startTime
)
.
Milliseconds
())
firstTokenMs
=
&
ms
}
if
_
,
err
:=
fmt
.
Fprintf
(
c
.
Writer
,
"data: %s
\n\n
"
,
payload
);
err
!=
nil
{
sendErrorEvent
(
"write_failed"
)
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
}
}
else
{
_
,
_
=
io
.
WriteString
(
c
.
Writer
,
line
)
flusher
.
Flush
()
continue
}
}
if
errors
.
Is
(
err
,
io
.
EOF
)
{
break
}
if
err
!=
nil
{
return
nil
,
err
if
_
,
err
:=
fmt
.
Fprintf
(
c
.
Writer
,
"%s
\n
"
,
line
);
err
!=
nil
{
sendErrorEvent
(
"write_failed"
)
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
}
flusher
.
Flush
()
case
<-
intervalCh
:
lastRead
:=
time
.
Unix
(
0
,
atomic
.
LoadInt64
(
&
lastReadAt
))
if
time
.
Since
(
lastRead
)
<
streamInterval
{
continue
}
log
.
Printf
(
"Stream data interval timeout (antigravity)"
)
sendErrorEvent
(
"stream_timeout"
)
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream data interval timeout"
)
}
}
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
}
func
(
s
*
AntigravityGatewayService
)
handleGeminiNonStreamingResponse
(
c
*
gin
.
Context
,
resp
*
http
.
Response
)
(
*
ClaudeUsage
,
error
)
{
...
...
@@ -1117,7 +1452,13 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
processor
:=
antigravity
.
NewStreamingProcessor
(
originalModel
)
var
firstTokenMs
*
int
reader
:=
bufio
.
NewReader
(
resp
.
Body
)
// 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
scanner
:=
bufio
.
NewScanner
(
resp
.
Body
)
maxLineSize
:=
defaultMaxLineSize
if
s
.
settingService
.
cfg
!=
nil
&&
s
.
settingService
.
cfg
.
Gateway
.
MaxLineSize
>
0
{
maxLineSize
=
s
.
settingService
.
cfg
.
Gateway
.
MaxLineSize
}
scanner
.
Buffer
(
make
([]
byte
,
64
*
1024
),
maxLineSize
)
// 辅助函数:转换 antigravity.ClaudeUsage 到 service.ClaudeUsage
convertUsage
:=
func
(
agUsage
*
antigravity
.
ClaudeUsage
)
*
ClaudeUsage
{
...
...
@@ -1132,13 +1473,85 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
}
}
for
{
line
,
err
:=
reader
.
ReadString
(
'\n'
)
if
err
!=
nil
&&
!
errors
.
Is
(
err
,
io
.
EOF
)
{
return
nil
,
fmt
.
Errorf
(
"stream read error: %w"
,
err
)
type
scanEvent
struct
{
line
string
err
error
}
// 独立 goroutine 读取上游,避免读取阻塞影响超时处理
events
:=
make
(
chan
scanEvent
,
16
)
done
:=
make
(
chan
struct
{})
sendEvent
:=
func
(
ev
scanEvent
)
bool
{
select
{
case
events
<-
ev
:
return
true
case
<-
done
:
return
false
}
}
var
lastReadAt
int64
atomic
.
StoreInt64
(
&
lastReadAt
,
time
.
Now
()
.
UnixNano
())
go
func
()
{
defer
close
(
events
)
for
scanner
.
Scan
()
{
atomic
.
StoreInt64
(
&
lastReadAt
,
time
.
Now
()
.
UnixNano
())
if
!
sendEvent
(
scanEvent
{
line
:
scanner
.
Text
()})
{
return
}
}
if
err
:=
scanner
.
Err
();
err
!=
nil
{
_
=
sendEvent
(
scanEvent
{
err
:
err
})
}
}()
defer
close
(
done
)
streamInterval
:=
time
.
Duration
(
0
)
if
s
.
settingService
.
cfg
!=
nil
&&
s
.
settingService
.
cfg
.
Gateway
.
StreamDataIntervalTimeout
>
0
{
streamInterval
=
time
.
Duration
(
s
.
settingService
.
cfg
.
Gateway
.
StreamDataIntervalTimeout
)
*
time
.
Second
}
var
intervalTicker
*
time
.
Ticker
if
streamInterval
>
0
{
intervalTicker
=
time
.
NewTicker
(
streamInterval
)
defer
intervalTicker
.
Stop
()
}
var
intervalCh
<-
chan
time
.
Time
if
intervalTicker
!=
nil
{
intervalCh
=
intervalTicker
.
C
}
// 仅发送一次错误事件,避免多次写入导致协议混乱
errorEventSent
:=
false
sendErrorEvent
:=
func
(
reason
string
)
{
if
errorEventSent
{
return
}
errorEventSent
=
true
_
,
_
=
fmt
.
Fprintf
(
c
.
Writer
,
"event: error
\n
data: {
\"
error
\"
:
\"
%s
\"
}
\n\n
"
,
reason
)
flusher
.
Flush
()
}
for
{
select
{
case
ev
,
ok
:=
<-
events
:
if
!
ok
{
// 发送结束事件
finalEvents
,
agUsage
:=
processor
.
Finish
()
if
len
(
finalEvents
)
>
0
{
_
,
_
=
c
.
Writer
.
Write
(
finalEvents
)
flusher
.
Flush
()
}
return
&
antigravityStreamResult
{
usage
:
convertUsage
(
agUsage
),
firstTokenMs
:
firstTokenMs
},
nil
}
if
ev
.
err
!=
nil
{
if
errors
.
Is
(
ev
.
err
,
bufio
.
ErrTooLong
)
{
log
.
Printf
(
"SSE line too long (antigravity): max_size=%d error=%v"
,
maxLineSize
,
ev
.
err
)
sendErrorEvent
(
"response_too_large"
)
return
&
antigravityStreamResult
{
usage
:
convertUsage
(
nil
),
firstTokenMs
:
firstTokenMs
},
ev
.
err
}
sendErrorEvent
(
"stream_read_error"
)
return
nil
,
fmt
.
Errorf
(
"stream read error: %w"
,
ev
.
err
)
}
if
len
(
line
)
>
0
{
line
:=
ev
.
line
// 处理 SSE 行,转换为 Claude 格式
claudeEvents
:=
processor
.
ProcessLine
(
strings
.
TrimRight
(
line
,
"
\r\n
"
))
...
...
@@ -1153,25 +1566,23 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
if
len
(
finalEvents
)
>
0
{
_
,
_
=
c
.
Writer
.
Write
(
finalEvents
)
}
sendErrorEvent
(
"write_failed"
)
return
&
antigravityStreamResult
{
usage
:
convertUsage
(
agUsage
),
firstTokenMs
:
firstTokenMs
},
writeErr
}
flusher
.
Flush
()
}
}
if
errors
.
Is
(
err
,
io
.
EOF
)
{
break
case
<-
intervalCh
:
lastRead
:=
time
.
Unix
(
0
,
atomic
.
LoadInt64
(
&
lastReadAt
))
if
time
.
Since
(
lastRead
)
<
streamInterval
{
continue
}
log
.
Printf
(
"Stream data interval timeout (antigravity)"
)
sendErrorEvent
(
"stream_timeout"
)
return
&
antigravityStreamResult
{
usage
:
convertUsage
(
nil
),
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream data interval timeout"
)
}
}
// 发送结束事件
finalEvents
,
agUsage
:=
processor
.
Finish
()
if
len
(
finalEvents
)
>
0
{
_
,
_
=
c
.
Writer
.
Write
(
finalEvents
)
flusher
.
Flush
()
}
return
&
antigravityStreamResult
{
usage
:
convertUsage
(
agUsage
),
firstTokenMs
:
firstTokenMs
},
nil
}
// extractImageSize 从 Gemini 请求中提取 image_size 参数
...
...
backend/internal/service/antigravity_gateway_service_test.go
0 → 100644
View file @
195e227c
package
service
import
(
"encoding/json"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/stretchr/testify/require"
)
func
TestStripSignatureSensitiveBlocksFromClaudeRequest
(
t
*
testing
.
T
)
{
req
:=
&
antigravity
.
ClaudeRequest
{
Model
:
"claude-sonnet-4-5"
,
Thinking
:
&
antigravity
.
ThinkingConfig
{
Type
:
"enabled"
,
BudgetTokens
:
1024
,
},
Messages
:
[]
antigravity
.
ClaudeMessage
{
{
Role
:
"assistant"
,
Content
:
json
.
RawMessage
(
`[
{"type":"thinking","thinking":"secret plan","signature":""},
{"type":"tool_use","id":"t1","name":"Bash","input":{"command":"ls"}}
]`
),
},
{
Role
:
"user"
,
Content
:
json
.
RawMessage
(
`[
{"type":"tool_result","tool_use_id":"t1","content":"ok","is_error":false},
{"type":"redacted_thinking","data":"..."}
]`
),
},
},
}
changed
,
err
:=
stripSignatureSensitiveBlocksFromClaudeRequest
(
req
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
changed
)
require
.
Nil
(
t
,
req
.
Thinking
)
require
.
Len
(
t
,
req
.
Messages
,
2
)
var
blocks0
[]
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
req
.
Messages
[
0
]
.
Content
,
&
blocks0
))
require
.
Len
(
t
,
blocks0
,
2
)
require
.
Equal
(
t
,
"text"
,
blocks0
[
0
][
"type"
])
require
.
Equal
(
t
,
"secret plan"
,
blocks0
[
0
][
"text"
])
require
.
Equal
(
t
,
"text"
,
blocks0
[
1
][
"type"
])
var
blocks1
[]
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
req
.
Messages
[
1
]
.
Content
,
&
blocks1
))
require
.
Len
(
t
,
blocks1
,
1
)
require
.
Equal
(
t
,
"text"
,
blocks1
[
0
][
"type"
])
require
.
NotEmpty
(
t
,
blocks1
[
0
][
"text"
])
}
func
TestStripThinkingFromClaudeRequest_DoesNotDowngradeTools
(
t
*
testing
.
T
)
{
req
:=
&
antigravity
.
ClaudeRequest
{
Model
:
"claude-sonnet-4-5"
,
Thinking
:
&
antigravity
.
ThinkingConfig
{
Type
:
"enabled"
,
BudgetTokens
:
1024
,
},
Messages
:
[]
antigravity
.
ClaudeMessage
{
{
Role
:
"assistant"
,
Content
:
json
.
RawMessage
(
`[{"type":"thinking","thinking":"secret plan"},{"type":"tool_use","id":"t1","name":"Bash","input":{"command":"ls"}}]`
),
},
},
}
changed
,
err
:=
stripThinkingFromClaudeRequest
(
req
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
changed
)
require
.
Nil
(
t
,
req
.
Thinking
)
var
blocks
[]
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
req
.
Messages
[
0
]
.
Content
,
&
blocks
))
require
.
Len
(
t
,
blocks
,
2
)
require
.
Equal
(
t
,
"text"
,
blocks
[
0
][
"type"
])
require
.
Equal
(
t
,
"secret plan"
,
blocks
[
0
][
"text"
])
require
.
Equal
(
t
,
"tool_use"
,
blocks
[
1
][
"type"
])
}
backend/internal/service/auth_service.go
View file @
195e227c
...
...
@@ -221,9 +221,33 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
// VerifyTurnstile 验证Turnstile token
func
(
s
*
AuthService
)
VerifyTurnstile
(
ctx
context
.
Context
,
token
string
,
remoteIP
string
)
error
{
required
:=
s
.
cfg
!=
nil
&&
s
.
cfg
.
Server
.
Mode
==
"release"
&&
s
.
cfg
.
Turnstile
.
Required
if
required
{
if
s
.
settingService
==
nil
{
log
.
Println
(
"[Auth] Turnstile required but settings service is not configured"
)
return
ErrTurnstileNotConfigured
}
enabled
:=
s
.
settingService
.
IsTurnstileEnabled
(
ctx
)
secretConfigured
:=
s
.
settingService
.
GetTurnstileSecretKey
(
ctx
)
!=
""
if
!
enabled
||
!
secretConfigured
{
log
.
Printf
(
"[Auth] Turnstile required but not configured (enabled=%v, secret_configured=%v)"
,
enabled
,
secretConfigured
)
return
ErrTurnstileNotConfigured
}
}
if
s
.
turnstileService
==
nil
{
if
required
{
log
.
Println
(
"[Auth] Turnstile required but service not configured"
)
return
ErrTurnstileNotConfigured
}
return
nil
// 服务未配置则跳过验证
}
if
!
required
&&
s
.
settingService
!=
nil
&&
s
.
settingService
.
IsTurnstileEnabled
(
ctx
)
&&
s
.
settingService
.
GetTurnstileSecretKey
(
ctx
)
==
""
{
log
.
Println
(
"[Auth] Turnstile enabled but secret key not configured"
)
}
return
s
.
turnstileService
.
VerifyToken
(
ctx
,
token
,
remoteIP
)
}
...
...
backend/internal/service/billing_cache_service.go
View file @
195e227c
...
...
@@ -16,7 +16,8 @@ import (
// 注:ErrInsufficientBalance在redeem_service.go中定义
// 注:ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义
var
(
ErrSubscriptionInvalid
=
infraerrors
.
Forbidden
(
"SUBSCRIPTION_INVALID"
,
"subscription is invalid or expired"
)
ErrSubscriptionInvalid
=
infraerrors
.
Forbidden
(
"SUBSCRIPTION_INVALID"
,
"subscription is invalid or expired"
)
ErrBillingServiceUnavailable
=
infraerrors
.
ServiceUnavailable
(
"BILLING_SERVICE_ERROR"
,
"Billing service temporarily unavailable. Please retry later."
)
)
// subscriptionCacheData 订阅缓存数据结构(内部使用)
...
...
@@ -72,10 +73,11 @@ type cacheWriteTask struct {
// BillingCacheService 计费缓存服务
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
type
BillingCacheService
struct
{
cache
BillingCache
userRepo
UserRepository
subRepo
UserSubscriptionRepository
cfg
*
config
.
Config
cache
BillingCache
userRepo
UserRepository
subRepo
UserSubscriptionRepository
cfg
*
config
.
Config
circuitBreaker
*
billingCircuitBreaker
cacheWriteChan
chan
cacheWriteTask
cacheWriteWg
sync
.
WaitGroup
...
...
@@ -95,6 +97,7 @@ func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo
subRepo
:
subRepo
,
cfg
:
cfg
,
}
svc
.
circuitBreaker
=
newBillingCircuitBreaker
(
cfg
.
Billing
.
CircuitBreaker
)
svc
.
startCacheWriteWorkers
()
return
svc
}
...
...
@@ -450,6 +453,9 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
if
s
.
cfg
.
RunMode
==
config
.
RunModeSimple
{
return
nil
}
if
s
.
circuitBreaker
!=
nil
&&
!
s
.
circuitBreaker
.
Allow
()
{
return
ErrBillingServiceUnavailable
}
// 判断计费模式
isSubscriptionMode
:=
group
!=
nil
&&
group
.
IsSubscriptionType
()
&&
subscription
!=
nil
...
...
@@ -465,9 +471,14 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
func
(
s
*
BillingCacheService
)
checkBalanceEligibility
(
ctx
context
.
Context
,
userID
int64
)
error
{
balance
,
err
:=
s
.
GetUserBalance
(
ctx
,
userID
)
if
err
!=
nil
{
// 缓存/数据库错误,允许通过(降级处理)
log
.
Printf
(
"Warning: get user balance failed, allowing request: %v"
,
err
)
return
nil
if
s
.
circuitBreaker
!=
nil
{
s
.
circuitBreaker
.
OnFailure
(
err
)
}
log
.
Printf
(
"ALERT: billing balance check failed for user %d: %v"
,
userID
,
err
)
return
ErrBillingServiceUnavailable
.
WithCause
(
err
)
}
if
s
.
circuitBreaker
!=
nil
{
s
.
circuitBreaker
.
OnSuccess
()
}
if
balance
<=
0
{
...
...
@@ -482,9 +493,14 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
// 获取订阅缓存数据
subData
,
err
:=
s
.
GetSubscriptionStatus
(
ctx
,
userID
,
group
.
ID
)
if
err
!=
nil
{
// 缓存/数据库错误,降级使用传入的subscription进行检查
log
.
Printf
(
"Warning: get subscription cache failed, using fallback: %v"
,
err
)
return
s
.
checkSubscriptionLimitsFallback
(
subscription
,
group
)
if
s
.
circuitBreaker
!=
nil
{
s
.
circuitBreaker
.
OnFailure
(
err
)
}
log
.
Printf
(
"ALERT: billing subscription check failed for user %d group %d: %v"
,
userID
,
group
.
ID
,
err
)
return
ErrBillingServiceUnavailable
.
WithCause
(
err
)
}
if
s
.
circuitBreaker
!=
nil
{
s
.
circuitBreaker
.
OnSuccess
()
}
// 检查订阅状态
...
...
@@ -513,27 +529,133 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
return
nil
}
// checkSubscriptionLimitsFallback 降级检查订阅限额
func
(
s
*
BillingCacheService
)
checkSubscriptionLimitsFallback
(
subscription
*
UserSubscription
,
group
*
Group
)
error
{
if
subscription
==
nil
{
return
ErrSubscriptionInvalid
type
billingCircuitBreakerState
int
const
(
billingCircuitClosed
billingCircuitBreakerState
=
iota
billingCircuitOpen
billingCircuitHalfOpen
)
type
billingCircuitBreaker
struct
{
mu
sync
.
Mutex
state
billingCircuitBreakerState
failures
int
openedAt
time
.
Time
failureThreshold
int
resetTimeout
time
.
Duration
halfOpenRequests
int
halfOpenRemaining
int
}
func
newBillingCircuitBreaker
(
cfg
config
.
CircuitBreakerConfig
)
*
billingCircuitBreaker
{
if
!
cfg
.
Enabled
{
return
nil
}
resetTimeout
:=
time
.
Duration
(
cfg
.
ResetTimeoutSeconds
)
*
time
.
Second
if
resetTimeout
<=
0
{
resetTimeout
=
30
*
time
.
Second
}
halfOpen
:=
cfg
.
HalfOpenRequests
if
halfOpen
<=
0
{
halfOpen
=
1
}
threshold
:=
cfg
.
FailureThreshold
if
threshold
<=
0
{
threshold
=
5
}
return
&
billingCircuitBreaker
{
state
:
billingCircuitClosed
,
failureThreshold
:
threshold
,
resetTimeout
:
resetTimeout
,
halfOpenRequests
:
halfOpen
,
}
}
if
!
subscription
.
IsActive
()
{
return
ErrSubscriptionInvalid
func
(
b
*
billingCircuitBreaker
)
Allow
()
bool
{
b
.
mu
.
Lock
()
defer
b
.
mu
.
Unlock
()
switch
b
.
state
{
case
billingCircuitClosed
:
return
true
case
billingCircuitOpen
:
if
time
.
Since
(
b
.
openedAt
)
<
b
.
resetTimeout
{
return
false
}
b
.
state
=
billingCircuitHalfOpen
b
.
halfOpenRemaining
=
b
.
halfOpenRequests
log
.
Printf
(
"ALERT: billing circuit breaker entering half-open state"
)
fallthrough
case
billingCircuitHalfOpen
:
if
b
.
halfOpenRemaining
<=
0
{
return
false
}
b
.
halfOpenRemaining
--
return
true
default
:
return
false
}
}
if
!
subscription
.
CheckDailyLimit
(
group
,
0
)
{
return
ErrDailyLimitExceeded
func
(
b
*
billingCircuitBreaker
)
OnFailure
(
err
error
)
{
if
b
==
nil
{
return
}
b
.
mu
.
Lock
()
defer
b
.
mu
.
Unlock
()
if
!
subscription
.
CheckWeeklyLimit
(
group
,
0
)
{
return
ErrWeeklyLimitExceeded
switch
b
.
state
{
case
billingCircuitOpen
:
return
case
billingCircuitHalfOpen
:
b
.
state
=
billingCircuitOpen
b
.
openedAt
=
time
.
Now
()
b
.
halfOpenRemaining
=
0
log
.
Printf
(
"ALERT: billing circuit breaker opened after half-open failure: %v"
,
err
)
return
default
:
b
.
failures
++
if
b
.
failures
>=
b
.
failureThreshold
{
b
.
state
=
billingCircuitOpen
b
.
openedAt
=
time
.
Now
()
b
.
halfOpenRemaining
=
0
log
.
Printf
(
"ALERT: billing circuit breaker opened after %d failures: %v"
,
b
.
failures
,
err
)
}
}
}
if
!
subscription
.
CheckMonthlyLimit
(
group
,
0
)
{
return
ErrMonthlyLimitExceeded
func
(
b
*
billingCircuitBreaker
)
OnSuccess
()
{
if
b
==
nil
{
return
}
b
.
mu
.
Lock
()
defer
b
.
mu
.
Unlock
()
return
nil
previousState
:=
b
.
state
previousFailures
:=
b
.
failures
b
.
state
=
billingCircuitClosed
b
.
failures
=
0
b
.
halfOpenRemaining
=
0
// 只有状态真正发生变化时才记录日志
if
previousState
!=
billingCircuitClosed
{
log
.
Printf
(
"ALERT: billing circuit breaker closed (was %s)"
,
circuitStateString
(
previousState
))
}
else
if
previousFailures
>
0
{
log
.
Printf
(
"INFO: billing circuit breaker failures reset from %d"
,
previousFailures
)
}
}
func
circuitStateString
(
state
billingCircuitBreakerState
)
string
{
switch
state
{
case
billingCircuitClosed
:
return
"closed"
case
billingCircuitOpen
:
return
"open"
case
billingCircuitHalfOpen
:
return
"half-open"
default
:
return
"unknown"
}
}
backend/internal/service/crs_sync_service.go
View file @
195e227c
...
...
@@ -8,12 +8,13 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
)
type
CRSSyncService
struct
{
...
...
@@ -22,6 +23,7 @@ type CRSSyncService struct {
oauthService
*
OAuthService
openaiOAuthService
*
OpenAIOAuthService
geminiOAuthService
*
GeminiOAuthService
cfg
*
config
.
Config
}
func
NewCRSSyncService
(
...
...
@@ -30,6 +32,7 @@ func NewCRSSyncService(
oauthService
*
OAuthService
,
openaiOAuthService
*
OpenAIOAuthService
,
geminiOAuthService
*
GeminiOAuthService
,
cfg
*
config
.
Config
,
)
*
CRSSyncService
{
return
&
CRSSyncService
{
accountRepo
:
accountRepo
,
...
...
@@ -37,6 +40,7 @@ func NewCRSSyncService(
oauthService
:
oauthService
,
openaiOAuthService
:
openaiOAuthService
,
geminiOAuthService
:
geminiOAuthService
,
cfg
:
cfg
,
}
}
...
...
@@ -187,16 +191,31 @@ type crsGeminiAPIKeyAccount struct {
}
func
(
s
*
CRSSyncService
)
SyncFromCRS
(
ctx
context
.
Context
,
input
SyncFromCRSInput
)
(
*
SyncFromCRSResult
,
error
)
{
baseURL
,
err
:=
normalizeBaseURL
(
input
.
BaseURL
)
if
err
!=
nil
{
return
nil
,
err
if
s
.
cfg
==
nil
{
return
nil
,
errors
.
New
(
"config is not available"
)
}
baseURL
:=
strings
.
TrimSpace
(
input
.
BaseURL
)
if
s
.
cfg
.
Security
.
URLAllowlist
.
Enabled
{
normalized
,
err
:=
normalizeBaseURL
(
baseURL
,
s
.
cfg
.
Security
.
URLAllowlist
.
CRSHosts
,
s
.
cfg
.
Security
.
URLAllowlist
.
AllowPrivateHosts
)
if
err
!=
nil
{
return
nil
,
err
}
baseURL
=
normalized
}
else
{
normalized
,
err
:=
urlvalidator
.
ValidateURLFormat
(
baseURL
,
s
.
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"invalid base_url: %w"
,
err
)
}
baseURL
=
normalized
}
if
strings
.
TrimSpace
(
input
.
Username
)
==
""
||
strings
.
TrimSpace
(
input
.
Password
)
==
""
{
return
nil
,
errors
.
New
(
"username and password are required"
)
}
client
,
err
:=
httpclient
.
GetClient
(
httpclient
.
Options
{
Timeout
:
20
*
time
.
Second
,
Timeout
:
20
*
time
.
Second
,
ValidateResolvedIP
:
s
.
cfg
.
Security
.
URLAllowlist
.
Enabled
,
AllowPrivateHosts
:
s
.
cfg
.
Security
.
URLAllowlist
.
AllowPrivateHosts
,
})
if
err
!=
nil
{
client
=
&
http
.
Client
{
Timeout
:
20
*
time
.
Second
}
...
...
@@ -1055,17 +1074,18 @@ func mapCRSStatus(isActive bool, status string) string {
return
"active"
}
func
normalizeBaseURL
(
raw
string
)
(
string
,
error
)
{
trimmed
:=
strings
.
TrimSpace
(
raw
)
if
trimmed
==
""
{
return
""
,
errors
.
New
(
"base_url is required"
)
}
u
,
err
:=
url
.
Parse
(
trimmed
)
if
err
!=
nil
||
u
.
Scheme
==
""
||
u
.
Host
==
""
{
return
""
,
fmt
.
Errorf
(
"invalid base_url: %s"
,
trimmed
)
func
normalizeBaseURL
(
raw
string
,
allowlist
[]
string
,
allowPrivate
bool
)
(
string
,
error
)
{
// 当 allowlist 为空时,不强制要求白名单(只进行基本的 URL 和 SSRF 验证)
requireAllowlist
:=
len
(
allowlist
)
>
0
normalized
,
err
:=
urlvalidator
.
ValidateHTTPSURL
(
raw
,
urlvalidator
.
ValidationOptions
{
AllowedHosts
:
allowlist
,
RequireAllowlist
:
requireAllowlist
,
AllowPrivate
:
allowPrivate
,
})
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"invalid base_url: %w"
,
err
)
}
u
.
Path
=
strings
.
TrimRight
(
u
.
Path
,
"/"
)
return
strings
.
TrimRight
(
u
.
String
(),
"/"
),
nil
return
normalized
,
nil
}
// cleanBaseURL removes trailing suffix from base_url in credentials
...
...
backend/internal/service/domain_constants.go
View file @
195e227c
...
...
@@ -101,6 +101,10 @@ const (
SettingKeyFallbackModelOpenAI
=
"fallback_model_openai"
SettingKeyFallbackModelGemini
=
"fallback_model_gemini"
SettingKeyFallbackModelAntigravity
=
"fallback_model_antigravity"
// Request identity patch (Claude -> Gemini systemInstruction injection)
SettingKeyEnableIdentityPatch
=
"enable_identity_patch"
SettingKeyIdentityPatchPrompt
=
"identity_patch_prompt"
)
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
...
...
backend/internal/service/gateway_request.go
View file @
195e227c
...
...
@@ -84,25 +84,180 @@ func FilterThinkingBlocks(body []byte) []byte {
return
filterThinkingBlocksInternal
(
body
,
false
)
}
// FilterThinkingBlocksForRetry removes thinking blocks from HISTORICAL messages for retry scenarios.
// This is used when upstream returns signature-related 400 errors.
// FilterThinkingBlocksForRetry strips thinking-related constructs for retry scenarios.
//
// Key insight:
// - User's thinking.type = "enabled" should be PRESERVED (user's intent)
// - Only HISTORICAL assistant messages have thinking blocks with signatures
// - These signatures may be invalid when switching accounts/platforms
// - New responses will generate fresh thinking blocks without signature issues
// Why:
// - Upstreams may reject historical `thinking`/`redacted_thinking` blocks due to invalid/missing signatures.
// - Anthropic extended thinking has a structural constraint: when top-level `thinking` is enabled and the
// final message is an assistant prefill, the assistant content must start with a thinking block.
// - If we remove thinking blocks but keep top-level `thinking` enabled, we can trigger:
// "Expected `thinking` or `redacted_thinking`, but found `text`"
//
// Strategy:
// - Keep thinking.type = "enabled" (preserve user intent)
// - Remove thinking/redacted_thinking blocks from historical assistant messages
// - Ensure no message has empty content after filtering
// Strategy (B: preserve content as text):
// - Disable top-level `thinking` (remove `thinking` field).
// - Convert `thinking` blocks to `text` blocks (preserve the thinking content).
// - Remove `redacted_thinking` blocks (cannot be converted to text).
// - Ensure no message ends up with empty content.
func
FilterThinkingBlocksForRetry
(
body
[]
byte
)
[]
byte
{
// Fast path: check for presence of thinking-related keys in messages
hasThinkingContent
:=
bytes
.
Contains
(
body
,
[]
byte
(
`"type":"thinking"`
))
||
bytes
.
Contains
(
body
,
[]
byte
(
`"type": "thinking"`
))
||
bytes
.
Contains
(
body
,
[]
byte
(
`"type":"redacted_thinking"`
))
||
bytes
.
Contains
(
body
,
[]
byte
(
`"type": "redacted_thinking"`
))
||
bytes
.
Contains
(
body
,
[]
byte
(
`"thinking":`
))
||
bytes
.
Contains
(
body
,
[]
byte
(
`"thinking" :`
))
// Also check for empty content arrays that need fixing.
// Note: This is a heuristic check; the actual empty content handling is done below.
hasEmptyContent
:=
bytes
.
Contains
(
body
,
[]
byte
(
`"content":[]`
))
||
bytes
.
Contains
(
body
,
[]
byte
(
`"content": []`
))
||
bytes
.
Contains
(
body
,
[]
byte
(
`"content" : []`
))
||
bytes
.
Contains
(
body
,
[]
byte
(
`"content" :[]`
))
// Fast path: nothing to process
if
!
hasThinkingContent
&&
!
hasEmptyContent
{
return
body
}
var
req
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
req
);
err
!=
nil
{
return
body
}
modified
:=
false
messages
,
ok
:=
req
[
"messages"
]
.
([]
any
)
if
!
ok
{
return
body
}
// Disable top-level thinking mode for retry to avoid structural/signature constraints upstream.
if
_
,
exists
:=
req
[
"thinking"
];
exists
{
delete
(
req
,
"thinking"
)
modified
=
true
}
newMessages
:=
make
([]
any
,
0
,
len
(
messages
))
for
_
,
msg
:=
range
messages
{
msgMap
,
ok
:=
msg
.
(
map
[
string
]
any
)
if
!
ok
{
newMessages
=
append
(
newMessages
,
msg
)
continue
}
role
,
_
:=
msgMap
[
"role"
]
.
(
string
)
content
,
ok
:=
msgMap
[
"content"
]
.
([]
any
)
if
!
ok
{
// String content or other format - keep as is
newMessages
=
append
(
newMessages
,
msg
)
continue
}
newContent
:=
make
([]
any
,
0
,
len
(
content
))
modifiedThisMsg
:=
false
for
_
,
block
:=
range
content
{
blockMap
,
ok
:=
block
.
(
map
[
string
]
any
)
if
!
ok
{
newContent
=
append
(
newContent
,
block
)
continue
}
blockType
,
_
:=
blockMap
[
"type"
]
.
(
string
)
// Convert thinking blocks to text (preserve content) and drop redacted_thinking.
switch
blockType
{
case
"thinking"
:
modifiedThisMsg
=
true
thinkingText
,
_
:=
blockMap
[
"thinking"
]
.
(
string
)
if
thinkingText
==
""
{
continue
}
newContent
=
append
(
newContent
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
thinkingText
,
})
continue
case
"redacted_thinking"
:
modifiedThisMsg
=
true
continue
}
// Handle blocks without type discriminator but with a "thinking" field.
if
blockType
==
""
{
if
rawThinking
,
hasThinking
:=
blockMap
[
"thinking"
];
hasThinking
{
modifiedThisMsg
=
true
switch
v
:=
rawThinking
.
(
type
)
{
case
string
:
if
v
!=
""
{
newContent
=
append
(
newContent
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
v
})
}
default
:
if
b
,
err
:=
json
.
Marshal
(
v
);
err
==
nil
&&
len
(
b
)
>
0
{
newContent
=
append
(
newContent
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
string
(
b
)})
}
}
continue
}
}
newContent
=
append
(
newContent
,
block
)
}
// Handle empty content: either from filtering or originally empty
if
len
(
newContent
)
==
0
{
modified
=
true
placeholder
:=
"(content removed)"
if
role
==
"assistant"
{
placeholder
=
"(assistant content removed)"
}
newContent
=
append
(
newContent
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
placeholder
,
})
msgMap
[
"content"
]
=
newContent
}
else
if
modifiedThisMsg
{
modified
=
true
msgMap
[
"content"
]
=
newContent
}
newMessages
=
append
(
newMessages
,
msgMap
)
}
if
modified
{
req
[
"messages"
]
=
newMessages
}
else
{
// Avoid rewriting JSON when no changes are needed.
return
body
}
newBody
,
err
:=
json
.
Marshal
(
req
)
if
err
!=
nil
{
return
body
}
return
newBody
}
// FilterSignatureSensitiveBlocksForRetry is a stronger retry filter for cases where upstream errors indicate
// signature/thought_signature validation issues involving tool blocks.
//
// This performs everything in FilterThinkingBlocksForRetry, plus:
// - Convert `tool_use` blocks to text (name/id/input) so we stop sending structured tool calls.
// - Convert `tool_result` blocks to text so we keep tool results visible without tool semantics.
//
// Use this only when needed: converting tool blocks to text changes model behaviour and can increase the
// risk of prompt injection (tool output becomes plain conversation text).
func
FilterSignatureSensitiveBlocksForRetry
(
body
[]
byte
)
[]
byte
{
// Fast path: only run when we see likely relevant constructs.
if
!
bytes
.
Contains
(
body
,
[]
byte
(
`"type":"thinking"`
))
&&
!
bytes
.
Contains
(
body
,
[]
byte
(
`"type": "thinking"`
))
&&
!
bytes
.
Contains
(
body
,
[]
byte
(
`"type":"redacted_thinking"`
))
&&
!
bytes
.
Contains
(
body
,
[]
byte
(
`"type": "redacted_thinking"`
))
{
!
bytes
.
Contains
(
body
,
[]
byte
(
`"type": "redacted_thinking"`
))
&&
!
bytes
.
Contains
(
body
,
[]
byte
(
`"type":"tool_use"`
))
&&
!
bytes
.
Contains
(
body
,
[]
byte
(
`"type": "tool_use"`
))
&&
!
bytes
.
Contains
(
body
,
[]
byte
(
`"type":"tool_result"`
))
&&
!
bytes
.
Contains
(
body
,
[]
byte
(
`"type": "tool_result"`
))
&&
!
bytes
.
Contains
(
body
,
[]
byte
(
`"thinking":`
))
&&
!
bytes
.
Contains
(
body
,
[]
byte
(
`"thinking" :`
))
{
return
body
}
...
...
@@ -111,15 +266,19 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
return
body
}
// DO NOT modify thinking.type - preserve user's intent to use thinking mode
// The issue is with historical message signatures, not the thinking mode itself
modified
:=
false
// Disable top-level thinking for retry to avoid structural/signature constraints upstream.
if
_
,
exists
:=
req
[
"thinking"
];
exists
{
delete
(
req
,
"thinking"
)
modified
=
true
}
messages
,
ok
:=
req
[
"messages"
]
.
([]
any
)
if
!
ok
{
return
body
}
modified
:=
false
newMessages
:=
make
([]
any
,
0
,
len
(
messages
))
for
_
,
msg
:=
range
messages
{
...
...
@@ -132,7 +291,6 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
role
,
_
:=
msgMap
[
"role"
]
.
(
string
)
content
,
ok
:=
msgMap
[
"content"
]
.
([]
any
)
if
!
ok
{
// String content or other format - keep as is
newMessages
=
append
(
newMessages
,
msg
)
continue
}
...
...
@@ -148,43 +306,96 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
}
blockType
,
_
:=
blockMap
[
"type"
]
.
(
string
)
// Remove thinking/redacted_thinking blocks from historical messages
// These have signatures that may be invalid across different accounts
if
blockType
==
"thinking"
||
blockType
==
"redacted_thinking"
{
switch
blockType
{
case
"thinking"
:
modifiedThisMsg
=
true
thinkingText
,
_
:=
blockMap
[
"thinking"
]
.
(
string
)
if
thinkingText
==
""
{
continue
}
newContent
=
append
(
newContent
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
thinkingText
})
continue
case
"redacted_thinking"
:
modifiedThisMsg
=
true
continue
case
"tool_use"
:
modifiedThisMsg
=
true
name
,
_
:=
blockMap
[
"name"
]
.
(
string
)
id
,
_
:=
blockMap
[
"id"
]
.
(
string
)
input
:=
blockMap
[
"input"
]
inputJSON
,
_
:=
json
.
Marshal
(
input
)
text
:=
"(tool_use)"
if
name
!=
""
{
text
+=
" name="
+
name
}
if
id
!=
""
{
text
+=
" id="
+
id
}
if
len
(
inputJSON
)
>
0
&&
string
(
inputJSON
)
!=
"null"
{
text
+=
" input="
+
string
(
inputJSON
)
}
newContent
=
append
(
newContent
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
text
})
continue
case
"tool_result"
:
modifiedThisMsg
=
true
toolUseID
,
_
:=
blockMap
[
"tool_use_id"
]
.
(
string
)
isError
,
_
:=
blockMap
[
"is_error"
]
.
(
bool
)
content
:=
blockMap
[
"content"
]
contentJSON
,
_
:=
json
.
Marshal
(
content
)
text
:=
"(tool_result)"
if
toolUseID
!=
""
{
text
+=
" tool_use_id="
+
toolUseID
}
if
isError
{
text
+=
" is_error=true"
}
if
len
(
contentJSON
)
>
0
&&
string
(
contentJSON
)
!=
"null"
{
text
+=
"
\n
"
+
string
(
contentJSON
)
}
newContent
=
append
(
newContent
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
text
})
continue
}
if
blockType
==
""
{
if
rawThinking
,
hasThinking
:=
blockMap
[
"thinking"
];
hasThinking
{
modifiedThisMsg
=
true
switch
v
:=
rawThinking
.
(
type
)
{
case
string
:
if
v
!=
""
{
newContent
=
append
(
newContent
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
v
})
}
default
:
if
b
,
err
:=
json
.
Marshal
(
v
);
err
==
nil
&&
len
(
b
)
>
0
{
newContent
=
append
(
newContent
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
string
(
b
)})
}
}
continue
}
}
newContent
=
append
(
newContent
,
block
)
}
if
modifiedThisMsg
{
modified
=
true
// Handle empty content after filtering
if
len
(
newContent
)
==
0
{
// For assistant messages, skip entirely (remove from conversation)
// For user messages, add placeholder to avoid empty content error
if
role
==
"user"
{
newContent
=
append
(
newContent
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
"(content removed)"
,
})
msgMap
[
"content"
]
=
newContent
newMessages
=
append
(
newMessages
,
msgMap
)
placeholder
:=
"(content removed)"
if
role
==
"assistant"
{
placeholder
=
"(assistant content removed)"
}
// Skip assistant messages with empty content (don't append)
continue
newContent
=
append
(
newContent
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
placeholder
})
}
msgMap
[
"content"
]
=
newContent
}
newMessages
=
append
(
newMessages
,
msgMap
)
}
if
modified
{
re
q
[
"messages"
]
=
newMessages
if
!
modified
{
re
turn
body
}
req
[
"messages"
]
=
newMessages
newBody
,
err
:=
json
.
Marshal
(
req
)
if
err
!=
nil
{
return
body
...
...
backend/internal/service/gateway_request_test.go
View file @
195e227c
...
...
@@ -151,3 +151,148 @@ func TestFilterThinkingBlocks(t *testing.T) {
})
}
}
func
TestFilterThinkingBlocksForRetry_DisablesThinkingAndPreservesAsText
(
t
*
testing
.
T
)
{
input
:=
[]
byte
(
`{
"model":"claude-3-5-sonnet-20241022",
"thinking":{"type":"enabled","budget_tokens":1024},
"messages":[
{"role":"user","content":[{"type":"text","text":"Hi"}]},
{"role":"assistant","content":[
{"type":"thinking","thinking":"Let me think...","signature":"bad_sig"},
{"type":"text","text":"Answer"}
]}
]
}`
)
out
:=
FilterThinkingBlocksForRetry
(
input
)
var
req
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
out
,
&
req
))
_
,
hasThinking
:=
req
[
"thinking"
]
require
.
False
(
t
,
hasThinking
)
msgs
,
ok
:=
req
[
"messages"
]
.
([]
any
)
require
.
True
(
t
,
ok
)
require
.
Len
(
t
,
msgs
,
2
)
assistant
,
ok
:=
msgs
[
1
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
content
,
ok
:=
assistant
[
"content"
]
.
([]
any
)
require
.
True
(
t
,
ok
)
require
.
Len
(
t
,
content
,
2
)
first
,
ok
:=
content
[
0
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"text"
,
first
[
"type"
])
require
.
Equal
(
t
,
"Let me think..."
,
first
[
"text"
])
}
func
TestFilterThinkingBlocksForRetry_DisablesThinkingEvenWithoutThinkingBlocks
(
t
*
testing
.
T
)
{
input
:=
[]
byte
(
`{
"model":"claude-3-5-sonnet-20241022",
"thinking":{"type":"enabled","budget_tokens":1024},
"messages":[
{"role":"user","content":[{"type":"text","text":"Hi"}]},
{"role":"assistant","content":[{"type":"text","text":"Prefill"}]}
]
}`
)
out
:=
FilterThinkingBlocksForRetry
(
input
)
var
req
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
out
,
&
req
))
_
,
hasThinking
:=
req
[
"thinking"
]
require
.
False
(
t
,
hasThinking
)
}
func
TestFilterThinkingBlocksForRetry_RemovesRedactedThinkingAndKeepsValidContent
(
t
*
testing
.
T
)
{
input
:=
[]
byte
(
`{
"thinking":{"type":"enabled","budget_tokens":1024},
"messages":[
{"role":"assistant","content":[
{"type":"redacted_thinking","data":"..."},
{"type":"text","text":"Visible"}
]}
]
}`
)
out
:=
FilterThinkingBlocksForRetry
(
input
)
var
req
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
out
,
&
req
))
_
,
hasThinking
:=
req
[
"thinking"
]
require
.
False
(
t
,
hasThinking
)
msgs
,
ok
:=
req
[
"messages"
]
.
([]
any
)
require
.
True
(
t
,
ok
)
msg0
,
ok
:=
msgs
[
0
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
content
,
ok
:=
msg0
[
"content"
]
.
([]
any
)
require
.
True
(
t
,
ok
)
require
.
Len
(
t
,
content
,
1
)
content0
,
ok
:=
content
[
0
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"text"
,
content0
[
"type"
])
require
.
Equal
(
t
,
"Visible"
,
content0
[
"text"
])
}
func
TestFilterThinkingBlocksForRetry_EmptyContentGetsPlaceholder
(
t
*
testing
.
T
)
{
input
:=
[]
byte
(
`{
"thinking":{"type":"enabled"},
"messages":[
{"role":"assistant","content":[{"type":"redacted_thinking","data":"..."}]}
]
}`
)
out
:=
FilterThinkingBlocksForRetry
(
input
)
var
req
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
out
,
&
req
))
msgs
,
ok
:=
req
[
"messages"
]
.
([]
any
)
require
.
True
(
t
,
ok
)
msg0
,
ok
:=
msgs
[
0
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
content
,
ok
:=
msg0
[
"content"
]
.
([]
any
)
require
.
True
(
t
,
ok
)
require
.
Len
(
t
,
content
,
1
)
content0
,
ok
:=
content
[
0
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"text"
,
content0
[
"type"
])
require
.
NotEmpty
(
t
,
content0
[
"text"
])
}
func
TestFilterSignatureSensitiveBlocksForRetry_DowngradesTools
(
t
*
testing
.
T
)
{
input
:=
[]
byte
(
`{
"thinking":{"type":"enabled","budget_tokens":1024},
"messages":[
{"role":"assistant","content":[
{"type":"tool_use","id":"t1","name":"Bash","input":{"command":"ls"}},
{"type":"tool_result","tool_use_id":"t1","content":"ok","is_error":false}
]}
]
}`
)
out
:=
FilterSignatureSensitiveBlocksForRetry
(
input
)
var
req
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
out
,
&
req
))
_
,
hasThinking
:=
req
[
"thinking"
]
require
.
False
(
t
,
hasThinking
)
msgs
,
ok
:=
req
[
"messages"
]
.
([]
any
)
require
.
True
(
t
,
ok
)
msg0
,
ok
:=
msgs
[
0
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
content
,
ok
:=
msg0
[
"content"
]
.
([]
any
)
require
.
True
(
t
,
ok
)
require
.
Len
(
t
,
content
,
2
)
content0
,
ok
:=
content
[
0
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
content1
,
ok
:=
content
[
1
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"text"
,
content0
[
"type"
])
require
.
Equal
(
t
,
"text"
,
content1
[
"type"
])
require
.
Contains
(
t
,
content0
[
"text"
],
"tool_use"
)
require
.
Contains
(
t
,
content1
[
"text"
],
"tool_result"
)
}
backend/internal/service/gateway_service.go
View file @
195e227c
...
...
@@ -15,11 +15,14 @@ import (
"regexp"
"sort"
"strings"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
...
...
@@ -30,6 +33,7 @@ const (
claudeAPIURL
=
"https://api.anthropic.com/v1/messages?beta=true"
claudeAPICountTokensURL
=
"https://api.anthropic.com/v1/messages/count_tokens?beta=true"
stickySessionTTL
=
time
.
Hour
// 粘性会话TTL
defaultMaxLineSize
=
10
*
1024
*
1024
claudeCodeSystemPrompt
=
"You are Claude Code, Anthropic's official CLI for Claude."
)
...
...
@@ -933,8 +937,16 @@ func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (s
// 重试相关常量
const
(
maxRetries
=
10
// 最大重试次数
retryDelay
=
3
*
time
.
Second
// 重试等待时间
// 最大尝试次数(包含首次请求)。过多重试会导致请求堆积与资源耗尽。
maxRetryAttempts
=
5
// 指数退避:第 N 次失败后的等待 = retryBaseDelay * 2^(N-1),并且上限为 retryMaxDelay。
retryBaseDelay
=
300
*
time
.
Millisecond
retryMaxDelay
=
3
*
time
.
Second
// 最大重试耗时(包含请求本身耗时 + 退避等待时间)。
// 用于防止极端情况下 goroutine 长时间堆积导致资源耗尽。
maxRetryElapsed
=
10
*
time
.
Second
)
func
(
s
*
GatewayService
)
shouldRetryUpstreamError
(
account
*
Account
,
statusCode
int
)
bool
{
...
...
@@ -957,6 +969,40 @@ func (s *GatewayService) shouldFailoverUpstreamError(statusCode int) bool {
}
}
func
retryBackoffDelay
(
attempt
int
)
time
.
Duration
{
// attempt 从 1 开始,表示第 attempt 次请求刚失败,需要等待后进行第 attempt+1 次请求。
if
attempt
<=
0
{
return
retryBaseDelay
}
delay
:=
retryBaseDelay
*
time
.
Duration
(
1
<<
(
attempt
-
1
))
if
delay
>
retryMaxDelay
{
return
retryMaxDelay
}
return
delay
}
func
sleepWithContext
(
ctx
context
.
Context
,
d
time
.
Duration
)
error
{
if
d
<=
0
{
return
nil
}
timer
:=
time
.
NewTimer
(
d
)
defer
func
()
{
if
!
timer
.
Stop
()
{
select
{
case
<-
timer
.
C
:
default
:
}
}
}()
select
{
case
<-
ctx
.
Done
()
:
return
ctx
.
Err
()
case
<-
timer
.
C
:
return
nil
}
}
// isClaudeCodeClient 判断请求是否来自 Claude Code 客户端
// 简化判断:User-Agent 匹配 + metadata.user_id 存在
func
isClaudeCodeClient
(
userAgent
string
,
metadataUserID
string
)
bool
{
...
...
@@ -1073,7 +1119,8 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 重试循环
var
resp
*
http
.
Response
for
attempt
:=
1
;
attempt
<=
maxRetries
;
attempt
++
{
retryStart
:=
time
.
Now
()
for
attempt
:=
1
;
attempt
<=
maxRetryAttempts
;
attempt
++
{
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
upstreamReq
,
err
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
body
,
token
,
tokenType
,
reqModel
)
if
err
!=
nil
{
...
...
@@ -1083,6 +1130,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 发送请求
resp
,
err
=
s
.
httpUpstream
.
Do
(
upstreamReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
err
!=
nil
{
if
resp
!=
nil
&&
resp
.
Body
!=
nil
{
_
=
resp
.
Body
.
Close
()
}
return
nil
,
fmt
.
Errorf
(
"upstream request failed: %w"
,
err
)
}
...
...
@@ -1093,28 +1143,80 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
_
=
resp
.
Body
.
Close
()
if
s
.
isThinkingBlockSignatureError
(
respBody
)
{
looksLikeToolSignatureError
:=
func
(
msg
string
)
bool
{
m
:=
strings
.
ToLower
(
msg
)
return
strings
.
Contains
(
m
,
"tool_use"
)
||
strings
.
Contains
(
m
,
"tool_result"
)
||
strings
.
Contains
(
m
,
"functioncall"
)
||
strings
.
Contains
(
m
,
"function_call"
)
||
strings
.
Contains
(
m
,
"functionresponse"
)
||
strings
.
Contains
(
m
,
"function_response"
)
}
// 避免在重试预算已耗尽时再发起额外请求
if
time
.
Since
(
retryStart
)
>=
maxRetryElapsed
{
resp
.
Body
=
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
))
break
}
log
.
Printf
(
"Account %d: detected thinking block signature error, retrying with filtered thinking blocks"
,
account
.
ID
)
// 过滤thinking blocks并重试(使用更激进的过滤)
// Conservative two-stage fallback:
// 1) Disable thinking + thinking->text (preserve content)
// 2) Only if upstream still errors AND error message points to tool/function signature issues:
// also downgrade tool_use/tool_result blocks to text.
filteredBody
:=
FilterThinkingBlocksForRetry
(
body
)
retryReq
,
buildErr
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
filteredBody
,
token
,
tokenType
,
reqModel
)
if
buildErr
==
nil
{
retryResp
,
retryErr
:=
s
.
httpUpstream
.
Do
(
retryReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
retryErr
==
nil
{
// 使用重试后的响应,继续后续处理
if
retryResp
.
StatusCode
<
400
{
log
.
Printf
(
"Account %d: signature error retry succeeded"
,
account
.
ID
)
}
else
{
log
.
Printf
(
"Account %d: signature error retry returned status %d"
,
account
.
ID
,
retryResp
.
StatusCode
)
log
.
Printf
(
"Account %d: signature error retry succeeded (thinking downgraded)"
,
account
.
ID
)
resp
=
retryResp
break
}
retryRespBody
,
retryReadErr
:=
io
.
ReadAll
(
io
.
LimitReader
(
retryResp
.
Body
,
2
<<
20
))
_
=
retryResp
.
Body
.
Close
()
if
retryReadErr
==
nil
&&
retryResp
.
StatusCode
==
400
&&
s
.
isThinkingBlockSignatureError
(
retryRespBody
)
{
msg2
:=
extractUpstreamErrorMessage
(
retryRespBody
)
if
looksLikeToolSignatureError
(
msg2
)
&&
time
.
Since
(
retryStart
)
<
maxRetryElapsed
{
log
.
Printf
(
"Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded"
,
account
.
ID
)
filteredBody2
:=
FilterSignatureSensitiveBlocksForRetry
(
body
)
retryReq2
,
buildErr2
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
filteredBody2
,
token
,
tokenType
,
reqModel
)
if
buildErr2
==
nil
{
retryResp2
,
retryErr2
:=
s
.
httpUpstream
.
Do
(
retryReq2
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
retryErr2
==
nil
{
resp
=
retryResp2
break
}
if
retryResp2
!=
nil
&&
retryResp2
.
Body
!=
nil
{
_
=
retryResp2
.
Body
.
Close
()
}
log
.
Printf
(
"Account %d: tool-downgrade signature retry failed: %v"
,
account
.
ID
,
retryErr2
)
}
else
{
log
.
Printf
(
"Account %d: tool-downgrade signature retry build failed: %v"
,
account
.
ID
,
buildErr2
)
}
}
}
// Fall back to the original retry response context.
resp
=
&
http
.
Response
{
StatusCode
:
retryResp
.
StatusCode
,
Header
:
retryResp
.
Header
.
Clone
(),
Body
:
io
.
NopCloser
(
bytes
.
NewReader
(
retryRespBody
)),
}
resp
=
retryResp
break
}
if
retryResp
!=
nil
&&
retryResp
.
Body
!=
nil
{
_
=
retryResp
.
Body
.
Close
()
}
log
.
Printf
(
"Account %d: signature error retry failed: %v"
,
account
.
ID
,
retryErr
)
}
else
{
log
.
Printf
(
"Account %d: signature error retry build request failed: %v"
,
account
.
ID
,
buildErr
)
}
// 重试失败,恢复原始响应体继续处理
// Retry failed: restore original response body and continue handling.
resp
.
Body
=
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
))
break
}
...
...
@@ -1125,11 +1227,27 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 检查是否需要通用重试(排除400,因为400已经在上面特殊处理过了)
if
resp
.
StatusCode
>=
400
&&
resp
.
StatusCode
!=
400
&&
s
.
shouldRetryUpstreamError
(
account
,
resp
.
StatusCode
)
{
if
attempt
<
maxRetries
{
log
.
Printf
(
"Account %d: upstream error %d, retry %d/%d after %v"
,
account
.
ID
,
resp
.
StatusCode
,
attempt
,
maxRetries
,
retryDelay
)
if
attempt
<
maxRetryAttempts
{
elapsed
:=
time
.
Since
(
retryStart
)
if
elapsed
>=
maxRetryElapsed
{
break
}
delay
:=
retryBackoffDelay
(
attempt
)
remaining
:=
maxRetryElapsed
-
elapsed
if
delay
>
remaining
{
delay
=
remaining
}
if
delay
<=
0
{
break
}
log
.
Printf
(
"Account %d: upstream error %d, retry %d/%d after %v (elapsed=%v/%v)"
,
account
.
ID
,
resp
.
StatusCode
,
attempt
,
maxRetryAttempts
,
delay
,
elapsed
,
maxRetryElapsed
)
_
=
resp
.
Body
.
Close
()
time
.
Sleep
(
retryDelay
)
if
err
:=
sleepWithContext
(
ctx
,
delay
);
err
!=
nil
{
return
nil
,
err
}
continue
}
// 最后一次尝试也失败,跳出循环处理重试耗尽
...
...
@@ -1146,6 +1264,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}
break
}
if
resp
==
nil
||
resp
.
Body
==
nil
{
return
nil
,
errors
.
New
(
"upstream request failed: empty response"
)
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
// 处理重试耗尽的情况
...
...
@@ -1229,7 +1350,13 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
targetURL
:=
claudeAPIURL
if
account
.
Type
==
AccountTypeAPIKey
{
baseURL
:=
account
.
GetBaseURL
()
targetURL
=
baseURL
+
"/v1/messages"
if
baseURL
!=
""
{
validatedURL
,
err
:=
s
.
validateUpstreamBaseURL
(
baseURL
)
if
err
!=
nil
{
return
nil
,
err
}
targetURL
=
validatedURL
+
"/v1/messages"
}
}
// OAuth账号:应用统一指纹
...
...
@@ -1537,10 +1664,10 @@ func (s *GatewayService) handleRetryExhaustedSideEffects(ctx context.Context, re
// OAuth/Setup Token 账号的 403:标记账号异常
if
account
.
IsOAuth
()
&&
statusCode
==
403
{
s
.
rateLimitService
.
HandleUpstreamError
(
ctx
,
account
,
statusCode
,
resp
.
Header
,
body
)
log
.
Printf
(
"Account %d: marked as error after %d retries for status %d"
,
account
.
ID
,
maxRetr
ie
s
,
statusCode
)
log
.
Printf
(
"Account %d: marked as error after %d retries for status %d"
,
account
.
ID
,
maxRetr
yAttempt
s
,
statusCode
)
}
else
{
// API Key 未配置错误码:不标记账号状态
log
.
Printf
(
"Account %d: upstream error %d after %d retries (not marking account)"
,
account
.
ID
,
statusCode
,
maxRetr
ie
s
)
log
.
Printf
(
"Account %d: upstream error %d after %d retries (not marking account)"
,
account
.
ID
,
statusCode
,
maxRetr
yAttempt
s
)
}
}
...
...
@@ -1577,6 +1704,10 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
// 更新5h窗口状态
s
.
rateLimitService
.
UpdateSessionWindow
(
ctx
,
account
,
resp
.
Header
)
if
s
.
cfg
!=
nil
{
responseheaders
.
WriteFilteredHeaders
(
c
.
Writer
.
Header
(),
resp
.
Header
,
s
.
cfg
.
Security
.
ResponseHeaders
)
}
// 设置SSE响应头
c
.
Header
(
"Content-Type"
,
"text/event-stream"
)
c
.
Header
(
"Cache-Control"
,
"no-cache"
)
...
...
@@ -1598,51 +1729,133 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
var
firstTokenMs
*
int
scanner
:=
bufio
.
NewScanner
(
resp
.
Body
)
// 设置更大的buffer以处理长行
scanner
.
Buffer
(
make
([]
byte
,
64
*
1024
),
1024
*
1024
)
maxLineSize
:=
defaultMaxLineSize
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
MaxLineSize
>
0
{
maxLineSize
=
s
.
cfg
.
Gateway
.
MaxLineSize
}
scanner
.
Buffer
(
make
([]
byte
,
64
*
1024
),
maxLineSize
)
type
scanEvent
struct
{
line
string
err
error
}
// 独立 goroutine 读取上游,避免读取阻塞导致超时/keepalive无法处理
events
:=
make
(
chan
scanEvent
,
16
)
done
:=
make
(
chan
struct
{})
sendEvent
:=
func
(
ev
scanEvent
)
bool
{
select
{
case
events
<-
ev
:
return
true
case
<-
done
:
return
false
}
}
var
lastReadAt
int64
atomic
.
StoreInt64
(
&
lastReadAt
,
time
.
Now
()
.
UnixNano
())
go
func
()
{
defer
close
(
events
)
for
scanner
.
Scan
()
{
atomic
.
StoreInt64
(
&
lastReadAt
,
time
.
Now
()
.
UnixNano
())
if
!
sendEvent
(
scanEvent
{
line
:
scanner
.
Text
()})
{
return
}
}
if
err
:=
scanner
.
Err
();
err
!=
nil
{
_
=
sendEvent
(
scanEvent
{
err
:
err
})
}
}()
defer
close
(
done
)
needModelReplace
:=
originalModel
!=
mappedModel
streamInterval
:=
time
.
Duration
(
0
)
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
StreamDataIntervalTimeout
>
0
{
streamInterval
=
time
.
Duration
(
s
.
cfg
.
Gateway
.
StreamDataIntervalTimeout
)
*
time
.
Second
}
// 仅监控上游数据间隔超时,避免下游写入阻塞导致误判
var
intervalTicker
*
time
.
Ticker
if
streamInterval
>
0
{
intervalTicker
=
time
.
NewTicker
(
streamInterval
)
defer
intervalTicker
.
Stop
()
}
var
intervalCh
<-
chan
time
.
Time
if
intervalTicker
!=
nil
{
intervalCh
=
intervalTicker
.
C
}
for
scanner
.
Scan
()
{
line
:=
scanner
.
Text
()
if
line
==
"event: error"
{
return
nil
,
errors
.
New
(
"have error in stream"
)
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
errorEventSent
:=
false
sendErrorEvent
:=
func
(
reason
string
)
{
if
errorEventSent
{
return
}
errorEventSent
=
true
_
,
_
=
fmt
.
Fprintf
(
w
,
"event: error
\n
data: {
\"
error
\"
:
\"
%s
\"
}
\n\n
"
,
reason
)
flusher
.
Flush
()
}
// Extract data from SSE line (supports both "data: " and "data:" formats)
if
sseDataRe
.
MatchString
(
line
)
{
data
:=
sseDataRe
.
ReplaceAllString
(
line
,
""
)
needModelReplace
:=
originalModel
!=
mappedModel
// 如果有模型映射,替换响应中的model字段
if
needModelReplace
{
line
=
s
.
replaceModelInSSELine
(
line
,
mappedModel
,
originalModel
)
for
{
select
{
case
ev
,
ok
:=
<-
events
:
if
!
ok
{
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
}
// 转发行
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
if
ev
.
err
!=
nil
{
if
errors
.
Is
(
ev
.
err
,
bufio
.
ErrTooLong
)
{
log
.
Printf
(
"SSE line too long: account=%d max_size=%d error=%v"
,
account
.
ID
,
maxLineSize
,
ev
.
err
)
sendErrorEvent
(
"response_too_large"
)
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
ev
.
err
}
sendErrorEvent
(
"stream_read_error"
)
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream read error: %w"
,
ev
.
err
)
}
line
:=
ev
.
line
if
line
==
"event: error"
{
return
nil
,
errors
.
New
(
"have error in stream"
)
}
flusher
.
Flush
()
// 记录首字时间:第一个有效的 content_block_delta 或 message_start
if
firstTokenMs
==
nil
&&
data
!=
""
&&
data
!=
"[DONE]"
{
ms
:=
int
(
time
.
Since
(
startTime
)
.
Milliseconds
())
firstTokenMs
=
&
ms
// Extract data from SSE line (supports both "data: " and "data:" formats)
if
sseDataRe
.
MatchString
(
line
)
{
data
:=
sseDataRe
.
ReplaceAllString
(
line
,
""
)
// 如果有模型映射,替换响应中的model字段
if
needModelReplace
{
line
=
s
.
replaceModelInSSELine
(
line
,
mappedModel
,
originalModel
)
}
// 转发行
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
sendErrorEvent
(
"write_failed"
)
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
}
flusher
.
Flush
()
// 记录首字时间:第一个有效的 content_block_delta 或 message_start
if
firstTokenMs
==
nil
&&
data
!=
""
&&
data
!=
"[DONE]"
{
ms
:=
int
(
time
.
Since
(
startTime
)
.
Milliseconds
())
firstTokenMs
=
&
ms
}
s
.
parseSSEUsage
(
data
,
usage
)
}
else
{
// 非 data 行直接转发
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
sendErrorEvent
(
"write_failed"
)
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
}
flusher
.
Flush
()
}
s
.
parseSSEUsage
(
data
,
usage
)
}
else
{
// 非 data 行直接转发
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
ni
l
{
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
case
<-
intervalCh
:
lastRead
:=
time
.
Unix
(
0
,
atomic
.
LoadInt64
(
&
lastReadAt
))
if
time
.
Since
(
lastRead
)
<
streamInterva
l
{
continue
}
flusher
.
Flush
()
log
.
Printf
(
"Stream data interval timeout: account=%d model=%s interval=%s"
,
account
.
ID
,
originalModel
,
streamInterval
)
sendErrorEvent
(
"stream_timeout"
)
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream data interval timeout"
)
}
}
if
err
:=
scanner
.
Err
();
err
!=
nil
{
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream read error: %w"
,
err
)
}
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
}
// replaceModelInSSELine 替换SSE数据行中的model字段
...
...
@@ -1747,15 +1960,17 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
body
=
s
.
replaceModelInResponseBody
(
body
,
mappedModel
,
originalModel
)
}
// 透传响应头
for
key
,
values
:=
range
resp
.
Header
{
for
_
,
value
:=
range
values
{
c
.
Header
(
key
,
value
)
responseheaders
.
WriteFilteredHeaders
(
c
.
Writer
.
Header
(),
resp
.
Header
,
s
.
cfg
.
Security
.
ResponseHeaders
)
contentType
:=
"application/json"
if
s
.
cfg
!=
nil
&&
!
s
.
cfg
.
Security
.
ResponseHeaders
.
Enabled
{
if
upstreamType
:=
resp
.
Header
.
Get
(
"Content-Type"
);
upstreamType
!=
""
{
contentType
=
upstreamType
}
}
// 写入响应
c
.
Data
(
resp
.
StatusCode
,
"application/json"
,
body
)
c
.
Data
(
resp
.
StatusCode
,
contentType
,
body
)
return
&
response
.
Usage
,
nil
}
...
...
@@ -1989,7 +2204,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
if
resp
.
StatusCode
==
400
&&
s
.
isThinkingBlockSignatureError
(
respBody
)
{
log
.
Printf
(
"Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks"
,
account
.
ID
)
filteredBody
:=
FilterThinkingBlocks
(
body
)
filteredBody
:=
FilterThinkingBlocks
ForRetry
(
body
)
retryReq
,
buildErr
:=
s
.
buildCountTokensRequest
(
ctx
,
c
,
account
,
filteredBody
,
token
,
tokenType
,
reqModel
)
if
buildErr
==
nil
{
retryResp
,
retryErr
:=
s
.
httpUpstream
.
Do
(
retryReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
...
...
@@ -2045,7 +2260,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
targetURL
:=
claudeAPICountTokensURL
if
account
.
Type
==
AccountTypeAPIKey
{
baseURL
:=
account
.
GetBaseURL
()
targetURL
=
baseURL
+
"/v1/messages/count_tokens"
if
baseURL
!=
""
{
validatedURL
,
err
:=
s
.
validateUpstreamBaseURL
(
baseURL
)
if
err
!=
nil
{
return
nil
,
err
}
targetURL
=
validatedURL
+
"/v1/messages/count_tokens"
}
}
// OAuth 账号:应用统一指纹和重写 userID
...
...
@@ -2125,6 +2346,25 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m
})
}
func
(
s
*
GatewayService
)
validateUpstreamBaseURL
(
raw
string
)
(
string
,
error
)
{
if
s
.
cfg
!=
nil
&&
!
s
.
cfg
.
Security
.
URLAllowlist
.
Enabled
{
normalized
,
err
:=
urlvalidator
.
ValidateURLFormat
(
raw
,
s
.
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
)
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"invalid base_url: %w"
,
err
)
}
return
normalized
,
nil
}
normalized
,
err
:=
urlvalidator
.
ValidateHTTPSURL
(
raw
,
urlvalidator
.
ValidationOptions
{
AllowedHosts
:
s
.
cfg
.
Security
.
URLAllowlist
.
UpstreamHosts
,
RequireAllowlist
:
true
,
AllowPrivate
:
s
.
cfg
.
Security
.
URLAllowlist
.
AllowPrivateHosts
,
})
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"invalid base_url: %w"
,
err
)
}
return
normalized
,
nil
}
// GetAvailableModels returns the list of models available for a group
// It aggregates model_mapping keys from all schedulable accounts in the group
func
(
s
*
GatewayService
)
GetAvailableModels
(
ctx
context
.
Context
,
groupID
*
int64
,
platform
string
)
[]
string
{
...
...
backend/internal/service/gemini_messages_compat_service.go
View file @
195e227c
...
...
@@ -18,9 +18,12 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
)
...
...
@@ -41,6 +44,7 @@ type GeminiMessagesCompatService struct {
rateLimitService
*
RateLimitService
httpUpstream
HTTPUpstream
antigravityGatewayService
*
AntigravityGatewayService
cfg
*
config
.
Config
}
func
NewGeminiMessagesCompatService
(
...
...
@@ -51,6 +55,7 @@ func NewGeminiMessagesCompatService(
rateLimitService
*
RateLimitService
,
httpUpstream
HTTPUpstream
,
antigravityGatewayService
*
AntigravityGatewayService
,
cfg
*
config
.
Config
,
)
*
GeminiMessagesCompatService
{
return
&
GeminiMessagesCompatService
{
accountRepo
:
accountRepo
,
...
...
@@ -60,6 +65,7 @@ func NewGeminiMessagesCompatService(
rateLimitService
:
rateLimitService
,
httpUpstream
:
httpUpstream
,
antigravityGatewayService
:
antigravityGatewayService
,
cfg
:
cfg
,
}
}
...
...
@@ -230,6 +236,25 @@ func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *Antigravit
return
s
.
antigravityGatewayService
}
func
(
s
*
GeminiMessagesCompatService
)
validateUpstreamBaseURL
(
raw
string
)
(
string
,
error
)
{
if
s
.
cfg
!=
nil
&&
!
s
.
cfg
.
Security
.
URLAllowlist
.
Enabled
{
normalized
,
err
:=
urlvalidator
.
ValidateURLFormat
(
raw
,
s
.
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
)
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"invalid base_url: %w"
,
err
)
}
return
normalized
,
nil
}
normalized
,
err
:=
urlvalidator
.
ValidateHTTPSURL
(
raw
,
urlvalidator
.
ValidationOptions
{
AllowedHosts
:
s
.
cfg
.
Security
.
URLAllowlist
.
UpstreamHosts
,
RequireAllowlist
:
true
,
AllowPrivate
:
s
.
cfg
.
Security
.
URLAllowlist
.
AllowPrivateHosts
,
})
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"invalid base_url: %w"
,
err
)
}
return
normalized
,
nil
}
// HasAntigravityAccounts 检查是否有可用的 antigravity 账户
func
(
s
*
GeminiMessagesCompatService
)
HasAntigravityAccounts
(
ctx
context
.
Context
,
groupID
*
int64
)
(
bool
,
error
)
{
var
accounts
[]
Account
...
...
@@ -359,6 +384,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
if
err
!=
nil
{
return
nil
,
s
.
writeClaudeError
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
err
.
Error
())
}
originalClaudeBody
:=
body
proxyURL
:=
""
if
account
.
ProxyID
!=
nil
&&
account
.
Proxy
!=
nil
{
...
...
@@ -381,16 +407,20 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
return
nil
,
""
,
errors
.
New
(
"gemini api_key not configured"
)
}
baseURL
:=
strings
.
Trim
Right
(
account
.
GetCredential
(
"base_url"
)
,
"/"
)
baseURL
:=
strings
.
Trim
Space
(
account
.
GetCredential
(
"base_url"
))
if
baseURL
==
""
{
baseURL
=
geminicli
.
AIStudioBaseURL
}
normalizedBaseURL
,
err
:=
s
.
validateUpstreamBaseURL
(
baseURL
)
if
err
!=
nil
{
return
nil
,
""
,
err
}
action
:=
"generateContent"
if
req
.
Stream
{
action
=
"streamGenerateContent"
}
fullURL
:=
fmt
.
Sprintf
(
"%s/v1beta/models/%s:%s"
,
strings
.
TrimRight
(
b
aseURL
,
"/"
),
mappedModel
,
action
)
fullURL
:=
fmt
.
Sprintf
(
"%s/v1beta/models/%s:%s"
,
strings
.
TrimRight
(
normalizedB
aseURL
,
"/"
),
mappedModel
,
action
)
if
req
.
Stream
{
fullURL
+=
"?alt=sse"
}
...
...
@@ -427,7 +457,11 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
if
projectID
!=
""
{
// Mode 1: Code Assist API
fullURL
:=
fmt
.
Sprintf
(
"%s/v1internal:%s"
,
geminicli
.
GeminiCliBaseURL
,
action
)
baseURL
,
err
:=
s
.
validateUpstreamBaseURL
(
geminicli
.
GeminiCliBaseURL
)
if
err
!=
nil
{
return
nil
,
""
,
err
}
fullURL
:=
fmt
.
Sprintf
(
"%s/v1internal:%s"
,
strings
.
TrimRight
(
baseURL
,
"/"
),
action
)
if
useUpstreamStream
{
fullURL
+=
"?alt=sse"
}
...
...
@@ -453,12 +487,16 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
return
upstreamReq
,
"x-request-id"
,
nil
}
else
{
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
baseURL
:=
strings
.
Trim
Right
(
account
.
GetCredential
(
"base_url"
)
,
"/"
)
baseURL
:=
strings
.
Trim
Space
(
account
.
GetCredential
(
"base_url"
))
if
baseURL
==
""
{
baseURL
=
geminicli
.
AIStudioBaseURL
}
normalizedBaseURL
,
err
:=
s
.
validateUpstreamBaseURL
(
baseURL
)
if
err
!=
nil
{
return
nil
,
""
,
err
}
fullURL
:=
fmt
.
Sprintf
(
"%s/v1beta/models/%s:%s"
,
baseURL
,
mappedModel
,
action
)
fullURL
:=
fmt
.
Sprintf
(
"%s/v1beta/models/%s:%s"
,
strings
.
TrimRight
(
normalizedBaseURL
,
"/"
)
,
mappedModel
,
action
)
if
useUpstreamStream
{
fullURL
+=
"?alt=sse"
}
...
...
@@ -479,6 +517,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}
var
resp
*
http
.
Response
signatureRetryStage
:=
0
for
attempt
:=
1
;
attempt
<=
geminiMaxRetries
;
attempt
++
{
upstreamReq
,
idHeader
,
err
:=
buildReq
(
ctx
)
if
err
!=
nil
{
...
...
@@ -503,6 +542,46 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
return
nil
,
s
.
writeClaudeError
(
c
,
http
.
StatusBadGateway
,
"upstream_error"
,
"Upstream request failed after retries: "
+
sanitizeUpstreamErrorMessage
(
err
.
Error
()))
}
// Special-case: signature/thought_signature validation errors are not transient, but may be fixed by
// downgrading Claude thinking/tool history to plain text (conservative two-stage retry).
if
resp
.
StatusCode
==
http
.
StatusBadRequest
&&
signatureRetryStage
<
2
{
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
_
=
resp
.
Body
.
Close
()
if
isGeminiSignatureRelatedError
(
respBody
)
{
var
strippedClaudeBody
[]
byte
stageName
:=
""
switch
signatureRetryStage
{
case
0
:
// Stage 1: disable thinking + thinking->text
strippedClaudeBody
=
FilterThinkingBlocksForRetry
(
originalClaudeBody
)
stageName
=
"thinking-only"
signatureRetryStage
=
1
default
:
// Stage 2: additionally downgrade tool_use/tool_result blocks to text
strippedClaudeBody
=
FilterSignatureSensitiveBlocksForRetry
(
originalClaudeBody
)
stageName
=
"thinking+tools"
signatureRetryStage
=
2
}
retryGeminiReq
,
txErr
:=
convertClaudeMessagesToGeminiGenerateContent
(
strippedClaudeBody
)
if
txErr
==
nil
{
log
.
Printf
(
"Gemini account %d: detected signature-related 400, retrying with downgraded Claude blocks (%s)"
,
account
.
ID
,
stageName
)
geminiReq
=
retryGeminiReq
// Consume one retry budget attempt and continue with the updated request payload.
sleepGeminiBackoff
(
1
)
continue
}
}
// Restore body for downstream error handling.
resp
=
&
http
.
Response
{
StatusCode
:
http
.
StatusBadRequest
,
Header
:
resp
.
Header
.
Clone
(),
Body
:
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
)),
}
break
}
if
resp
.
StatusCode
>=
400
&&
s
.
shouldRetryGeminiUpstreamError
(
account
,
resp
.
StatusCode
)
{
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
_
=
resp
.
Body
.
Close
()
...
...
@@ -600,6 +679,14 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
},
nil
}
func
isGeminiSignatureRelatedError
(
respBody
[]
byte
)
bool
{
msg
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
extractAntigravityErrorMessage
(
respBody
)))
if
msg
==
""
{
msg
=
strings
.
ToLower
(
string
(
respBody
))
}
return
strings
.
Contains
(
msg
,
"thought_signature"
)
||
strings
.
Contains
(
msg
,
"signature"
)
}
func
(
s
*
GeminiMessagesCompatService
)
ForwardNative
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
originalModel
string
,
action
string
,
stream
bool
,
body
[]
byte
)
(
*
ForwardResult
,
error
)
{
startTime
:=
time
.
Now
()
...
...
@@ -650,12 +737,16 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return
nil
,
""
,
errors
.
New
(
"gemini api_key not configured"
)
}
baseURL
:=
strings
.
Trim
Right
(
account
.
GetCredential
(
"base_url"
)
,
"/"
)
baseURL
:=
strings
.
Trim
Space
(
account
.
GetCredential
(
"base_url"
))
if
baseURL
==
""
{
baseURL
=
geminicli
.
AIStudioBaseURL
}
normalizedBaseURL
,
err
:=
s
.
validateUpstreamBaseURL
(
baseURL
)
if
err
!=
nil
{
return
nil
,
""
,
err
}
fullURL
:=
fmt
.
Sprintf
(
"%s/v1beta/models/%s:%s"
,
strings
.
TrimRight
(
b
aseURL
,
"/"
),
mappedModel
,
upstreamAction
)
fullURL
:=
fmt
.
Sprintf
(
"%s/v1beta/models/%s:%s"
,
strings
.
TrimRight
(
normalizedB
aseURL
,
"/"
),
mappedModel
,
upstreamAction
)
if
useUpstreamStream
{
fullURL
+=
"?alt=sse"
}
...
...
@@ -687,7 +778,11 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
if
projectID
!=
""
&&
!
forceAIStudio
{
// Mode 1: Code Assist API
fullURL
:=
fmt
.
Sprintf
(
"%s/v1internal:%s"
,
geminicli
.
GeminiCliBaseURL
,
upstreamAction
)
baseURL
,
err
:=
s
.
validateUpstreamBaseURL
(
geminicli
.
GeminiCliBaseURL
)
if
err
!=
nil
{
return
nil
,
""
,
err
}
fullURL
:=
fmt
.
Sprintf
(
"%s/v1internal:%s"
,
strings
.
TrimRight
(
baseURL
,
"/"
),
upstreamAction
)
if
useUpstreamStream
{
fullURL
+=
"?alt=sse"
}
...
...
@@ -713,12 +808,16 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return
upstreamReq
,
"x-request-id"
,
nil
}
else
{
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
baseURL
:=
strings
.
Trim
Right
(
account
.
GetCredential
(
"base_url"
)
,
"/"
)
baseURL
:=
strings
.
Trim
Space
(
account
.
GetCredential
(
"base_url"
))
if
baseURL
==
""
{
baseURL
=
geminicli
.
AIStudioBaseURL
}
normalizedBaseURL
,
err
:=
s
.
validateUpstreamBaseURL
(
baseURL
)
if
err
!=
nil
{
return
nil
,
""
,
err
}
fullURL
:=
fmt
.
Sprintf
(
"%s/v1beta/models/%s:%s"
,
baseURL
,
mappedModel
,
upstreamAction
)
fullURL
:=
fmt
.
Sprintf
(
"%s/v1beta/models/%s:%s"
,
strings
.
TrimRight
(
normalizedBaseURL
,
"/"
)
,
mappedModel
,
upstreamAction
)
if
useUpstreamStream
{
fullURL
+=
"?alt=sse"
}
...
...
@@ -1652,6 +1751,8 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
_
=
json
.
Unmarshal
(
respBody
,
&
parsed
)
}
responseheaders
.
WriteFilteredHeaders
(
c
.
Writer
.
Header
(),
resp
.
Header
,
s
.
cfg
.
Security
.
ResponseHeaders
)
contentType
:=
resp
.
Header
.
Get
(
"Content-Type"
)
if
contentType
==
""
{
contentType
=
"application/json"
...
...
@@ -1676,6 +1777,10 @@ func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Conte
}
log
.
Printf
(
"[GeminiAPI] ===================================================="
)
if
s
.
cfg
!=
nil
{
responseheaders
.
WriteFilteredHeaders
(
c
.
Writer
.
Header
(),
resp
.
Header
,
s
.
cfg
.
Security
.
ResponseHeaders
)
}
c
.
Status
(
resp
.
StatusCode
)
c
.
Header
(
"Cache-Control"
,
"no-cache"
)
c
.
Header
(
"Connection"
,
"keep-alive"
)
...
...
@@ -1773,11 +1878,15 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
return
nil
,
errors
.
New
(
"invalid path"
)
}
baseURL
:=
strings
.
Trim
Right
(
account
.
GetCredential
(
"base_url"
)
,
"/"
)
baseURL
:=
strings
.
Trim
Space
(
account
.
GetCredential
(
"base_url"
))
if
baseURL
==
""
{
baseURL
=
geminicli
.
AIStudioBaseURL
}
fullURL
:=
strings
.
TrimRight
(
baseURL
,
"/"
)
+
path
normalizedBaseURL
,
err
:=
s
.
validateUpstreamBaseURL
(
baseURL
)
if
err
!=
nil
{
return
nil
,
err
}
fullURL
:=
strings
.
TrimRight
(
normalizedBaseURL
,
"/"
)
+
path
var
proxyURL
string
if
account
.
ProxyID
!=
nil
&&
account
.
Proxy
!=
nil
{
...
...
@@ -1816,9 +1925,14 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
body
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
8
<<
20
))
wwwAuthenticate
:=
resp
.
Header
.
Get
(
"Www-Authenticate"
)
filteredHeaders
:=
responseheaders
.
FilterHeaders
(
resp
.
Header
,
s
.
cfg
.
Security
.
ResponseHeaders
)
if
wwwAuthenticate
!=
""
{
filteredHeaders
.
Set
(
"Www-Authenticate"
,
wwwAuthenticate
)
}
return
&
UpstreamHTTPResult
{
StatusCode
:
resp
.
StatusCode
,
Headers
:
resp
.
Header
.
Clone
()
,
Headers
:
filteredHeaders
,
Body
:
body
,
},
nil
}
...
...
backend/internal/service/gemini_oauth_service.go
View file @
195e227c
...
...
@@ -1000,8 +1000,9 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR
req
.
Header
.
Set
(
"User-Agent"
,
geminicli
.
GeminiCLIUserAgent
)
client
,
err
:=
httpclient
.
GetClient
(
httpclient
.
Options
{
ProxyURL
:
strings
.
TrimSpace
(
proxyURL
),
Timeout
:
30
*
time
.
Second
,
ProxyURL
:
strings
.
TrimSpace
(
proxyURL
),
Timeout
:
30
*
time
.
Second
,
ValidateResolvedIP
:
true
,
})
if
err
!=
nil
{
client
=
&
http
.
Client
{
Timeout
:
30
*
time
.
Second
}
...
...
backend/internal/service/openai_gateway_service.go
View file @
195e227c
...
...
@@ -16,9 +16,12 @@ import (
"sort"
"strconv"
"strings"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
)
...
...
@@ -630,10 +633,14 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
case
AccountTypeAPIKey
:
// API Key accounts use Platform API or custom base URL
baseURL
:=
account
.
GetOpenAIBaseURL
()
if
baseURL
!=
""
{
targetURL
=
baseURL
+
"/responses"
}
else
{
if
baseURL
==
""
{
targetURL
=
openaiPlatformAPIURL
}
else
{
validatedURL
,
err
:=
s
.
validateUpstreamBaseURL
(
baseURL
)
if
err
!=
nil
{
return
nil
,
err
}
targetURL
=
validatedURL
+
"/responses"
}
default
:
targetURL
=
openaiPlatformAPIURL
...
...
@@ -755,6 +762,10 @@ type openaiStreamingResult struct {
}
func
(
s
*
OpenAIGatewayService
)
handleStreamingResponse
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
Account
,
startTime
time
.
Time
,
originalModel
,
mappedModel
string
)
(
*
openaiStreamingResult
,
error
)
{
if
s
.
cfg
!=
nil
{
responseheaders
.
WriteFilteredHeaders
(
c
.
Writer
.
Header
(),
resp
.
Header
,
s
.
cfg
.
Security
.
ResponseHeaders
)
}
// Set SSE response headers
c
.
Header
(
"Content-Type"
,
"text/event-stream"
)
c
.
Header
(
"Cache-Control"
,
"no-cache"
)
...
...
@@ -775,48 +786,158 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
usage
:=
&
OpenAIUsage
{}
var
firstTokenMs
*
int
scanner
:=
bufio
.
NewScanner
(
resp
.
Body
)
scanner
.
Buffer
(
make
([]
byte
,
64
*
1024
),
1024
*
1024
)
maxLineSize
:=
defaultMaxLineSize
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
MaxLineSize
>
0
{
maxLineSize
=
s
.
cfg
.
Gateway
.
MaxLineSize
}
scanner
.
Buffer
(
make
([]
byte
,
64
*
1024
),
maxLineSize
)
type
scanEvent
struct
{
line
string
err
error
}
// 独立 goroutine 读取上游,避免读取阻塞影响 keepalive/超时处理
events
:=
make
(
chan
scanEvent
,
16
)
done
:=
make
(
chan
struct
{})
sendEvent
:=
func
(
ev
scanEvent
)
bool
{
select
{
case
events
<-
ev
:
return
true
case
<-
done
:
return
false
}
}
var
lastReadAt
int64
atomic
.
StoreInt64
(
&
lastReadAt
,
time
.
Now
()
.
UnixNano
())
go
func
()
{
defer
close
(
events
)
for
scanner
.
Scan
()
{
atomic
.
StoreInt64
(
&
lastReadAt
,
time
.
Now
()
.
UnixNano
())
if
!
sendEvent
(
scanEvent
{
line
:
scanner
.
Text
()})
{
return
}
}
if
err
:=
scanner
.
Err
();
err
!=
nil
{
_
=
sendEvent
(
scanEvent
{
err
:
err
})
}
}()
defer
close
(
done
)
streamInterval
:=
time
.
Duration
(
0
)
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
StreamDataIntervalTimeout
>
0
{
streamInterval
=
time
.
Duration
(
s
.
cfg
.
Gateway
.
StreamDataIntervalTimeout
)
*
time
.
Second
}
// 仅监控上游数据间隔超时,不被下游写入阻塞影响
var
intervalTicker
*
time
.
Ticker
if
streamInterval
>
0
{
intervalTicker
=
time
.
NewTicker
(
streamInterval
)
defer
intervalTicker
.
Stop
()
}
var
intervalCh
<-
chan
time
.
Time
if
intervalTicker
!=
nil
{
intervalCh
=
intervalTicker
.
C
}
keepaliveInterval
:=
time
.
Duration
(
0
)
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
StreamKeepaliveInterval
>
0
{
keepaliveInterval
=
time
.
Duration
(
s
.
cfg
.
Gateway
.
StreamKeepaliveInterval
)
*
time
.
Second
}
// 下游 keepalive 仅用于防止代理空闲断开
var
keepaliveTicker
*
time
.
Ticker
if
keepaliveInterval
>
0
{
keepaliveTicker
=
time
.
NewTicker
(
keepaliveInterval
)
defer
keepaliveTicker
.
Stop
()
}
var
keepaliveCh
<-
chan
time
.
Time
if
keepaliveTicker
!=
nil
{
keepaliveCh
=
keepaliveTicker
.
C
}
// 记录上次收到上游数据的时间,用于控制 keepalive 发送频率
lastDataAt
:=
time
.
Now
()
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
errorEventSent
:=
false
sendErrorEvent
:=
func
(
reason
string
)
{
if
errorEventSent
{
return
}
errorEventSent
=
true
_
,
_
=
fmt
.
Fprintf
(
w
,
"event: error
\n
data: {
\"
error
\"
:
\"
%s
\"
}
\n\n
"
,
reason
)
flusher
.
Flush
()
}
needModelReplace
:=
originalModel
!=
mappedModel
for
scanner
.
Scan
()
{
line
:=
scanner
.
Text
()
for
{
select
{
case
ev
,
ok
:=
<-
events
:
if
!
ok
{
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
}
if
ev
.
err
!=
nil
{
if
errors
.
Is
(
ev
.
err
,
bufio
.
ErrTooLong
)
{
log
.
Printf
(
"SSE line too long: account=%d max_size=%d error=%v"
,
account
.
ID
,
maxLineSize
,
ev
.
err
)
sendErrorEvent
(
"response_too_large"
)
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
ev
.
err
}
sendErrorEvent
(
"stream_read_error"
)
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream read error: %w"
,
ev
.
err
)
}
// Extract data from SSE line (supports both "data: " and "data:" formats)
if
openaiSSEDataRe
.
MatchString
(
line
)
{
data
:=
openaiSSEDataRe
.
ReplaceAllString
(
line
,
""
)
line
:=
ev
.
line
lastDataAt
=
time
.
Now
()
// Extract data from SSE line (supports both "data: " and "data:" formats)
if
openaiSSEDataRe
.
MatchString
(
line
)
{
data
:=
openaiSSEDataRe
.
ReplaceAllString
(
line
,
""
)
// Replace model in response if needed
if
needModelReplace
{
line
=
s
.
replaceModelInSSELine
(
line
,
mappedModel
,
originalModel
)
}
// Replace model in response if needed
if
needModelReplace
{
line
=
s
.
replaceModelInSSELine
(
line
,
mappedModel
,
originalModel
)
// Forward line
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
sendErrorEvent
(
"write_failed"
)
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
}
flusher
.
Flush
()
// Record first token time
if
firstTokenMs
==
nil
&&
data
!=
""
&&
data
!=
"[DONE]"
{
ms
:=
int
(
time
.
Since
(
startTime
)
.
Milliseconds
())
firstTokenMs
=
&
ms
}
s
.
parseSSEUsage
(
data
,
usage
)
}
else
{
// Forward non-data lines as-is
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
sendErrorEvent
(
"write_failed"
)
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
}
flusher
.
Flush
()
}
// Forward line
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
case
<-
intervalCh
:
lastRead
:=
time
.
Unix
(
0
,
atomic
.
LoadInt64
(
&
lastReadAt
))
if
time
.
Since
(
lastRead
)
<
streamInterval
{
continue
}
flusher
.
Flush
()
log
.
Printf
(
"Stream data interval timeout: account=%d model=%s interval=%s"
,
account
.
ID
,
originalModel
,
streamInterval
)
sendErrorEvent
(
"stream_timeout"
)
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream data interval timeout"
)
// Record first token time
if
firstTokenMs
==
nil
&&
data
!=
""
&&
data
!=
"[DONE]"
{
ms
:=
int
(
time
.
Since
(
startTime
)
.
Milliseconds
())
firstTokenMs
=
&
ms
case
<-
keepaliveCh
:
if
time
.
Since
(
lastDataAt
)
<
keepaliveInterval
{
continue
}
s
.
parseSSEUsage
(
data
,
usage
)
}
else
{
// Forward non-data lines as-is
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
if
_
,
err
:=
fmt
.
Fprint
(
w
,
":
\n\n
"
);
err
!=
nil
{
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
}
flusher
.
Flush
()
}
}
if
err
:=
scanner
.
Err
();
err
!=
nil
{
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream read error: %w"
,
err
)
}
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
}
func
(
s
*
OpenAIGatewayService
)
replaceModelInSSELine
(
line
,
fromModel
,
toModel
string
)
string
{
...
...
@@ -911,18 +1032,39 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
body
=
s
.
replaceModelInResponseBody
(
body
,
mappedModel
,
originalModel
)
}
// Pass through headers
for
key
,
values
:=
range
resp
.
Header
{
for
_
,
value
:=
range
values
{
c
.
Header
(
key
,
value
)
responseheaders
.
WriteFilteredHeaders
(
c
.
Writer
.
Header
(),
resp
.
Header
,
s
.
cfg
.
Security
.
ResponseHeaders
)
contentType
:=
"application/json"
if
s
.
cfg
!=
nil
&&
!
s
.
cfg
.
Security
.
ResponseHeaders
.
Enabled
{
if
upstreamType
:=
resp
.
Header
.
Get
(
"Content-Type"
);
upstreamType
!=
""
{
contentType
=
upstreamType
}
}
c
.
Data
(
resp
.
StatusCode
,
"application/json"
,
body
)
c
.
Data
(
resp
.
StatusCode
,
contentType
,
body
)
return
usage
,
nil
}
func
(
s
*
OpenAIGatewayService
)
validateUpstreamBaseURL
(
raw
string
)
(
string
,
error
)
{
if
s
.
cfg
!=
nil
&&
!
s
.
cfg
.
Security
.
URLAllowlist
.
Enabled
{
normalized
,
err
:=
urlvalidator
.
ValidateURLFormat
(
raw
,
s
.
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
)
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"invalid base_url: %w"
,
err
)
}
return
normalized
,
nil
}
normalized
,
err
:=
urlvalidator
.
ValidateHTTPSURL
(
raw
,
urlvalidator
.
ValidationOptions
{
AllowedHosts
:
s
.
cfg
.
Security
.
URLAllowlist
.
UpstreamHosts
,
RequireAllowlist
:
true
,
AllowPrivate
:
s
.
cfg
.
Security
.
URLAllowlist
.
AllowPrivateHosts
,
})
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"invalid base_url: %w"
,
err
)
}
return
normalized
,
nil
}
func
(
s
*
OpenAIGatewayService
)
replaceModelInResponseBody
(
body
[]
byte
,
fromModel
,
toModel
string
)
[]
byte
{
var
resp
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
resp
);
err
!=
nil
{
...
...
backend/internal/service/openai_gateway_service_test.go
0 → 100644
View file @
195e227c
package
service
import
(
"bufio"
"bytes"
"errors"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
)
func
TestOpenAIStreamingTimeout
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
cfg
:=
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
StreamDataIntervalTimeout
:
1
,
StreamKeepaliveInterval
:
0
,
MaxLineSize
:
defaultMaxLineSize
,
},
}
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
}
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
pr
,
pw
:=
io
.
Pipe
()
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
pr
,
Header
:
http
.
Header
{},
}
start
:=
time
.
Now
()
_
,
err
:=
svc
.
handleStreamingResponse
(
c
.
Request
.
Context
(),
resp
,
c
,
&
Account
{
ID
:
1
},
start
,
"model"
,
"model"
)
_
=
pw
.
Close
()
_
=
pr
.
Close
()
if
err
==
nil
||
!
strings
.
Contains
(
err
.
Error
(),
"stream data interval timeout"
)
{
t
.
Fatalf
(
"expected stream timeout error, got %v"
,
err
)
}
if
!
strings
.
Contains
(
rec
.
Body
.
String
(),
"stream_timeout"
)
{
t
.
Fatalf
(
"expected stream_timeout SSE error, got %q"
,
rec
.
Body
.
String
())
}
}
func
TestOpenAIStreamingTooLong
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
cfg
:=
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
StreamDataIntervalTimeout
:
0
,
StreamKeepaliveInterval
:
0
,
MaxLineSize
:
64
*
1024
,
},
}
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
}
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
pr
,
pw
:=
io
.
Pipe
()
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
pr
,
Header
:
http
.
Header
{},
}
go
func
()
{
defer
func
()
{
_
=
pw
.
Close
()
}()
// 写入超过 MaxLineSize 的单行数据,触发 ErrTooLong
payload
:=
"data: "
+
strings
.
Repeat
(
"a"
,
128
*
1024
)
+
"
\n
"
_
,
_
=
pw
.
Write
([]
byte
(
payload
))
}()
_
,
err
:=
svc
.
handleStreamingResponse
(
c
.
Request
.
Context
(),
resp
,
c
,
&
Account
{
ID
:
2
},
time
.
Now
(),
"model"
,
"model"
)
_
=
pr
.
Close
()
if
!
errors
.
Is
(
err
,
bufio
.
ErrTooLong
)
{
t
.
Fatalf
(
"expected ErrTooLong, got %v"
,
err
)
}
if
!
strings
.
Contains
(
rec
.
Body
.
String
(),
"response_too_large"
)
{
t
.
Fatalf
(
"expected response_too_large SSE error, got %q"
,
rec
.
Body
.
String
())
}
}
func
TestOpenAINonStreamingContentTypePassThrough
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
cfg
:=
&
config
.
Config
{
Security
:
config
.
SecurityConfig
{
ResponseHeaders
:
config
.
ResponseHeaderConfig
{
Enabled
:
false
},
},
}
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
}
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
body
:=
[]
byte
(
`{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`
)
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
io
.
NopCloser
(
bytes
.
NewReader
(
body
)),
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"application/vnd.test+json"
}},
}
_
,
err
:=
svc
.
handleNonStreamingResponse
(
c
.
Request
.
Context
(),
resp
,
c
,
&
Account
{},
"model"
,
"model"
)
if
err
!=
nil
{
t
.
Fatalf
(
"handleNonStreamingResponse error: %v"
,
err
)
}
if
!
strings
.
Contains
(
rec
.
Header
()
.
Get
(
"Content-Type"
),
"application/vnd.test+json"
)
{
t
.
Fatalf
(
"expected Content-Type passthrough, got %q"
,
rec
.
Header
()
.
Get
(
"Content-Type"
))
}
}
func
TestOpenAINonStreamingContentTypeDefault
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
cfg
:=
&
config
.
Config
{
Security
:
config
.
SecurityConfig
{
ResponseHeaders
:
config
.
ResponseHeaderConfig
{
Enabled
:
false
},
},
}
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
}
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
body
:=
[]
byte
(
`{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`
)
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
io
.
NopCloser
(
bytes
.
NewReader
(
body
)),
Header
:
http
.
Header
{},
}
_
,
err
:=
svc
.
handleNonStreamingResponse
(
c
.
Request
.
Context
(),
resp
,
c
,
&
Account
{},
"model"
,
"model"
)
if
err
!=
nil
{
t
.
Fatalf
(
"handleNonStreamingResponse error: %v"
,
err
)
}
if
!
strings
.
Contains
(
rec
.
Header
()
.
Get
(
"Content-Type"
),
"application/json"
)
{
t
.
Fatalf
(
"expected default Content-Type, got %q"
,
rec
.
Header
()
.
Get
(
"Content-Type"
))
}
}
func
TestOpenAIStreamingHeadersOverride
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
cfg
:=
&
config
.
Config
{
Security
:
config
.
SecurityConfig
{
ResponseHeaders
:
config
.
ResponseHeaderConfig
{
Enabled
:
false
},
},
Gateway
:
config
.
GatewayConfig
{
StreamDataIntervalTimeout
:
0
,
StreamKeepaliveInterval
:
0
,
MaxLineSize
:
defaultMaxLineSize
,
},
}
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
}
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
pr
,
pw
:=
io
.
Pipe
()
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
pr
,
Header
:
http
.
Header
{
"Cache-Control"
:
[]
string
{
"upstream"
},
"X-Request-Id"
:
[]
string
{
"req-123"
},
"Content-Type"
:
[]
string
{
"application/custom"
},
},
}
go
func
()
{
defer
func
()
{
_
=
pw
.
Close
()
}()
_
,
_
=
pw
.
Write
([]
byte
(
"data: {}
\n\n
"
))
}()
_
,
err
:=
svc
.
handleStreamingResponse
(
c
.
Request
.
Context
(),
resp
,
c
,
&
Account
{
ID
:
1
},
time
.
Now
(),
"model"
,
"model"
)
_
=
pr
.
Close
()
if
err
!=
nil
{
t
.
Fatalf
(
"handleStreamingResponse error: %v"
,
err
)
}
if
rec
.
Header
()
.
Get
(
"Cache-Control"
)
!=
"no-cache"
{
t
.
Fatalf
(
"expected Cache-Control override, got %q"
,
rec
.
Header
()
.
Get
(
"Cache-Control"
))
}
if
rec
.
Header
()
.
Get
(
"Content-Type"
)
!=
"text/event-stream"
{
t
.
Fatalf
(
"expected Content-Type override, got %q"
,
rec
.
Header
()
.
Get
(
"Content-Type"
))
}
if
rec
.
Header
()
.
Get
(
"X-Request-Id"
)
!=
"req-123"
{
t
.
Fatalf
(
"expected X-Request-Id passthrough, got %q"
,
rec
.
Header
()
.
Get
(
"X-Request-Id"
))
}
}
func
TestOpenAIInvalidBaseURLWhenAllowlistDisabled
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
cfg
:=
&
config
.
Config
{
Security
:
config
.
SecurityConfig
{
URLAllowlist
:
config
.
URLAllowlistConfig
{
Enabled
:
false
},
},
}
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
}
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
account
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Credentials
:
map
[
string
]
any
{
"base_url"
:
"://invalid-url"
},
}
_
,
err
:=
svc
.
buildUpstreamRequest
(
c
.
Request
.
Context
(),
c
,
account
,
[]
byte
(
"{}"
),
"token"
,
false
)
if
err
==
nil
{
t
.
Fatalf
(
"expected error for invalid base_url when allowlist disabled"
)
}
}
func
TestOpenAIValidateUpstreamBaseURLDisabledRequiresHTTPS
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
Security
:
config
.
SecurityConfig
{
URLAllowlist
:
config
.
URLAllowlistConfig
{
Enabled
:
false
},
},
}
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
}
if
_
,
err
:=
svc
.
validateUpstreamBaseURL
(
"http://not-https.example.com"
);
err
==
nil
{
t
.
Fatalf
(
"expected http to be rejected when allow_insecure_http is false"
)
}
normalized
,
err
:=
svc
.
validateUpstreamBaseURL
(
"https://example.com"
)
if
err
!=
nil
{
t
.
Fatalf
(
"expected https to be allowed when allowlist disabled, got %v"
,
err
)
}
if
normalized
!=
"https://example.com"
{
t
.
Fatalf
(
"expected raw url passthrough, got %q"
,
normalized
)
}
}
func
TestOpenAIValidateUpstreamBaseURLDisabledAllowsHTTP
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
Security
:
config
.
SecurityConfig
{
URLAllowlist
:
config
.
URLAllowlistConfig
{
Enabled
:
false
,
AllowInsecureHTTP
:
true
,
},
},
}
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
}
normalized
,
err
:=
svc
.
validateUpstreamBaseURL
(
"http://not-https.example.com"
)
if
err
!=
nil
{
t
.
Fatalf
(
"expected http allowed when allow_insecure_http is true, got %v"
,
err
)
}
if
normalized
!=
"http://not-https.example.com"
{
t
.
Fatalf
(
"expected raw url passthrough, got %q"
,
normalized
)
}
}
func
TestOpenAIValidateUpstreamBaseURLEnabledEnforcesAllowlist
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
Security
:
config
.
SecurityConfig
{
URLAllowlist
:
config
.
URLAllowlistConfig
{
Enabled
:
true
,
UpstreamHosts
:
[]
string
{
"example.com"
},
},
},
}
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
}
if
_
,
err
:=
svc
.
validateUpstreamBaseURL
(
"https://example.com"
);
err
!=
nil
{
t
.
Fatalf
(
"expected allowlisted host to pass, got %v"
,
err
)
}
if
_
,
err
:=
svc
.
validateUpstreamBaseURL
(
"https://evil.com"
);
err
==
nil
{
t
.
Fatalf
(
"expected non-allowlisted host to fail"
)
}
}
backend/internal/service/pricing_service.go
View file @
195e227c
...
...
@@ -16,6 +16,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
)
var
(
...
...
@@ -213,16 +214,35 @@ func (s *PricingService) syncWithRemote() error {
// downloadPricingData 从远程下载价格数据
func
(
s
*
PricingService
)
downloadPricingData
()
error
{
log
.
Printf
(
"[Pricing] Downloading from %s"
,
s
.
cfg
.
Pricing
.
RemoteURL
)
remoteURL
,
err
:=
s
.
validatePricingURL
(
s
.
cfg
.
Pricing
.
RemoteURL
)
if
err
!=
nil
{
return
err
}
log
.
Printf
(
"[Pricing] Downloading from %s"
,
remoteURL
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
30
*
time
.
Second
)
defer
cancel
()
body
,
err
:=
s
.
remoteClient
.
FetchPricingJSON
(
ctx
,
s
.
cfg
.
Pricing
.
RemoteURL
)
var
expectedHash
string
if
strings
.
TrimSpace
(
s
.
cfg
.
Pricing
.
HashURL
)
!=
""
{
expectedHash
,
err
=
s
.
fetchRemoteHash
()
if
err
!=
nil
{
return
fmt
.
Errorf
(
"fetch remote hash: %w"
,
err
)
}
}
body
,
err
:=
s
.
remoteClient
.
FetchPricingJSON
(
ctx
,
remoteURL
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"download failed: %w"
,
err
)
}
if
expectedHash
!=
""
{
actualHash
:=
sha256
.
Sum256
(
body
)
if
!
strings
.
EqualFold
(
expectedHash
,
hex
.
EncodeToString
(
actualHash
[
:
]))
{
return
fmt
.
Errorf
(
"pricing hash mismatch"
)
}
}
// 解析JSON数据(使用灵活的解析方式)
data
,
err
:=
s
.
parsePricingData
(
body
)
if
err
!=
nil
{
...
...
@@ -378,10 +398,38 @@ func (s *PricingService) useFallbackPricing() error {
// fetchRemoteHash 从远程获取哈希值
func
(
s
*
PricingService
)
fetchRemoteHash
()
(
string
,
error
)
{
hashURL
,
err
:=
s
.
validatePricingURL
(
s
.
cfg
.
Pricing
.
HashURL
)
if
err
!=
nil
{
return
""
,
err
}
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
10
*
time
.
Second
)
defer
cancel
()
return
s
.
remoteClient
.
FetchHashText
(
ctx
,
s
.
cfg
.
Pricing
.
HashURL
)
hash
,
err
:=
s
.
remoteClient
.
FetchHashText
(
ctx
,
hashURL
)
if
err
!=
nil
{
return
""
,
err
}
return
strings
.
TrimSpace
(
hash
),
nil
}
func
(
s
*
PricingService
)
validatePricingURL
(
raw
string
)
(
string
,
error
)
{
if
s
.
cfg
!=
nil
&&
!
s
.
cfg
.
Security
.
URLAllowlist
.
Enabled
{
normalized
,
err
:=
urlvalidator
.
ValidateURLFormat
(
raw
,
s
.
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
)
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"invalid pricing url: %w"
,
err
)
}
return
normalized
,
nil
}
normalized
,
err
:=
urlvalidator
.
ValidateHTTPSURL
(
raw
,
urlvalidator
.
ValidationOptions
{
AllowedHosts
:
s
.
cfg
.
Security
.
URLAllowlist
.
PricingHosts
,
RequireAllowlist
:
true
,
AllowPrivate
:
s
.
cfg
.
Security
.
URLAllowlist
.
AllowPrivateHosts
,
})
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"invalid pricing url: %w"
,
err
)
}
return
normalized
,
nil
}
// computeFileHash 计算文件哈希
...
...
backend/internal/service/setting_service.go
View file @
195e227c
...
...
@@ -130,6 +130,10 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates
[
SettingKeyFallbackModelGemini
]
=
settings
.
FallbackModelGemini
updates
[
SettingKeyFallbackModelAntigravity
]
=
settings
.
FallbackModelAntigravity
// Identity patch configuration (Claude -> Gemini)
updates
[
SettingKeyEnableIdentityPatch
]
=
strconv
.
FormatBool
(
settings
.
EnableIdentityPatch
)
updates
[
SettingKeyIdentityPatchPrompt
]
=
settings
.
IdentityPatchPrompt
return
s
.
settingRepo
.
SetMultiple
(
ctx
,
updates
)
}
...
...
@@ -213,6 +217,9 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyFallbackModelOpenAI
:
"gpt-4o"
,
SettingKeyFallbackModelGemini
:
"gemini-2.5-pro"
,
SettingKeyFallbackModelAntigravity
:
"gemini-2.5-pro"
,
// Identity patch defaults
SettingKeyEnableIdentityPatch
:
"true"
,
SettingKeyIdentityPatchPrompt
:
""
,
}
return
s
.
settingRepo
.
SetMultiple
(
ctx
,
defaults
)
...
...
@@ -221,21 +228,23 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// parseSettings 解析设置到结构体
func
(
s
*
SettingService
)
parseSettings
(
settings
map
[
string
]
string
)
*
SystemSettings
{
result
:=
&
SystemSettings
{
RegistrationEnabled
:
settings
[
SettingKeyRegistrationEnabled
]
==
"true"
,
EmailVerifyEnabled
:
settings
[
SettingKeyEmailVerifyEnabled
]
==
"true"
,
SMTPHost
:
settings
[
SettingKeySMTPHost
],
SMTPUsername
:
settings
[
SettingKeySMTPUsername
],
SMTPFrom
:
settings
[
SettingKeySMTPFrom
],
SMTPFromName
:
settings
[
SettingKeySMTPFromName
],
SMTPUseTLS
:
settings
[
SettingKeySMTPUseTLS
]
==
"true"
,
TurnstileEnabled
:
settings
[
SettingKeyTurnstileEnabled
]
==
"true"
,
TurnstileSiteKey
:
settings
[
SettingKeyTurnstileSiteKey
],
SiteName
:
s
.
getStringOrDefault
(
settings
,
SettingKeySiteName
,
"Sub2API"
),
SiteLogo
:
settings
[
SettingKeySiteLogo
],
SiteSubtitle
:
s
.
getStringOrDefault
(
settings
,
SettingKeySiteSubtitle
,
"Subscription to API Conversion Platform"
),
APIBaseURL
:
settings
[
SettingKeyAPIBaseURL
],
ContactInfo
:
settings
[
SettingKeyContactInfo
],
DocURL
:
settings
[
SettingKeyDocURL
],
RegistrationEnabled
:
settings
[
SettingKeyRegistrationEnabled
]
==
"true"
,
EmailVerifyEnabled
:
settings
[
SettingKeyEmailVerifyEnabled
]
==
"true"
,
SMTPHost
:
settings
[
SettingKeySMTPHost
],
SMTPUsername
:
settings
[
SettingKeySMTPUsername
],
SMTPFrom
:
settings
[
SettingKeySMTPFrom
],
SMTPFromName
:
settings
[
SettingKeySMTPFromName
],
SMTPUseTLS
:
settings
[
SettingKeySMTPUseTLS
]
==
"true"
,
SMTPPasswordConfigured
:
settings
[
SettingKeySMTPPassword
]
!=
""
,
TurnstileEnabled
:
settings
[
SettingKeyTurnstileEnabled
]
==
"true"
,
TurnstileSiteKey
:
settings
[
SettingKeyTurnstileSiteKey
],
TurnstileSecretKeyConfigured
:
settings
[
SettingKeyTurnstileSecretKey
]
!=
""
,
SiteName
:
s
.
getStringOrDefault
(
settings
,
SettingKeySiteName
,
"Sub2API"
),
SiteLogo
:
settings
[
SettingKeySiteLogo
],
SiteSubtitle
:
s
.
getStringOrDefault
(
settings
,
SettingKeySiteSubtitle
,
"Subscription to API Conversion Platform"
),
APIBaseURL
:
settings
[
SettingKeyAPIBaseURL
],
ContactInfo
:
settings
[
SettingKeyContactInfo
],
DocURL
:
settings
[
SettingKeyDocURL
],
}
// 解析整数类型
...
...
@@ -269,6 +278,14 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result
.
FallbackModelGemini
=
s
.
getStringOrDefault
(
settings
,
SettingKeyFallbackModelGemini
,
"gemini-2.5-pro"
)
result
.
FallbackModelAntigravity
=
s
.
getStringOrDefault
(
settings
,
SettingKeyFallbackModelAntigravity
,
"gemini-2.5-pro"
)
// Identity patch settings (default: enabled, to preserve existing behavior)
if
v
,
ok
:=
settings
[
SettingKeyEnableIdentityPatch
];
ok
&&
v
!=
""
{
result
.
EnableIdentityPatch
=
v
==
"true"
}
else
{
result
.
EnableIdentityPatch
=
true
}
result
.
IdentityPatchPrompt
=
settings
[
SettingKeyIdentityPatchPrompt
]
return
result
}
...
...
@@ -298,6 +315,25 @@ func (s *SettingService) GetTurnstileSecretKey(ctx context.Context) string {
return
value
}
// IsIdentityPatchEnabled 检查是否启用身份补丁(Claude -> Gemini systemInstruction 注入)
func
(
s
*
SettingService
)
IsIdentityPatchEnabled
(
ctx
context
.
Context
)
bool
{
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyEnableIdentityPatch
)
if
err
!=
nil
{
// 默认开启,保持兼容
return
true
}
return
value
==
"true"
}
// GetIdentityPatchPrompt 获取自定义身份补丁提示词(为空表示使用内置默认模板)
func
(
s
*
SettingService
)
GetIdentityPatchPrompt
(
ctx
context
.
Context
)
string
{
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyIdentityPatchPrompt
)
if
err
!=
nil
{
return
""
}
return
value
}
// GenerateAdminAPIKey 生成新的管理员 API Key
func
(
s
*
SettingService
)
GenerateAdminAPIKey
(
ctx
context
.
Context
)
(
string
,
error
)
{
// 生成 32 字节随机数 = 64 位十六进制字符
...
...
backend/internal/service/settings_view.go
View file @
195e227c
...
...
@@ -4,17 +4,19 @@ type SystemSettings struct {
RegistrationEnabled
bool
EmailVerifyEnabled
bool
SMTPHost
string
SMTPPort
int
SMTPUsername
string
SMTPPassword
string
SMTPFrom
string
SMTPFromName
string
SMTPUseTLS
bool
TurnstileEnabled
bool
TurnstileSiteKey
string
TurnstileSecretKey
string
SMTPHost
string
SMTPPort
int
SMTPUsername
string
SMTPPassword
string
SMTPPasswordConfigured
bool
SMTPFrom
string
SMTPFromName
string
SMTPUseTLS
bool
TurnstileEnabled
bool
TurnstileSiteKey
string
TurnstileSecretKey
string
TurnstileSecretKeyConfigured
bool
SiteName
string
SiteLogo
string
...
...
@@ -32,6 +34,10 @@ type SystemSettings struct {
FallbackModelOpenAI
string
`json:"fallback_model_openai"`
FallbackModelGemini
string
`json:"fallback_model_gemini"`
FallbackModelAntigravity
string
`json:"fallback_model_antigravity"`
// Identity patch configuration (Claude -> Gemini)
EnableIdentityPatch
bool
`json:"enable_identity_patch"`
IdentityPatchPrompt
string
`json:"identity_patch_prompt"`
}
type
PublicSettings
struct
{
...
...
backend/internal/setup/setup.go
View file @
195e227c
...
...
@@ -21,10 +21,44 @@ import (
// Config paths
const
(
ConfigFile
=
"config.yaml"
Env
File
=
".
env
"
ConfigFile
Name
=
"config.yaml"
InstallLock
File
=
".
installed
"
)
// GetDataDir returns the data directory for storing config and lock files.
// Priority: DATA_DIR env > /app/data (if exists and writable) > current directory
func
GetDataDir
()
string
{
// Check DATA_DIR environment variable first
if
dir
:=
os
.
Getenv
(
"DATA_DIR"
);
dir
!=
""
{
return
dir
}
// Check if /app/data exists and is writable (Docker environment)
dockerDataDir
:=
"/app/data"
if
info
,
err
:=
os
.
Stat
(
dockerDataDir
);
err
==
nil
&&
info
.
IsDir
()
{
// Try to check if writable by creating a temp file
testFile
:=
dockerDataDir
+
"/.write_test"
if
f
,
err
:=
os
.
Create
(
testFile
);
err
==
nil
{
_
=
f
.
Close
()
_
=
os
.
Remove
(
testFile
)
return
dockerDataDir
}
}
// Default to current directory
return
"."
}
// GetConfigFilePath returns the full path to config.yaml
func
GetConfigFilePath
()
string
{
return
GetDataDir
()
+
"/"
+
ConfigFileName
}
// GetInstallLockPath returns the full path to .installed lock file
func
GetInstallLockPath
()
string
{
return
GetDataDir
()
+
"/"
+
InstallLockFile
}
// SetupConfig holds the setup configuration
type
SetupConfig
struct
{
Database
DatabaseConfig
`json:"database" yaml:"database"`
...
...
@@ -71,13 +105,12 @@ type JWTConfig struct {
// Uses multiple checks to prevent attackers from forcing re-setup by deleting config
func
NeedsSetup
()
bool
{
// Check 1: Config file must not exist
if
_
,
err
:=
os
.
Stat
(
ConfigFile
);
!
os
.
IsNotExist
(
err
)
{
if
_
,
err
:=
os
.
Stat
(
Get
ConfigFile
Path
()
);
!
os
.
IsNotExist
(
err
)
{
return
false
// Config exists, no setup needed
}
// Check 2: Installation lock file (harder to bypass)
lockFile
:=
".installed"
if
_
,
err
:=
os
.
Stat
(
lockFile
);
!
os
.
IsNotExist
(
err
)
{
if
_
,
err
:=
os
.
Stat
(
GetInstallLockPath
());
!
os
.
IsNotExist
(
err
)
{
return
false
// Lock file exists, already installed
}
...
...
@@ -201,6 +234,7 @@ func Install(cfg *SetupConfig) error {
return
fmt
.
Errorf
(
"failed to generate jwt secret: %w"
,
err
)
}
cfg
.
JWT
.
Secret
=
secret
log
.
Println
(
"Warning: JWT secret auto-generated. Consider setting a fixed secret for production."
)
}
// Test connections
...
...
@@ -237,9 +271,8 @@ func Install(cfg *SetupConfig) error {
// createInstallLock creates a lock file to prevent re-installation attacks
func
createInstallLock
()
error
{
lockFile
:=
".installed"
content
:=
fmt
.
Sprintf
(
"installed_at=%s
\n
"
,
time
.
Now
()
.
UTC
()
.
Format
(
time
.
RFC3339
))
return
os
.
WriteFile
(
lockFile
,
[]
byte
(
content
),
0400
)
// Read-only for owner
return
os
.
WriteFile
(
GetInstallLockPath
()
,
[]
byte
(
content
),
0400
)
// Read-only for owner
}
func
initializeDatabase
(
cfg
*
SetupConfig
)
error
{
...
...
@@ -390,7 +423,7 @@ func writeConfigFile(cfg *SetupConfig) error {
return
err
}
return
os
.
WriteFile
(
ConfigFile
,
data
,
0600
)
return
os
.
WriteFile
(
Get
ConfigFile
Path
()
,
data
,
0600
)
}
func
generateSecret
(
length
int
)
(
string
,
error
)
{
...
...
@@ -433,6 +466,7 @@ func getEnvIntOrDefault(key string, defaultValue int) int {
// This is designed for Docker deployment where all config is passed via env vars
func
AutoSetupFromEnv
()
error
{
log
.
Println
(
"Auto setup enabled, configuring from environment variables..."
)
log
.
Printf
(
"Data directory: %s"
,
GetDataDir
())
// Get timezone from TZ or TIMEZONE env var (TZ is standard for Docker)
tz
:=
getEnvOrDefault
(
"TZ"
,
""
)
...
...
@@ -479,7 +513,7 @@ func AutoSetupFromEnv() error {
return
fmt
.
Errorf
(
"failed to generate jwt secret: %w"
,
err
)
}
cfg
.
JWT
.
Secret
=
secret
log
.
Println
(
"
Generated
JWT secret auto
matically
"
)
log
.
Println
(
"
Warning:
JWT secret auto
-generated. Consider setting a fixed secret for production.
"
)
}
// Generate admin password if not provided
...
...
@@ -489,8 +523,8 @@ func AutoSetupFromEnv() error {
return
fmt
.
Errorf
(
"failed to generate admin password: %w"
,
err
)
}
cfg
.
Admin
.
Password
=
password
log
.
Printf
(
"Generated admin password: %s"
,
cfg
.
Admin
.
Password
)
log
.
Println
(
"IMPORTANT: Save this password! It will not be shown again."
)
fmt
.
Printf
(
"Generated admin password
(one-time)
: %s
\n
"
,
cfg
.
Admin
.
Password
)
fmt
.
Println
(
"IMPORTANT: Save this password! It will not be shown again."
)
}
// Test database connection
...
...
backend/internal/util/logredact/redact.go
0 → 100644
View file @
195e227c
package
logredact
import
(
"encoding/json"
"strings"
)
// maxRedactDepth 限制递归深度以防止栈溢出
const
maxRedactDepth
=
32
var
defaultSensitiveKeys
=
map
[
string
]
struct
{}{
"authorization_code"
:
{},
"code"
:
{},
"code_verifier"
:
{},
"access_token"
:
{},
"refresh_token"
:
{},
"id_token"
:
{},
"client_secret"
:
{},
"password"
:
{},
}
func
RedactMap
(
input
map
[
string
]
any
,
extraKeys
...
string
)
map
[
string
]
any
{
if
input
==
nil
{
return
map
[
string
]
any
{}
}
keys
:=
buildKeySet
(
extraKeys
)
redacted
,
ok
:=
redactValueWithDepth
(
input
,
keys
,
0
)
.
(
map
[
string
]
any
)
if
!
ok
{
return
map
[
string
]
any
{}
}
return
redacted
}
func
RedactJSON
(
raw
[]
byte
,
extraKeys
...
string
)
string
{
if
len
(
raw
)
==
0
{
return
""
}
var
value
any
if
err
:=
json
.
Unmarshal
(
raw
,
&
value
);
err
!=
nil
{
return
"<non-json payload redacted>"
}
keys
:=
buildKeySet
(
extraKeys
)
redacted
:=
redactValueWithDepth
(
value
,
keys
,
0
)
encoded
,
err
:=
json
.
Marshal
(
redacted
)
if
err
!=
nil
{
return
"<redacted>"
}
return
string
(
encoded
)
}
func
buildKeySet
(
extraKeys
[]
string
)
map
[
string
]
struct
{}
{
keys
:=
make
(
map
[
string
]
struct
{},
len
(
defaultSensitiveKeys
)
+
len
(
extraKeys
))
for
k
:=
range
defaultSensitiveKeys
{
keys
[
k
]
=
struct
{}{}
}
for
_
,
key
:=
range
extraKeys
{
normalized
:=
normalizeKey
(
key
)
if
normalized
==
""
{
continue
}
keys
[
normalized
]
=
struct
{}{}
}
return
keys
}
func
redactValueWithDepth
(
value
any
,
keys
map
[
string
]
struct
{},
depth
int
)
any
{
if
depth
>
maxRedactDepth
{
return
"<depth limit exceeded>"
}
switch
v
:=
value
.
(
type
)
{
case
map
[
string
]
any
:
out
:=
make
(
map
[
string
]
any
,
len
(
v
))
for
k
,
val
:=
range
v
{
if
isSensitiveKey
(
k
,
keys
)
{
out
[
k
]
=
"***"
continue
}
out
[
k
]
=
redactValueWithDepth
(
val
,
keys
,
depth
+
1
)
}
return
out
case
[]
any
:
out
:=
make
([]
any
,
len
(
v
))
for
i
,
item
:=
range
v
{
out
[
i
]
=
redactValueWithDepth
(
item
,
keys
,
depth
+
1
)
}
return
out
default
:
return
value
}
}
func
isSensitiveKey
(
key
string
,
keys
map
[
string
]
struct
{})
bool
{
_
,
ok
:=
keys
[
normalizeKey
(
key
)]
return
ok
}
func
normalizeKey
(
key
string
)
string
{
return
strings
.
ToLower
(
strings
.
TrimSpace
(
key
))
}
backend/internal/util/responseheaders/responseheaders.go
0 → 100644
View file @
195e227c
package
responseheaders
import
(
"net/http"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
)
// defaultAllowed 定义允许透传的响应头白名单
// 注意:以下头部由 Go HTTP 包自动处理,不应手动设置:
// - content-length: 由 ResponseWriter 根据实际写入数据自动设置
// - transfer-encoding: 由 HTTP 库根据需要自动添加/移除
// - connection: 由 HTTP 库管理连接复用
var
defaultAllowed
=
map
[
string
]
struct
{}{
"content-type"
:
{},
"content-encoding"
:
{},
"content-language"
:
{},
"cache-control"
:
{},
"etag"
:
{},
"last-modified"
:
{},
"expires"
:
{},
"vary"
:
{},
"date"
:
{},
"x-request-id"
:
{},
"x-ratelimit-limit-requests"
:
{},
"x-ratelimit-limit-tokens"
:
{},
"x-ratelimit-remaining-requests"
:
{},
"x-ratelimit-remaining-tokens"
:
{},
"x-ratelimit-reset-requests"
:
{},
"x-ratelimit-reset-tokens"
:
{},
"retry-after"
:
{},
"location"
:
{},
"www-authenticate"
:
{},
}
// hopByHopHeaders 是跳过的 hop-by-hop 头部,这些头部由 HTTP 库自动处理
var
hopByHopHeaders
=
map
[
string
]
struct
{}{
"content-length"
:
{},
"transfer-encoding"
:
{},
"connection"
:
{},
}
func
FilterHeaders
(
src
http
.
Header
,
cfg
config
.
ResponseHeaderConfig
)
http
.
Header
{
allowed
:=
make
(
map
[
string
]
struct
{},
len
(
defaultAllowed
)
+
len
(
cfg
.
AdditionalAllowed
))
for
key
:=
range
defaultAllowed
{
allowed
[
key
]
=
struct
{}{}
}
// 关闭时只使用默认白名单,additional/force_remove 不生效
if
cfg
.
Enabled
{
for
_
,
key
:=
range
cfg
.
AdditionalAllowed
{
normalized
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
key
))
if
normalized
==
""
{
continue
}
allowed
[
normalized
]
=
struct
{}{}
}
}
forceRemove
:=
map
[
string
]
struct
{}{}
if
cfg
.
Enabled
{
forceRemove
=
make
(
map
[
string
]
struct
{},
len
(
cfg
.
ForceRemove
))
for
_
,
key
:=
range
cfg
.
ForceRemove
{
normalized
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
key
))
if
normalized
==
""
{
continue
}
forceRemove
[
normalized
]
=
struct
{}{}
}
}
filtered
:=
make
(
http
.
Header
,
len
(
src
))
for
key
,
values
:=
range
src
{
lower
:=
strings
.
ToLower
(
key
)
if
_
,
blocked
:=
forceRemove
[
lower
];
blocked
{
continue
}
if
_
,
ok
:=
allowed
[
lower
];
!
ok
{
continue
}
// 跳过 hop-by-hop 头部,这些由 HTTP 库自动处理
if
_
,
isHopByHop
:=
hopByHopHeaders
[
lower
];
isHopByHop
{
continue
}
for
_
,
value
:=
range
values
{
filtered
.
Add
(
key
,
value
)
}
}
return
filtered
}
func
WriteFilteredHeaders
(
dst
http
.
Header
,
src
http
.
Header
,
cfg
config
.
ResponseHeaderConfig
)
{
filtered
:=
FilterHeaders
(
src
,
cfg
)
for
key
,
values
:=
range
filtered
{
for
_
,
value
:=
range
values
{
dst
.
Add
(
key
,
value
)
}
}
}
Prev
1
2
3
4
5
6
7
8
…
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment