Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
195e227c
Commit
195e227c
authored
Jan 06, 2026
by
song
Browse files
merge: 合并 upstream/main 并保留本地图片计费功能
parents
6fa704d6
752882a0
Changes
187
Show whitespace changes
Inline
Side-by-side
backend/internal/service/admin_service.go
View file @
195e227c
...
...
@@ -123,6 +123,7 @@ type UpdateGroupInput struct {
type
CreateAccountInput
struct
{
Name
string
Notes
*
string
Platform
string
Type
string
Credentials
map
[
string
]
any
...
...
@@ -138,6 +139,7 @@ type CreateAccountInput struct {
type
UpdateAccountInput
struct
{
Name
string
Notes
*
string
Type
string
// Account type: oauth, setup-token, apikey
Credentials
map
[
string
]
any
Extra
map
[
string
]
any
...
...
@@ -687,6 +689,7 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
account
:=
&
Account
{
Name
:
input
.
Name
,
Notes
:
normalizeAccountNotes
(
input
.
Notes
),
Platform
:
input
.
Platform
,
Type
:
input
.
Type
,
Credentials
:
input
.
Credentials
,
...
...
@@ -723,6 +726,9 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
if
input
.
Type
!=
""
{
account
.
Type
=
input
.
Type
}
if
input
.
Notes
!=
nil
{
account
.
Notes
=
normalizeAccountNotes
(
input
.
Notes
)
}
if
len
(
input
.
Credentials
)
>
0
{
account
.
Credentials
=
input
.
Credentials
}
...
...
@@ -730,7 +736,12 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
account
.
Extra
=
input
.
Extra
}
if
input
.
ProxyID
!=
nil
{
// 0 表示清除代理(前端发送 0 而不是 null 来表达清除意图)
if
*
input
.
ProxyID
==
0
{
account
.
ProxyID
=
nil
}
else
{
account
.
ProxyID
=
input
.
ProxyID
}
account
.
Proxy
=
nil
// 清除关联对象,防止 GORM Save 时根据 Proxy.ID 覆盖 ProxyID
}
// 只在指针非 nil 时更新 Concurrency(支持设置为 0)
...
...
backend/internal/service/antigravity_gateway_service.go
View file @
195e227c
...
...
@@ -9,8 +9,10 @@ import (
"fmt"
"io"
"log"
mathrand
"math/rand"
"net/http"
"strings"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
...
...
@@ -255,6 +257,16 @@ func (s *AntigravityGatewayService) buildClaudeTestRequest(projectID, mappedMode
return
antigravity
.
TransformClaudeToGemini
(
claudeReq
,
projectID
,
mappedModel
)
}
func
(
s
*
AntigravityGatewayService
)
getClaudeTransformOptions
(
ctx
context
.
Context
)
antigravity
.
TransformOptions
{
opts
:=
antigravity
.
DefaultTransformOptions
()
if
s
.
settingService
==
nil
{
return
opts
}
opts
.
EnableIdentityPatch
=
s
.
settingService
.
IsIdentityPatchEnabled
(
ctx
)
opts
.
IdentityPatch
=
s
.
settingService
.
GetIdentityPatchPrompt
(
ctx
)
return
opts
}
// extractGeminiResponseText 从 Gemini 响应中提取文本
func
extractGeminiResponseText
(
respBody
[]
byte
)
string
{
var
resp
map
[
string
]
any
...
...
@@ -380,7 +392,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
}
// 转换 Claude 请求为 Gemini 格式
geminiBody
,
err
:=
antigravity
.
TransformClaudeToGemini
(
&
claudeReq
,
projectID
,
mappedModel
)
geminiBody
,
err
:=
antigravity
.
TransformClaudeToGemini
WithOptions
(
&
claudeReq
,
projectID
,
mappedModel
,
s
.
getClaudeTransformOptions
(
ctx
)
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"transform request: %w"
,
err
)
}
...
...
@@ -394,6 +406,14 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
// 重试循环
var
resp
*
http
.
Response
for
attempt
:=
1
;
attempt
<=
antigravityMaxRetries
;
attempt
++
{
// 检查 context 是否已取消(客户端断开连接)
select
{
case
<-
ctx
.
Done
()
:
log
.
Printf
(
"%s status=context_canceled error=%v"
,
prefix
,
ctx
.
Err
())
return
nil
,
ctx
.
Err
()
default
:
}
upstreamReq
,
err
:=
antigravity
.
NewAPIRequest
(
ctx
,
action
,
accessToken
,
geminiBody
)
if
err
!=
nil
{
return
nil
,
err
...
...
@@ -403,7 +423,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
if
err
!=
nil
{
if
attempt
<
antigravityMaxRetries
{
log
.
Printf
(
"%s status=request_failed retry=%d/%d error=%v"
,
prefix
,
attempt
,
antigravityMaxRetries
,
err
)
sleepAntigravityBackoff
(
attempt
)
if
!
sleepAntigravityBackoffWithContext
(
ctx
,
attempt
)
{
log
.
Printf
(
"%s status=context_canceled_during_backoff"
,
prefix
)
return
nil
,
ctx
.
Err
()
}
continue
}
log
.
Printf
(
"%s status=request_failed retries_exhausted error=%v"
,
prefix
,
err
)
...
...
@@ -416,7 +439,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
if
attempt
<
antigravityMaxRetries
{
log
.
Printf
(
"%s status=%d retry=%d/%d"
,
prefix
,
resp
.
StatusCode
,
attempt
,
antigravityMaxRetries
)
sleepAntigravityBackoff
(
attempt
)
if
!
sleepAntigravityBackoffWithContext
(
ctx
,
attempt
)
{
log
.
Printf
(
"%s status=context_canceled_during_backoff"
,
prefix
)
return
nil
,
ctx
.
Err
()
}
continue
}
// 所有重试都失败,标记限流状态
...
...
@@ -443,35 +469,70 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
// Antigravity /v1internal 链路在部分场景会对 thought/thinking signature 做严格校验,
// 当历史消息携带的 signature 不合法时会直接 400;去除 thinking 后可继续完成请求。
if
resp
.
StatusCode
==
http
.
StatusBadRequest
&&
isSignatureRelatedError
(
respBody
)
{
// Conservative two-stage fallback:
// 1) Disable top-level thinking + thinking->text
// 2) Only if still signature-related 400: also downgrade tool_use/tool_result to text.
retryStages
:=
[]
struct
{
name
string
strip
func
(
*
antigravity
.
ClaudeRequest
)
(
bool
,
error
)
}{
{
name
:
"thinking-only"
,
strip
:
stripThinkingFromClaudeRequest
},
{
name
:
"thinking+tools"
,
strip
:
stripSignatureSensitiveBlocksFromClaudeRequest
},
}
for
_
,
stage
:=
range
retryStages
{
retryClaudeReq
:=
claudeReq
retryClaudeReq
.
Messages
=
append
([]
antigravity
.
ClaudeMessage
(
nil
),
claudeReq
.
Messages
...
)
stripped
,
stripErr
:=
stripThinkingFromClaudeRequest
(
&
retryClaudeReq
)
if
stripErr
==
nil
&&
stripped
{
log
.
Printf
(
"Antigravity account %d: detected signature-related 400, retrying once without thinking blocks"
,
account
.
ID
)
stripped
,
stripErr
:=
stage
.
strip
(
&
retryClaudeReq
)
if
stripErr
!=
nil
||
!
stripped
{
continue
}
log
.
Printf
(
"Antigravity account %d: detected signature-related 400, retrying once (%s)"
,
account
.
ID
,
stage
.
name
)
retryGeminiBody
,
txErr
:=
antigravity
.
TransformClaudeToGemini
(
&
retryClaudeReq
,
projectID
,
mappedModel
)
if
txErr
==
nil
{
retryGeminiBody
,
txErr
:=
antigravity
.
TransformClaudeToGeminiWithOptions
(
&
retryClaudeReq
,
projectID
,
mappedModel
,
s
.
getClaudeTransformOptions
(
ctx
))
if
txErr
!=
nil
{
continue
}
retryReq
,
buildErr
:=
antigravity
.
NewAPIRequest
(
ctx
,
action
,
accessToken
,
retryGeminiBody
)
if
buildErr
==
nil
{
if
buildErr
!=
nil
{
continue
}
retryResp
,
retryErr
:=
s
.
httpUpstream
.
Do
(
retryReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
retryErr
==
nil
{
// Retry success: continue normal success flow with the new response.
if
retryErr
!=
nil
{
log
.
Printf
(
"Antigravity account %d: signature retry request failed (%s): %v"
,
account
.
ID
,
stage
.
name
,
retryErr
)
continue
}
if
retryResp
.
StatusCode
<
400
{
_
=
resp
.
Body
.
Close
()
resp
=
retryResp
respBody
=
nil
}
else
{
// Retry still errored: replace error context with retry response.
break
}
retryBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
retryResp
.
Body
,
2
<<
20
))
_
=
retryResp
.
Body
.
Close
()
// If this stage fixed the signature issue, we stop; otherwise we may try the next stage.
if
retryResp
.
StatusCode
!=
http
.
StatusBadRequest
||
!
isSignatureRelatedError
(
retryBody
)
{
respBody
=
retryBody
resp
=
retryResp
}
}
else
{
log
.
Printf
(
"Antigravity account %d: signature retry request failed: %v"
,
account
.
ID
,
retryErr
)
resp
=
&
http
.
Response
{
StatusCode
:
retryResp
.
StatusCode
,
Header
:
retryResp
.
Header
.
Clone
(),
Body
:
io
.
NopCloser
(
bytes
.
NewReader
(
retryBody
)),
}
break
}
// Still signature-related; capture context and allow next stage.
respBody
=
retryBody
resp
=
&
http
.
Response
{
StatusCode
:
retryResp
.
StatusCode
,
Header
:
retryResp
.
Header
.
Clone
(),
Body
:
io
.
NopCloser
(
bytes
.
NewReader
(
retryBody
)),
}
}
}
...
...
@@ -528,7 +589,17 @@ func isSignatureRelatedError(respBody []byte) bool {
}
// Keep this intentionally broad: different upstreams may use "signature" or "thought_signature".
return
strings
.
Contains
(
msg
,
"thought_signature"
)
||
strings
.
Contains
(
msg
,
"signature"
)
if
strings
.
Contains
(
msg
,
"thought_signature"
)
||
strings
.
Contains
(
msg
,
"signature"
)
{
return
true
}
// Also detect thinking block structural errors:
// "Expected `thinking` or `redacted_thinking`, but found `text`"
if
strings
.
Contains
(
msg
,
"expected"
)
&&
(
strings
.
Contains
(
msg
,
"thinking"
)
||
strings
.
Contains
(
msg
,
"redacted_thinking"
))
{
return
true
}
return
false
}
func
extractAntigravityErrorMessage
(
body
[]
byte
)
string
{
...
...
@@ -555,7 +626,7 @@ func extractAntigravityErrorMessage(body []byte) string {
// stripThinkingFromClaudeRequest converts thinking blocks to text blocks in a Claude Messages request.
// This preserves the thinking content while avoiding signature validation errors.
// Note: redacted_thinking blocks are removed because they cannot be converted to text.
// It also disables top-level `thinking` to
prevent dummy-thought injection during retry
.
// It also disables top-level `thinking` to
avoid upstream structural constraints for thinking mode
.
func
stripThinkingFromClaudeRequest
(
req
*
antigravity
.
ClaudeRequest
)
(
bool
,
error
)
{
if
req
==
nil
{
return
false
,
nil
...
...
@@ -585,6 +656,92 @@ func stripThinkingFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error
continue
}
filtered
:=
make
([]
map
[
string
]
any
,
0
,
len
(
blocks
))
modifiedAny
:=
false
for
_
,
block
:=
range
blocks
{
t
,
_
:=
block
[
"type"
]
.
(
string
)
switch
t
{
case
"thinking"
:
thinkingText
,
_
:=
block
[
"thinking"
]
.
(
string
)
if
thinkingText
!=
""
{
filtered
=
append
(
filtered
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
thinkingText
,
})
}
modifiedAny
=
true
case
"redacted_thinking"
:
modifiedAny
=
true
case
""
:
if
thinkingText
,
hasThinking
:=
block
[
"thinking"
]
.
(
string
);
hasThinking
{
if
thinkingText
!=
""
{
filtered
=
append
(
filtered
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
thinkingText
,
})
}
modifiedAny
=
true
}
else
{
filtered
=
append
(
filtered
,
block
)
}
default
:
filtered
=
append
(
filtered
,
block
)
}
}
if
!
modifiedAny
{
continue
}
if
len
(
filtered
)
==
0
{
filtered
=
append
(
filtered
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
"(content removed)"
,
})
}
newRaw
,
err
:=
json
.
Marshal
(
filtered
)
if
err
!=
nil
{
return
changed
,
err
}
req
.
Messages
[
i
]
.
Content
=
newRaw
changed
=
true
}
return
changed
,
nil
}
// stripSignatureSensitiveBlocksFromClaudeRequest is a stronger retry degradation that additionally converts
// tool blocks to plain text. Use this only after a thinking-only retry still fails with signature errors.
func
stripSignatureSensitiveBlocksFromClaudeRequest
(
req
*
antigravity
.
ClaudeRequest
)
(
bool
,
error
)
{
if
req
==
nil
{
return
false
,
nil
}
changed
:=
false
if
req
.
Thinking
!=
nil
{
req
.
Thinking
=
nil
changed
=
true
}
for
i
:=
range
req
.
Messages
{
raw
:=
req
.
Messages
[
i
]
.
Content
if
len
(
raw
)
==
0
{
continue
}
// If content is a string, nothing to strip.
var
str
string
if
json
.
Unmarshal
(
raw
,
&
str
)
==
nil
{
continue
}
// Otherwise treat as an array of blocks and convert signature-sensitive blocks to text.
var
blocks
[]
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
raw
,
&
blocks
);
err
!=
nil
{
continue
}
filtered
:=
make
([]
map
[
string
]
any
,
0
,
len
(
blocks
))
modifiedAny
:=
false
for
_
,
block
:=
range
blocks
{
...
...
@@ -603,6 +760,49 @@ func stripThinkingFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error
case
"redacted_thinking"
:
// Remove redacted_thinking (cannot convert encrypted content)
modifiedAny
=
true
case
"tool_use"
:
// Convert tool_use to text to avoid upstream signature/thought_signature validation errors.
// This is a retry-only degradation path, so we prioritise request validity over tool semantics.
name
,
_
:=
block
[
"name"
]
.
(
string
)
id
,
_
:=
block
[
"id"
]
.
(
string
)
input
:=
block
[
"input"
]
inputJSON
,
_
:=
json
.
Marshal
(
input
)
text
:=
"(tool_use)"
if
name
!=
""
{
text
+=
" name="
+
name
}
if
id
!=
""
{
text
+=
" id="
+
id
}
if
len
(
inputJSON
)
>
0
&&
string
(
inputJSON
)
!=
"null"
{
text
+=
" input="
+
string
(
inputJSON
)
}
filtered
=
append
(
filtered
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
text
,
})
modifiedAny
=
true
case
"tool_result"
:
// Convert tool_result to text so it stays consistent when tool_use is downgraded.
toolUseID
,
_
:=
block
[
"tool_use_id"
]
.
(
string
)
isError
,
_
:=
block
[
"is_error"
]
.
(
bool
)
content
:=
block
[
"content"
]
contentJSON
,
_
:=
json
.
Marshal
(
content
)
text
:=
"(tool_result)"
if
toolUseID
!=
""
{
text
+=
" tool_use_id="
+
toolUseID
}
if
isError
{
text
+=
" is_error=true"
}
if
len
(
contentJSON
)
>
0
&&
string
(
contentJSON
)
!=
"null"
{
text
+=
"
\n
"
+
string
(
contentJSON
)
}
filtered
=
append
(
filtered
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
text
,
})
modifiedAny
=
true
case
""
:
// Handle untyped block with "thinking" field
if
thinkingText
,
hasThinking
:=
block
[
"thinking"
]
.
(
string
);
hasThinking
{
...
...
@@ -625,6 +825,14 @@ func stripThinkingFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error
continue
}
if
len
(
filtered
)
==
0
{
// Keep request valid: upstream rejects empty content arrays.
filtered
=
append
(
filtered
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
"(content removed)"
,
})
}
newRaw
,
err
:=
json
.
Marshal
(
filtered
)
if
err
!=
nil
{
return
changed
,
err
...
...
@@ -711,6 +919,14 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
// 重试循环
var
resp
*
http
.
Response
for
attempt
:=
1
;
attempt
<=
antigravityMaxRetries
;
attempt
++
{
// 检查 context 是否已取消(客户端断开连接)
select
{
case
<-
ctx
.
Done
()
:
log
.
Printf
(
"%s status=context_canceled error=%v"
,
prefix
,
ctx
.
Err
())
return
nil
,
ctx
.
Err
()
default
:
}
upstreamReq
,
err
:=
antigravity
.
NewAPIRequest
(
ctx
,
upstreamAction
,
accessToken
,
wrappedBody
)
if
err
!=
nil
{
return
nil
,
err
...
...
@@ -720,7 +936,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
if
err
!=
nil
{
if
attempt
<
antigravityMaxRetries
{
log
.
Printf
(
"%s status=request_failed retry=%d/%d error=%v"
,
prefix
,
attempt
,
antigravityMaxRetries
,
err
)
sleepAntigravityBackoff
(
attempt
)
if
!
sleepAntigravityBackoffWithContext
(
ctx
,
attempt
)
{
log
.
Printf
(
"%s status=context_canceled_during_backoff"
,
prefix
)
return
nil
,
ctx
.
Err
()
}
continue
}
log
.
Printf
(
"%s status=request_failed retries_exhausted error=%v"
,
prefix
,
err
)
...
...
@@ -733,7 +952,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
if
attempt
<
antigravityMaxRetries
{
log
.
Printf
(
"%s status=%d retry=%d/%d"
,
prefix
,
resp
.
StatusCode
,
attempt
,
antigravityMaxRetries
)
sleepAntigravityBackoff
(
attempt
)
if
!
sleepAntigravityBackoffWithContext
(
ctx
,
attempt
)
{
log
.
Printf
(
"%s status=context_canceled_during_backoff"
,
prefix
)
return
nil
,
ctx
.
Err
()
}
continue
}
// 所有重试都失败,标记限流状态
...
...
@@ -750,11 +972,18 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
break
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
defer
func
()
{
if
resp
!=
nil
&&
resp
.
Body
!=
nil
{
_
=
resp
.
Body
.
Close
()
}
}()
// 处理错误响应
if
resp
.
StatusCode
>=
400
{
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
// 尽早关闭原始响应体,释放连接;后续逻辑仍可能需要读取 body,因此用内存副本重新包装。
_
=
resp
.
Body
.
Close
()
resp
.
Body
=
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
))
// 模型兜底:模型不存在且开启 fallback 时,自动用 fallback 模型重试一次
if
s
.
settingService
!=
nil
&&
s
.
settingService
.
IsModelFallbackEnabled
(
ctx
)
&&
...
...
@@ -763,15 +992,13 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
if
fallbackModel
!=
""
&&
fallbackModel
!=
mappedModel
{
log
.
Printf
(
"[Antigravity] Model not found (%s), retrying with fallback model %s (account: %s)"
,
mappedModel
,
fallbackModel
,
account
.
Name
)
// 关闭原始响应,释放连接(respBody 已读取到内存)
_
=
resp
.
Body
.
Close
()
fallbackWrapped
,
err
:=
s
.
wrapV1InternalRequest
(
projectID
,
fallbackModel
,
body
)
if
err
==
nil
{
fallbackReq
,
err
:=
antigravity
.
NewAPIRequest
(
ctx
,
upstreamAction
,
accessToken
,
fallbackWrapped
)
if
err
==
nil
{
fallbackResp
,
err
:=
s
.
httpUpstream
.
Do
(
fallbackReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
err
==
nil
&&
fallbackResp
.
StatusCode
<
400
{
_
=
resp
.
Body
.
Close
()
resp
=
fallbackResp
}
else
if
fallbackResp
!=
nil
{
_
=
fallbackResp
.
Body
.
Close
()
...
...
@@ -872,8 +1099,28 @@ func (s *AntigravityGatewayService) shouldFailoverUpstreamError(statusCode int)
}
}
func
sleepAntigravityBackoff
(
attempt
int
)
{
sleepGeminiBackoff
(
attempt
)
// 复用 Gemini 的退避逻辑
// sleepAntigravityBackoffWithContext 带 context 取消检查的退避等待
// 返回 true 表示正常完成等待,false 表示 context 已取消
func
sleepAntigravityBackoffWithContext
(
ctx
context
.
Context
,
attempt
int
)
bool
{
delay
:=
geminiRetryBaseDelay
*
time
.
Duration
(
1
<<
uint
(
attempt
-
1
))
if
delay
>
geminiRetryMaxDelay
{
delay
=
geminiRetryMaxDelay
}
// +/- 20% jitter
r
:=
mathrand
.
New
(
mathrand
.
NewSource
(
time
.
Now
()
.
UnixNano
()))
jitter
:=
time
.
Duration
(
float64
(
delay
)
*
0.2
*
(
r
.
Float64
()
*
2
-
1
))
sleepFor
:=
delay
+
jitter
if
sleepFor
<
0
{
sleepFor
=
0
}
select
{
case
<-
ctx
.
Done
()
:
return
false
case
<-
time
.
After
(
sleepFor
)
:
return
true
}
}
func
(
s
*
AntigravityGatewayService
)
handleUpstreamError
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
)
{
...
...
@@ -928,20 +1175,102 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
return
nil
,
errors
.
New
(
"streaming not supported"
)
}
reader
:=
bufio
.
NewReader
(
resp
.
Body
)
// 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
scanner
:=
bufio
.
NewScanner
(
resp
.
Body
)
maxLineSize
:=
defaultMaxLineSize
if
s
.
settingService
.
cfg
!=
nil
&&
s
.
settingService
.
cfg
.
Gateway
.
MaxLineSize
>
0
{
maxLineSize
=
s
.
settingService
.
cfg
.
Gateway
.
MaxLineSize
}
scanner
.
Buffer
(
make
([]
byte
,
64
*
1024
),
maxLineSize
)
usage
:=
&
ClaudeUsage
{}
var
firstTokenMs
*
int
type
scanEvent
struct
{
line
string
err
error
}
// 独立 goroutine 读取上游,避免读取阻塞影响超时处理
events
:=
make
(
chan
scanEvent
,
16
)
done
:=
make
(
chan
struct
{})
sendEvent
:=
func
(
ev
scanEvent
)
bool
{
select
{
case
events
<-
ev
:
return
true
case
<-
done
:
return
false
}
}
var
lastReadAt
int64
atomic
.
StoreInt64
(
&
lastReadAt
,
time
.
Now
()
.
UnixNano
())
go
func
()
{
defer
close
(
events
)
for
scanner
.
Scan
()
{
atomic
.
StoreInt64
(
&
lastReadAt
,
time
.
Now
()
.
UnixNano
())
if
!
sendEvent
(
scanEvent
{
line
:
scanner
.
Text
()})
{
return
}
}
if
err
:=
scanner
.
Err
();
err
!=
nil
{
_
=
sendEvent
(
scanEvent
{
err
:
err
})
}
}()
defer
close
(
done
)
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
streamInterval
:=
time
.
Duration
(
0
)
if
s
.
settingService
.
cfg
!=
nil
&&
s
.
settingService
.
cfg
.
Gateway
.
StreamDataIntervalTimeout
>
0
{
streamInterval
=
time
.
Duration
(
s
.
settingService
.
cfg
.
Gateway
.
StreamDataIntervalTimeout
)
*
time
.
Second
}
var
intervalTicker
*
time
.
Ticker
if
streamInterval
>
0
{
intervalTicker
=
time
.
NewTicker
(
streamInterval
)
defer
intervalTicker
.
Stop
()
}
var
intervalCh
<-
chan
time
.
Time
if
intervalTicker
!=
nil
{
intervalCh
=
intervalTicker
.
C
}
// 仅发送一次错误事件,避免多次写入导致协议混乱
errorEventSent
:=
false
sendErrorEvent
:=
func
(
reason
string
)
{
if
errorEventSent
{
return
}
errorEventSent
=
true
_
,
_
=
fmt
.
Fprintf
(
c
.
Writer
,
"event: error
\n
data: {
\"
error
\"
:
\"
%s
\"
}
\n\n
"
,
reason
)
flusher
.
Flush
()
}
for
{
line
,
err
:=
reader
.
ReadString
(
'\n'
)
if
len
(
line
)
>
0
{
select
{
case
ev
,
ok
:=
<-
events
:
if
!
ok
{
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
}
if
ev
.
err
!=
nil
{
if
errors
.
Is
(
ev
.
err
,
bufio
.
ErrTooLong
)
{
log
.
Printf
(
"SSE line too long (antigravity): max_size=%d error=%v"
,
maxLineSize
,
ev
.
err
)
sendErrorEvent
(
"response_too_large"
)
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
ev
.
err
}
sendErrorEvent
(
"stream_read_error"
)
return
nil
,
ev
.
err
}
line
:=
ev
.
line
trimmed
:=
strings
.
TrimRight
(
line
,
"
\r\n
"
)
if
strings
.
HasPrefix
(
trimmed
,
"data:"
)
{
payload
:=
strings
.
TrimSpace
(
strings
.
TrimPrefix
(
trimmed
,
"data:"
))
if
payload
==
""
||
payload
==
"[DONE]"
{
_
,
_
=
io
.
WriteString
(
c
.
Writer
,
line
)
if
_
,
err
:=
fmt
.
Fprintf
(
c
.
Writer
,
"%s
\n
"
,
line
);
err
!=
nil
{
sendErrorEvent
(
"write_failed"
)
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
}
flusher
.
Flush
()
}
else
{
continue
}
// 解包 v1internal 响应
inner
,
parseErr
:=
s
.
unwrapV1InternalResponse
([]
byte
(
payload
))
if
parseErr
==
nil
&&
inner
!=
nil
{
...
...
@@ -961,24 +1290,30 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
firstTokenMs
=
&
ms
}
_
,
_
=
fmt
.
Fprintf
(
c
.
Writer
,
"data: %s
\n\n
"
,
payload
)
flusher
.
Flush
()
if
_
,
err
:=
fmt
.
Fprintf
(
c
.
Writer
,
"data: %s
\n\n
"
,
payload
);
err
!=
nil
{
sendErrorEvent
(
"write_failed"
)
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
}
}
else
{
_
,
_
=
io
.
WriteString
(
c
.
Writer
,
line
)
flusher
.
Flush
()
continue
}
if
_
,
err
:=
fmt
.
Fprintf
(
c
.
Writer
,
"%s
\n
"
,
line
);
err
!=
nil
{
sendErrorEvent
(
"write_failed"
)
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
}
flusher
.
Flush
()
if
errors
.
Is
(
err
,
io
.
EOF
)
{
break
case
<-
intervalCh
:
lastRead
:=
time
.
Unix
(
0
,
atomic
.
LoadInt64
(
&
lastReadAt
))
if
time
.
Since
(
lastRead
)
<
streamInterval
{
continue
}
if
err
!=
nil
{
return
nil
,
err
log
.
Printf
(
"Stream data interval timeout (antigravity)"
)
sendErrorEvent
(
"stream_timeout"
)
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream data interval timeout"
)
}
}
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
}
func
(
s
*
AntigravityGatewayService
)
handleGeminiNonStreamingResponse
(
c
*
gin
.
Context
,
resp
*
http
.
Response
)
(
*
ClaudeUsage
,
error
)
{
...
...
@@ -1117,7 +1452,13 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
processor
:=
antigravity
.
NewStreamingProcessor
(
originalModel
)
var
firstTokenMs
*
int
reader
:=
bufio
.
NewReader
(
resp
.
Body
)
// 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
scanner
:=
bufio
.
NewScanner
(
resp
.
Body
)
maxLineSize
:=
defaultMaxLineSize
if
s
.
settingService
.
cfg
!=
nil
&&
s
.
settingService
.
cfg
.
Gateway
.
MaxLineSize
>
0
{
maxLineSize
=
s
.
settingService
.
cfg
.
Gateway
.
MaxLineSize
}
scanner
.
Buffer
(
make
([]
byte
,
64
*
1024
),
maxLineSize
)
// 辅助函数:转换 antigravity.ClaudeUsage 到 service.ClaudeUsage
convertUsage
:=
func
(
agUsage
*
antigravity
.
ClaudeUsage
)
*
ClaudeUsage
{
...
...
@@ -1132,13 +1473,85 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
}
}
type
scanEvent
struct
{
line
string
err
error
}
// 独立 goroutine 读取上游,避免读取阻塞影响超时处理
events
:=
make
(
chan
scanEvent
,
16
)
done
:=
make
(
chan
struct
{})
sendEvent
:=
func
(
ev
scanEvent
)
bool
{
select
{
case
events
<-
ev
:
return
true
case
<-
done
:
return
false
}
}
var
lastReadAt
int64
atomic
.
StoreInt64
(
&
lastReadAt
,
time
.
Now
()
.
UnixNano
())
go
func
()
{
defer
close
(
events
)
for
scanner
.
Scan
()
{
atomic
.
StoreInt64
(
&
lastReadAt
,
time
.
Now
()
.
UnixNano
())
if
!
sendEvent
(
scanEvent
{
line
:
scanner
.
Text
()})
{
return
}
}
if
err
:=
scanner
.
Err
();
err
!=
nil
{
_
=
sendEvent
(
scanEvent
{
err
:
err
})
}
}()
defer
close
(
done
)
streamInterval
:=
time
.
Duration
(
0
)
if
s
.
settingService
.
cfg
!=
nil
&&
s
.
settingService
.
cfg
.
Gateway
.
StreamDataIntervalTimeout
>
0
{
streamInterval
=
time
.
Duration
(
s
.
settingService
.
cfg
.
Gateway
.
StreamDataIntervalTimeout
)
*
time
.
Second
}
var
intervalTicker
*
time
.
Ticker
if
streamInterval
>
0
{
intervalTicker
=
time
.
NewTicker
(
streamInterval
)
defer
intervalTicker
.
Stop
()
}
var
intervalCh
<-
chan
time
.
Time
if
intervalTicker
!=
nil
{
intervalCh
=
intervalTicker
.
C
}
// 仅发送一次错误事件,避免多次写入导致协议混乱
errorEventSent
:=
false
sendErrorEvent
:=
func
(
reason
string
)
{
if
errorEventSent
{
return
}
errorEventSent
=
true
_
,
_
=
fmt
.
Fprintf
(
c
.
Writer
,
"event: error
\n
data: {
\"
error
\"
:
\"
%s
\"
}
\n\n
"
,
reason
)
flusher
.
Flush
()
}
for
{
line
,
err
:=
reader
.
ReadString
(
'\n'
)
if
err
!=
nil
&&
!
errors
.
Is
(
err
,
io
.
EOF
)
{
return
nil
,
fmt
.
Errorf
(
"stream read error: %w"
,
err
)
select
{
case
ev
,
ok
:=
<-
events
:
if
!
ok
{
// 发送结束事件
finalEvents
,
agUsage
:=
processor
.
Finish
()
if
len
(
finalEvents
)
>
0
{
_
,
_
=
c
.
Writer
.
Write
(
finalEvents
)
flusher
.
Flush
()
}
return
&
antigravityStreamResult
{
usage
:
convertUsage
(
agUsage
),
firstTokenMs
:
firstTokenMs
},
nil
}
if
ev
.
err
!=
nil
{
if
errors
.
Is
(
ev
.
err
,
bufio
.
ErrTooLong
)
{
log
.
Printf
(
"SSE line too long (antigravity): max_size=%d error=%v"
,
maxLineSize
,
ev
.
err
)
sendErrorEvent
(
"response_too_large"
)
return
&
antigravityStreamResult
{
usage
:
convertUsage
(
nil
),
firstTokenMs
:
firstTokenMs
},
ev
.
err
}
sendErrorEvent
(
"stream_read_error"
)
return
nil
,
fmt
.
Errorf
(
"stream read error: %w"
,
ev
.
err
)
}
if
len
(
line
)
>
0
{
line
:=
ev
.
line
// 处理 SSE 行,转换为 Claude 格式
claudeEvents
:=
processor
.
ProcessLine
(
strings
.
TrimRight
(
line
,
"
\r\n
"
))
...
...
@@ -1153,25 +1566,23 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
if
len
(
finalEvents
)
>
0
{
_
,
_
=
c
.
Writer
.
Write
(
finalEvents
)
}
sendErrorEvent
(
"write_failed"
)
return
&
antigravityStreamResult
{
usage
:
convertUsage
(
agUsage
),
firstTokenMs
:
firstTokenMs
},
writeErr
}
flusher
.
Flush
()
}
}
if
errors
.
Is
(
err
,
io
.
EOF
)
{
break
case
<-
intervalCh
:
lastRead
:=
time
.
Unix
(
0
,
atomic
.
LoadInt64
(
&
lastReadAt
))
if
time
.
Since
(
lastRead
)
<
streamInterval
{
continue
}
log
.
Printf
(
"Stream data interval timeout (antigravity)"
)
sendErrorEvent
(
"stream_timeout"
)
return
&
antigravityStreamResult
{
usage
:
convertUsage
(
nil
),
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream data interval timeout"
)
}
// 发送结束事件
finalEvents
,
agUsage
:=
processor
.
Finish
()
if
len
(
finalEvents
)
>
0
{
_
,
_
=
c
.
Writer
.
Write
(
finalEvents
)
flusher
.
Flush
()
}
return
&
antigravityStreamResult
{
usage
:
convertUsage
(
agUsage
),
firstTokenMs
:
firstTokenMs
},
nil
}
// extractImageSize 从 Gemini 请求中提取 image_size 参数
...
...
backend/internal/service/antigravity_gateway_service_test.go
0 → 100644
View file @
195e227c
package
service
import
(
"encoding/json"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/stretchr/testify/require"
)
func
TestStripSignatureSensitiveBlocksFromClaudeRequest
(
t
*
testing
.
T
)
{
req
:=
&
antigravity
.
ClaudeRequest
{
Model
:
"claude-sonnet-4-5"
,
Thinking
:
&
antigravity
.
ThinkingConfig
{
Type
:
"enabled"
,
BudgetTokens
:
1024
,
},
Messages
:
[]
antigravity
.
ClaudeMessage
{
{
Role
:
"assistant"
,
Content
:
json
.
RawMessage
(
`[
{"type":"thinking","thinking":"secret plan","signature":""},
{"type":"tool_use","id":"t1","name":"Bash","input":{"command":"ls"}}
]`
),
},
{
Role
:
"user"
,
Content
:
json
.
RawMessage
(
`[
{"type":"tool_result","tool_use_id":"t1","content":"ok","is_error":false},
{"type":"redacted_thinking","data":"..."}
]`
),
},
},
}
changed
,
err
:=
stripSignatureSensitiveBlocksFromClaudeRequest
(
req
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
changed
)
require
.
Nil
(
t
,
req
.
Thinking
)
require
.
Len
(
t
,
req
.
Messages
,
2
)
var
blocks0
[]
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
req
.
Messages
[
0
]
.
Content
,
&
blocks0
))
require
.
Len
(
t
,
blocks0
,
2
)
require
.
Equal
(
t
,
"text"
,
blocks0
[
0
][
"type"
])
require
.
Equal
(
t
,
"secret plan"
,
blocks0
[
0
][
"text"
])
require
.
Equal
(
t
,
"text"
,
blocks0
[
1
][
"type"
])
var
blocks1
[]
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
req
.
Messages
[
1
]
.
Content
,
&
blocks1
))
require
.
Len
(
t
,
blocks1
,
1
)
require
.
Equal
(
t
,
"text"
,
blocks1
[
0
][
"type"
])
require
.
NotEmpty
(
t
,
blocks1
[
0
][
"text"
])
}
func
TestStripThinkingFromClaudeRequest_DoesNotDowngradeTools
(
t
*
testing
.
T
)
{
req
:=
&
antigravity
.
ClaudeRequest
{
Model
:
"claude-sonnet-4-5"
,
Thinking
:
&
antigravity
.
ThinkingConfig
{
Type
:
"enabled"
,
BudgetTokens
:
1024
,
},
Messages
:
[]
antigravity
.
ClaudeMessage
{
{
Role
:
"assistant"
,
Content
:
json
.
RawMessage
(
`[{"type":"thinking","thinking":"secret plan"},{"type":"tool_use","id":"t1","name":"Bash","input":{"command":"ls"}}]`
),
},
},
}
changed
,
err
:=
stripThinkingFromClaudeRequest
(
req
)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
changed
)
require
.
Nil
(
t
,
req
.
Thinking
)
var
blocks
[]
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
req
.
Messages
[
0
]
.
Content
,
&
blocks
))
require
.
Len
(
t
,
blocks
,
2
)
require
.
Equal
(
t
,
"text"
,
blocks
[
0
][
"type"
])
require
.
Equal
(
t
,
"secret plan"
,
blocks
[
0
][
"text"
])
require
.
Equal
(
t
,
"tool_use"
,
blocks
[
1
][
"type"
])
}
backend/internal/service/auth_service.go
View file @
195e227c
...
...
@@ -221,9 +221,33 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
// VerifyTurnstile 验证Turnstile token
func
(
s
*
AuthService
)
VerifyTurnstile
(
ctx
context
.
Context
,
token
string
,
remoteIP
string
)
error
{
required
:=
s
.
cfg
!=
nil
&&
s
.
cfg
.
Server
.
Mode
==
"release"
&&
s
.
cfg
.
Turnstile
.
Required
if
required
{
if
s
.
settingService
==
nil
{
log
.
Println
(
"[Auth] Turnstile required but settings service is not configured"
)
return
ErrTurnstileNotConfigured
}
enabled
:=
s
.
settingService
.
IsTurnstileEnabled
(
ctx
)
secretConfigured
:=
s
.
settingService
.
GetTurnstileSecretKey
(
ctx
)
!=
""
if
!
enabled
||
!
secretConfigured
{
log
.
Printf
(
"[Auth] Turnstile required but not configured (enabled=%v, secret_configured=%v)"
,
enabled
,
secretConfigured
)
return
ErrTurnstileNotConfigured
}
}
if
s
.
turnstileService
==
nil
{
if
required
{
log
.
Println
(
"[Auth] Turnstile required but service not configured"
)
return
ErrTurnstileNotConfigured
}
return
nil
// 服务未配置则跳过验证
}
if
!
required
&&
s
.
settingService
!=
nil
&&
s
.
settingService
.
IsTurnstileEnabled
(
ctx
)
&&
s
.
settingService
.
GetTurnstileSecretKey
(
ctx
)
==
""
{
log
.
Println
(
"[Auth] Turnstile enabled but secret key not configured"
)
}
return
s
.
turnstileService
.
VerifyToken
(
ctx
,
token
,
remoteIP
)
}
...
...
backend/internal/service/billing_cache_service.go
View file @
195e227c
...
...
@@ -17,6 +17,7 @@ import (
// 注:ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义
var
(
ErrSubscriptionInvalid
=
infraerrors
.
Forbidden
(
"SUBSCRIPTION_INVALID"
,
"subscription is invalid or expired"
)
ErrBillingServiceUnavailable
=
infraerrors
.
ServiceUnavailable
(
"BILLING_SERVICE_ERROR"
,
"Billing service temporarily unavailable. Please retry later."
)
)
// subscriptionCacheData 订阅缓存数据结构(内部使用)
...
...
@@ -76,6 +77,7 @@ type BillingCacheService struct {
userRepo
UserRepository
subRepo
UserSubscriptionRepository
cfg
*
config
.
Config
circuitBreaker
*
billingCircuitBreaker
cacheWriteChan
chan
cacheWriteTask
cacheWriteWg
sync
.
WaitGroup
...
...
@@ -95,6 +97,7 @@ func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo
subRepo
:
subRepo
,
cfg
:
cfg
,
}
svc
.
circuitBreaker
=
newBillingCircuitBreaker
(
cfg
.
Billing
.
CircuitBreaker
)
svc
.
startCacheWriteWorkers
()
return
svc
}
...
...
@@ -450,6 +453,9 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
if
s
.
cfg
.
RunMode
==
config
.
RunModeSimple
{
return
nil
}
if
s
.
circuitBreaker
!=
nil
&&
!
s
.
circuitBreaker
.
Allow
()
{
return
ErrBillingServiceUnavailable
}
// 判断计费模式
isSubscriptionMode
:=
group
!=
nil
&&
group
.
IsSubscriptionType
()
&&
subscription
!=
nil
...
...
@@ -465,9 +471,14 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
func
(
s
*
BillingCacheService
)
checkBalanceEligibility
(
ctx
context
.
Context
,
userID
int64
)
error
{
balance
,
err
:=
s
.
GetUserBalance
(
ctx
,
userID
)
if
err
!=
nil
{
// 缓存/数据库错误,允许通过(降级处理)
log
.
Printf
(
"Warning: get user balance failed, allowing request: %v"
,
err
)
return
nil
if
s
.
circuitBreaker
!=
nil
{
s
.
circuitBreaker
.
OnFailure
(
err
)
}
log
.
Printf
(
"ALERT: billing balance check failed for user %d: %v"
,
userID
,
err
)
return
ErrBillingServiceUnavailable
.
WithCause
(
err
)
}
if
s
.
circuitBreaker
!=
nil
{
s
.
circuitBreaker
.
OnSuccess
()
}
if
balance
<=
0
{
...
...
@@ -482,9 +493,14 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
// 获取订阅缓存数据
subData
,
err
:=
s
.
GetSubscriptionStatus
(
ctx
,
userID
,
group
.
ID
)
if
err
!=
nil
{
// 缓存/数据库错误,降级使用传入的subscription进行检查
log
.
Printf
(
"Warning: get subscription cache failed, using fallback: %v"
,
err
)
return
s
.
checkSubscriptionLimitsFallback
(
subscription
,
group
)
if
s
.
circuitBreaker
!=
nil
{
s
.
circuitBreaker
.
OnFailure
(
err
)
}
log
.
Printf
(
"ALERT: billing subscription check failed for user %d group %d: %v"
,
userID
,
group
.
ID
,
err
)
return
ErrBillingServiceUnavailable
.
WithCause
(
err
)
}
if
s
.
circuitBreaker
!=
nil
{
s
.
circuitBreaker
.
OnSuccess
()
}
// 检查订阅状态
...
...
@@ -513,27 +529,133 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
return
nil
}
// checkSubscriptionLimitsFallback 降级检查订阅限额
func
(
s
*
BillingCacheService
)
checkSubscriptionLimitsFallback
(
subscription
*
UserSubscription
,
group
*
Group
)
error
{
if
subscription
==
nil
{
return
ErrSubscriptionInvalid
type
billingCircuitBreakerState
int
const
(
billingCircuitClosed
billingCircuitBreakerState
=
iota
billingCircuitOpen
billingCircuitHalfOpen
)
type
billingCircuitBreaker
struct
{
mu
sync
.
Mutex
state
billingCircuitBreakerState
failures
int
openedAt
time
.
Time
failureThreshold
int
resetTimeout
time
.
Duration
halfOpenRequests
int
halfOpenRemaining
int
}
func
newBillingCircuitBreaker
(
cfg
config
.
CircuitBreakerConfig
)
*
billingCircuitBreaker
{
if
!
cfg
.
Enabled
{
return
nil
}
resetTimeout
:=
time
.
Duration
(
cfg
.
ResetTimeoutSeconds
)
*
time
.
Second
if
resetTimeout
<=
0
{
resetTimeout
=
30
*
time
.
Second
}
halfOpen
:=
cfg
.
HalfOpenRequests
if
halfOpen
<=
0
{
halfOpen
=
1
}
threshold
:=
cfg
.
FailureThreshold
if
threshold
<=
0
{
threshold
=
5
}
return
&
billingCircuitBreaker
{
state
:
billingCircuitClosed
,
failureThreshold
:
threshold
,
resetTimeout
:
resetTimeout
,
halfOpenRequests
:
halfOpen
,
}
}
if
!
subscription
.
IsActive
()
{
return
ErrSubscriptionInvalid
func
(
b
*
billingCircuitBreaker
)
Allow
()
bool
{
b
.
mu
.
Lock
()
defer
b
.
mu
.
Unlock
()
switch
b
.
state
{
case
billingCircuitClosed
:
return
true
case
billingCircuitOpen
:
if
time
.
Since
(
b
.
openedAt
)
<
b
.
resetTimeout
{
return
false
}
b
.
state
=
billingCircuitHalfOpen
b
.
halfOpenRemaining
=
b
.
halfOpenRequests
log
.
Printf
(
"ALERT: billing circuit breaker entering half-open state"
)
fallthrough
case
billingCircuitHalfOpen
:
if
b
.
halfOpenRemaining
<=
0
{
return
false
}
b
.
halfOpenRemaining
--
return
true
default
:
return
false
}
}
if
!
subscription
.
CheckDailyLimit
(
group
,
0
)
{
return
ErrDailyLimitExceeded
func
(
b
*
billingCircuitBreaker
)
OnFailure
(
err
error
)
{
if
b
==
nil
{
return
}
b
.
mu
.
Lock
()
defer
b
.
mu
.
Unlock
()
if
!
subscription
.
CheckWeeklyLimit
(
group
,
0
)
{
return
ErrWeeklyLimitExceeded
switch
b
.
state
{
case
billingCircuitOpen
:
return
case
billingCircuitHalfOpen
:
b
.
state
=
billingCircuitOpen
b
.
openedAt
=
time
.
Now
()
b
.
halfOpenRemaining
=
0
log
.
Printf
(
"ALERT: billing circuit breaker opened after half-open failure: %v"
,
err
)
return
default
:
b
.
failures
++
if
b
.
failures
>=
b
.
failureThreshold
{
b
.
state
=
billingCircuitOpen
b
.
openedAt
=
time
.
Now
()
b
.
halfOpenRemaining
=
0
log
.
Printf
(
"ALERT: billing circuit breaker opened after %d failures: %v"
,
b
.
failures
,
err
)
}
}
}
if
!
subscription
.
CheckMonthlyLimit
(
group
,
0
)
{
return
ErrMonthlyLimitExceeded
func
(
b
*
billingCircuitBreaker
)
OnSuccess
()
{
if
b
==
nil
{
return
}
b
.
mu
.
Lock
()
defer
b
.
mu
.
Unlock
()
return
nil
previousState
:=
b
.
state
previousFailures
:=
b
.
failures
b
.
state
=
billingCircuitClosed
b
.
failures
=
0
b
.
halfOpenRemaining
=
0
// 只有状态真正发生变化时才记录日志
if
previousState
!=
billingCircuitClosed
{
log
.
Printf
(
"ALERT: billing circuit breaker closed (was %s)"
,
circuitStateString
(
previousState
))
}
else
if
previousFailures
>
0
{
log
.
Printf
(
"INFO: billing circuit breaker failures reset from %d"
,
previousFailures
)
}
}
func
circuitStateString
(
state
billingCircuitBreakerState
)
string
{
switch
state
{
case
billingCircuitClosed
:
return
"closed"
case
billingCircuitOpen
:
return
"open"
case
billingCircuitHalfOpen
:
return
"half-open"
default
:
return
"unknown"
}
}
backend/internal/service/crs_sync_service.go
View file @
195e227c
...
...
@@ -8,12 +8,13 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
)
type
CRSSyncService
struct
{
...
...
@@ -22,6 +23,7 @@ type CRSSyncService struct {
oauthService
*
OAuthService
openaiOAuthService
*
OpenAIOAuthService
geminiOAuthService
*
GeminiOAuthService
cfg
*
config
.
Config
}
func
NewCRSSyncService
(
...
...
@@ -30,6 +32,7 @@ func NewCRSSyncService(
oauthService
*
OAuthService
,
openaiOAuthService
*
OpenAIOAuthService
,
geminiOAuthService
*
GeminiOAuthService
,
cfg
*
config
.
Config
,
)
*
CRSSyncService
{
return
&
CRSSyncService
{
accountRepo
:
accountRepo
,
...
...
@@ -37,6 +40,7 @@ func NewCRSSyncService(
oauthService
:
oauthService
,
openaiOAuthService
:
openaiOAuthService
,
geminiOAuthService
:
geminiOAuthService
,
cfg
:
cfg
,
}
}
...
...
@@ -187,16 +191,31 @@ type crsGeminiAPIKeyAccount struct {
}
func
(
s
*
CRSSyncService
)
SyncFromCRS
(
ctx
context
.
Context
,
input
SyncFromCRSInput
)
(
*
SyncFromCRSResult
,
error
)
{
baseURL
,
err
:=
normalizeBaseURL
(
input
.
BaseURL
)
if
s
.
cfg
==
nil
{
return
nil
,
errors
.
New
(
"config is not available"
)
}
baseURL
:=
strings
.
TrimSpace
(
input
.
BaseURL
)
if
s
.
cfg
.
Security
.
URLAllowlist
.
Enabled
{
normalized
,
err
:=
normalizeBaseURL
(
baseURL
,
s
.
cfg
.
Security
.
URLAllowlist
.
CRSHosts
,
s
.
cfg
.
Security
.
URLAllowlist
.
AllowPrivateHosts
)
if
err
!=
nil
{
return
nil
,
err
}
baseURL
=
normalized
}
else
{
normalized
,
err
:=
urlvalidator
.
ValidateURLFormat
(
baseURL
,
s
.
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"invalid base_url: %w"
,
err
)
}
baseURL
=
normalized
}
if
strings
.
TrimSpace
(
input
.
Username
)
==
""
||
strings
.
TrimSpace
(
input
.
Password
)
==
""
{
return
nil
,
errors
.
New
(
"username and password are required"
)
}
client
,
err
:=
httpclient
.
GetClient
(
httpclient
.
Options
{
Timeout
:
20
*
time
.
Second
,
ValidateResolvedIP
:
s
.
cfg
.
Security
.
URLAllowlist
.
Enabled
,
AllowPrivateHosts
:
s
.
cfg
.
Security
.
URLAllowlist
.
AllowPrivateHosts
,
})
if
err
!=
nil
{
client
=
&
http
.
Client
{
Timeout
:
20
*
time
.
Second
}
...
...
@@ -1055,17 +1074,18 @@ func mapCRSStatus(isActive bool, status string) string {
return
"active"
}
func
normalizeBaseURL
(
raw
string
)
(
string
,
error
)
{
trimmed
:=
strings
.
TrimSpace
(
raw
)
if
trimmed
==
""
{
return
""
,
errors
.
New
(
"base_url is required"
)
}
u
,
err
:=
url
.
Parse
(
trimmed
)
if
err
!=
nil
||
u
.
Scheme
==
""
||
u
.
Host
==
""
{
return
""
,
fmt
.
Errorf
(
"invalid base_url: %s"
,
trimmed
)
func
normalizeBaseURL
(
raw
string
,
allowlist
[]
string
,
allowPrivate
bool
)
(
string
,
error
)
{
// 当 allowlist 为空时,不强制要求白名单(只进行基本的 URL 和 SSRF 验证)
requireAllowlist
:=
len
(
allowlist
)
>
0
normalized
,
err
:=
urlvalidator
.
ValidateHTTPSURL
(
raw
,
urlvalidator
.
ValidationOptions
{
AllowedHosts
:
allowlist
,
RequireAllowlist
:
requireAllowlist
,
AllowPrivate
:
allowPrivate
,
})
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"invalid base_url: %w"
,
err
)
}
u
.
Path
=
strings
.
TrimRight
(
u
.
Path
,
"/"
)
return
strings
.
TrimRight
(
u
.
String
(),
"/"
),
nil
return
normalized
,
nil
}
// cleanBaseURL removes trailing suffix from base_url in credentials
...
...
backend/internal/service/domain_constants.go
View file @
195e227c
...
...
@@ -101,6 +101,10 @@ const (
SettingKeyFallbackModelOpenAI
=
"fallback_model_openai"
SettingKeyFallbackModelGemini
=
"fallback_model_gemini"
SettingKeyFallbackModelAntigravity
=
"fallback_model_antigravity"
// Request identity patch (Claude -> Gemini systemInstruction injection)
SettingKeyEnableIdentityPatch
=
"enable_identity_patch"
SettingKeyIdentityPatchPrompt
=
"identity_patch_prompt"
)
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
...
...
backend/internal/service/gateway_request.go
View file @
195e227c
...
...
@@ -84,25 +84,37 @@ func FilterThinkingBlocks(body []byte) []byte {
return
filterThinkingBlocksInternal
(
body
,
false
)
}
// FilterThinkingBlocksForRetry removes thinking blocks from HISTORICAL messages for retry scenarios.
// This is used when upstream returns signature-related 400 errors.
// FilterThinkingBlocksForRetry strips thinking-related constructs for retry scenarios.
//
// Key insight:
// - User's thinking.type = "enabled" should be PRESERVED (user's intent)
// - Only HISTORICAL assistant messages have thinking blocks with signatures
// - These signatures may be invalid when switching accounts/platforms
// - New responses will generate fresh thinking blocks without signature issues
// Why:
// - Upstreams may reject historical `thinking`/`redacted_thinking` blocks due to invalid/missing signatures.
// - Anthropic extended thinking has a structural constraint: when top-level `thinking` is enabled and the
// final message is an assistant prefill, the assistant content must start with a thinking block.
// - If we remove thinking blocks but keep top-level `thinking` enabled, we can trigger:
// "Expected `thinking` or `redacted_thinking`, but found `text`"
//
// Strategy:
// - Keep thinking.type = "enabled" (preserve user intent)
// - Remove thinking/redacted_thinking blocks from historical assistant messages
// - Ensure no message has empty content after filtering
// Strategy (B: preserve content as text):
// - Disable top-level `thinking` (remove `thinking` field).
// - Convert `thinking` blocks to `text` blocks (preserve the thinking content).
// - Remove `redacted_thinking` blocks (cannot be converted to text).
// - Ensure no message ends up with empty content.
func
FilterThinkingBlocksForRetry
(
body
[]
byte
)
[]
byte
{
// Fast path: check for presence of thinking-related keys in messages
if
!
bytes
.
Contains
(
body
,
[]
byte
(
`"type":"thinking"`
))
&&
!
bytes
.
Contains
(
body
,
[]
byte
(
`"type": "thinking"`
))
&&
!
bytes
.
Contains
(
body
,
[]
byte
(
`"type":"redacted_thinking"`
))
&&
!
bytes
.
Contains
(
body
,
[]
byte
(
`"type": "redacted_thinking"`
))
{
hasThinkingContent
:=
bytes
.
Contains
(
body
,
[]
byte
(
`"type":"thinking"`
))
||
bytes
.
Contains
(
body
,
[]
byte
(
`"type": "thinking"`
))
||
bytes
.
Contains
(
body
,
[]
byte
(
`"type":"redacted_thinking"`
))
||
bytes
.
Contains
(
body
,
[]
byte
(
`"type": "redacted_thinking"`
))
||
bytes
.
Contains
(
body
,
[]
byte
(
`"thinking":`
))
||
bytes
.
Contains
(
body
,
[]
byte
(
`"thinking" :`
))
// Also check for empty content arrays that need fixing.
// Note: This is a heuristic check; the actual empty content handling is done below.
hasEmptyContent
:=
bytes
.
Contains
(
body
,
[]
byte
(
`"content":[]`
))
||
bytes
.
Contains
(
body
,
[]
byte
(
`"content": []`
))
||
bytes
.
Contains
(
body
,
[]
byte
(
`"content" : []`
))
||
bytes
.
Contains
(
body
,
[]
byte
(
`"content" :[]`
))
// Fast path: nothing to process
if
!
hasThinkingContent
&&
!
hasEmptyContent
{
return
body
}
...
...
@@ -111,15 +123,19 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
return
body
}
// DO NOT modify thinking.type - preserve user's intent to use thinking mode
// The issue is with historical message signatures, not the thinking mode itself
modified
:=
false
messages
,
ok
:=
req
[
"messages"
]
.
([]
any
)
if
!
ok
{
return
body
}
modified
:=
false
// Disable top-level thinking mode for retry to avoid structural/signature constraints upstream.
if
_
,
exists
:=
req
[
"thinking"
];
exists
{
delete
(
req
,
"thinking"
)
modified
=
true
}
newMessages
:=
make
([]
any
,
0
,
len
(
messages
))
for
_
,
msg
:=
range
messages
{
...
...
@@ -149,42 +165,237 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
blockType
,
_
:=
blockMap
[
"type"
]
.
(
string
)
// Remove thinking/redacted_thinking blocks from historical messages
// These have signatures that may be invalid across different accounts
if
blockType
==
"thinking"
||
blockType
==
"redacted_thinking"
{
// Convert thinking blocks to text (preserve content) and drop redacted_thinking.
switch
blockType
{
case
"thinking"
:
modifiedThisMsg
=
true
thinkingText
,
_
:=
blockMap
[
"thinking"
]
.
(
string
)
if
thinkingText
==
""
{
continue
}
newContent
=
append
(
newContent
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
thinkingText
,
})
continue
case
"redacted_thinking"
:
modifiedThisMsg
=
true
continue
}
// Handle blocks without type discriminator but with a "thinking" field.
if
blockType
==
""
{
if
rawThinking
,
hasThinking
:=
blockMap
[
"thinking"
];
hasThinking
{
modifiedThisMsg
=
true
switch
v
:=
rawThinking
.
(
type
)
{
case
string
:
if
v
!=
""
{
newContent
=
append
(
newContent
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
v
})
}
default
:
if
b
,
err
:=
json
.
Marshal
(
v
);
err
==
nil
&&
len
(
b
)
>
0
{
newContent
=
append
(
newContent
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
string
(
b
)})
}
}
continue
}
}
newContent
=
append
(
newContent
,
block
)
}
if
modifiedThisMsg
{
modified
=
true
// Handle empty content after filtering
// Handle empty content: either from filtering or originally empty
if
len
(
newContent
)
==
0
{
// For assistant messages, skip entirely (remove from conversation)
// For user messages, add placeholder to avoid empty content error
if
role
==
"user"
{
modified
=
true
placeholder
:=
"(content removed)"
if
role
==
"assistant"
{
placeholder
=
"(assistant content removed)"
}
newContent
=
append
(
newContent
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
"(content removed)"
,
"text"
:
placeholder
,
})
msgMap
[
"content"
]
=
newContent
}
else
if
modifiedThisMsg
{
modified
=
true
msgMap
[
"content"
]
=
newContent
}
newMessages
=
append
(
newMessages
,
msgMap
)
}
// Skip assistant messages with empty content (don't append)
if
modified
{
req
[
"messages"
]
=
newMessages
}
else
{
// Avoid rewriting JSON when no changes are needed.
return
body
}
newBody
,
err
:=
json
.
Marshal
(
req
)
if
err
!=
nil
{
return
body
}
return
newBody
}
// FilterSignatureSensitiveBlocksForRetry is a stronger retry filter for cases where upstream errors indicate
// signature/thought_signature validation issues involving tool blocks.
//
// This performs everything in FilterThinkingBlocksForRetry, plus:
// - Convert `tool_use` blocks to text (name/id/input) so we stop sending structured tool calls.
// - Convert `tool_result` blocks to text so we keep tool results visible without tool semantics.
//
// Use this only when needed: converting tool blocks to text changes model behaviour and can increase the
// risk of prompt injection (tool output becomes plain conversation text).
func
FilterSignatureSensitiveBlocksForRetry
(
body
[]
byte
)
[]
byte
{
// Fast path: only run when we see likely relevant constructs.
if
!
bytes
.
Contains
(
body
,
[]
byte
(
`"type":"thinking"`
))
&&
!
bytes
.
Contains
(
body
,
[]
byte
(
`"type": "thinking"`
))
&&
!
bytes
.
Contains
(
body
,
[]
byte
(
`"type":"redacted_thinking"`
))
&&
!
bytes
.
Contains
(
body
,
[]
byte
(
`"type": "redacted_thinking"`
))
&&
!
bytes
.
Contains
(
body
,
[]
byte
(
`"type":"tool_use"`
))
&&
!
bytes
.
Contains
(
body
,
[]
byte
(
`"type": "tool_use"`
))
&&
!
bytes
.
Contains
(
body
,
[]
byte
(
`"type":"tool_result"`
))
&&
!
bytes
.
Contains
(
body
,
[]
byte
(
`"type": "tool_result"`
))
&&
!
bytes
.
Contains
(
body
,
[]
byte
(
`"thinking":`
))
&&
!
bytes
.
Contains
(
body
,
[]
byte
(
`"thinking" :`
))
{
return
body
}
var
req
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
req
);
err
!=
nil
{
return
body
}
modified
:=
false
// Disable top-level thinking for retry to avoid structural/signature constraints upstream.
if
_
,
exists
:=
req
[
"thinking"
];
exists
{
delete
(
req
,
"thinking"
)
modified
=
true
}
messages
,
ok
:=
req
[
"messages"
]
.
([]
any
)
if
!
ok
{
return
body
}
newMessages
:=
make
([]
any
,
0
,
len
(
messages
))
for
_
,
msg
:=
range
messages
{
msgMap
,
ok
:=
msg
.
(
map
[
string
]
any
)
if
!
ok
{
newMessages
=
append
(
newMessages
,
msg
)
continue
}
role
,
_
:=
msgMap
[
"role"
]
.
(
string
)
content
,
ok
:=
msgMap
[
"content"
]
.
([]
any
)
if
!
ok
{
newMessages
=
append
(
newMessages
,
msg
)
continue
}
newContent
:=
make
([]
any
,
0
,
len
(
content
))
modifiedThisMsg
:=
false
for
_
,
block
:=
range
content
{
blockMap
,
ok
:=
block
.
(
map
[
string
]
any
)
if
!
ok
{
newContent
=
append
(
newContent
,
block
)
continue
}
blockType
,
_
:=
blockMap
[
"type"
]
.
(
string
)
switch
blockType
{
case
"thinking"
:
modifiedThisMsg
=
true
thinkingText
,
_
:=
blockMap
[
"thinking"
]
.
(
string
)
if
thinkingText
==
""
{
continue
}
newContent
=
append
(
newContent
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
thinkingText
})
continue
case
"redacted_thinking"
:
modifiedThisMsg
=
true
continue
case
"tool_use"
:
modifiedThisMsg
=
true
name
,
_
:=
blockMap
[
"name"
]
.
(
string
)
id
,
_
:=
blockMap
[
"id"
]
.
(
string
)
input
:=
blockMap
[
"input"
]
inputJSON
,
_
:=
json
.
Marshal
(
input
)
text
:=
"(tool_use)"
if
name
!=
""
{
text
+=
" name="
+
name
}
if
id
!=
""
{
text
+=
" id="
+
id
}
if
len
(
inputJSON
)
>
0
&&
string
(
inputJSON
)
!=
"null"
{
text
+=
" input="
+
string
(
inputJSON
)
}
newContent
=
append
(
newContent
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
text
})
continue
case
"tool_result"
:
modifiedThisMsg
=
true
toolUseID
,
_
:=
blockMap
[
"tool_use_id"
]
.
(
string
)
isError
,
_
:=
blockMap
[
"is_error"
]
.
(
bool
)
content
:=
blockMap
[
"content"
]
contentJSON
,
_
:=
json
.
Marshal
(
content
)
text
:=
"(tool_result)"
if
toolUseID
!=
""
{
text
+=
" tool_use_id="
+
toolUseID
}
if
isError
{
text
+=
" is_error=true"
}
if
len
(
contentJSON
)
>
0
&&
string
(
contentJSON
)
!=
"null"
{
text
+=
"
\n
"
+
string
(
contentJSON
)
}
newContent
=
append
(
newContent
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
text
})
continue
}
if
blockType
==
""
{
if
rawThinking
,
hasThinking
:=
blockMap
[
"thinking"
];
hasThinking
{
modifiedThisMsg
=
true
switch
v
:=
rawThinking
.
(
type
)
{
case
string
:
if
v
!=
""
{
newContent
=
append
(
newContent
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
v
})
}
default
:
if
b
,
err
:=
json
.
Marshal
(
v
);
err
==
nil
&&
len
(
b
)
>
0
{
newContent
=
append
(
newContent
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
string
(
b
)})
}
}
continue
}
}
newContent
=
append
(
newContent
,
block
)
}
if
modifiedThisMsg
{
modified
=
true
if
len
(
newContent
)
==
0
{
placeholder
:=
"(content removed)"
if
role
==
"assistant"
{
placeholder
=
"(assistant content removed)"
}
newContent
=
append
(
newContent
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
placeholder
})
}
msgMap
[
"content"
]
=
newContent
}
newMessages
=
append
(
newMessages
,
msgMap
)
}
if
modified
{
re
q
[
"messages"
]
=
newMessages
if
!
modified
{
re
turn
body
}
req
[
"messages"
]
=
newMessages
newBody
,
err
:=
json
.
Marshal
(
req
)
if
err
!=
nil
{
return
body
...
...
backend/internal/service/gateway_request_test.go
View file @
195e227c
...
...
@@ -151,3 +151,148 @@ func TestFilterThinkingBlocks(t *testing.T) {
})
}
}
func
TestFilterThinkingBlocksForRetry_DisablesThinkingAndPreservesAsText
(
t
*
testing
.
T
)
{
input
:=
[]
byte
(
`{
"model":"claude-3-5-sonnet-20241022",
"thinking":{"type":"enabled","budget_tokens":1024},
"messages":[
{"role":"user","content":[{"type":"text","text":"Hi"}]},
{"role":"assistant","content":[
{"type":"thinking","thinking":"Let me think...","signature":"bad_sig"},
{"type":"text","text":"Answer"}
]}
]
}`
)
out
:=
FilterThinkingBlocksForRetry
(
input
)
var
req
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
out
,
&
req
))
_
,
hasThinking
:=
req
[
"thinking"
]
require
.
False
(
t
,
hasThinking
)
msgs
,
ok
:=
req
[
"messages"
]
.
([]
any
)
require
.
True
(
t
,
ok
)
require
.
Len
(
t
,
msgs
,
2
)
assistant
,
ok
:=
msgs
[
1
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
content
,
ok
:=
assistant
[
"content"
]
.
([]
any
)
require
.
True
(
t
,
ok
)
require
.
Len
(
t
,
content
,
2
)
first
,
ok
:=
content
[
0
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"text"
,
first
[
"type"
])
require
.
Equal
(
t
,
"Let me think..."
,
first
[
"text"
])
}
func
TestFilterThinkingBlocksForRetry_DisablesThinkingEvenWithoutThinkingBlocks
(
t
*
testing
.
T
)
{
input
:=
[]
byte
(
`{
"model":"claude-3-5-sonnet-20241022",
"thinking":{"type":"enabled","budget_tokens":1024},
"messages":[
{"role":"user","content":[{"type":"text","text":"Hi"}]},
{"role":"assistant","content":[{"type":"text","text":"Prefill"}]}
]
}`
)
out
:=
FilterThinkingBlocksForRetry
(
input
)
var
req
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
out
,
&
req
))
_
,
hasThinking
:=
req
[
"thinking"
]
require
.
False
(
t
,
hasThinking
)
}
func
TestFilterThinkingBlocksForRetry_RemovesRedactedThinkingAndKeepsValidContent
(
t
*
testing
.
T
)
{
input
:=
[]
byte
(
`{
"thinking":{"type":"enabled","budget_tokens":1024},
"messages":[
{"role":"assistant","content":[
{"type":"redacted_thinking","data":"..."},
{"type":"text","text":"Visible"}
]}
]
}`
)
out
:=
FilterThinkingBlocksForRetry
(
input
)
var
req
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
out
,
&
req
))
_
,
hasThinking
:=
req
[
"thinking"
]
require
.
False
(
t
,
hasThinking
)
msgs
,
ok
:=
req
[
"messages"
]
.
([]
any
)
require
.
True
(
t
,
ok
)
msg0
,
ok
:=
msgs
[
0
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
content
,
ok
:=
msg0
[
"content"
]
.
([]
any
)
require
.
True
(
t
,
ok
)
require
.
Len
(
t
,
content
,
1
)
content0
,
ok
:=
content
[
0
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"text"
,
content0
[
"type"
])
require
.
Equal
(
t
,
"Visible"
,
content0
[
"text"
])
}
func
TestFilterThinkingBlocksForRetry_EmptyContentGetsPlaceholder
(
t
*
testing
.
T
)
{
input
:=
[]
byte
(
`{
"thinking":{"type":"enabled"},
"messages":[
{"role":"assistant","content":[{"type":"redacted_thinking","data":"..."}]}
]
}`
)
out
:=
FilterThinkingBlocksForRetry
(
input
)
var
req
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
out
,
&
req
))
msgs
,
ok
:=
req
[
"messages"
]
.
([]
any
)
require
.
True
(
t
,
ok
)
msg0
,
ok
:=
msgs
[
0
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
content
,
ok
:=
msg0
[
"content"
]
.
([]
any
)
require
.
True
(
t
,
ok
)
require
.
Len
(
t
,
content
,
1
)
content0
,
ok
:=
content
[
0
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"text"
,
content0
[
"type"
])
require
.
NotEmpty
(
t
,
content0
[
"text"
])
}
func
TestFilterSignatureSensitiveBlocksForRetry_DowngradesTools
(
t
*
testing
.
T
)
{
input
:=
[]
byte
(
`{
"thinking":{"type":"enabled","budget_tokens":1024},
"messages":[
{"role":"assistant","content":[
{"type":"tool_use","id":"t1","name":"Bash","input":{"command":"ls"}},
{"type":"tool_result","tool_use_id":"t1","content":"ok","is_error":false}
]}
]
}`
)
out
:=
FilterSignatureSensitiveBlocksForRetry
(
input
)
var
req
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
out
,
&
req
))
_
,
hasThinking
:=
req
[
"thinking"
]
require
.
False
(
t
,
hasThinking
)
msgs
,
ok
:=
req
[
"messages"
]
.
([]
any
)
require
.
True
(
t
,
ok
)
msg0
,
ok
:=
msgs
[
0
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
content
,
ok
:=
msg0
[
"content"
]
.
([]
any
)
require
.
True
(
t
,
ok
)
require
.
Len
(
t
,
content
,
2
)
content0
,
ok
:=
content
[
0
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
content1
,
ok
:=
content
[
1
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"text"
,
content0
[
"type"
])
require
.
Equal
(
t
,
"text"
,
content1
[
"type"
])
require
.
Contains
(
t
,
content0
[
"text"
],
"tool_use"
)
require
.
Contains
(
t
,
content1
[
"text"
],
"tool_result"
)
}
backend/internal/service/gateway_service.go
View file @
195e227c
...
...
@@ -15,11 +15,14 @@ import (
"regexp"
"sort"
"strings"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
...
...
@@ -30,6 +33,7 @@ const (
claudeAPIURL
=
"https://api.anthropic.com/v1/messages?beta=true"
claudeAPICountTokensURL
=
"https://api.anthropic.com/v1/messages/count_tokens?beta=true"
stickySessionTTL
=
time
.
Hour
// 粘性会话TTL
defaultMaxLineSize
=
10
*
1024
*
1024
claudeCodeSystemPrompt
=
"You are Claude Code, Anthropic's official CLI for Claude."
)
...
...
@@ -933,8 +937,16 @@ func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (s
// 重试相关常量
const
(
maxRetries
=
10
// 最大重试次数
retryDelay
=
3
*
time
.
Second
// 重试等待时间
// 最大尝试次数(包含首次请求)。过多重试会导致请求堆积与资源耗尽。
maxRetryAttempts
=
5
// 指数退避:第 N 次失败后的等待 = retryBaseDelay * 2^(N-1),并且上限为 retryMaxDelay。
retryBaseDelay
=
300
*
time
.
Millisecond
retryMaxDelay
=
3
*
time
.
Second
// 最大重试耗时(包含请求本身耗时 + 退避等待时间)。
// 用于防止极端情况下 goroutine 长时间堆积导致资源耗尽。
maxRetryElapsed
=
10
*
time
.
Second
)
func
(
s
*
GatewayService
)
shouldRetryUpstreamError
(
account
*
Account
,
statusCode
int
)
bool
{
...
...
@@ -957,6 +969,40 @@ func (s *GatewayService) shouldFailoverUpstreamError(statusCode int) bool {
}
}
func
retryBackoffDelay
(
attempt
int
)
time
.
Duration
{
// attempt 从 1 开始,表示第 attempt 次请求刚失败,需要等待后进行第 attempt+1 次请求。
if
attempt
<=
0
{
return
retryBaseDelay
}
delay
:=
retryBaseDelay
*
time
.
Duration
(
1
<<
(
attempt
-
1
))
if
delay
>
retryMaxDelay
{
return
retryMaxDelay
}
return
delay
}
func
sleepWithContext
(
ctx
context
.
Context
,
d
time
.
Duration
)
error
{
if
d
<=
0
{
return
nil
}
timer
:=
time
.
NewTimer
(
d
)
defer
func
()
{
if
!
timer
.
Stop
()
{
select
{
case
<-
timer
.
C
:
default
:
}
}
}()
select
{
case
<-
ctx
.
Done
()
:
return
ctx
.
Err
()
case
<-
timer
.
C
:
return
nil
}
}
// isClaudeCodeClient 判断请求是否来自 Claude Code 客户端
// 简化判断:User-Agent 匹配 + metadata.user_id 存在
func
isClaudeCodeClient
(
userAgent
string
,
metadataUserID
string
)
bool
{
...
...
@@ -1073,7 +1119,8 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 重试循环
var
resp
*
http
.
Response
for
attempt
:=
1
;
attempt
<=
maxRetries
;
attempt
++
{
retryStart
:=
time
.
Now
()
for
attempt
:=
1
;
attempt
<=
maxRetryAttempts
;
attempt
++
{
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
upstreamReq
,
err
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
body
,
token
,
tokenType
,
reqModel
)
if
err
!=
nil
{
...
...
@@ -1083,6 +1130,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 发送请求
resp
,
err
=
s
.
httpUpstream
.
Do
(
upstreamReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
err
!=
nil
{
if
resp
!=
nil
&&
resp
.
Body
!=
nil
{
_
=
resp
.
Body
.
Close
()
}
return
nil
,
fmt
.
Errorf
(
"upstream request failed: %w"
,
err
)
}
...
...
@@ -1093,28 +1143,80 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
_
=
resp
.
Body
.
Close
()
if
s
.
isThinkingBlockSignatureError
(
respBody
)
{
looksLikeToolSignatureError
:=
func
(
msg
string
)
bool
{
m
:=
strings
.
ToLower
(
msg
)
return
strings
.
Contains
(
m
,
"tool_use"
)
||
strings
.
Contains
(
m
,
"tool_result"
)
||
strings
.
Contains
(
m
,
"functioncall"
)
||
strings
.
Contains
(
m
,
"function_call"
)
||
strings
.
Contains
(
m
,
"functionresponse"
)
||
strings
.
Contains
(
m
,
"function_response"
)
}
// 避免在重试预算已耗尽时再发起额外请求
if
time
.
Since
(
retryStart
)
>=
maxRetryElapsed
{
resp
.
Body
=
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
))
break
}
log
.
Printf
(
"Account %d: detected thinking block signature error, retrying with filtered thinking blocks"
,
account
.
ID
)
// 过滤thinking blocks并重试(使用更激进的过滤)
// Conservative two-stage fallback:
// 1) Disable thinking + thinking->text (preserve content)
// 2) Only if upstream still errors AND error message points to tool/function signature issues:
// also downgrade tool_use/tool_result blocks to text.
filteredBody
:=
FilterThinkingBlocksForRetry
(
body
)
retryReq
,
buildErr
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
filteredBody
,
token
,
tokenType
,
reqModel
)
if
buildErr
==
nil
{
retryResp
,
retryErr
:=
s
.
httpUpstream
.
Do
(
retryReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
retryErr
==
nil
{
// 使用重试后的响应,继续后续处理
if
retryResp
.
StatusCode
<
400
{
log
.
Printf
(
"Account %d: signature error retry succeeded"
,
account
.
ID
)
log
.
Printf
(
"Account %d: signature error retry succeeded (thinking downgraded)"
,
account
.
ID
)
resp
=
retryResp
break
}
retryRespBody
,
retryReadErr
:=
io
.
ReadAll
(
io
.
LimitReader
(
retryResp
.
Body
,
2
<<
20
))
_
=
retryResp
.
Body
.
Close
()
if
retryReadErr
==
nil
&&
retryResp
.
StatusCode
==
400
&&
s
.
isThinkingBlockSignatureError
(
retryRespBody
)
{
msg2
:=
extractUpstreamErrorMessage
(
retryRespBody
)
if
looksLikeToolSignatureError
(
msg2
)
&&
time
.
Since
(
retryStart
)
<
maxRetryElapsed
{
log
.
Printf
(
"Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded"
,
account
.
ID
)
filteredBody2
:=
FilterSignatureSensitiveBlocksForRetry
(
body
)
retryReq2
,
buildErr2
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
filteredBody2
,
token
,
tokenType
,
reqModel
)
if
buildErr2
==
nil
{
retryResp2
,
retryErr2
:=
s
.
httpUpstream
.
Do
(
retryReq2
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
retryErr2
==
nil
{
resp
=
retryResp2
break
}
if
retryResp2
!=
nil
&&
retryResp2
.
Body
!=
nil
{
_
=
retryResp2
.
Body
.
Close
()
}
log
.
Printf
(
"Account %d: tool-downgrade signature retry failed: %v"
,
account
.
ID
,
retryErr2
)
}
else
{
log
.
Printf
(
"Account %d: signature error retry returned status %d"
,
account
.
ID
,
retryResp
.
StatusCode
)
log
.
Printf
(
"Account %d: tool-downgrade signature retry build failed: %v"
,
account
.
ID
,
buildErr2
)
}
}
}
// Fall back to the original retry response context.
resp
=
&
http
.
Response
{
StatusCode
:
retryResp
.
StatusCode
,
Header
:
retryResp
.
Header
.
Clone
(),
Body
:
io
.
NopCloser
(
bytes
.
NewReader
(
retryRespBody
)),
}
resp
=
retryResp
break
}
if
retryResp
!=
nil
&&
retryResp
.
Body
!=
nil
{
_
=
retryResp
.
Body
.
Close
()
}
log
.
Printf
(
"Account %d: signature error retry failed: %v"
,
account
.
ID
,
retryErr
)
}
else
{
log
.
Printf
(
"Account %d: signature error retry build request failed: %v"
,
account
.
ID
,
buildErr
)
}
// 重试失败,恢复原始响应体继续处理
// Retry failed: restore original response body and continue handling.
resp
.
Body
=
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
))
break
}
...
...
@@ -1125,11 +1227,27 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 检查是否需要通用重试(排除400,因为400已经在上面特殊处理过了)
if
resp
.
StatusCode
>=
400
&&
resp
.
StatusCode
!=
400
&&
s
.
shouldRetryUpstreamError
(
account
,
resp
.
StatusCode
)
{
if
attempt
<
maxRetries
{
log
.
Printf
(
"Account %d: upstream error %d, retry %d/%d after %v"
,
account
.
ID
,
resp
.
StatusCode
,
attempt
,
maxRetries
,
retryDelay
)
if
attempt
<
maxRetryAttempts
{
elapsed
:=
time
.
Since
(
retryStart
)
if
elapsed
>=
maxRetryElapsed
{
break
}
delay
:=
retryBackoffDelay
(
attempt
)
remaining
:=
maxRetryElapsed
-
elapsed
if
delay
>
remaining
{
delay
=
remaining
}
if
delay
<=
0
{
break
}
log
.
Printf
(
"Account %d: upstream error %d, retry %d/%d after %v (elapsed=%v/%v)"
,
account
.
ID
,
resp
.
StatusCode
,
attempt
,
maxRetryAttempts
,
delay
,
elapsed
,
maxRetryElapsed
)
_
=
resp
.
Body
.
Close
()
time
.
Sleep
(
retryDelay
)
if
err
:=
sleepWithContext
(
ctx
,
delay
);
err
!=
nil
{
return
nil
,
err
}
continue
}
// 最后一次尝试也失败,跳出循环处理重试耗尽
...
...
@@ -1146,6 +1264,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}
break
}
if
resp
==
nil
||
resp
.
Body
==
nil
{
return
nil
,
errors
.
New
(
"upstream request failed: empty response"
)
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
// 处理重试耗尽的情况
...
...
@@ -1229,7 +1350,13 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
targetURL
:=
claudeAPIURL
if
account
.
Type
==
AccountTypeAPIKey
{
baseURL
:=
account
.
GetBaseURL
()
targetURL
=
baseURL
+
"/v1/messages"
if
baseURL
!=
""
{
validatedURL
,
err
:=
s
.
validateUpstreamBaseURL
(
baseURL
)
if
err
!=
nil
{
return
nil
,
err
}
targetURL
=
validatedURL
+
"/v1/messages"
}
}
// OAuth账号:应用统一指纹
...
...
@@ -1537,10 +1664,10 @@ func (s *GatewayService) handleRetryExhaustedSideEffects(ctx context.Context, re
// OAuth/Setup Token 账号的 403:标记账号异常
if
account
.
IsOAuth
()
&&
statusCode
==
403
{
s
.
rateLimitService
.
HandleUpstreamError
(
ctx
,
account
,
statusCode
,
resp
.
Header
,
body
)
log
.
Printf
(
"Account %d: marked as error after %d retries for status %d"
,
account
.
ID
,
maxRetr
ie
s
,
statusCode
)
log
.
Printf
(
"Account %d: marked as error after %d retries for status %d"
,
account
.
ID
,
maxRetr
yAttempt
s
,
statusCode
)
}
else
{
// API Key 未配置错误码:不标记账号状态
log
.
Printf
(
"Account %d: upstream error %d after %d retries (not marking account)"
,
account
.
ID
,
statusCode
,
maxRetr
ie
s
)
log
.
Printf
(
"Account %d: upstream error %d after %d retries (not marking account)"
,
account
.
ID
,
statusCode
,
maxRetr
yAttempt
s
)
}
}
...
...
@@ -1577,6 +1704,10 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
// 更新5h窗口状态
s
.
rateLimitService
.
UpdateSessionWindow
(
ctx
,
account
,
resp
.
Header
)
if
s
.
cfg
!=
nil
{
responseheaders
.
WriteFilteredHeaders
(
c
.
Writer
.
Header
(),
resp
.
Header
,
s
.
cfg
.
Security
.
ResponseHeaders
)
}
// 设置SSE响应头
c
.
Header
(
"Content-Type"
,
"text/event-stream"
)
c
.
Header
(
"Cache-Control"
,
"no-cache"
)
...
...
@@ -1598,12 +1729,87 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
var
firstTokenMs
*
int
scanner
:=
bufio
.
NewScanner
(
resp
.
Body
)
// 设置更大的buffer以处理长行
scanner
.
Buffer
(
make
([]
byte
,
64
*
1024
),
1024
*
1024
)
maxLineSize
:=
defaultMaxLineSize
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
MaxLineSize
>
0
{
maxLineSize
=
s
.
cfg
.
Gateway
.
MaxLineSize
}
scanner
.
Buffer
(
make
([]
byte
,
64
*
1024
),
maxLineSize
)
type
scanEvent
struct
{
line
string
err
error
}
// 独立 goroutine 读取上游,避免读取阻塞导致超时/keepalive无法处理
events
:=
make
(
chan
scanEvent
,
16
)
done
:=
make
(
chan
struct
{})
sendEvent
:=
func
(
ev
scanEvent
)
bool
{
select
{
case
events
<-
ev
:
return
true
case
<-
done
:
return
false
}
}
var
lastReadAt
int64
atomic
.
StoreInt64
(
&
lastReadAt
,
time
.
Now
()
.
UnixNano
())
go
func
()
{
defer
close
(
events
)
for
scanner
.
Scan
()
{
atomic
.
StoreInt64
(
&
lastReadAt
,
time
.
Now
()
.
UnixNano
())
if
!
sendEvent
(
scanEvent
{
line
:
scanner
.
Text
()})
{
return
}
}
if
err
:=
scanner
.
Err
();
err
!=
nil
{
_
=
sendEvent
(
scanEvent
{
err
:
err
})
}
}()
defer
close
(
done
)
streamInterval
:=
time
.
Duration
(
0
)
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
StreamDataIntervalTimeout
>
0
{
streamInterval
=
time
.
Duration
(
s
.
cfg
.
Gateway
.
StreamDataIntervalTimeout
)
*
time
.
Second
}
// 仅监控上游数据间隔超时,避免下游写入阻塞导致误判
var
intervalTicker
*
time
.
Ticker
if
streamInterval
>
0
{
intervalTicker
=
time
.
NewTicker
(
streamInterval
)
defer
intervalTicker
.
Stop
()
}
var
intervalCh
<-
chan
time
.
Time
if
intervalTicker
!=
nil
{
intervalCh
=
intervalTicker
.
C
}
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
errorEventSent
:=
false
sendErrorEvent
:=
func
(
reason
string
)
{
if
errorEventSent
{
return
}
errorEventSent
=
true
_
,
_
=
fmt
.
Fprintf
(
w
,
"event: error
\n
data: {
\"
error
\"
:
\"
%s
\"
}
\n\n
"
,
reason
)
flusher
.
Flush
()
}
needModelReplace
:=
originalModel
!=
mappedModel
for
scanner
.
Scan
()
{
line
:=
scanner
.
Text
()
for
{
select
{
case
ev
,
ok
:=
<-
events
:
if
!
ok
{
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
}
if
ev
.
err
!=
nil
{
if
errors
.
Is
(
ev
.
err
,
bufio
.
ErrTooLong
)
{
log
.
Printf
(
"SSE line too long: account=%d max_size=%d error=%v"
,
account
.
ID
,
maxLineSize
,
ev
.
err
)
sendErrorEvent
(
"response_too_large"
)
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
ev
.
err
}
sendErrorEvent
(
"stream_read_error"
)
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream read error: %w"
,
ev
.
err
)
}
line
:=
ev
.
line
if
line
==
"event: error"
{
return
nil
,
errors
.
New
(
"have error in stream"
)
}
...
...
@@ -1619,6 +1825,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
// 转发行
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
sendErrorEvent
(
"write_failed"
)
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
}
flusher
.
Flush
()
...
...
@@ -1632,17 +1839,23 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
else
{
// 非 data 行直接转发
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
sendErrorEvent
(
"write_failed"
)
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
}
flusher
.
Flush
()
}
}
if
err
:=
scanner
.
Err
();
err
!=
nil
{
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream read error: %w"
,
err
)
case
<-
intervalCh
:
lastRead
:=
time
.
Unix
(
0
,
atomic
.
LoadInt64
(
&
lastReadAt
))
if
time
.
Since
(
lastRead
)
<
streamInterval
{
continue
}
log
.
Printf
(
"Stream data interval timeout: account=%d model=%s interval=%s"
,
account
.
ID
,
originalModel
,
streamInterval
)
sendErrorEvent
(
"stream_timeout"
)
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream data interval timeout"
)
}
}
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
}
// replaceModelInSSELine 替换SSE数据行中的model字段
...
...
@@ -1747,15 +1960,17 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
body
=
s
.
replaceModelInResponseBody
(
body
,
mappedModel
,
originalModel
)
}
// 透传响应头
for
key
,
values
:=
range
resp
.
Header
{
for
_
,
value
:=
range
values
{
c
.
Header
(
key
,
value
)
responseheaders
.
WriteFilteredHeaders
(
c
.
Writer
.
Header
(),
resp
.
Header
,
s
.
cfg
.
Security
.
ResponseHeaders
)
contentType
:=
"application/json"
if
s
.
cfg
!=
nil
&&
!
s
.
cfg
.
Security
.
ResponseHeaders
.
Enabled
{
if
upstreamType
:=
resp
.
Header
.
Get
(
"Content-Type"
);
upstreamType
!=
""
{
contentType
=
upstreamType
}
}
// 写入响应
c
.
Data
(
resp
.
StatusCode
,
"application/json"
,
body
)
c
.
Data
(
resp
.
StatusCode
,
contentType
,
body
)
return
&
response
.
Usage
,
nil
}
...
...
@@ -1989,7 +2204,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
if
resp
.
StatusCode
==
400
&&
s
.
isThinkingBlockSignatureError
(
respBody
)
{
log
.
Printf
(
"Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks"
,
account
.
ID
)
filteredBody
:=
FilterThinkingBlocks
(
body
)
filteredBody
:=
FilterThinkingBlocks
ForRetry
(
body
)
retryReq
,
buildErr
:=
s
.
buildCountTokensRequest
(
ctx
,
c
,
account
,
filteredBody
,
token
,
tokenType
,
reqModel
)
if
buildErr
==
nil
{
retryResp
,
retryErr
:=
s
.
httpUpstream
.
Do
(
retryReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
...
...
@@ -2045,7 +2260,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
targetURL
:=
claudeAPICountTokensURL
if
account
.
Type
==
AccountTypeAPIKey
{
baseURL
:=
account
.
GetBaseURL
()
targetURL
=
baseURL
+
"/v1/messages/count_tokens"
if
baseURL
!=
""
{
validatedURL
,
err
:=
s
.
validateUpstreamBaseURL
(
baseURL
)
if
err
!=
nil
{
return
nil
,
err
}
targetURL
=
validatedURL
+
"/v1/messages/count_tokens"
}
}
// OAuth 账号:应用统一指纹和重写 userID
...
...
@@ -2125,6 +2346,25 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m
})
}
func
(
s
*
GatewayService
)
validateUpstreamBaseURL
(
raw
string
)
(
string
,
error
)
{
if
s
.
cfg
!=
nil
&&
!
s
.
cfg
.
Security
.
URLAllowlist
.
Enabled
{
normalized
,
err
:=
urlvalidator
.
ValidateURLFormat
(
raw
,
s
.
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
)
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"invalid base_url: %w"
,
err
)
}
return
normalized
,
nil
}
normalized
,
err
:=
urlvalidator
.
ValidateHTTPSURL
(
raw
,
urlvalidator
.
ValidationOptions
{
AllowedHosts
:
s
.
cfg
.
Security
.
URLAllowlist
.
UpstreamHosts
,
RequireAllowlist
:
true
,
AllowPrivate
:
s
.
cfg
.
Security
.
URLAllowlist
.
AllowPrivateHosts
,
})
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"invalid base_url: %w"
,
err
)
}
return
normalized
,
nil
}
// GetAvailableModels returns the list of models available for a group
// It aggregates model_mapping keys from all schedulable accounts in the group
func
(
s
*
GatewayService
)
GetAvailableModels
(
ctx
context
.
Context
,
groupID
*
int64
,
platform
string
)
[]
string
{
...
...
backend/internal/service/gemini_messages_compat_service.go
View file @
195e227c
...
...
@@ -18,9 +18,12 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
)
...
...
@@ -41,6 +44,7 @@ type GeminiMessagesCompatService struct {
rateLimitService
*
RateLimitService
httpUpstream
HTTPUpstream
antigravityGatewayService
*
AntigravityGatewayService
cfg
*
config
.
Config
}
func
NewGeminiMessagesCompatService
(
...
...
@@ -51,6 +55,7 @@ func NewGeminiMessagesCompatService(
rateLimitService
*
RateLimitService
,
httpUpstream
HTTPUpstream
,
antigravityGatewayService
*
AntigravityGatewayService
,
cfg
*
config
.
Config
,
)
*
GeminiMessagesCompatService
{
return
&
GeminiMessagesCompatService
{
accountRepo
:
accountRepo
,
...
...
@@ -60,6 +65,7 @@ func NewGeminiMessagesCompatService(
rateLimitService
:
rateLimitService
,
httpUpstream
:
httpUpstream
,
antigravityGatewayService
:
antigravityGatewayService
,
cfg
:
cfg
,
}
}
...
...
@@ -230,6 +236,25 @@ func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *Antigravit
return
s
.
antigravityGatewayService
}
func
(
s
*
GeminiMessagesCompatService
)
validateUpstreamBaseURL
(
raw
string
)
(
string
,
error
)
{
if
s
.
cfg
!=
nil
&&
!
s
.
cfg
.
Security
.
URLAllowlist
.
Enabled
{
normalized
,
err
:=
urlvalidator
.
ValidateURLFormat
(
raw
,
s
.
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
)
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"invalid base_url: %w"
,
err
)
}
return
normalized
,
nil
}
normalized
,
err
:=
urlvalidator
.
ValidateHTTPSURL
(
raw
,
urlvalidator
.
ValidationOptions
{
AllowedHosts
:
s
.
cfg
.
Security
.
URLAllowlist
.
UpstreamHosts
,
RequireAllowlist
:
true
,
AllowPrivate
:
s
.
cfg
.
Security
.
URLAllowlist
.
AllowPrivateHosts
,
})
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"invalid base_url: %w"
,
err
)
}
return
normalized
,
nil
}
// HasAntigravityAccounts 检查是否有可用的 antigravity 账户
func
(
s
*
GeminiMessagesCompatService
)
HasAntigravityAccounts
(
ctx
context
.
Context
,
groupID
*
int64
)
(
bool
,
error
)
{
var
accounts
[]
Account
...
...
@@ -359,6 +384,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
if
err
!=
nil
{
return
nil
,
s
.
writeClaudeError
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
err
.
Error
())
}
originalClaudeBody
:=
body
proxyURL
:=
""
if
account
.
ProxyID
!=
nil
&&
account
.
Proxy
!=
nil
{
...
...
@@ -381,16 +407,20 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
return
nil
,
""
,
errors
.
New
(
"gemini api_key not configured"
)
}
baseURL
:=
strings
.
Trim
Right
(
account
.
GetCredential
(
"base_url"
)
,
"/"
)
baseURL
:=
strings
.
Trim
Space
(
account
.
GetCredential
(
"base_url"
))
if
baseURL
==
""
{
baseURL
=
geminicli
.
AIStudioBaseURL
}
normalizedBaseURL
,
err
:=
s
.
validateUpstreamBaseURL
(
baseURL
)
if
err
!=
nil
{
return
nil
,
""
,
err
}
action
:=
"generateContent"
if
req
.
Stream
{
action
=
"streamGenerateContent"
}
fullURL
:=
fmt
.
Sprintf
(
"%s/v1beta/models/%s:%s"
,
strings
.
TrimRight
(
b
aseURL
,
"/"
),
mappedModel
,
action
)
fullURL
:=
fmt
.
Sprintf
(
"%s/v1beta/models/%s:%s"
,
strings
.
TrimRight
(
normalizedB
aseURL
,
"/"
),
mappedModel
,
action
)
if
req
.
Stream
{
fullURL
+=
"?alt=sse"
}
...
...
@@ -427,7 +457,11 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
if
projectID
!=
""
{
// Mode 1: Code Assist API
fullURL
:=
fmt
.
Sprintf
(
"%s/v1internal:%s"
,
geminicli
.
GeminiCliBaseURL
,
action
)
baseURL
,
err
:=
s
.
validateUpstreamBaseURL
(
geminicli
.
GeminiCliBaseURL
)
if
err
!=
nil
{
return
nil
,
""
,
err
}
fullURL
:=
fmt
.
Sprintf
(
"%s/v1internal:%s"
,
strings
.
TrimRight
(
baseURL
,
"/"
),
action
)
if
useUpstreamStream
{
fullURL
+=
"?alt=sse"
}
...
...
@@ -453,12 +487,16 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
return
upstreamReq
,
"x-request-id"
,
nil
}
else
{
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
baseURL
:=
strings
.
Trim
Right
(
account
.
GetCredential
(
"base_url"
)
,
"/"
)
baseURL
:=
strings
.
Trim
Space
(
account
.
GetCredential
(
"base_url"
))
if
baseURL
==
""
{
baseURL
=
geminicli
.
AIStudioBaseURL
}
normalizedBaseURL
,
err
:=
s
.
validateUpstreamBaseURL
(
baseURL
)
if
err
!=
nil
{
return
nil
,
""
,
err
}
fullURL
:=
fmt
.
Sprintf
(
"%s/v1beta/models/%s:%s"
,
baseURL
,
mappedModel
,
action
)
fullURL
:=
fmt
.
Sprintf
(
"%s/v1beta/models/%s:%s"
,
strings
.
TrimRight
(
normalizedBaseURL
,
"/"
)
,
mappedModel
,
action
)
if
useUpstreamStream
{
fullURL
+=
"?alt=sse"
}
...
...
@@ -479,6 +517,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}
var
resp
*
http
.
Response
signatureRetryStage
:=
0
for
attempt
:=
1
;
attempt
<=
geminiMaxRetries
;
attempt
++
{
upstreamReq
,
idHeader
,
err
:=
buildReq
(
ctx
)
if
err
!=
nil
{
...
...
@@ -503,6 +542,46 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
return
nil
,
s
.
writeClaudeError
(
c
,
http
.
StatusBadGateway
,
"upstream_error"
,
"Upstream request failed after retries: "
+
sanitizeUpstreamErrorMessage
(
err
.
Error
()))
}
// Special-case: signature/thought_signature validation errors are not transient, but may be fixed by
// downgrading Claude thinking/tool history to plain text (conservative two-stage retry).
if
resp
.
StatusCode
==
http
.
StatusBadRequest
&&
signatureRetryStage
<
2
{
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
_
=
resp
.
Body
.
Close
()
if
isGeminiSignatureRelatedError
(
respBody
)
{
var
strippedClaudeBody
[]
byte
stageName
:=
""
switch
signatureRetryStage
{
case
0
:
// Stage 1: disable thinking + thinking->text
strippedClaudeBody
=
FilterThinkingBlocksForRetry
(
originalClaudeBody
)
stageName
=
"thinking-only"
signatureRetryStage
=
1
default
:
// Stage 2: additionally downgrade tool_use/tool_result blocks to text
strippedClaudeBody
=
FilterSignatureSensitiveBlocksForRetry
(
originalClaudeBody
)
stageName
=
"thinking+tools"
signatureRetryStage
=
2
}
retryGeminiReq
,
txErr
:=
convertClaudeMessagesToGeminiGenerateContent
(
strippedClaudeBody
)
if
txErr
==
nil
{
log
.
Printf
(
"Gemini account %d: detected signature-related 400, retrying with downgraded Claude blocks (%s)"
,
account
.
ID
,
stageName
)
geminiReq
=
retryGeminiReq
// Consume one retry budget attempt and continue with the updated request payload.
sleepGeminiBackoff
(
1
)
continue
}
}
// Restore body for downstream error handling.
resp
=
&
http
.
Response
{
StatusCode
:
http
.
StatusBadRequest
,
Header
:
resp
.
Header
.
Clone
(),
Body
:
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
)),
}
break
}
if
resp
.
StatusCode
>=
400
&&
s
.
shouldRetryGeminiUpstreamError
(
account
,
resp
.
StatusCode
)
{
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
_
=
resp
.
Body
.
Close
()
...
...
@@ -600,6 +679,14 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
},
nil
}
func
isGeminiSignatureRelatedError
(
respBody
[]
byte
)
bool
{
msg
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
extractAntigravityErrorMessage
(
respBody
)))
if
msg
==
""
{
msg
=
strings
.
ToLower
(
string
(
respBody
))
}
return
strings
.
Contains
(
msg
,
"thought_signature"
)
||
strings
.
Contains
(
msg
,
"signature"
)
}
func
(
s
*
GeminiMessagesCompatService
)
ForwardNative
(
ctx
context
.
Context
,
c
*
gin
.
Context
,
account
*
Account
,
originalModel
string
,
action
string
,
stream
bool
,
body
[]
byte
)
(
*
ForwardResult
,
error
)
{
startTime
:=
time
.
Now
()
...
...
@@ -650,12 +737,16 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return
nil
,
""
,
errors
.
New
(
"gemini api_key not configured"
)
}
baseURL
:=
strings
.
Trim
Right
(
account
.
GetCredential
(
"base_url"
)
,
"/"
)
baseURL
:=
strings
.
Trim
Space
(
account
.
GetCredential
(
"base_url"
))
if
baseURL
==
""
{
baseURL
=
geminicli
.
AIStudioBaseURL
}
normalizedBaseURL
,
err
:=
s
.
validateUpstreamBaseURL
(
baseURL
)
if
err
!=
nil
{
return
nil
,
""
,
err
}
fullURL
:=
fmt
.
Sprintf
(
"%s/v1beta/models/%s:%s"
,
strings
.
TrimRight
(
b
aseURL
,
"/"
),
mappedModel
,
upstreamAction
)
fullURL
:=
fmt
.
Sprintf
(
"%s/v1beta/models/%s:%s"
,
strings
.
TrimRight
(
normalizedB
aseURL
,
"/"
),
mappedModel
,
upstreamAction
)
if
useUpstreamStream
{
fullURL
+=
"?alt=sse"
}
...
...
@@ -687,7 +778,11 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
if
projectID
!=
""
&&
!
forceAIStudio
{
// Mode 1: Code Assist API
fullURL
:=
fmt
.
Sprintf
(
"%s/v1internal:%s"
,
geminicli
.
GeminiCliBaseURL
,
upstreamAction
)
baseURL
,
err
:=
s
.
validateUpstreamBaseURL
(
geminicli
.
GeminiCliBaseURL
)
if
err
!=
nil
{
return
nil
,
""
,
err
}
fullURL
:=
fmt
.
Sprintf
(
"%s/v1internal:%s"
,
strings
.
TrimRight
(
baseURL
,
"/"
),
upstreamAction
)
if
useUpstreamStream
{
fullURL
+=
"?alt=sse"
}
...
...
@@ -713,12 +808,16 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return
upstreamReq
,
"x-request-id"
,
nil
}
else
{
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
baseURL
:=
strings
.
Trim
Right
(
account
.
GetCredential
(
"base_url"
)
,
"/"
)
baseURL
:=
strings
.
Trim
Space
(
account
.
GetCredential
(
"base_url"
))
if
baseURL
==
""
{
baseURL
=
geminicli
.
AIStudioBaseURL
}
normalizedBaseURL
,
err
:=
s
.
validateUpstreamBaseURL
(
baseURL
)
if
err
!=
nil
{
return
nil
,
""
,
err
}
fullURL
:=
fmt
.
Sprintf
(
"%s/v1beta/models/%s:%s"
,
baseURL
,
mappedModel
,
upstreamAction
)
fullURL
:=
fmt
.
Sprintf
(
"%s/v1beta/models/%s:%s"
,
strings
.
TrimRight
(
normalizedBaseURL
,
"/"
)
,
mappedModel
,
upstreamAction
)
if
useUpstreamStream
{
fullURL
+=
"?alt=sse"
}
...
...
@@ -1652,6 +1751,8 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
_
=
json
.
Unmarshal
(
respBody
,
&
parsed
)
}
responseheaders
.
WriteFilteredHeaders
(
c
.
Writer
.
Header
(),
resp
.
Header
,
s
.
cfg
.
Security
.
ResponseHeaders
)
contentType
:=
resp
.
Header
.
Get
(
"Content-Type"
)
if
contentType
==
""
{
contentType
=
"application/json"
...
...
@@ -1676,6 +1777,10 @@ func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Conte
}
log
.
Printf
(
"[GeminiAPI] ===================================================="
)
if
s
.
cfg
!=
nil
{
responseheaders
.
WriteFilteredHeaders
(
c
.
Writer
.
Header
(),
resp
.
Header
,
s
.
cfg
.
Security
.
ResponseHeaders
)
}
c
.
Status
(
resp
.
StatusCode
)
c
.
Header
(
"Cache-Control"
,
"no-cache"
)
c
.
Header
(
"Connection"
,
"keep-alive"
)
...
...
@@ -1773,11 +1878,15 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
return
nil
,
errors
.
New
(
"invalid path"
)
}
baseURL
:=
strings
.
Trim
Right
(
account
.
GetCredential
(
"base_url"
)
,
"/"
)
baseURL
:=
strings
.
Trim
Space
(
account
.
GetCredential
(
"base_url"
))
if
baseURL
==
""
{
baseURL
=
geminicli
.
AIStudioBaseURL
}
fullURL
:=
strings
.
TrimRight
(
baseURL
,
"/"
)
+
path
normalizedBaseURL
,
err
:=
s
.
validateUpstreamBaseURL
(
baseURL
)
if
err
!=
nil
{
return
nil
,
err
}
fullURL
:=
strings
.
TrimRight
(
normalizedBaseURL
,
"/"
)
+
path
var
proxyURL
string
if
account
.
ProxyID
!=
nil
&&
account
.
Proxy
!=
nil
{
...
...
@@ -1816,9 +1925,14 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
body
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
8
<<
20
))
wwwAuthenticate
:=
resp
.
Header
.
Get
(
"Www-Authenticate"
)
filteredHeaders
:=
responseheaders
.
FilterHeaders
(
resp
.
Header
,
s
.
cfg
.
Security
.
ResponseHeaders
)
if
wwwAuthenticate
!=
""
{
filteredHeaders
.
Set
(
"Www-Authenticate"
,
wwwAuthenticate
)
}
return
&
UpstreamHTTPResult
{
StatusCode
:
resp
.
StatusCode
,
Headers
:
resp
.
Header
.
Clone
()
,
Headers
:
filteredHeaders
,
Body
:
body
,
},
nil
}
...
...
backend/internal/service/gemini_oauth_service.go
View file @
195e227c
...
...
@@ -1002,6 +1002,7 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR
client
,
err
:=
httpclient
.
GetClient
(
httpclient
.
Options
{
ProxyURL
:
strings
.
TrimSpace
(
proxyURL
),
Timeout
:
30
*
time
.
Second
,
ValidateResolvedIP
:
true
,
})
if
err
!=
nil
{
client
=
&
http
.
Client
{
Timeout
:
30
*
time
.
Second
}
...
...
backend/internal/service/openai_gateway_service.go
View file @
195e227c
...
...
@@ -16,9 +16,12 @@ import (
"sort"
"strconv"
"strings"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
)
...
...
@@ -630,10 +633,14 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
case
AccountTypeAPIKey
:
// API Key accounts use Platform API or custom base URL
baseURL
:=
account
.
GetOpenAIBaseURL
()
if
baseURL
!=
""
{
targetURL
=
baseURL
+
"/responses"
}
else
{
if
baseURL
==
""
{
targetURL
=
openaiPlatformAPIURL
}
else
{
validatedURL
,
err
:=
s
.
validateUpstreamBaseURL
(
baseURL
)
if
err
!=
nil
{
return
nil
,
err
}
targetURL
=
validatedURL
+
"/responses"
}
default
:
targetURL
=
openaiPlatformAPIURL
...
...
@@ -755,6 +762,10 @@ type openaiStreamingResult struct {
}
func
(
s
*
OpenAIGatewayService
)
handleStreamingResponse
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
Account
,
startTime
time
.
Time
,
originalModel
,
mappedModel
string
)
(
*
openaiStreamingResult
,
error
)
{
if
s
.
cfg
!=
nil
{
responseheaders
.
WriteFilteredHeaders
(
c
.
Writer
.
Header
(),
resp
.
Header
,
s
.
cfg
.
Security
.
ResponseHeaders
)
}
// Set SSE response headers
c
.
Header
(
"Content-Type"
,
"text/event-stream"
)
c
.
Header
(
"Cache-Control"
,
"no-cache"
)
...
...
@@ -775,12 +786,106 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
usage
:=
&
OpenAIUsage
{}
var
firstTokenMs
*
int
scanner
:=
bufio
.
NewScanner
(
resp
.
Body
)
scanner
.
Buffer
(
make
([]
byte
,
64
*
1024
),
1024
*
1024
)
maxLineSize
:=
defaultMaxLineSize
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
MaxLineSize
>
0
{
maxLineSize
=
s
.
cfg
.
Gateway
.
MaxLineSize
}
scanner
.
Buffer
(
make
([]
byte
,
64
*
1024
),
maxLineSize
)
type
scanEvent
struct
{
line
string
err
error
}
// 独立 goroutine 读取上游,避免读取阻塞影响 keepalive/超时处理
events
:=
make
(
chan
scanEvent
,
16
)
done
:=
make
(
chan
struct
{})
sendEvent
:=
func
(
ev
scanEvent
)
bool
{
select
{
case
events
<-
ev
:
return
true
case
<-
done
:
return
false
}
}
var
lastReadAt
int64
atomic
.
StoreInt64
(
&
lastReadAt
,
time
.
Now
()
.
UnixNano
())
go
func
()
{
defer
close
(
events
)
for
scanner
.
Scan
()
{
atomic
.
StoreInt64
(
&
lastReadAt
,
time
.
Now
()
.
UnixNano
())
if
!
sendEvent
(
scanEvent
{
line
:
scanner
.
Text
()})
{
return
}
}
if
err
:=
scanner
.
Err
();
err
!=
nil
{
_
=
sendEvent
(
scanEvent
{
err
:
err
})
}
}()
defer
close
(
done
)
streamInterval
:=
time
.
Duration
(
0
)
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
StreamDataIntervalTimeout
>
0
{
streamInterval
=
time
.
Duration
(
s
.
cfg
.
Gateway
.
StreamDataIntervalTimeout
)
*
time
.
Second
}
// 仅监控上游数据间隔超时,不被下游写入阻塞影响
var
intervalTicker
*
time
.
Ticker
if
streamInterval
>
0
{
intervalTicker
=
time
.
NewTicker
(
streamInterval
)
defer
intervalTicker
.
Stop
()
}
var
intervalCh
<-
chan
time
.
Time
if
intervalTicker
!=
nil
{
intervalCh
=
intervalTicker
.
C
}
keepaliveInterval
:=
time
.
Duration
(
0
)
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
StreamKeepaliveInterval
>
0
{
keepaliveInterval
=
time
.
Duration
(
s
.
cfg
.
Gateway
.
StreamKeepaliveInterval
)
*
time
.
Second
}
// 下游 keepalive 仅用于防止代理空闲断开
var
keepaliveTicker
*
time
.
Ticker
if
keepaliveInterval
>
0
{
keepaliveTicker
=
time
.
NewTicker
(
keepaliveInterval
)
defer
keepaliveTicker
.
Stop
()
}
var
keepaliveCh
<-
chan
time
.
Time
if
keepaliveTicker
!=
nil
{
keepaliveCh
=
keepaliveTicker
.
C
}
// 记录上次收到上游数据的时间,用于控制 keepalive 发送频率
lastDataAt
:=
time
.
Now
()
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
errorEventSent
:=
false
sendErrorEvent
:=
func
(
reason
string
)
{
if
errorEventSent
{
return
}
errorEventSent
=
true
_
,
_
=
fmt
.
Fprintf
(
w
,
"event: error
\n
data: {
\"
error
\"
:
\"
%s
\"
}
\n\n
"
,
reason
)
flusher
.
Flush
()
}
needModelReplace
:=
originalModel
!=
mappedModel
for
scanner
.
Scan
()
{
line
:=
scanner
.
Text
()
for
{
select
{
case
ev
,
ok
:=
<-
events
:
if
!
ok
{
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
}
if
ev
.
err
!=
nil
{
if
errors
.
Is
(
ev
.
err
,
bufio
.
ErrTooLong
)
{
log
.
Printf
(
"SSE line too long: account=%d max_size=%d error=%v"
,
account
.
ID
,
maxLineSize
,
ev
.
err
)
sendErrorEvent
(
"response_too_large"
)
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
ev
.
err
}
sendErrorEvent
(
"stream_read_error"
)
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream read error: %w"
,
ev
.
err
)
}
line
:=
ev
.
line
lastDataAt
=
time
.
Now
()
// Extract data from SSE line (supports both "data: " and "data:" formats)
if
openaiSSEDataRe
.
MatchString
(
line
)
{
...
...
@@ -793,6 +898,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
// Forward line
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
sendErrorEvent
(
"write_failed"
)
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
}
flusher
.
Flush
()
...
...
@@ -806,17 +912,32 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
}
else
{
// Forward non-data lines as-is
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
sendErrorEvent
(
"write_failed"
)
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
}
flusher
.
Flush
()
}
case
<-
intervalCh
:
lastRead
:=
time
.
Unix
(
0
,
atomic
.
LoadInt64
(
&
lastReadAt
))
if
time
.
Since
(
lastRead
)
<
streamInterval
{
continue
}
log
.
Printf
(
"Stream data interval timeout: account=%d model=%s interval=%s"
,
account
.
ID
,
originalModel
,
streamInterval
)
sendErrorEvent
(
"stream_timeout"
)
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream data interval timeout"
)
if
err
:=
scanner
.
Err
();
err
!=
nil
{
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream read error: %w"
,
err
)
case
<-
keepaliveCh
:
if
time
.
Since
(
lastDataAt
)
<
keepaliveInterval
{
continue
}
if
_
,
err
:=
fmt
.
Fprint
(
w
,
":
\n\n
"
);
err
!=
nil
{
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
}
flusher
.
Flush
()
}
}
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
}
func
(
s
*
OpenAIGatewayService
)
replaceModelInSSELine
(
line
,
fromModel
,
toModel
string
)
string
{
...
...
@@ -911,18 +1032,39 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
body
=
s
.
replaceModelInResponseBody
(
body
,
mappedModel
,
originalModel
)
}
// Pass through headers
for
key
,
values
:=
range
resp
.
Header
{
for
_
,
value
:=
range
values
{
c
.
Header
(
key
,
value
)
responseheaders
.
WriteFilteredHeaders
(
c
.
Writer
.
Header
(),
resp
.
Header
,
s
.
cfg
.
Security
.
ResponseHeaders
)
contentType
:=
"application/json"
if
s
.
cfg
!=
nil
&&
!
s
.
cfg
.
Security
.
ResponseHeaders
.
Enabled
{
if
upstreamType
:=
resp
.
Header
.
Get
(
"Content-Type"
);
upstreamType
!=
""
{
contentType
=
upstreamType
}
}
c
.
Data
(
resp
.
StatusCode
,
"application/json"
,
body
)
c
.
Data
(
resp
.
StatusCode
,
contentType
,
body
)
return
usage
,
nil
}
func
(
s
*
OpenAIGatewayService
)
validateUpstreamBaseURL
(
raw
string
)
(
string
,
error
)
{
if
s
.
cfg
!=
nil
&&
!
s
.
cfg
.
Security
.
URLAllowlist
.
Enabled
{
normalized
,
err
:=
urlvalidator
.
ValidateURLFormat
(
raw
,
s
.
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
)
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"invalid base_url: %w"
,
err
)
}
return
normalized
,
nil
}
normalized
,
err
:=
urlvalidator
.
ValidateHTTPSURL
(
raw
,
urlvalidator
.
ValidationOptions
{
AllowedHosts
:
s
.
cfg
.
Security
.
URLAllowlist
.
UpstreamHosts
,
RequireAllowlist
:
true
,
AllowPrivate
:
s
.
cfg
.
Security
.
URLAllowlist
.
AllowPrivateHosts
,
})
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"invalid base_url: %w"
,
err
)
}
return
normalized
,
nil
}
func
(
s
*
OpenAIGatewayService
)
replaceModelInResponseBody
(
body
[]
byte
,
fromModel
,
toModel
string
)
[]
byte
{
var
resp
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
resp
);
err
!=
nil
{
...
...
backend/internal/service/openai_gateway_service_test.go
0 → 100644
View file @
195e227c
package
service
import
(
"bufio"
"bytes"
"errors"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
)
func
TestOpenAIStreamingTimeout
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
cfg
:=
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
StreamDataIntervalTimeout
:
1
,
StreamKeepaliveInterval
:
0
,
MaxLineSize
:
defaultMaxLineSize
,
},
}
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
}
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
pr
,
pw
:=
io
.
Pipe
()
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
pr
,
Header
:
http
.
Header
{},
}
start
:=
time
.
Now
()
_
,
err
:=
svc
.
handleStreamingResponse
(
c
.
Request
.
Context
(),
resp
,
c
,
&
Account
{
ID
:
1
},
start
,
"model"
,
"model"
)
_
=
pw
.
Close
()
_
=
pr
.
Close
()
if
err
==
nil
||
!
strings
.
Contains
(
err
.
Error
(),
"stream data interval timeout"
)
{
t
.
Fatalf
(
"expected stream timeout error, got %v"
,
err
)
}
if
!
strings
.
Contains
(
rec
.
Body
.
String
(),
"stream_timeout"
)
{
t
.
Fatalf
(
"expected stream_timeout SSE error, got %q"
,
rec
.
Body
.
String
())
}
}
func
TestOpenAIStreamingTooLong
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
cfg
:=
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
StreamDataIntervalTimeout
:
0
,
StreamKeepaliveInterval
:
0
,
MaxLineSize
:
64
*
1024
,
},
}
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
}
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
pr
,
pw
:=
io
.
Pipe
()
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
pr
,
Header
:
http
.
Header
{},
}
go
func
()
{
defer
func
()
{
_
=
pw
.
Close
()
}()
// 写入超过 MaxLineSize 的单行数据,触发 ErrTooLong
payload
:=
"data: "
+
strings
.
Repeat
(
"a"
,
128
*
1024
)
+
"
\n
"
_
,
_
=
pw
.
Write
([]
byte
(
payload
))
}()
_
,
err
:=
svc
.
handleStreamingResponse
(
c
.
Request
.
Context
(),
resp
,
c
,
&
Account
{
ID
:
2
},
time
.
Now
(),
"model"
,
"model"
)
_
=
pr
.
Close
()
if
!
errors
.
Is
(
err
,
bufio
.
ErrTooLong
)
{
t
.
Fatalf
(
"expected ErrTooLong, got %v"
,
err
)
}
if
!
strings
.
Contains
(
rec
.
Body
.
String
(),
"response_too_large"
)
{
t
.
Fatalf
(
"expected response_too_large SSE error, got %q"
,
rec
.
Body
.
String
())
}
}
func
TestOpenAINonStreamingContentTypePassThrough
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
cfg
:=
&
config
.
Config
{
Security
:
config
.
SecurityConfig
{
ResponseHeaders
:
config
.
ResponseHeaderConfig
{
Enabled
:
false
},
},
}
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
}
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
body
:=
[]
byte
(
`{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`
)
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
io
.
NopCloser
(
bytes
.
NewReader
(
body
)),
Header
:
http
.
Header
{
"Content-Type"
:
[]
string
{
"application/vnd.test+json"
}},
}
_
,
err
:=
svc
.
handleNonStreamingResponse
(
c
.
Request
.
Context
(),
resp
,
c
,
&
Account
{},
"model"
,
"model"
)
if
err
!=
nil
{
t
.
Fatalf
(
"handleNonStreamingResponse error: %v"
,
err
)
}
if
!
strings
.
Contains
(
rec
.
Header
()
.
Get
(
"Content-Type"
),
"application/vnd.test+json"
)
{
t
.
Fatalf
(
"expected Content-Type passthrough, got %q"
,
rec
.
Header
()
.
Get
(
"Content-Type"
))
}
}
func
TestOpenAINonStreamingContentTypeDefault
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
cfg
:=
&
config
.
Config
{
Security
:
config
.
SecurityConfig
{
ResponseHeaders
:
config
.
ResponseHeaderConfig
{
Enabled
:
false
},
},
}
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
}
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
body
:=
[]
byte
(
`{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`
)
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
io
.
NopCloser
(
bytes
.
NewReader
(
body
)),
Header
:
http
.
Header
{},
}
_
,
err
:=
svc
.
handleNonStreamingResponse
(
c
.
Request
.
Context
(),
resp
,
c
,
&
Account
{},
"model"
,
"model"
)
if
err
!=
nil
{
t
.
Fatalf
(
"handleNonStreamingResponse error: %v"
,
err
)
}
if
!
strings
.
Contains
(
rec
.
Header
()
.
Get
(
"Content-Type"
),
"application/json"
)
{
t
.
Fatalf
(
"expected default Content-Type, got %q"
,
rec
.
Header
()
.
Get
(
"Content-Type"
))
}
}
func
TestOpenAIStreamingHeadersOverride
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
cfg
:=
&
config
.
Config
{
Security
:
config
.
SecurityConfig
{
ResponseHeaders
:
config
.
ResponseHeaderConfig
{
Enabled
:
false
},
},
Gateway
:
config
.
GatewayConfig
{
StreamDataIntervalTimeout
:
0
,
StreamKeepaliveInterval
:
0
,
MaxLineSize
:
defaultMaxLineSize
,
},
}
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
}
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
pr
,
pw
:=
io
.
Pipe
()
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
pr
,
Header
:
http
.
Header
{
"Cache-Control"
:
[]
string
{
"upstream"
},
"X-Request-Id"
:
[]
string
{
"req-123"
},
"Content-Type"
:
[]
string
{
"application/custom"
},
},
}
go
func
()
{
defer
func
()
{
_
=
pw
.
Close
()
}()
_
,
_
=
pw
.
Write
([]
byte
(
"data: {}
\n\n
"
))
}()
_
,
err
:=
svc
.
handleStreamingResponse
(
c
.
Request
.
Context
(),
resp
,
c
,
&
Account
{
ID
:
1
},
time
.
Now
(),
"model"
,
"model"
)
_
=
pr
.
Close
()
if
err
!=
nil
{
t
.
Fatalf
(
"handleStreamingResponse error: %v"
,
err
)
}
if
rec
.
Header
()
.
Get
(
"Cache-Control"
)
!=
"no-cache"
{
t
.
Fatalf
(
"expected Cache-Control override, got %q"
,
rec
.
Header
()
.
Get
(
"Cache-Control"
))
}
if
rec
.
Header
()
.
Get
(
"Content-Type"
)
!=
"text/event-stream"
{
t
.
Fatalf
(
"expected Content-Type override, got %q"
,
rec
.
Header
()
.
Get
(
"Content-Type"
))
}
if
rec
.
Header
()
.
Get
(
"X-Request-Id"
)
!=
"req-123"
{
t
.
Fatalf
(
"expected X-Request-Id passthrough, got %q"
,
rec
.
Header
()
.
Get
(
"X-Request-Id"
))
}
}
func
TestOpenAIInvalidBaseURLWhenAllowlistDisabled
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
cfg
:=
&
config
.
Config
{
Security
:
config
.
SecurityConfig
{
URLAllowlist
:
config
.
URLAllowlistConfig
{
Enabled
:
false
},
},
}
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
}
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
account
:=
&
Account
{
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
Credentials
:
map
[
string
]
any
{
"base_url"
:
"://invalid-url"
},
}
_
,
err
:=
svc
.
buildUpstreamRequest
(
c
.
Request
.
Context
(),
c
,
account
,
[]
byte
(
"{}"
),
"token"
,
false
)
if
err
==
nil
{
t
.
Fatalf
(
"expected error for invalid base_url when allowlist disabled"
)
}
}
func
TestOpenAIValidateUpstreamBaseURLDisabledRequiresHTTPS
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
Security
:
config
.
SecurityConfig
{
URLAllowlist
:
config
.
URLAllowlistConfig
{
Enabled
:
false
},
},
}
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
}
if
_
,
err
:=
svc
.
validateUpstreamBaseURL
(
"http://not-https.example.com"
);
err
==
nil
{
t
.
Fatalf
(
"expected http to be rejected when allow_insecure_http is false"
)
}
normalized
,
err
:=
svc
.
validateUpstreamBaseURL
(
"https://example.com"
)
if
err
!=
nil
{
t
.
Fatalf
(
"expected https to be allowed when allowlist disabled, got %v"
,
err
)
}
if
normalized
!=
"https://example.com"
{
t
.
Fatalf
(
"expected raw url passthrough, got %q"
,
normalized
)
}
}
func
TestOpenAIValidateUpstreamBaseURLDisabledAllowsHTTP
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
Security
:
config
.
SecurityConfig
{
URLAllowlist
:
config
.
URLAllowlistConfig
{
Enabled
:
false
,
AllowInsecureHTTP
:
true
,
},
},
}
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
}
normalized
,
err
:=
svc
.
validateUpstreamBaseURL
(
"http://not-https.example.com"
)
if
err
!=
nil
{
t
.
Fatalf
(
"expected http allowed when allow_insecure_http is true, got %v"
,
err
)
}
if
normalized
!=
"http://not-https.example.com"
{
t
.
Fatalf
(
"expected raw url passthrough, got %q"
,
normalized
)
}
}
func
TestOpenAIValidateUpstreamBaseURLEnabledEnforcesAllowlist
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
Security
:
config
.
SecurityConfig
{
URLAllowlist
:
config
.
URLAllowlistConfig
{
Enabled
:
true
,
UpstreamHosts
:
[]
string
{
"example.com"
},
},
},
}
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
}
if
_
,
err
:=
svc
.
validateUpstreamBaseURL
(
"https://example.com"
);
err
!=
nil
{
t
.
Fatalf
(
"expected allowlisted host to pass, got %v"
,
err
)
}
if
_
,
err
:=
svc
.
validateUpstreamBaseURL
(
"https://evil.com"
);
err
==
nil
{
t
.
Fatalf
(
"expected non-allowlisted host to fail"
)
}
}
backend/internal/service/pricing_service.go
View file @
195e227c
...
...
@@ -16,6 +16,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
)
var
(
...
...
@@ -213,16 +214,35 @@ func (s *PricingService) syncWithRemote() error {
// downloadPricingData 从远程下载价格数据
func
(
s
*
PricingService
)
downloadPricingData
()
error
{
log
.
Printf
(
"[Pricing] Downloading from %s"
,
s
.
cfg
.
Pricing
.
RemoteURL
)
remoteURL
,
err
:=
s
.
validatePricingURL
(
s
.
cfg
.
Pricing
.
RemoteURL
)
if
err
!=
nil
{
return
err
}
log
.
Printf
(
"[Pricing] Downloading from %s"
,
remoteURL
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
30
*
time
.
Second
)
defer
cancel
()
body
,
err
:=
s
.
remoteClient
.
FetchPricingJSON
(
ctx
,
s
.
cfg
.
Pricing
.
RemoteURL
)
var
expectedHash
string
if
strings
.
TrimSpace
(
s
.
cfg
.
Pricing
.
HashURL
)
!=
""
{
expectedHash
,
err
=
s
.
fetchRemoteHash
()
if
err
!=
nil
{
return
fmt
.
Errorf
(
"fetch remote hash: %w"
,
err
)
}
}
body
,
err
:=
s
.
remoteClient
.
FetchPricingJSON
(
ctx
,
remoteURL
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"download failed: %w"
,
err
)
}
if
expectedHash
!=
""
{
actualHash
:=
sha256
.
Sum256
(
body
)
if
!
strings
.
EqualFold
(
expectedHash
,
hex
.
EncodeToString
(
actualHash
[
:
]))
{
return
fmt
.
Errorf
(
"pricing hash mismatch"
)
}
}
// 解析JSON数据(使用灵活的解析方式)
data
,
err
:=
s
.
parsePricingData
(
body
)
if
err
!=
nil
{
...
...
@@ -378,10 +398,38 @@ func (s *PricingService) useFallbackPricing() error {
// fetchRemoteHash 从远程获取哈希值
func
(
s
*
PricingService
)
fetchRemoteHash
()
(
string
,
error
)
{
hashURL
,
err
:=
s
.
validatePricingURL
(
s
.
cfg
.
Pricing
.
HashURL
)
if
err
!=
nil
{
return
""
,
err
}
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
10
*
time
.
Second
)
defer
cancel
()
return
s
.
remoteClient
.
FetchHashText
(
ctx
,
s
.
cfg
.
Pricing
.
HashURL
)
hash
,
err
:=
s
.
remoteClient
.
FetchHashText
(
ctx
,
hashURL
)
if
err
!=
nil
{
return
""
,
err
}
return
strings
.
TrimSpace
(
hash
),
nil
}
func
(
s
*
PricingService
)
validatePricingURL
(
raw
string
)
(
string
,
error
)
{
if
s
.
cfg
!=
nil
&&
!
s
.
cfg
.
Security
.
URLAllowlist
.
Enabled
{
normalized
,
err
:=
urlvalidator
.
ValidateURLFormat
(
raw
,
s
.
cfg
.
Security
.
URLAllowlist
.
AllowInsecureHTTP
)
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"invalid pricing url: %w"
,
err
)
}
return
normalized
,
nil
}
normalized
,
err
:=
urlvalidator
.
ValidateHTTPSURL
(
raw
,
urlvalidator
.
ValidationOptions
{
AllowedHosts
:
s
.
cfg
.
Security
.
URLAllowlist
.
PricingHosts
,
RequireAllowlist
:
true
,
AllowPrivate
:
s
.
cfg
.
Security
.
URLAllowlist
.
AllowPrivateHosts
,
})
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"invalid pricing url: %w"
,
err
)
}
return
normalized
,
nil
}
// computeFileHash 计算文件哈希
...
...
backend/internal/service/setting_service.go
View file @
195e227c
...
...
@@ -130,6 +130,10 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates
[
SettingKeyFallbackModelGemini
]
=
settings
.
FallbackModelGemini
updates
[
SettingKeyFallbackModelAntigravity
]
=
settings
.
FallbackModelAntigravity
// Identity patch configuration (Claude -> Gemini)
updates
[
SettingKeyEnableIdentityPatch
]
=
strconv
.
FormatBool
(
settings
.
EnableIdentityPatch
)
updates
[
SettingKeyIdentityPatchPrompt
]
=
settings
.
IdentityPatchPrompt
return
s
.
settingRepo
.
SetMultiple
(
ctx
,
updates
)
}
...
...
@@ -213,6 +217,9 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyFallbackModelOpenAI
:
"gpt-4o"
,
SettingKeyFallbackModelGemini
:
"gemini-2.5-pro"
,
SettingKeyFallbackModelAntigravity
:
"gemini-2.5-pro"
,
// Identity patch defaults
SettingKeyEnableIdentityPatch
:
"true"
,
SettingKeyIdentityPatchPrompt
:
""
,
}
return
s
.
settingRepo
.
SetMultiple
(
ctx
,
defaults
)
...
...
@@ -228,8 +235,10 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
SMTPFrom
:
settings
[
SettingKeySMTPFrom
],
SMTPFromName
:
settings
[
SettingKeySMTPFromName
],
SMTPUseTLS
:
settings
[
SettingKeySMTPUseTLS
]
==
"true"
,
SMTPPasswordConfigured
:
settings
[
SettingKeySMTPPassword
]
!=
""
,
TurnstileEnabled
:
settings
[
SettingKeyTurnstileEnabled
]
==
"true"
,
TurnstileSiteKey
:
settings
[
SettingKeyTurnstileSiteKey
],
TurnstileSecretKeyConfigured
:
settings
[
SettingKeyTurnstileSecretKey
]
!=
""
,
SiteName
:
s
.
getStringOrDefault
(
settings
,
SettingKeySiteName
,
"Sub2API"
),
SiteLogo
:
settings
[
SettingKeySiteLogo
],
SiteSubtitle
:
s
.
getStringOrDefault
(
settings
,
SettingKeySiteSubtitle
,
"Subscription to API Conversion Platform"
),
...
...
@@ -269,6 +278,14 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result
.
FallbackModelGemini
=
s
.
getStringOrDefault
(
settings
,
SettingKeyFallbackModelGemini
,
"gemini-2.5-pro"
)
result
.
FallbackModelAntigravity
=
s
.
getStringOrDefault
(
settings
,
SettingKeyFallbackModelAntigravity
,
"gemini-2.5-pro"
)
// Identity patch settings (default: enabled, to preserve existing behavior)
if
v
,
ok
:=
settings
[
SettingKeyEnableIdentityPatch
];
ok
&&
v
!=
""
{
result
.
EnableIdentityPatch
=
v
==
"true"
}
else
{
result
.
EnableIdentityPatch
=
true
}
result
.
IdentityPatchPrompt
=
settings
[
SettingKeyIdentityPatchPrompt
]
return
result
}
...
...
@@ -298,6 +315,25 @@ func (s *SettingService) GetTurnstileSecretKey(ctx context.Context) string {
return
value
}
// IsIdentityPatchEnabled 检查是否启用身份补丁(Claude -> Gemini systemInstruction 注入)
func
(
s
*
SettingService
)
IsIdentityPatchEnabled
(
ctx
context
.
Context
)
bool
{
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyEnableIdentityPatch
)
if
err
!=
nil
{
// 默认开启,保持兼容
return
true
}
return
value
==
"true"
}
// GetIdentityPatchPrompt 获取自定义身份补丁提示词(为空表示使用内置默认模板)
func
(
s
*
SettingService
)
GetIdentityPatchPrompt
(
ctx
context
.
Context
)
string
{
value
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyIdentityPatchPrompt
)
if
err
!=
nil
{
return
""
}
return
value
}
// GenerateAdminAPIKey 生成新的管理员 API Key
func
(
s
*
SettingService
)
GenerateAdminAPIKey
(
ctx
context
.
Context
)
(
string
,
error
)
{
// 生成 32 字节随机数 = 64 位十六进制字符
...
...
backend/internal/service/settings_view.go
View file @
195e227c
...
...
@@ -8,6 +8,7 @@ type SystemSettings struct {
SMTPPort
int
SMTPUsername
string
SMTPPassword
string
SMTPPasswordConfigured
bool
SMTPFrom
string
SMTPFromName
string
SMTPUseTLS
bool
...
...
@@ -15,6 +16,7 @@ type SystemSettings struct {
TurnstileEnabled
bool
TurnstileSiteKey
string
TurnstileSecretKey
string
TurnstileSecretKeyConfigured
bool
SiteName
string
SiteLogo
string
...
...
@@ -32,6 +34,10 @@ type SystemSettings struct {
FallbackModelOpenAI
string
`json:"fallback_model_openai"`
FallbackModelGemini
string
`json:"fallback_model_gemini"`
FallbackModelAntigravity
string
`json:"fallback_model_antigravity"`
// Identity patch configuration (Claude -> Gemini)
EnableIdentityPatch
bool
`json:"enable_identity_patch"`
IdentityPatchPrompt
string
`json:"identity_patch_prompt"`
}
type
PublicSettings
struct
{
...
...
backend/internal/setup/setup.go
View file @
195e227c
...
...
@@ -21,10 +21,44 @@ import (
// Config paths
const
(
ConfigFile
=
"config.yaml"
Env
File
=
".
env
"
ConfigFile
Name
=
"config.yaml"
InstallLock
File
=
".
installed
"
)
// GetDataDir returns the data directory for storing config and lock files.
// Priority: DATA_DIR env > /app/data (if exists and writable) > current directory
func
GetDataDir
()
string
{
// Check DATA_DIR environment variable first
if
dir
:=
os
.
Getenv
(
"DATA_DIR"
);
dir
!=
""
{
return
dir
}
// Check if /app/data exists and is writable (Docker environment)
dockerDataDir
:=
"/app/data"
if
info
,
err
:=
os
.
Stat
(
dockerDataDir
);
err
==
nil
&&
info
.
IsDir
()
{
// Try to check if writable by creating a temp file
testFile
:=
dockerDataDir
+
"/.write_test"
if
f
,
err
:=
os
.
Create
(
testFile
);
err
==
nil
{
_
=
f
.
Close
()
_
=
os
.
Remove
(
testFile
)
return
dockerDataDir
}
}
// Default to current directory
return
"."
}
// GetConfigFilePath returns the full path to config.yaml
func
GetConfigFilePath
()
string
{
return
GetDataDir
()
+
"/"
+
ConfigFileName
}
// GetInstallLockPath returns the full path to .installed lock file
func
GetInstallLockPath
()
string
{
return
GetDataDir
()
+
"/"
+
InstallLockFile
}
// SetupConfig holds the setup configuration
type
SetupConfig
struct
{
Database
DatabaseConfig
`json:"database" yaml:"database"`
...
...
@@ -71,13 +105,12 @@ type JWTConfig struct {
// Uses multiple checks to prevent attackers from forcing re-setup by deleting config
func
NeedsSetup
()
bool
{
// Check 1: Config file must not exist
if
_
,
err
:=
os
.
Stat
(
ConfigFile
);
!
os
.
IsNotExist
(
err
)
{
if
_
,
err
:=
os
.
Stat
(
Get
ConfigFile
Path
()
);
!
os
.
IsNotExist
(
err
)
{
return
false
// Config exists, no setup needed
}
// Check 2: Installation lock file (harder to bypass)
lockFile
:=
".installed"
if
_
,
err
:=
os
.
Stat
(
lockFile
);
!
os
.
IsNotExist
(
err
)
{
if
_
,
err
:=
os
.
Stat
(
GetInstallLockPath
());
!
os
.
IsNotExist
(
err
)
{
return
false
// Lock file exists, already installed
}
...
...
@@ -201,6 +234,7 @@ func Install(cfg *SetupConfig) error {
return
fmt
.
Errorf
(
"failed to generate jwt secret: %w"
,
err
)
}
cfg
.
JWT
.
Secret
=
secret
log
.
Println
(
"Warning: JWT secret auto-generated. Consider setting a fixed secret for production."
)
}
// Test connections
...
...
@@ -237,9 +271,8 @@ func Install(cfg *SetupConfig) error {
// createInstallLock creates a lock file to prevent re-installation attacks
func
createInstallLock
()
error
{
lockFile
:=
".installed"
content
:=
fmt
.
Sprintf
(
"installed_at=%s
\n
"
,
time
.
Now
()
.
UTC
()
.
Format
(
time
.
RFC3339
))
return
os
.
WriteFile
(
lockFile
,
[]
byte
(
content
),
0400
)
// Read-only for owner
return
os
.
WriteFile
(
GetInstallLockPath
()
,
[]
byte
(
content
),
0400
)
// Read-only for owner
}
func
initializeDatabase
(
cfg
*
SetupConfig
)
error
{
...
...
@@ -390,7 +423,7 @@ func writeConfigFile(cfg *SetupConfig) error {
return
err
}
return
os
.
WriteFile
(
ConfigFile
,
data
,
0600
)
return
os
.
WriteFile
(
Get
ConfigFile
Path
()
,
data
,
0600
)
}
func
generateSecret
(
length
int
)
(
string
,
error
)
{
...
...
@@ -433,6 +466,7 @@ func getEnvIntOrDefault(key string, defaultValue int) int {
// This is designed for Docker deployment where all config is passed via env vars
func
AutoSetupFromEnv
()
error
{
log
.
Println
(
"Auto setup enabled, configuring from environment variables..."
)
log
.
Printf
(
"Data directory: %s"
,
GetDataDir
())
// Get timezone from TZ or TIMEZONE env var (TZ is standard for Docker)
tz
:=
getEnvOrDefault
(
"TZ"
,
""
)
...
...
@@ -479,7 +513,7 @@ func AutoSetupFromEnv() error {
return
fmt
.
Errorf
(
"failed to generate jwt secret: %w"
,
err
)
}
cfg
.
JWT
.
Secret
=
secret
log
.
Println
(
"
Generated
JWT secret auto
matically
"
)
log
.
Println
(
"
Warning:
JWT secret auto
-generated. Consider setting a fixed secret for production.
"
)
}
// Generate admin password if not provided
...
...
@@ -489,8 +523,8 @@ func AutoSetupFromEnv() error {
return
fmt
.
Errorf
(
"failed to generate admin password: %w"
,
err
)
}
cfg
.
Admin
.
Password
=
password
log
.
Printf
(
"Generated admin password: %s"
,
cfg
.
Admin
.
Password
)
log
.
Println
(
"IMPORTANT: Save this password! It will not be shown again."
)
fmt
.
Printf
(
"Generated admin password
(one-time)
: %s
\n
"
,
cfg
.
Admin
.
Password
)
fmt
.
Println
(
"IMPORTANT: Save this password! It will not be shown again."
)
}
// Test database connection
...
...
backend/internal/util/logredact/redact.go
0 → 100644
View file @
195e227c
package
logredact
import
(
"encoding/json"
"strings"
)
// maxRedactDepth 限制递归深度以防止栈溢出
const
maxRedactDepth
=
32
var
defaultSensitiveKeys
=
map
[
string
]
struct
{}{
"authorization_code"
:
{},
"code"
:
{},
"code_verifier"
:
{},
"access_token"
:
{},
"refresh_token"
:
{},
"id_token"
:
{},
"client_secret"
:
{},
"password"
:
{},
}
func
RedactMap
(
input
map
[
string
]
any
,
extraKeys
...
string
)
map
[
string
]
any
{
if
input
==
nil
{
return
map
[
string
]
any
{}
}
keys
:=
buildKeySet
(
extraKeys
)
redacted
,
ok
:=
redactValueWithDepth
(
input
,
keys
,
0
)
.
(
map
[
string
]
any
)
if
!
ok
{
return
map
[
string
]
any
{}
}
return
redacted
}
func
RedactJSON
(
raw
[]
byte
,
extraKeys
...
string
)
string
{
if
len
(
raw
)
==
0
{
return
""
}
var
value
any
if
err
:=
json
.
Unmarshal
(
raw
,
&
value
);
err
!=
nil
{
return
"<non-json payload redacted>"
}
keys
:=
buildKeySet
(
extraKeys
)
redacted
:=
redactValueWithDepth
(
value
,
keys
,
0
)
encoded
,
err
:=
json
.
Marshal
(
redacted
)
if
err
!=
nil
{
return
"<redacted>"
}
return
string
(
encoded
)
}
func
buildKeySet
(
extraKeys
[]
string
)
map
[
string
]
struct
{}
{
keys
:=
make
(
map
[
string
]
struct
{},
len
(
defaultSensitiveKeys
)
+
len
(
extraKeys
))
for
k
:=
range
defaultSensitiveKeys
{
keys
[
k
]
=
struct
{}{}
}
for
_
,
key
:=
range
extraKeys
{
normalized
:=
normalizeKey
(
key
)
if
normalized
==
""
{
continue
}
keys
[
normalized
]
=
struct
{}{}
}
return
keys
}
func
redactValueWithDepth
(
value
any
,
keys
map
[
string
]
struct
{},
depth
int
)
any
{
if
depth
>
maxRedactDepth
{
return
"<depth limit exceeded>"
}
switch
v
:=
value
.
(
type
)
{
case
map
[
string
]
any
:
out
:=
make
(
map
[
string
]
any
,
len
(
v
))
for
k
,
val
:=
range
v
{
if
isSensitiveKey
(
k
,
keys
)
{
out
[
k
]
=
"***"
continue
}
out
[
k
]
=
redactValueWithDepth
(
val
,
keys
,
depth
+
1
)
}
return
out
case
[]
any
:
out
:=
make
([]
any
,
len
(
v
))
for
i
,
item
:=
range
v
{
out
[
i
]
=
redactValueWithDepth
(
item
,
keys
,
depth
+
1
)
}
return
out
default
:
return
value
}
}
func
isSensitiveKey
(
key
string
,
keys
map
[
string
]
struct
{})
bool
{
_
,
ok
:=
keys
[
normalizeKey
(
key
)]
return
ok
}
func
normalizeKey
(
key
string
)
string
{
return
strings
.
ToLower
(
strings
.
TrimSpace
(
key
))
}
backend/internal/util/responseheaders/responseheaders.go
0 → 100644
View file @
195e227c
package
responseheaders
import
(
"net/http"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
)
// defaultAllowed 定义允许透传的响应头白名单
// 注意:以下头部由 Go HTTP 包自动处理,不应手动设置:
// - content-length: 由 ResponseWriter 根据实际写入数据自动设置
// - transfer-encoding: 由 HTTP 库根据需要自动添加/移除
// - connection: 由 HTTP 库管理连接复用
var
defaultAllowed
=
map
[
string
]
struct
{}{
"content-type"
:
{},
"content-encoding"
:
{},
"content-language"
:
{},
"cache-control"
:
{},
"etag"
:
{},
"last-modified"
:
{},
"expires"
:
{},
"vary"
:
{},
"date"
:
{},
"x-request-id"
:
{},
"x-ratelimit-limit-requests"
:
{},
"x-ratelimit-limit-tokens"
:
{},
"x-ratelimit-remaining-requests"
:
{},
"x-ratelimit-remaining-tokens"
:
{},
"x-ratelimit-reset-requests"
:
{},
"x-ratelimit-reset-tokens"
:
{},
"retry-after"
:
{},
"location"
:
{},
"www-authenticate"
:
{},
}
// hopByHopHeaders 是跳过的 hop-by-hop 头部,这些头部由 HTTP 库自动处理
var
hopByHopHeaders
=
map
[
string
]
struct
{}{
"content-length"
:
{},
"transfer-encoding"
:
{},
"connection"
:
{},
}
func
FilterHeaders
(
src
http
.
Header
,
cfg
config
.
ResponseHeaderConfig
)
http
.
Header
{
allowed
:=
make
(
map
[
string
]
struct
{},
len
(
defaultAllowed
)
+
len
(
cfg
.
AdditionalAllowed
))
for
key
:=
range
defaultAllowed
{
allowed
[
key
]
=
struct
{}{}
}
// 关闭时只使用默认白名单,additional/force_remove 不生效
if
cfg
.
Enabled
{
for
_
,
key
:=
range
cfg
.
AdditionalAllowed
{
normalized
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
key
))
if
normalized
==
""
{
continue
}
allowed
[
normalized
]
=
struct
{}{}
}
}
forceRemove
:=
map
[
string
]
struct
{}{}
if
cfg
.
Enabled
{
forceRemove
=
make
(
map
[
string
]
struct
{},
len
(
cfg
.
ForceRemove
))
for
_
,
key
:=
range
cfg
.
ForceRemove
{
normalized
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
key
))
if
normalized
==
""
{
continue
}
forceRemove
[
normalized
]
=
struct
{}{}
}
}
filtered
:=
make
(
http
.
Header
,
len
(
src
))
for
key
,
values
:=
range
src
{
lower
:=
strings
.
ToLower
(
key
)
if
_
,
blocked
:=
forceRemove
[
lower
];
blocked
{
continue
}
if
_
,
ok
:=
allowed
[
lower
];
!
ok
{
continue
}
// 跳过 hop-by-hop 头部,这些由 HTTP 库自动处理
if
_
,
isHopByHop
:=
hopByHopHeaders
[
lower
];
isHopByHop
{
continue
}
for
_
,
value
:=
range
values
{
filtered
.
Add
(
key
,
value
)
}
}
return
filtered
}
func
WriteFilteredHeaders
(
dst
http
.
Header
,
src
http
.
Header
,
cfg
config
.
ResponseHeaderConfig
)
{
filtered
:=
FilterHeaders
(
src
,
cfg
)
for
key
,
values
:=
range
filtered
{
for
_
,
value
:=
range
values
{
dst
.
Add
(
key
,
value
)
}
}
}
Prev
1
2
3
4
5
6
7
8
…
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment