Commit 195e227c authored by song's avatar song
Browse files

merge: 合并 upstream/main 并保留本地图片计费功能

parents 6fa704d6 752882a0
...@@ -123,6 +123,7 @@ type UpdateGroupInput struct { ...@@ -123,6 +123,7 @@ type UpdateGroupInput struct {
type CreateAccountInput struct { type CreateAccountInput struct {
Name string Name string
Notes *string
Platform string Platform string
Type string Type string
Credentials map[string]any Credentials map[string]any
...@@ -138,6 +139,7 @@ type CreateAccountInput struct { ...@@ -138,6 +139,7 @@ type CreateAccountInput struct {
type UpdateAccountInput struct { type UpdateAccountInput struct {
Name string Name string
Notes *string
Type string // Account type: oauth, setup-token, apikey Type string // Account type: oauth, setup-token, apikey
Credentials map[string]any Credentials map[string]any
Extra map[string]any Extra map[string]any
...@@ -687,6 +689,7 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou ...@@ -687,6 +689,7 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
account := &Account{ account := &Account{
Name: input.Name, Name: input.Name,
Notes: normalizeAccountNotes(input.Notes),
Platform: input.Platform, Platform: input.Platform,
Type: input.Type, Type: input.Type,
Credentials: input.Credentials, Credentials: input.Credentials,
...@@ -723,6 +726,9 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U ...@@ -723,6 +726,9 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
if input.Type != "" { if input.Type != "" {
account.Type = input.Type account.Type = input.Type
} }
if input.Notes != nil {
account.Notes = normalizeAccountNotes(input.Notes)
}
if len(input.Credentials) > 0 { if len(input.Credentials) > 0 {
account.Credentials = input.Credentials account.Credentials = input.Credentials
} }
...@@ -730,7 +736,12 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U ...@@ -730,7 +736,12 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
account.Extra = input.Extra account.Extra = input.Extra
} }
if input.ProxyID != nil { if input.ProxyID != nil {
account.ProxyID = input.ProxyID // 0 表示清除代理(前端发送 0 而不是 null 来表达清除意图)
if *input.ProxyID == 0 {
account.ProxyID = nil
} else {
account.ProxyID = input.ProxyID
}
account.Proxy = nil // 清除关联对象,防止 GORM Save 时根据 Proxy.ID 覆盖 ProxyID account.Proxy = nil // 清除关联对象,防止 GORM Save 时根据 Proxy.ID 覆盖 ProxyID
} }
// 只在指针非 nil 时更新 Concurrency(支持设置为 0) // 只在指针非 nil 时更新 Concurrency(支持设置为 0)
......
...@@ -9,8 +9,10 @@ import ( ...@@ -9,8 +9,10 @@ import (
"fmt" "fmt"
"io" "io"
"log" "log"
mathrand "math/rand"
"net/http" "net/http"
"strings" "strings"
"sync/atomic"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
...@@ -255,6 +257,16 @@ func (s *AntigravityGatewayService) buildClaudeTestRequest(projectID, mappedMode ...@@ -255,6 +257,16 @@ func (s *AntigravityGatewayService) buildClaudeTestRequest(projectID, mappedMode
return antigravity.TransformClaudeToGemini(claudeReq, projectID, mappedModel) return antigravity.TransformClaudeToGemini(claudeReq, projectID, mappedModel)
} }
func (s *AntigravityGatewayService) getClaudeTransformOptions(ctx context.Context) antigravity.TransformOptions {
opts := antigravity.DefaultTransformOptions()
if s.settingService == nil {
return opts
}
opts.EnableIdentityPatch = s.settingService.IsIdentityPatchEnabled(ctx)
opts.IdentityPatch = s.settingService.GetIdentityPatchPrompt(ctx)
return opts
}
// extractGeminiResponseText 从 Gemini 响应中提取文本 // extractGeminiResponseText 从 Gemini 响应中提取文本
func extractGeminiResponseText(respBody []byte) string { func extractGeminiResponseText(respBody []byte) string {
var resp map[string]any var resp map[string]any
...@@ -380,7 +392,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -380,7 +392,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
} }
// 转换 Claude 请求为 Gemini 格式 // 转换 Claude 请求为 Gemini 格式
geminiBody, err := antigravity.TransformClaudeToGemini(&claudeReq, projectID, mappedModel) geminiBody, err := antigravity.TransformClaudeToGeminiWithOptions(&claudeReq, projectID, mappedModel, s.getClaudeTransformOptions(ctx))
if err != nil { if err != nil {
return nil, fmt.Errorf("transform request: %w", err) return nil, fmt.Errorf("transform request: %w", err)
} }
...@@ -394,6 +406,14 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -394,6 +406,14 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
// 重试循环 // 重试循环
var resp *http.Response var resp *http.Response
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
// 检查 context 是否已取消(客户端断开连接)
select {
case <-ctx.Done():
log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err())
return nil, ctx.Err()
default:
}
upstreamReq, err := antigravity.NewAPIRequest(ctx, action, accessToken, geminiBody) upstreamReq, err := antigravity.NewAPIRequest(ctx, action, accessToken, geminiBody)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -403,7 +423,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -403,7 +423,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
if err != nil { if err != nil {
if attempt < antigravityMaxRetries { if attempt < antigravityMaxRetries {
log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err)
sleepAntigravityBackoff(attempt) if !sleepAntigravityBackoffWithContext(ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", prefix)
return nil, ctx.Err()
}
continue continue
} }
log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err) log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err)
...@@ -416,7 +439,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -416,7 +439,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
if attempt < antigravityMaxRetries { if attempt < antigravityMaxRetries {
log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries) log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries)
sleepAntigravityBackoff(attempt) if !sleepAntigravityBackoffWithContext(ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", prefix)
return nil, ctx.Err()
}
continue continue
} }
// 所有重试都失败,标记限流状态 // 所有重试都失败,标记限流状态
...@@ -443,35 +469,70 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -443,35 +469,70 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
// Antigravity /v1internal 链路在部分场景会对 thought/thinking signature 做严格校验, // Antigravity /v1internal 链路在部分场景会对 thought/thinking signature 做严格校验,
// 当历史消息携带的 signature 不合法时会直接 400;去除 thinking 后可继续完成请求。 // 当历史消息携带的 signature 不合法时会直接 400;去除 thinking 后可继续完成请求。
if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) { if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) {
retryClaudeReq := claudeReq // Conservative two-stage fallback:
retryClaudeReq.Messages = append([]antigravity.ClaudeMessage(nil), claudeReq.Messages...) // 1) Disable top-level thinking + thinking->text
// 2) Only if still signature-related 400: also downgrade tool_use/tool_result to text.
stripped, stripErr := stripThinkingFromClaudeRequest(&retryClaudeReq)
if stripErr == nil && stripped { retryStages := []struct {
log.Printf("Antigravity account %d: detected signature-related 400, retrying once without thinking blocks", account.ID) name string
strip func(*antigravity.ClaudeRequest) (bool, error)
retryGeminiBody, txErr := antigravity.TransformClaudeToGemini(&retryClaudeReq, projectID, mappedModel) }{
if txErr == nil { {name: "thinking-only", strip: stripThinkingFromClaudeRequest},
retryReq, buildErr := antigravity.NewAPIRequest(ctx, action, accessToken, retryGeminiBody) {name: "thinking+tools", strip: stripSignatureSensitiveBlocksFromClaudeRequest},
if buildErr == nil { }
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
if retryErr == nil { for _, stage := range retryStages {
// Retry success: continue normal success flow with the new response. retryClaudeReq := claudeReq
if retryResp.StatusCode < 400 { retryClaudeReq.Messages = append([]antigravity.ClaudeMessage(nil), claudeReq.Messages...)
_ = resp.Body.Close()
resp = retryResp stripped, stripErr := stage.strip(&retryClaudeReq)
respBody = nil if stripErr != nil || !stripped {
} else { continue
// Retry still errored: replace error context with retry response. }
retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
_ = retryResp.Body.Close() log.Printf("Antigravity account %d: detected signature-related 400, retrying once (%s)", account.ID, stage.name)
respBody = retryBody
resp = retryResp retryGeminiBody, txErr := antigravity.TransformClaudeToGeminiWithOptions(&retryClaudeReq, projectID, mappedModel, s.getClaudeTransformOptions(ctx))
} if txErr != nil {
} else { continue
log.Printf("Antigravity account %d: signature retry request failed: %v", account.ID, retryErr) }
} retryReq, buildErr := antigravity.NewAPIRequest(ctx, action, accessToken, retryGeminiBody)
if buildErr != nil {
continue
}
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
if retryErr != nil {
log.Printf("Antigravity account %d: signature retry request failed (%s): %v", account.ID, stage.name, retryErr)
continue
}
if retryResp.StatusCode < 400 {
_ = resp.Body.Close()
resp = retryResp
respBody = nil
break
}
retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
_ = retryResp.Body.Close()
// If this stage fixed the signature issue, we stop; otherwise we may try the next stage.
if retryResp.StatusCode != http.StatusBadRequest || !isSignatureRelatedError(retryBody) {
respBody = retryBody
resp = &http.Response{
StatusCode: retryResp.StatusCode,
Header: retryResp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(retryBody)),
} }
break
}
// Still signature-related; capture context and allow next stage.
respBody = retryBody
resp = &http.Response{
StatusCode: retryResp.StatusCode,
Header: retryResp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(retryBody)),
} }
} }
} }
...@@ -528,7 +589,17 @@ func isSignatureRelatedError(respBody []byte) bool { ...@@ -528,7 +589,17 @@ func isSignatureRelatedError(respBody []byte) bool {
} }
// Keep this intentionally broad: different upstreams may use "signature" or "thought_signature". // Keep this intentionally broad: different upstreams may use "signature" or "thought_signature".
return strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature") if strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature") {
return true
}
// Also detect thinking block structural errors:
// "Expected `thinking` or `redacted_thinking`, but found `text`"
if strings.Contains(msg, "expected") && (strings.Contains(msg, "thinking") || strings.Contains(msg, "redacted_thinking")) {
return true
}
return false
} }
func extractAntigravityErrorMessage(body []byte) string { func extractAntigravityErrorMessage(body []byte) string {
...@@ -555,7 +626,7 @@ func extractAntigravityErrorMessage(body []byte) string { ...@@ -555,7 +626,7 @@ func extractAntigravityErrorMessage(body []byte) string {
// stripThinkingFromClaudeRequest converts thinking blocks to text blocks in a Claude Messages request. // stripThinkingFromClaudeRequest converts thinking blocks to text blocks in a Claude Messages request.
// This preserves the thinking content while avoiding signature validation errors. // This preserves the thinking content while avoiding signature validation errors.
// Note: redacted_thinking blocks are removed because they cannot be converted to text. // Note: redacted_thinking blocks are removed because they cannot be converted to text.
// It also disables top-level `thinking` to prevent dummy-thought injection during retry. // It also disables top-level `thinking` to avoid upstream structural constraints for thinking mode.
func stripThinkingFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error) { func stripThinkingFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error) {
if req == nil { if req == nil {
return false, nil return false, nil
...@@ -585,6 +656,92 @@ func stripThinkingFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error ...@@ -585,6 +656,92 @@ func stripThinkingFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error
continue continue
} }
filtered := make([]map[string]any, 0, len(blocks))
modifiedAny := false
for _, block := range blocks {
t, _ := block["type"].(string)
switch t {
case "thinking":
thinkingText, _ := block["thinking"].(string)
if thinkingText != "" {
filtered = append(filtered, map[string]any{
"type": "text",
"text": thinkingText,
})
}
modifiedAny = true
case "redacted_thinking":
modifiedAny = true
case "":
if thinkingText, hasThinking := block["thinking"].(string); hasThinking {
if thinkingText != "" {
filtered = append(filtered, map[string]any{
"type": "text",
"text": thinkingText,
})
}
modifiedAny = true
} else {
filtered = append(filtered, block)
}
default:
filtered = append(filtered, block)
}
}
if !modifiedAny {
continue
}
if len(filtered) == 0 {
filtered = append(filtered, map[string]any{
"type": "text",
"text": "(content removed)",
})
}
newRaw, err := json.Marshal(filtered)
if err != nil {
return changed, err
}
req.Messages[i].Content = newRaw
changed = true
}
return changed, nil
}
// stripSignatureSensitiveBlocksFromClaudeRequest is a stronger retry degradation that additionally converts
// tool blocks to plain text. Use this only after a thinking-only retry still fails with signature errors.
func stripSignatureSensitiveBlocksFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error) {
if req == nil {
return false, nil
}
changed := false
if req.Thinking != nil {
req.Thinking = nil
changed = true
}
for i := range req.Messages {
raw := req.Messages[i].Content
if len(raw) == 0 {
continue
}
// If content is a string, nothing to strip.
var str string
if json.Unmarshal(raw, &str) == nil {
continue
}
// Otherwise treat as an array of blocks and convert signature-sensitive blocks to text.
var blocks []map[string]any
if err := json.Unmarshal(raw, &blocks); err != nil {
continue
}
filtered := make([]map[string]any, 0, len(blocks)) filtered := make([]map[string]any, 0, len(blocks))
modifiedAny := false modifiedAny := false
for _, block := range blocks { for _, block := range blocks {
...@@ -603,6 +760,49 @@ func stripThinkingFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error ...@@ -603,6 +760,49 @@ func stripThinkingFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error
case "redacted_thinking": case "redacted_thinking":
// Remove redacted_thinking (cannot convert encrypted content) // Remove redacted_thinking (cannot convert encrypted content)
modifiedAny = true modifiedAny = true
case "tool_use":
// Convert tool_use to text to avoid upstream signature/thought_signature validation errors.
// This is a retry-only degradation path, so we prioritise request validity over tool semantics.
name, _ := block["name"].(string)
id, _ := block["id"].(string)
input := block["input"]
inputJSON, _ := json.Marshal(input)
text := "(tool_use)"
if name != "" {
text += " name=" + name
}
if id != "" {
text += " id=" + id
}
if len(inputJSON) > 0 && string(inputJSON) != "null" {
text += " input=" + string(inputJSON)
}
filtered = append(filtered, map[string]any{
"type": "text",
"text": text,
})
modifiedAny = true
case "tool_result":
// Convert tool_result to text so it stays consistent when tool_use is downgraded.
toolUseID, _ := block["tool_use_id"].(string)
isError, _ := block["is_error"].(bool)
content := block["content"]
contentJSON, _ := json.Marshal(content)
text := "(tool_result)"
if toolUseID != "" {
text += " tool_use_id=" + toolUseID
}
if isError {
text += " is_error=true"
}
if len(contentJSON) > 0 && string(contentJSON) != "null" {
text += "\n" + string(contentJSON)
}
filtered = append(filtered, map[string]any{
"type": "text",
"text": text,
})
modifiedAny = true
case "": case "":
// Handle untyped block with "thinking" field // Handle untyped block with "thinking" field
if thinkingText, hasThinking := block["thinking"].(string); hasThinking { if thinkingText, hasThinking := block["thinking"].(string); hasThinking {
...@@ -625,6 +825,14 @@ func stripThinkingFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error ...@@ -625,6 +825,14 @@ func stripThinkingFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error
continue continue
} }
if len(filtered) == 0 {
// Keep request valid: upstream rejects empty content arrays.
filtered = append(filtered, map[string]any{
"type": "text",
"text": "(content removed)",
})
}
newRaw, err := json.Marshal(filtered) newRaw, err := json.Marshal(filtered)
if err != nil { if err != nil {
return changed, err return changed, err
...@@ -711,6 +919,14 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co ...@@ -711,6 +919,14 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
// 重试循环 // 重试循环
var resp *http.Response var resp *http.Response
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
// 检查 context 是否已取消(客户端断开连接)
select {
case <-ctx.Done():
log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err())
return nil, ctx.Err()
default:
}
upstreamReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, wrappedBody) upstreamReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, wrappedBody)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -720,7 +936,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co ...@@ -720,7 +936,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
if err != nil { if err != nil {
if attempt < antigravityMaxRetries { if attempt < antigravityMaxRetries {
log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err)
sleepAntigravityBackoff(attempt) if !sleepAntigravityBackoffWithContext(ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", prefix)
return nil, ctx.Err()
}
continue continue
} }
log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err) log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err)
...@@ -733,7 +952,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co ...@@ -733,7 +952,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
if attempt < antigravityMaxRetries { if attempt < antigravityMaxRetries {
log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries) log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries)
sleepAntigravityBackoff(attempt) if !sleepAntigravityBackoffWithContext(ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", prefix)
return nil, ctx.Err()
}
continue continue
} }
// 所有重试都失败,标记限流状态 // 所有重试都失败,标记限流状态
...@@ -750,11 +972,18 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co ...@@ -750,11 +972,18 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
break break
} }
defer func() { _ = resp.Body.Close() }() defer func() {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
}
}()
// 处理错误响应 // 处理错误响应
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
// 尽早关闭原始响应体,释放连接;后续逻辑仍可能需要读取 body,因此用内存副本重新包装。
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
// 模型兜底:模型不存在且开启 fallback 时,自动用 fallback 模型重试一次 // 模型兜底:模型不存在且开启 fallback 时,自动用 fallback 模型重试一次
if s.settingService != nil && s.settingService.IsModelFallbackEnabled(ctx) && if s.settingService != nil && s.settingService.IsModelFallbackEnabled(ctx) &&
...@@ -763,15 +992,13 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co ...@@ -763,15 +992,13 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
if fallbackModel != "" && fallbackModel != mappedModel { if fallbackModel != "" && fallbackModel != mappedModel {
log.Printf("[Antigravity] Model not found (%s), retrying with fallback model %s (account: %s)", mappedModel, fallbackModel, account.Name) log.Printf("[Antigravity] Model not found (%s), retrying with fallback model %s (account: %s)", mappedModel, fallbackModel, account.Name)
// 关闭原始响应,释放连接(respBody 已读取到内存)
_ = resp.Body.Close()
fallbackWrapped, err := s.wrapV1InternalRequest(projectID, fallbackModel, body) fallbackWrapped, err := s.wrapV1InternalRequest(projectID, fallbackModel, body)
if err == nil { if err == nil {
fallbackReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, fallbackWrapped) fallbackReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, fallbackWrapped)
if err == nil { if err == nil {
fallbackResp, err := s.httpUpstream.Do(fallbackReq, proxyURL, account.ID, account.Concurrency) fallbackResp, err := s.httpUpstream.Do(fallbackReq, proxyURL, account.ID, account.Concurrency)
if err == nil && fallbackResp.StatusCode < 400 { if err == nil && fallbackResp.StatusCode < 400 {
_ = resp.Body.Close()
resp = fallbackResp resp = fallbackResp
} else if fallbackResp != nil { } else if fallbackResp != nil {
_ = fallbackResp.Body.Close() _ = fallbackResp.Body.Close()
...@@ -872,8 +1099,28 @@ func (s *AntigravityGatewayService) shouldFailoverUpstreamError(statusCode int) ...@@ -872,8 +1099,28 @@ func (s *AntigravityGatewayService) shouldFailoverUpstreamError(statusCode int)
} }
} }
func sleepAntigravityBackoff(attempt int) { // sleepAntigravityBackoffWithContext 带 context 取消检查的退避等待
sleepGeminiBackoff(attempt) // 复用 Gemini 的退避逻辑 // 返回 true 表示正常完成等待,false 表示 context 已取消
func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
delay := geminiRetryBaseDelay * time.Duration(1<<uint(attempt-1))
if delay > geminiRetryMaxDelay {
delay = geminiRetryMaxDelay
}
// +/- 20% jitter
r := mathrand.New(mathrand.NewSource(time.Now().UnixNano()))
jitter := time.Duration(float64(delay) * 0.2 * (r.Float64()*2 - 1))
sleepFor := delay + jitter
if sleepFor < 0 {
sleepFor = 0
}
select {
case <-ctx.Done():
return false
case <-time.After(sleepFor):
return true
}
} }
func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte) { func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte) {
...@@ -928,57 +1175,145 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context ...@@ -928,57 +1175,145 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
return nil, errors.New("streaming not supported") return nil, errors.New("streaming not supported")
} }
reader := bufio.NewReader(resp.Body) // 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
usage := &ClaudeUsage{} usage := &ClaudeUsage{}
var firstTokenMs *int var firstTokenMs *int
type scanEvent struct {
line string
err error
}
// 独立 goroutine 读取上游,避免读取阻塞影响超时处理
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
streamInterval := time.Duration(0)
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
// 仅发送一次错误事件,避免多次写入导致协议混乱
errorEventSent := false
sendErrorEvent := func(reason string) {
if errorEventSent {
return
}
errorEventSent = true
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
flusher.Flush()
}
for { for {
line, err := reader.ReadString('\n') select {
if len(line) > 0 { case ev, ok := <-events:
if !ok {
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
if ev.err != nil {
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
sendErrorEvent("response_too_large")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
}
sendErrorEvent("stream_read_error")
return nil, ev.err
}
line := ev.line
trimmed := strings.TrimRight(line, "\r\n") trimmed := strings.TrimRight(line, "\r\n")
if strings.HasPrefix(trimmed, "data:") { if strings.HasPrefix(trimmed, "data:") {
payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:")) payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
if payload == "" || payload == "[DONE]" { if payload == "" || payload == "[DONE]" {
_, _ = io.WriteString(c.Writer, line) if _, err := fmt.Fprintf(c.Writer, "%s\n", line); err != nil {
flusher.Flush() sendErrorEvent("write_failed")
} else { return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err
// 解包 v1internal 响应
inner, parseErr := s.unwrapV1InternalResponse([]byte(payload))
if parseErr == nil && inner != nil {
payload = string(inner)
} }
flusher.Flush()
continue
}
// 解析 usage // 解包 v1internal 响应
var parsed map[string]any inner, parseErr := s.unwrapV1InternalResponse([]byte(payload))
if json.Unmarshal(inner, &parsed) == nil { if parseErr == nil && inner != nil {
if u := extractGeminiUsage(parsed); u != nil { payload = string(inner)
usage = u }
}
}
if firstTokenMs == nil { // 解析 usage
ms := int(time.Since(startTime).Milliseconds()) var parsed map[string]any
firstTokenMs = &ms if json.Unmarshal(inner, &parsed) == nil {
if u := extractGeminiUsage(parsed); u != nil {
usage = u
} }
}
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", payload) if firstTokenMs == nil {
flusher.Flush() ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", payload); err != nil {
sendErrorEvent("write_failed")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err
} }
} else {
_, _ = io.WriteString(c.Writer, line)
flusher.Flush() flusher.Flush()
continue
} }
}
if errors.Is(err, io.EOF) { if _, err := fmt.Fprintf(c.Writer, "%s\n", line); err != nil {
break sendErrorEvent("write_failed")
} return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err
if err != nil { }
return nil, err flusher.Flush()
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
log.Printf("Stream data interval timeout (antigravity)")
sendErrorEvent("stream_timeout")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
} }
} }
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
} }
func (s *AntigravityGatewayService) handleGeminiNonStreamingResponse(c *gin.Context, resp *http.Response) (*ClaudeUsage, error) { func (s *AntigravityGatewayService) handleGeminiNonStreamingResponse(c *gin.Context, resp *http.Response) (*ClaudeUsage, error) {
...@@ -1117,7 +1452,13 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context ...@@ -1117,7 +1452,13 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
processor := antigravity.NewStreamingProcessor(originalModel) processor := antigravity.NewStreamingProcessor(originalModel)
var firstTokenMs *int var firstTokenMs *int
reader := bufio.NewReader(resp.Body) // 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
// 辅助函数:转换 antigravity.ClaudeUsage 到 service.ClaudeUsage // 辅助函数:转换 antigravity.ClaudeUsage 到 service.ClaudeUsage
convertUsage := func(agUsage *antigravity.ClaudeUsage) *ClaudeUsage { convertUsage := func(agUsage *antigravity.ClaudeUsage) *ClaudeUsage {
...@@ -1132,13 +1473,85 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context ...@@ -1132,13 +1473,85 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
} }
} }
for { type scanEvent struct {
line, err := reader.ReadString('\n') line string
if err != nil && !errors.Is(err, io.EOF) { err error
return nil, fmt.Errorf("stream read error: %w", err) }
// 独立 goroutine 读取上游,避免读取阻塞影响超时处理
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
} }
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
streamInterval := time.Duration(0)
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
// 仅发送一次错误事件,避免多次写入导致协议混乱
errorEventSent := false
sendErrorEvent := func(reason string) {
if errorEventSent {
return
}
errorEventSent = true
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
flusher.Flush()
}
for {
select {
case ev, ok := <-events:
if !ok {
// 发送结束事件
finalEvents, agUsage := processor.Finish()
if len(finalEvents) > 0 {
_, _ = c.Writer.Write(finalEvents)
flusher.Flush()
}
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, nil
}
if ev.err != nil {
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
sendErrorEvent("response_too_large")
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, ev.err
}
sendErrorEvent("stream_read_error")
return nil, fmt.Errorf("stream read error: %w", ev.err)
}
if len(line) > 0 { line := ev.line
// 处理 SSE 行,转换为 Claude 格式 // 处理 SSE 行,转换为 Claude 格式
claudeEvents := processor.ProcessLine(strings.TrimRight(line, "\r\n")) claudeEvents := processor.ProcessLine(strings.TrimRight(line, "\r\n"))
...@@ -1153,25 +1566,23 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context ...@@ -1153,25 +1566,23 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
if len(finalEvents) > 0 { if len(finalEvents) > 0 {
_, _ = c.Writer.Write(finalEvents) _, _ = c.Writer.Write(finalEvents)
} }
sendErrorEvent("write_failed")
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, writeErr return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, writeErr
} }
flusher.Flush() flusher.Flush()
} }
}
if errors.Is(err, io.EOF) { case <-intervalCh:
break lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
log.Printf("Stream data interval timeout (antigravity)")
sendErrorEvent("stream_timeout")
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
} }
} }
// 发送结束事件
finalEvents, agUsage := processor.Finish()
if len(finalEvents) > 0 {
_, _ = c.Writer.Write(finalEvents)
flusher.Flush()
}
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, nil
} }
// extractImageSize 从 Gemini 请求中提取 image_size 参数 // extractImageSize 从 Gemini 请求中提取 image_size 参数
......
package service
import (
"encoding/json"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/stretchr/testify/require"
)
func TestStripSignatureSensitiveBlocksFromClaudeRequest(t *testing.T) {
req := &antigravity.ClaudeRequest{
Model: "claude-sonnet-4-5",
Thinking: &antigravity.ThinkingConfig{
Type: "enabled",
BudgetTokens: 1024,
},
Messages: []antigravity.ClaudeMessage{
{
Role: "assistant",
Content: json.RawMessage(`[
{"type":"thinking","thinking":"secret plan","signature":""},
{"type":"tool_use","id":"t1","name":"Bash","input":{"command":"ls"}}
]`),
},
{
Role: "user",
Content: json.RawMessage(`[
{"type":"tool_result","tool_use_id":"t1","content":"ok","is_error":false},
{"type":"redacted_thinking","data":"..."}
]`),
},
},
}
changed, err := stripSignatureSensitiveBlocksFromClaudeRequest(req)
require.NoError(t, err)
require.True(t, changed)
require.Nil(t, req.Thinking)
require.Len(t, req.Messages, 2)
var blocks0 []map[string]any
require.NoError(t, json.Unmarshal(req.Messages[0].Content, &blocks0))
require.Len(t, blocks0, 2)
require.Equal(t, "text", blocks0[0]["type"])
require.Equal(t, "secret plan", blocks0[0]["text"])
require.Equal(t, "text", blocks0[1]["type"])
var blocks1 []map[string]any
require.NoError(t, json.Unmarshal(req.Messages[1].Content, &blocks1))
require.Len(t, blocks1, 1)
require.Equal(t, "text", blocks1[0]["type"])
require.NotEmpty(t, blocks1[0]["text"])
}
func TestStripThinkingFromClaudeRequest_DoesNotDowngradeTools(t *testing.T) {
req := &antigravity.ClaudeRequest{
Model: "claude-sonnet-4-5",
Thinking: &antigravity.ThinkingConfig{
Type: "enabled",
BudgetTokens: 1024,
},
Messages: []antigravity.ClaudeMessage{
{
Role: "assistant",
Content: json.RawMessage(`[{"type":"thinking","thinking":"secret plan"},{"type":"tool_use","id":"t1","name":"Bash","input":{"command":"ls"}}]`),
},
},
}
changed, err := stripThinkingFromClaudeRequest(req)
require.NoError(t, err)
require.True(t, changed)
require.Nil(t, req.Thinking)
var blocks []map[string]any
require.NoError(t, json.Unmarshal(req.Messages[0].Content, &blocks))
require.Len(t, blocks, 2)
require.Equal(t, "text", blocks[0]["type"])
require.Equal(t, "secret plan", blocks[0]["text"])
require.Equal(t, "tool_use", blocks[1]["type"])
}
...@@ -221,9 +221,33 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S ...@@ -221,9 +221,33 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
// VerifyTurnstile 验证Turnstile token // VerifyTurnstile 验证Turnstile token
func (s *AuthService) VerifyTurnstile(ctx context.Context, token string, remoteIP string) error { func (s *AuthService) VerifyTurnstile(ctx context.Context, token string, remoteIP string) error {
required := s.cfg != nil && s.cfg.Server.Mode == "release" && s.cfg.Turnstile.Required
if required {
if s.settingService == nil {
log.Println("[Auth] Turnstile required but settings service is not configured")
return ErrTurnstileNotConfigured
}
enabled := s.settingService.IsTurnstileEnabled(ctx)
secretConfigured := s.settingService.GetTurnstileSecretKey(ctx) != ""
if !enabled || !secretConfigured {
log.Printf("[Auth] Turnstile required but not configured (enabled=%v, secret_configured=%v)", enabled, secretConfigured)
return ErrTurnstileNotConfigured
}
}
if s.turnstileService == nil { if s.turnstileService == nil {
if required {
log.Println("[Auth] Turnstile required but service not configured")
return ErrTurnstileNotConfigured
}
return nil // 服务未配置则跳过验证 return nil // 服务未配置则跳过验证
} }
if !required && s.settingService != nil && s.settingService.IsTurnstileEnabled(ctx) && s.settingService.GetTurnstileSecretKey(ctx) == "" {
log.Println("[Auth] Turnstile enabled but secret key not configured")
}
return s.turnstileService.VerifyToken(ctx, token, remoteIP) return s.turnstileService.VerifyToken(ctx, token, remoteIP)
} }
......
...@@ -16,7 +16,8 @@ import ( ...@@ -16,7 +16,8 @@ import (
// 注:ErrInsufficientBalance在redeem_service.go中定义 // 注:ErrInsufficientBalance在redeem_service.go中定义
// 注:ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义 // 注:ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义
var ( var (
ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired") ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
ErrBillingServiceUnavailable = infraerrors.ServiceUnavailable("BILLING_SERVICE_ERROR", "Billing service temporarily unavailable. Please retry later.")
) )
// subscriptionCacheData 订阅缓存数据结构(内部使用) // subscriptionCacheData 订阅缓存数据结构(内部使用)
...@@ -72,10 +73,11 @@ type cacheWriteTask struct { ...@@ -72,10 +73,11 @@ type cacheWriteTask struct {
// BillingCacheService 计费缓存服务 // BillingCacheService 计费缓存服务
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查 // 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
type BillingCacheService struct { type BillingCacheService struct {
cache BillingCache cache BillingCache
userRepo UserRepository userRepo UserRepository
subRepo UserSubscriptionRepository subRepo UserSubscriptionRepository
cfg *config.Config cfg *config.Config
circuitBreaker *billingCircuitBreaker
cacheWriteChan chan cacheWriteTask cacheWriteChan chan cacheWriteTask
cacheWriteWg sync.WaitGroup cacheWriteWg sync.WaitGroup
...@@ -95,6 +97,7 @@ func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo ...@@ -95,6 +97,7 @@ func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo
subRepo: subRepo, subRepo: subRepo,
cfg: cfg, cfg: cfg,
} }
svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker)
svc.startCacheWriteWorkers() svc.startCacheWriteWorkers()
return svc return svc
} }
...@@ -450,6 +453,9 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user ...@@ -450,6 +453,9 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
if s.cfg.RunMode == config.RunModeSimple { if s.cfg.RunMode == config.RunModeSimple {
return nil return nil
} }
if s.circuitBreaker != nil && !s.circuitBreaker.Allow() {
return ErrBillingServiceUnavailable
}
// 判断计费模式 // 判断计费模式
isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil
...@@ -465,9 +471,14 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user ...@@ -465,9 +471,14 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
func (s *BillingCacheService) checkBalanceEligibility(ctx context.Context, userID int64) error { func (s *BillingCacheService) checkBalanceEligibility(ctx context.Context, userID int64) error {
balance, err := s.GetUserBalance(ctx, userID) balance, err := s.GetUserBalance(ctx, userID)
if err != nil { if err != nil {
// 缓存/数据库错误,允许通过(降级处理) if s.circuitBreaker != nil {
log.Printf("Warning: get user balance failed, allowing request: %v", err) s.circuitBreaker.OnFailure(err)
return nil }
log.Printf("ALERT: billing balance check failed for user %d: %v", userID, err)
return ErrBillingServiceUnavailable.WithCause(err)
}
if s.circuitBreaker != nil {
s.circuitBreaker.OnSuccess()
} }
if balance <= 0 { if balance <= 0 {
...@@ -482,9 +493,14 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context, ...@@ -482,9 +493,14 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
// 获取订阅缓存数据 // 获取订阅缓存数据
subData, err := s.GetSubscriptionStatus(ctx, userID, group.ID) subData, err := s.GetSubscriptionStatus(ctx, userID, group.ID)
if err != nil { if err != nil {
// 缓存/数据库错误,降级使用传入的subscription进行检查 if s.circuitBreaker != nil {
log.Printf("Warning: get subscription cache failed, using fallback: %v", err) s.circuitBreaker.OnFailure(err)
return s.checkSubscriptionLimitsFallback(subscription, group) }
log.Printf("ALERT: billing subscription check failed for user %d group %d: %v", userID, group.ID, err)
return ErrBillingServiceUnavailable.WithCause(err)
}
if s.circuitBreaker != nil {
s.circuitBreaker.OnSuccess()
} }
// 检查订阅状态 // 检查订阅状态
...@@ -513,27 +529,133 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context, ...@@ -513,27 +529,133 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
return nil return nil
} }
// checkSubscriptionLimitsFallback 降级检查订阅限额 type billingCircuitBreakerState int
func (s *BillingCacheService) checkSubscriptionLimitsFallback(subscription *UserSubscription, group *Group) error {
if subscription == nil { const (
return ErrSubscriptionInvalid billingCircuitClosed billingCircuitBreakerState = iota
billingCircuitOpen
billingCircuitHalfOpen
)
type billingCircuitBreaker struct {
mu sync.Mutex
state billingCircuitBreakerState
failures int
openedAt time.Time
failureThreshold int
resetTimeout time.Duration
halfOpenRequests int
halfOpenRemaining int
}
func newBillingCircuitBreaker(cfg config.CircuitBreakerConfig) *billingCircuitBreaker {
if !cfg.Enabled {
return nil
}
resetTimeout := time.Duration(cfg.ResetTimeoutSeconds) * time.Second
if resetTimeout <= 0 {
resetTimeout = 30 * time.Second
}
halfOpen := cfg.HalfOpenRequests
if halfOpen <= 0 {
halfOpen = 1
}
threshold := cfg.FailureThreshold
if threshold <= 0 {
threshold = 5
}
return &billingCircuitBreaker{
state: billingCircuitClosed,
failureThreshold: threshold,
resetTimeout: resetTimeout,
halfOpenRequests: halfOpen,
} }
}
if !subscription.IsActive() { func (b *billingCircuitBreaker) Allow() bool {
return ErrSubscriptionInvalid b.mu.Lock()
defer b.mu.Unlock()
switch b.state {
case billingCircuitClosed:
return true
case billingCircuitOpen:
if time.Since(b.openedAt) < b.resetTimeout {
return false
}
b.state = billingCircuitHalfOpen
b.halfOpenRemaining = b.halfOpenRequests
log.Printf("ALERT: billing circuit breaker entering half-open state")
fallthrough
case billingCircuitHalfOpen:
if b.halfOpenRemaining <= 0 {
return false
}
b.halfOpenRemaining--
return true
default:
return false
} }
}
if !subscription.CheckDailyLimit(group, 0) { func (b *billingCircuitBreaker) OnFailure(err error) {
return ErrDailyLimitExceeded if b == nil {
return
} }
b.mu.Lock()
defer b.mu.Unlock()
if !subscription.CheckWeeklyLimit(group, 0) { switch b.state {
return ErrWeeklyLimitExceeded case billingCircuitOpen:
return
case billingCircuitHalfOpen:
b.state = billingCircuitOpen
b.openedAt = time.Now()
b.halfOpenRemaining = 0
log.Printf("ALERT: billing circuit breaker opened after half-open failure: %v", err)
return
default:
b.failures++
if b.failures >= b.failureThreshold {
b.state = billingCircuitOpen
b.openedAt = time.Now()
b.halfOpenRemaining = 0
log.Printf("ALERT: billing circuit breaker opened after %d failures: %v", b.failures, err)
}
} }
}
if !subscription.CheckMonthlyLimit(group, 0) { func (b *billingCircuitBreaker) OnSuccess() {
return ErrMonthlyLimitExceeded if b == nil {
return
} }
b.mu.Lock()
defer b.mu.Unlock()
return nil previousState := b.state
previousFailures := b.failures
b.state = billingCircuitClosed
b.failures = 0
b.halfOpenRemaining = 0
// 只有状态真正发生变化时才记录日志
if previousState != billingCircuitClosed {
log.Printf("ALERT: billing circuit breaker closed (was %s)", circuitStateString(previousState))
} else if previousFailures > 0 {
log.Printf("INFO: billing circuit breaker failures reset from %d", previousFailures)
}
}
func circuitStateString(state billingCircuitBreakerState) string {
switch state {
case billingCircuitClosed:
return "closed"
case billingCircuitOpen:
return "open"
case billingCircuitHalfOpen:
return "half-open"
default:
return "unknown"
}
} }
...@@ -8,12 +8,13 @@ import ( ...@@ -8,12 +8,13 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/url"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
) )
type CRSSyncService struct { type CRSSyncService struct {
...@@ -22,6 +23,7 @@ type CRSSyncService struct { ...@@ -22,6 +23,7 @@ type CRSSyncService struct {
oauthService *OAuthService oauthService *OAuthService
openaiOAuthService *OpenAIOAuthService openaiOAuthService *OpenAIOAuthService
geminiOAuthService *GeminiOAuthService geminiOAuthService *GeminiOAuthService
cfg *config.Config
} }
func NewCRSSyncService( func NewCRSSyncService(
...@@ -30,6 +32,7 @@ func NewCRSSyncService( ...@@ -30,6 +32,7 @@ func NewCRSSyncService(
oauthService *OAuthService, oauthService *OAuthService,
openaiOAuthService *OpenAIOAuthService, openaiOAuthService *OpenAIOAuthService,
geminiOAuthService *GeminiOAuthService, geminiOAuthService *GeminiOAuthService,
cfg *config.Config,
) *CRSSyncService { ) *CRSSyncService {
return &CRSSyncService{ return &CRSSyncService{
accountRepo: accountRepo, accountRepo: accountRepo,
...@@ -37,6 +40,7 @@ func NewCRSSyncService( ...@@ -37,6 +40,7 @@ func NewCRSSyncService(
oauthService: oauthService, oauthService: oauthService,
openaiOAuthService: openaiOAuthService, openaiOAuthService: openaiOAuthService,
geminiOAuthService: geminiOAuthService, geminiOAuthService: geminiOAuthService,
cfg: cfg,
} }
} }
...@@ -187,16 +191,31 @@ type crsGeminiAPIKeyAccount struct { ...@@ -187,16 +191,31 @@ type crsGeminiAPIKeyAccount struct {
} }
func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) { func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) {
baseURL, err := normalizeBaseURL(input.BaseURL) if s.cfg == nil {
if err != nil { return nil, errors.New("config is not available")
return nil, err }
baseURL := strings.TrimSpace(input.BaseURL)
if s.cfg.Security.URLAllowlist.Enabled {
normalized, err := normalizeBaseURL(baseURL, s.cfg.Security.URLAllowlist.CRSHosts, s.cfg.Security.URLAllowlist.AllowPrivateHosts)
if err != nil {
return nil, err
}
baseURL = normalized
} else {
normalized, err := urlvalidator.ValidateURLFormat(baseURL, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
if err != nil {
return nil, fmt.Errorf("invalid base_url: %w", err)
}
baseURL = normalized
} }
if strings.TrimSpace(input.Username) == "" || strings.TrimSpace(input.Password) == "" { if strings.TrimSpace(input.Username) == "" || strings.TrimSpace(input.Password) == "" {
return nil, errors.New("username and password are required") return nil, errors.New("username and password are required")
} }
client, err := httpclient.GetClient(httpclient.Options{ client, err := httpclient.GetClient(httpclient.Options{
Timeout: 20 * time.Second, Timeout: 20 * time.Second,
ValidateResolvedIP: s.cfg.Security.URLAllowlist.Enabled,
AllowPrivateHosts: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
}) })
if err != nil { if err != nil {
client = &http.Client{Timeout: 20 * time.Second} client = &http.Client{Timeout: 20 * time.Second}
...@@ -1055,17 +1074,18 @@ func mapCRSStatus(isActive bool, status string) string { ...@@ -1055,17 +1074,18 @@ func mapCRSStatus(isActive bool, status string) string {
return "active" return "active"
} }
func normalizeBaseURL(raw string) (string, error) { func normalizeBaseURL(raw string, allowlist []string, allowPrivate bool) (string, error) {
trimmed := strings.TrimSpace(raw) // 当 allowlist 为空时,不强制要求白名单(只进行基本的 URL 和 SSRF 验证)
if trimmed == "" { requireAllowlist := len(allowlist) > 0
return "", errors.New("base_url is required") normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
} AllowedHosts: allowlist,
u, err := url.Parse(trimmed) RequireAllowlist: requireAllowlist,
if err != nil || u.Scheme == "" || u.Host == "" { AllowPrivate: allowPrivate,
return "", fmt.Errorf("invalid base_url: %s", trimmed) })
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
} }
u.Path = strings.TrimRight(u.Path, "/") return normalized, nil
return strings.TrimRight(u.String(), "/"), nil
} }
// cleanBaseURL removes trailing suffix from base_url in credentials // cleanBaseURL removes trailing suffix from base_url in credentials
......
...@@ -101,6 +101,10 @@ const ( ...@@ -101,6 +101,10 @@ const (
SettingKeyFallbackModelOpenAI = "fallback_model_openai" SettingKeyFallbackModelOpenAI = "fallback_model_openai"
SettingKeyFallbackModelGemini = "fallback_model_gemini" SettingKeyFallbackModelGemini = "fallback_model_gemini"
SettingKeyFallbackModelAntigravity = "fallback_model_antigravity" SettingKeyFallbackModelAntigravity = "fallback_model_antigravity"
// Request identity patch (Claude -> Gemini systemInstruction injection)
SettingKeyEnableIdentityPatch = "enable_identity_patch"
SettingKeyIdentityPatchPrompt = "identity_patch_prompt"
) )
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys). // AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
......
...@@ -84,25 +84,180 @@ func FilterThinkingBlocks(body []byte) []byte { ...@@ -84,25 +84,180 @@ func FilterThinkingBlocks(body []byte) []byte {
return filterThinkingBlocksInternal(body, false) return filterThinkingBlocksInternal(body, false)
} }
// FilterThinkingBlocksForRetry removes thinking blocks from HISTORICAL messages for retry scenarios. // FilterThinkingBlocksForRetry strips thinking-related constructs for retry scenarios.
// This is used when upstream returns signature-related 400 errors.
// //
// Key insight: // Why:
// - User's thinking.type = "enabled" should be PRESERVED (user's intent) // - Upstreams may reject historical `thinking`/`redacted_thinking` blocks due to invalid/missing signatures.
// - Only HISTORICAL assistant messages have thinking blocks with signatures // - Anthropic extended thinking has a structural constraint: when top-level `thinking` is enabled and the
// - These signatures may be invalid when switching accounts/platforms // final message is an assistant prefill, the assistant content must start with a thinking block.
// - New responses will generate fresh thinking blocks without signature issues // - If we remove thinking blocks but keep top-level `thinking` enabled, we can trigger:
// "Expected `thinking` or `redacted_thinking`, but found `text`"
// //
// Strategy: // Strategy (B: preserve content as text):
// - Keep thinking.type = "enabled" (preserve user intent) // - Disable top-level `thinking` (remove `thinking` field).
// - Remove thinking/redacted_thinking blocks from historical assistant messages // - Convert `thinking` blocks to `text` blocks (preserve the thinking content).
// - Ensure no message has empty content after filtering // - Remove `redacted_thinking` blocks (cannot be converted to text).
// - Ensure no message ends up with empty content.
func FilterThinkingBlocksForRetry(body []byte) []byte { func FilterThinkingBlocksForRetry(body []byte) []byte {
// Fast path: check for presence of thinking-related keys in messages hasThinkingContent := bytes.Contains(body, []byte(`"type":"thinking"`)) ||
bytes.Contains(body, []byte(`"type": "thinking"`)) ||
bytes.Contains(body, []byte(`"type":"redacted_thinking"`)) ||
bytes.Contains(body, []byte(`"type": "redacted_thinking"`)) ||
bytes.Contains(body, []byte(`"thinking":`)) ||
bytes.Contains(body, []byte(`"thinking" :`))
// Also check for empty content arrays that need fixing.
// Note: This is a heuristic check; the actual empty content handling is done below.
hasEmptyContent := bytes.Contains(body, []byte(`"content":[]`)) ||
bytes.Contains(body, []byte(`"content": []`)) ||
bytes.Contains(body, []byte(`"content" : []`)) ||
bytes.Contains(body, []byte(`"content" :[]`))
// Fast path: nothing to process
if !hasThinkingContent && !hasEmptyContent {
return body
}
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return body
}
modified := false
messages, ok := req["messages"].([]any)
if !ok {
return body
}
// Disable top-level thinking mode for retry to avoid structural/signature constraints upstream.
if _, exists := req["thinking"]; exists {
delete(req, "thinking")
modified = true
}
newMessages := make([]any, 0, len(messages))
for _, msg := range messages {
msgMap, ok := msg.(map[string]any)
if !ok {
newMessages = append(newMessages, msg)
continue
}
role, _ := msgMap["role"].(string)
content, ok := msgMap["content"].([]any)
if !ok {
// String content or other format - keep as is
newMessages = append(newMessages, msg)
continue
}
newContent := make([]any, 0, len(content))
modifiedThisMsg := false
for _, block := range content {
blockMap, ok := block.(map[string]any)
if !ok {
newContent = append(newContent, block)
continue
}
blockType, _ := blockMap["type"].(string)
// Convert thinking blocks to text (preserve content) and drop redacted_thinking.
switch blockType {
case "thinking":
modifiedThisMsg = true
thinkingText, _ := blockMap["thinking"].(string)
if thinkingText == "" {
continue
}
newContent = append(newContent, map[string]any{
"type": "text",
"text": thinkingText,
})
continue
case "redacted_thinking":
modifiedThisMsg = true
continue
}
// Handle blocks without type discriminator but with a "thinking" field.
if blockType == "" {
if rawThinking, hasThinking := blockMap["thinking"]; hasThinking {
modifiedThisMsg = true
switch v := rawThinking.(type) {
case string:
if v != "" {
newContent = append(newContent, map[string]any{"type": "text", "text": v})
}
default:
if b, err := json.Marshal(v); err == nil && len(b) > 0 {
newContent = append(newContent, map[string]any{"type": "text", "text": string(b)})
}
}
continue
}
}
newContent = append(newContent, block)
}
// Handle empty content: either from filtering or originally empty
if len(newContent) == 0 {
modified = true
placeholder := "(content removed)"
if role == "assistant" {
placeholder = "(assistant content removed)"
}
newContent = append(newContent, map[string]any{
"type": "text",
"text": placeholder,
})
msgMap["content"] = newContent
} else if modifiedThisMsg {
modified = true
msgMap["content"] = newContent
}
newMessages = append(newMessages, msgMap)
}
if modified {
req["messages"] = newMessages
} else {
// Avoid rewriting JSON when no changes are needed.
return body
}
newBody, err := json.Marshal(req)
if err != nil {
return body
}
return newBody
}
// FilterSignatureSensitiveBlocksForRetry is a stronger retry filter for cases where upstream errors indicate
// signature/thought_signature validation issues involving tool blocks.
//
// This performs everything in FilterThinkingBlocksForRetry, plus:
// - Convert `tool_use` blocks to text (name/id/input) so we stop sending structured tool calls.
// - Convert `tool_result` blocks to text so we keep tool results visible without tool semantics.
//
// Use this only when needed: converting tool blocks to text changes model behaviour and can increase the
// risk of prompt injection (tool output becomes plain conversation text).
func FilterSignatureSensitiveBlocksForRetry(body []byte) []byte {
// Fast path: only run when we see likely relevant constructs.
if !bytes.Contains(body, []byte(`"type":"thinking"`)) && if !bytes.Contains(body, []byte(`"type":"thinking"`)) &&
!bytes.Contains(body, []byte(`"type": "thinking"`)) && !bytes.Contains(body, []byte(`"type": "thinking"`)) &&
!bytes.Contains(body, []byte(`"type":"redacted_thinking"`)) && !bytes.Contains(body, []byte(`"type":"redacted_thinking"`)) &&
!bytes.Contains(body, []byte(`"type": "redacted_thinking"`)) { !bytes.Contains(body, []byte(`"type": "redacted_thinking"`)) &&
!bytes.Contains(body, []byte(`"type":"tool_use"`)) &&
!bytes.Contains(body, []byte(`"type": "tool_use"`)) &&
!bytes.Contains(body, []byte(`"type":"tool_result"`)) &&
!bytes.Contains(body, []byte(`"type": "tool_result"`)) &&
!bytes.Contains(body, []byte(`"thinking":`)) &&
!bytes.Contains(body, []byte(`"thinking" :`)) {
return body return body
} }
...@@ -111,15 +266,19 @@ func FilterThinkingBlocksForRetry(body []byte) []byte { ...@@ -111,15 +266,19 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
return body return body
} }
// DO NOT modify thinking.type - preserve user's intent to use thinking mode modified := false
// The issue is with historical message signatures, not the thinking mode itself
// Disable top-level thinking for retry to avoid structural/signature constraints upstream.
if _, exists := req["thinking"]; exists {
delete(req, "thinking")
modified = true
}
messages, ok := req["messages"].([]any) messages, ok := req["messages"].([]any)
if !ok { if !ok {
return body return body
} }
modified := false
newMessages := make([]any, 0, len(messages)) newMessages := make([]any, 0, len(messages))
for _, msg := range messages { for _, msg := range messages {
...@@ -132,7 +291,6 @@ func FilterThinkingBlocksForRetry(body []byte) []byte { ...@@ -132,7 +291,6 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
role, _ := msgMap["role"].(string) role, _ := msgMap["role"].(string)
content, ok := msgMap["content"].([]any) content, ok := msgMap["content"].([]any)
if !ok { if !ok {
// String content or other format - keep as is
newMessages = append(newMessages, msg) newMessages = append(newMessages, msg)
continue continue
} }
...@@ -148,43 +306,96 @@ func FilterThinkingBlocksForRetry(body []byte) []byte { ...@@ -148,43 +306,96 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
} }
blockType, _ := blockMap["type"].(string) blockType, _ := blockMap["type"].(string)
switch blockType {
// Remove thinking/redacted_thinking blocks from historical messages case "thinking":
// These have signatures that may be invalid across different accounts modifiedThisMsg = true
if blockType == "thinking" || blockType == "redacted_thinking" { thinkingText, _ := blockMap["thinking"].(string)
if thinkingText == "" {
continue
}
newContent = append(newContent, map[string]any{"type": "text", "text": thinkingText})
continue
case "redacted_thinking":
modifiedThisMsg = true
continue
case "tool_use":
modifiedThisMsg = true
name, _ := blockMap["name"].(string)
id, _ := blockMap["id"].(string)
input := blockMap["input"]
inputJSON, _ := json.Marshal(input)
text := "(tool_use)"
if name != "" {
text += " name=" + name
}
if id != "" {
text += " id=" + id
}
if len(inputJSON) > 0 && string(inputJSON) != "null" {
text += " input=" + string(inputJSON)
}
newContent = append(newContent, map[string]any{"type": "text", "text": text})
continue
case "tool_result":
modifiedThisMsg = true modifiedThisMsg = true
toolUseID, _ := blockMap["tool_use_id"].(string)
isError, _ := blockMap["is_error"].(bool)
content := blockMap["content"]
contentJSON, _ := json.Marshal(content)
text := "(tool_result)"
if toolUseID != "" {
text += " tool_use_id=" + toolUseID
}
if isError {
text += " is_error=true"
}
if len(contentJSON) > 0 && string(contentJSON) != "null" {
text += "\n" + string(contentJSON)
}
newContent = append(newContent, map[string]any{"type": "text", "text": text})
continue continue
} }
if blockType == "" {
if rawThinking, hasThinking := blockMap["thinking"]; hasThinking {
modifiedThisMsg = true
switch v := rawThinking.(type) {
case string:
if v != "" {
newContent = append(newContent, map[string]any{"type": "text", "text": v})
}
default:
if b, err := json.Marshal(v); err == nil && len(b) > 0 {
newContent = append(newContent, map[string]any{"type": "text", "text": string(b)})
}
}
continue
}
}
newContent = append(newContent, block) newContent = append(newContent, block)
} }
if modifiedThisMsg { if modifiedThisMsg {
modified = true modified = true
// Handle empty content after filtering
if len(newContent) == 0 { if len(newContent) == 0 {
// For assistant messages, skip entirely (remove from conversation) placeholder := "(content removed)"
// For user messages, add placeholder to avoid empty content error if role == "assistant" {
if role == "user" { placeholder = "(assistant content removed)"
newContent = append(newContent, map[string]any{
"type": "text",
"text": "(content removed)",
})
msgMap["content"] = newContent
newMessages = append(newMessages, msgMap)
} }
// Skip assistant messages with empty content (don't append) newContent = append(newContent, map[string]any{"type": "text", "text": placeholder})
continue
} }
msgMap["content"] = newContent msgMap["content"] = newContent
} }
newMessages = append(newMessages, msgMap) newMessages = append(newMessages, msgMap)
} }
if modified { if !modified {
req["messages"] = newMessages return body
} }
req["messages"] = newMessages
newBody, err := json.Marshal(req) newBody, err := json.Marshal(req)
if err != nil { if err != nil {
return body return body
......
...@@ -151,3 +151,148 @@ func TestFilterThinkingBlocks(t *testing.T) { ...@@ -151,3 +151,148 @@ func TestFilterThinkingBlocks(t *testing.T) {
}) })
} }
} }
func TestFilterThinkingBlocksForRetry_DisablesThinkingAndPreservesAsText(t *testing.T) {
input := []byte(`{
"model":"claude-3-5-sonnet-20241022",
"thinking":{"type":"enabled","budget_tokens":1024},
"messages":[
{"role":"user","content":[{"type":"text","text":"Hi"}]},
{"role":"assistant","content":[
{"type":"thinking","thinking":"Let me think...","signature":"bad_sig"},
{"type":"text","text":"Answer"}
]}
]
}`)
out := FilterThinkingBlocksForRetry(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
_, hasThinking := req["thinking"]
require.False(t, hasThinking)
msgs, ok := req["messages"].([]any)
require.True(t, ok)
require.Len(t, msgs, 2)
assistant, ok := msgs[1].(map[string]any)
require.True(t, ok)
content, ok := assistant["content"].([]any)
require.True(t, ok)
require.Len(t, content, 2)
first, ok := content[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "text", first["type"])
require.Equal(t, "Let me think...", first["text"])
}
func TestFilterThinkingBlocksForRetry_DisablesThinkingEvenWithoutThinkingBlocks(t *testing.T) {
input := []byte(`{
"model":"claude-3-5-sonnet-20241022",
"thinking":{"type":"enabled","budget_tokens":1024},
"messages":[
{"role":"user","content":[{"type":"text","text":"Hi"}]},
{"role":"assistant","content":[{"type":"text","text":"Prefill"}]}
]
}`)
out := FilterThinkingBlocksForRetry(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
_, hasThinking := req["thinking"]
require.False(t, hasThinking)
}
func TestFilterThinkingBlocksForRetry_RemovesRedactedThinkingAndKeepsValidContent(t *testing.T) {
input := []byte(`{
"thinking":{"type":"enabled","budget_tokens":1024},
"messages":[
{"role":"assistant","content":[
{"type":"redacted_thinking","data":"..."},
{"type":"text","text":"Visible"}
]}
]
}`)
out := FilterThinkingBlocksForRetry(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
_, hasThinking := req["thinking"]
require.False(t, hasThinking)
msgs, ok := req["messages"].([]any)
require.True(t, ok)
msg0, ok := msgs[0].(map[string]any)
require.True(t, ok)
content, ok := msg0["content"].([]any)
require.True(t, ok)
require.Len(t, content, 1)
content0, ok := content[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "text", content0["type"])
require.Equal(t, "Visible", content0["text"])
}
func TestFilterThinkingBlocksForRetry_EmptyContentGetsPlaceholder(t *testing.T) {
input := []byte(`{
"thinking":{"type":"enabled"},
"messages":[
{"role":"assistant","content":[{"type":"redacted_thinking","data":"..."}]}
]
}`)
out := FilterThinkingBlocksForRetry(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
msgs, ok := req["messages"].([]any)
require.True(t, ok)
msg0, ok := msgs[0].(map[string]any)
require.True(t, ok)
content, ok := msg0["content"].([]any)
require.True(t, ok)
require.Len(t, content, 1)
content0, ok := content[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "text", content0["type"])
require.NotEmpty(t, content0["text"])
}
func TestFilterSignatureSensitiveBlocksForRetry_DowngradesTools(t *testing.T) {
input := []byte(`{
"thinking":{"type":"enabled","budget_tokens":1024},
"messages":[
{"role":"assistant","content":[
{"type":"tool_use","id":"t1","name":"Bash","input":{"command":"ls"}},
{"type":"tool_result","tool_use_id":"t1","content":"ok","is_error":false}
]}
]
}`)
out := FilterSignatureSensitiveBlocksForRetry(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
_, hasThinking := req["thinking"]
require.False(t, hasThinking)
msgs, ok := req["messages"].([]any)
require.True(t, ok)
msg0, ok := msgs[0].(map[string]any)
require.True(t, ok)
content, ok := msg0["content"].([]any)
require.True(t, ok)
require.Len(t, content, 2)
content0, ok := content[0].(map[string]any)
require.True(t, ok)
content1, ok := content[1].(map[string]any)
require.True(t, ok)
require.Equal(t, "text", content0["type"])
require.Equal(t, "text", content1["type"])
require.Contains(t, content0["text"], "tool_use")
require.Contains(t, content1["text"], "tool_result")
}
...@@ -15,11 +15,14 @@ import ( ...@@ -15,11 +15,14 @@ import (
"regexp" "regexp"
"sort" "sort"
"strings" "strings"
"sync/atomic"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
...@@ -30,6 +33,7 @@ const ( ...@@ -30,6 +33,7 @@ const (
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true" claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true" claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
stickySessionTTL = time.Hour // 粘性会话TTL stickySessionTTL = time.Hour // 粘性会话TTL
defaultMaxLineSize = 10 * 1024 * 1024
claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude." claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude."
) )
...@@ -933,8 +937,16 @@ func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (s ...@@ -933,8 +937,16 @@ func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (s
// 重试相关常量 // 重试相关常量
const ( const (
maxRetries = 10 // 最大重试次数 // 最大尝试次数(包含首次请求)。过多重试会导致请求堆积与资源耗尽。
retryDelay = 3 * time.Second // 重试等待时间 maxRetryAttempts = 5
// 指数退避:第 N 次失败后的等待 = retryBaseDelay * 2^(N-1),并且上限为 retryMaxDelay。
retryBaseDelay = 300 * time.Millisecond
retryMaxDelay = 3 * time.Second
// 最大重试耗时(包含请求本身耗时 + 退避等待时间)。
// 用于防止极端情况下 goroutine 长时间堆积导致资源耗尽。
maxRetryElapsed = 10 * time.Second
) )
func (s *GatewayService) shouldRetryUpstreamError(account *Account, statusCode int) bool { func (s *GatewayService) shouldRetryUpstreamError(account *Account, statusCode int) bool {
...@@ -957,6 +969,40 @@ func (s *GatewayService) shouldFailoverUpstreamError(statusCode int) bool { ...@@ -957,6 +969,40 @@ func (s *GatewayService) shouldFailoverUpstreamError(statusCode int) bool {
} }
} }
func retryBackoffDelay(attempt int) time.Duration {
// attempt 从 1 开始,表示第 attempt 次请求刚失败,需要等待后进行第 attempt+1 次请求。
if attempt <= 0 {
return retryBaseDelay
}
delay := retryBaseDelay * time.Duration(1<<(attempt-1))
if delay > retryMaxDelay {
return retryMaxDelay
}
return delay
}
func sleepWithContext(ctx context.Context, d time.Duration) error {
if d <= 0 {
return nil
}
timer := time.NewTimer(d)
defer func() {
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
return nil
}
}
// isClaudeCodeClient 判断请求是否来自 Claude Code 客户端 // isClaudeCodeClient 判断请求是否来自 Claude Code 客户端
// 简化判断:User-Agent 匹配 + metadata.user_id 存在 // 简化判断:User-Agent 匹配 + metadata.user_id 存在
func isClaudeCodeClient(userAgent string, metadataUserID string) bool { func isClaudeCodeClient(userAgent string, metadataUserID string) bool {
...@@ -1073,7 +1119,8 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -1073,7 +1119,8 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 重试循环 // 重试循环
var resp *http.Response var resp *http.Response
for attempt := 1; attempt <= maxRetries; attempt++ { retryStart := time.Now()
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取) // 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel) upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel)
if err != nil { if err != nil {
...@@ -1083,6 +1130,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -1083,6 +1130,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 发送请求 // 发送请求
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil { if err != nil {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
}
return nil, fmt.Errorf("upstream request failed: %w", err) return nil, fmt.Errorf("upstream request failed: %w", err)
} }
...@@ -1093,28 +1143,80 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -1093,28 +1143,80 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
_ = resp.Body.Close() _ = resp.Body.Close()
if s.isThinkingBlockSignatureError(respBody) { if s.isThinkingBlockSignatureError(respBody) {
looksLikeToolSignatureError := func(msg string) bool {
m := strings.ToLower(msg)
return strings.Contains(m, "tool_use") ||
strings.Contains(m, "tool_result") ||
strings.Contains(m, "functioncall") ||
strings.Contains(m, "function_call") ||
strings.Contains(m, "functionresponse") ||
strings.Contains(m, "function_response")
}
// 避免在重试预算已耗尽时再发起额外请求
if time.Since(retryStart) >= maxRetryElapsed {
resp.Body = io.NopCloser(bytes.NewReader(respBody))
break
}
log.Printf("Account %d: detected thinking block signature error, retrying with filtered thinking blocks", account.ID) log.Printf("Account %d: detected thinking block signature error, retrying with filtered thinking blocks", account.ID)
// 过滤thinking blocks并重试(使用更激进的过滤) // Conservative two-stage fallback:
// 1) Disable thinking + thinking->text (preserve content)
// 2) Only if upstream still errors AND error message points to tool/function signature issues:
// also downgrade tool_use/tool_result blocks to text.
filteredBody := FilterThinkingBlocksForRetry(body) filteredBody := FilterThinkingBlocksForRetry(body)
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel) retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
if buildErr == nil { if buildErr == nil {
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency) retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
if retryErr == nil { if retryErr == nil {
// 使用重试后的响应,继续后续处理
if retryResp.StatusCode < 400 { if retryResp.StatusCode < 400 {
log.Printf("Account %d: signature error retry succeeded", account.ID) log.Printf("Account %d: signature error retry succeeded (thinking downgraded)", account.ID)
} else { resp = retryResp
log.Printf("Account %d: signature error retry returned status %d", account.ID, retryResp.StatusCode) break
}
retryRespBody, retryReadErr := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
_ = retryResp.Body.Close()
if retryReadErr == nil && retryResp.StatusCode == 400 && s.isThinkingBlockSignatureError(retryRespBody) {
msg2 := extractUpstreamErrorMessage(retryRespBody)
if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed {
log.Printf("Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID)
filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body)
retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel)
if buildErr2 == nil {
retryResp2, retryErr2 := s.httpUpstream.Do(retryReq2, proxyURL, account.ID, account.Concurrency)
if retryErr2 == nil {
resp = retryResp2
break
}
if retryResp2 != nil && retryResp2.Body != nil {
_ = retryResp2.Body.Close()
}
log.Printf("Account %d: tool-downgrade signature retry failed: %v", account.ID, retryErr2)
} else {
log.Printf("Account %d: tool-downgrade signature retry build failed: %v", account.ID, buildErr2)
}
}
}
// Fall back to the original retry response context.
resp = &http.Response{
StatusCode: retryResp.StatusCode,
Header: retryResp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(retryRespBody)),
} }
resp = retryResp
break break
} }
if retryResp != nil && retryResp.Body != nil {
_ = retryResp.Body.Close()
}
log.Printf("Account %d: signature error retry failed: %v", account.ID, retryErr) log.Printf("Account %d: signature error retry failed: %v", account.ID, retryErr)
} else { } else {
log.Printf("Account %d: signature error retry build request failed: %v", account.ID, buildErr) log.Printf("Account %d: signature error retry build request failed: %v", account.ID, buildErr)
} }
// 重试失败,恢复原始响应体继续处理
// Retry failed: restore original response body and continue handling.
resp.Body = io.NopCloser(bytes.NewReader(respBody)) resp.Body = io.NopCloser(bytes.NewReader(respBody))
break break
} }
...@@ -1125,11 +1227,27 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -1125,11 +1227,27 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 检查是否需要通用重试(排除400,因为400已经在上面特殊处理过了) // 检查是否需要通用重试(排除400,因为400已经在上面特殊处理过了)
if resp.StatusCode >= 400 && resp.StatusCode != 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) { if resp.StatusCode >= 400 && resp.StatusCode != 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
if attempt < maxRetries { if attempt < maxRetryAttempts {
log.Printf("Account %d: upstream error %d, retry %d/%d after %v", elapsed := time.Since(retryStart)
account.ID, resp.StatusCode, attempt, maxRetries, retryDelay) if elapsed >= maxRetryElapsed {
break
}
delay := retryBackoffDelay(attempt)
remaining := maxRetryElapsed - elapsed
if delay > remaining {
delay = remaining
}
if delay <= 0 {
break
}
log.Printf("Account %d: upstream error %d, retry %d/%d after %v (elapsed=%v/%v)",
account.ID, resp.StatusCode, attempt, maxRetryAttempts, delay, elapsed, maxRetryElapsed)
_ = resp.Body.Close() _ = resp.Body.Close()
time.Sleep(retryDelay) if err := sleepWithContext(ctx, delay); err != nil {
return nil, err
}
continue continue
} }
// 最后一次尝试也失败,跳出循环处理重试耗尽 // 最后一次尝试也失败,跳出循环处理重试耗尽
...@@ -1146,6 +1264,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -1146,6 +1264,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
} }
break break
} }
if resp == nil || resp.Body == nil {
return nil, errors.New("upstream request failed: empty response")
}
defer func() { _ = resp.Body.Close() }() defer func() { _ = resp.Body.Close() }()
// 处理重试耗尽的情况 // 处理重试耗尽的情况
...@@ -1229,7 +1350,13 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex ...@@ -1229,7 +1350,13 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
targetURL := claudeAPIURL targetURL := claudeAPIURL
if account.Type == AccountTypeAPIKey { if account.Type == AccountTypeAPIKey {
baseURL := account.GetBaseURL() baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages" if baseURL != "" {
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
targetURL = validatedURL + "/v1/messages"
}
} }
// OAuth账号:应用统一指纹 // OAuth账号:应用统一指纹
...@@ -1537,10 +1664,10 @@ func (s *GatewayService) handleRetryExhaustedSideEffects(ctx context.Context, re ...@@ -1537,10 +1664,10 @@ func (s *GatewayService) handleRetryExhaustedSideEffects(ctx context.Context, re
// OAuth/Setup Token 账号的 403:标记账号异常 // OAuth/Setup Token 账号的 403:标记账号异常
if account.IsOAuth() && statusCode == 403 { if account.IsOAuth() && statusCode == 403 {
s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, resp.Header, body) s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, resp.Header, body)
log.Printf("Account %d: marked as error after %d retries for status %d", account.ID, maxRetries, statusCode) log.Printf("Account %d: marked as error after %d retries for status %d", account.ID, maxRetryAttempts, statusCode)
} else { } else {
// API Key 未配置错误码:不标记账号状态 // API Key 未配置错误码:不标记账号状态
log.Printf("Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetries) log.Printf("Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetryAttempts)
} }
} }
...@@ -1577,6 +1704,10 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http ...@@ -1577,6 +1704,10 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
// 更新5h窗口状态 // 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
if s.cfg != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
}
// 设置SSE响应头 // 设置SSE响应头
c.Header("Content-Type", "text/event-stream") c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache") c.Header("Cache-Control", "no-cache")
...@@ -1598,51 +1729,133 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http ...@@ -1598,51 +1729,133 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
var firstTokenMs *int var firstTokenMs *int
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
// 设置更大的buffer以处理长行 // 设置更大的buffer以处理长行
scanner.Buffer(make([]byte, 64*1024), 1024*1024) maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
type scanEvent struct {
line string
err error
}
// 独立 goroutine 读取上游,避免读取阻塞导致超时/keepalive无法处理
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
needModelReplace := originalModel != mappedModel streamInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
// 仅监控上游数据间隔超时,避免下游写入阻塞导致误判
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
for scanner.Scan() { // 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
line := scanner.Text() errorEventSent := false
if line == "event: error" { sendErrorEvent := func(reason string) {
return nil, errors.New("have error in stream") if errorEventSent {
return
} }
errorEventSent = true
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
flusher.Flush()
}
// Extract data from SSE line (supports both "data: " and "data:" formats) needModelReplace := originalModel != mappedModel
if sseDataRe.MatchString(line) {
data := sseDataRe.ReplaceAllString(line, "")
// 如果有模型映射,替换响应中的model字段 for {
if needModelReplace { select {
line = s.replaceModelInSSELine(line, mappedModel, originalModel) case ev, ok := <-events:
if !ok {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
} }
if ev.err != nil {
// 转发行 if errors.Is(ev.err, bufio.ErrTooLong) {
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err sendErrorEvent("response_too_large")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
}
sendErrorEvent("stream_read_error")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
}
line := ev.line
if line == "event: error" {
return nil, errors.New("have error in stream")
} }
flusher.Flush()
// 记录首字时间:第一个有效的 content_block_delta 或 message_start // Extract data from SSE line (supports both "data: " and "data:" formats)
if firstTokenMs == nil && data != "" && data != "[DONE]" { if sseDataRe.MatchString(line) {
ms := int(time.Since(startTime).Milliseconds()) data := sseDataRe.ReplaceAllString(line, "")
firstTokenMs = &ms
// 如果有模型映射,替换响应中的model字段
if needModelReplace {
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
}
// 转发行
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
// 记录首字时间:第一个有效的 content_block_delta 或 message_start
if firstTokenMs == nil && data != "" && data != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
s.parseSSEUsage(data, usage)
} else {
// 非 data 行直接转发
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
} }
s.parseSSEUsage(data, usage)
} else { case <-intervalCh:
// 非 data 行直接转发 lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { if time.Since(lastRead) < streamInterval {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err continue
} }
flusher.Flush() log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
sendErrorEvent("stream_timeout")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
} }
} }
if err := scanner.Err(); err != nil {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
}
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
} }
// replaceModelInSSELine 替换SSE数据行中的model字段 // replaceModelInSSELine 替换SSE数据行中的model字段
...@@ -1747,15 +1960,17 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h ...@@ -1747,15 +1960,17 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
body = s.replaceModelInResponseBody(body, mappedModel, originalModel) body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
} }
// 透传响应头 responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
for key, values := range resp.Header {
for _, value := range values { contentType := "application/json"
c.Header(key, value) if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled {
if upstreamType := resp.Header.Get("Content-Type"); upstreamType != "" {
contentType = upstreamType
} }
} }
// 写入响应 // 写入响应
c.Data(resp.StatusCode, "application/json", body) c.Data(resp.StatusCode, contentType, body)
return &response.Usage, nil return &response.Usage, nil
} }
...@@ -1989,7 +2204,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, ...@@ -1989,7 +2204,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
if resp.StatusCode == 400 && s.isThinkingBlockSignatureError(respBody) { if resp.StatusCode == 400 && s.isThinkingBlockSignatureError(respBody) {
log.Printf("Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID) log.Printf("Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID)
filteredBody := FilterThinkingBlocks(body) filteredBody := FilterThinkingBlocksForRetry(body)
retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel) retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
if buildErr == nil { if buildErr == nil {
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency) retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
...@@ -2045,7 +2260,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con ...@@ -2045,7 +2260,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
targetURL := claudeAPICountTokensURL targetURL := claudeAPICountTokensURL
if account.Type == AccountTypeAPIKey { if account.Type == AccountTypeAPIKey {
baseURL := account.GetBaseURL() baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages/count_tokens" if baseURL != "" {
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
targetURL = validatedURL + "/v1/messages/count_tokens"
}
} }
// OAuth 账号:应用统一指纹和重写 userID // OAuth 账号:应用统一指纹和重写 userID
...@@ -2125,6 +2346,25 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m ...@@ -2125,6 +2346,25 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m
}) })
} }
func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
}
return normalized, nil
}
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
RequireAllowlist: true,
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
}
return normalized, nil
}
// GetAvailableModels returns the list of models available for a group // GetAvailableModels returns the list of models available for a group
// It aggregates model_mapping keys from all schedulable accounts in the group // It aggregates model_mapping keys from all schedulable accounts in the group
func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string { func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string {
......
...@@ -18,9 +18,12 @@ import ( ...@@ -18,9 +18,12 @@ import (
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
...@@ -41,6 +44,7 @@ type GeminiMessagesCompatService struct { ...@@ -41,6 +44,7 @@ type GeminiMessagesCompatService struct {
rateLimitService *RateLimitService rateLimitService *RateLimitService
httpUpstream HTTPUpstream httpUpstream HTTPUpstream
antigravityGatewayService *AntigravityGatewayService antigravityGatewayService *AntigravityGatewayService
cfg *config.Config
} }
func NewGeminiMessagesCompatService( func NewGeminiMessagesCompatService(
...@@ -51,6 +55,7 @@ func NewGeminiMessagesCompatService( ...@@ -51,6 +55,7 @@ func NewGeminiMessagesCompatService(
rateLimitService *RateLimitService, rateLimitService *RateLimitService,
httpUpstream HTTPUpstream, httpUpstream HTTPUpstream,
antigravityGatewayService *AntigravityGatewayService, antigravityGatewayService *AntigravityGatewayService,
cfg *config.Config,
) *GeminiMessagesCompatService { ) *GeminiMessagesCompatService {
return &GeminiMessagesCompatService{ return &GeminiMessagesCompatService{
accountRepo: accountRepo, accountRepo: accountRepo,
...@@ -60,6 +65,7 @@ func NewGeminiMessagesCompatService( ...@@ -60,6 +65,7 @@ func NewGeminiMessagesCompatService(
rateLimitService: rateLimitService, rateLimitService: rateLimitService,
httpUpstream: httpUpstream, httpUpstream: httpUpstream,
antigravityGatewayService: antigravityGatewayService, antigravityGatewayService: antigravityGatewayService,
cfg: cfg,
} }
} }
...@@ -230,6 +236,25 @@ func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *Antigravit ...@@ -230,6 +236,25 @@ func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *Antigravit
return s.antigravityGatewayService return s.antigravityGatewayService
} }
func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) {
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
}
return normalized, nil
}
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
RequireAllowlist: true,
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
}
return normalized, nil
}
// HasAntigravityAccounts 检查是否有可用的 antigravity 账户 // HasAntigravityAccounts 检查是否有可用的 antigravity 账户
func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context, groupID *int64) (bool, error) { func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context, groupID *int64) (bool, error) {
var accounts []Account var accounts []Account
...@@ -359,6 +384,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex ...@@ -359,6 +384,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
if err != nil { if err != nil {
return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", err.Error()) return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", err.Error())
} }
originalClaudeBody := body
proxyURL := "" proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil { if account.ProxyID != nil && account.Proxy != nil {
...@@ -381,16 +407,20 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex ...@@ -381,16 +407,20 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
return nil, "", errors.New("gemini api_key not configured") return nil, "", errors.New("gemini api_key not configured")
} }
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/") baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" { if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL baseURL = geminicli.AIStudioBaseURL
} }
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
}
action := "generateContent" action := "generateContent"
if req.Stream { if req.Stream {
action = "streamGenerateContent" action = "streamGenerateContent"
} }
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(baseURL, "/"), mappedModel, action) fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action)
if req.Stream { if req.Stream {
fullURL += "?alt=sse" fullURL += "?alt=sse"
} }
...@@ -427,7 +457,11 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex ...@@ -427,7 +457,11 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token) // 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
if projectID != "" { if projectID != "" {
// Mode 1: Code Assist API // Mode 1: Code Assist API
fullURL := fmt.Sprintf("%s/v1internal:%s", geminicli.GeminiCliBaseURL, action) baseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1internal:%s", strings.TrimRight(baseURL, "/"), action)
if useUpstreamStream { if useUpstreamStream {
fullURL += "?alt=sse" fullURL += "?alt=sse"
} }
...@@ -453,12 +487,16 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex ...@@ -453,12 +487,16 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
return upstreamReq, "x-request-id", nil return upstreamReq, "x-request-id", nil
} else { } else {
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token) // Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/") baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" { if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL baseURL = geminicli.AIStudioBaseURL
} }
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", baseURL, mappedModel, action) fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action)
if useUpstreamStream { if useUpstreamStream {
fullURL += "?alt=sse" fullURL += "?alt=sse"
} }
...@@ -479,6 +517,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex ...@@ -479,6 +517,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
} }
var resp *http.Response var resp *http.Response
signatureRetryStage := 0
for attempt := 1; attempt <= geminiMaxRetries; attempt++ { for attempt := 1; attempt <= geminiMaxRetries; attempt++ {
upstreamReq, idHeader, err := buildReq(ctx) upstreamReq, idHeader, err := buildReq(ctx)
if err != nil { if err != nil {
...@@ -503,6 +542,46 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex ...@@ -503,6 +542,46 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries: "+sanitizeUpstreamErrorMessage(err.Error())) return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries: "+sanitizeUpstreamErrorMessage(err.Error()))
} }
// Special-case: signature/thought_signature validation errors are not transient, but may be fixed by
// downgrading Claude thinking/tool history to plain text (conservative two-stage retry).
if resp.StatusCode == http.StatusBadRequest && signatureRetryStage < 2 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
if isGeminiSignatureRelatedError(respBody) {
var strippedClaudeBody []byte
stageName := ""
switch signatureRetryStage {
case 0:
// Stage 1: disable thinking + thinking->text
strippedClaudeBody = FilterThinkingBlocksForRetry(originalClaudeBody)
stageName = "thinking-only"
signatureRetryStage = 1
default:
// Stage 2: additionally downgrade tool_use/tool_result blocks to text
strippedClaudeBody = FilterSignatureSensitiveBlocksForRetry(originalClaudeBody)
stageName = "thinking+tools"
signatureRetryStage = 2
}
retryGeminiReq, txErr := convertClaudeMessagesToGeminiGenerateContent(strippedClaudeBody)
if txErr == nil {
log.Printf("Gemini account %d: detected signature-related 400, retrying with downgraded Claude blocks (%s)", account.ID, stageName)
geminiReq = retryGeminiReq
// Consume one retry budget attempt and continue with the updated request payload.
sleepGeminiBackoff(1)
continue
}
}
// Restore body for downstream error handling.
resp = &http.Response{
StatusCode: http.StatusBadRequest,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
}
break
}
if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) { if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close() _ = resp.Body.Close()
...@@ -600,6 +679,14 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex ...@@ -600,6 +679,14 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}, nil }, nil
} }
func isGeminiSignatureRelatedError(respBody []byte) bool {
msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
if msg == "" {
msg = strings.ToLower(string(respBody))
}
return strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature")
}
func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) { func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) {
startTime := time.Now() startTime := time.Now()
...@@ -650,12 +737,16 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. ...@@ -650,12 +737,16 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return nil, "", errors.New("gemini api_key not configured") return nil, "", errors.New("gemini api_key not configured")
} }
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/") baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" { if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL baseURL = geminicli.AIStudioBaseURL
} }
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(baseURL, "/"), mappedModel, upstreamAction) fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, upstreamAction)
if useUpstreamStream { if useUpstreamStream {
fullURL += "?alt=sse" fullURL += "?alt=sse"
} }
...@@ -687,7 +778,11 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. ...@@ -687,7 +778,11 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token) // 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
if projectID != "" && !forceAIStudio { if projectID != "" && !forceAIStudio {
// Mode 1: Code Assist API // Mode 1: Code Assist API
fullURL := fmt.Sprintf("%s/v1internal:%s", geminicli.GeminiCliBaseURL, upstreamAction) baseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1internal:%s", strings.TrimRight(baseURL, "/"), upstreamAction)
if useUpstreamStream { if useUpstreamStream {
fullURL += "?alt=sse" fullURL += "?alt=sse"
} }
...@@ -713,12 +808,16 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. ...@@ -713,12 +808,16 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return upstreamReq, "x-request-id", nil return upstreamReq, "x-request-id", nil
} else { } else {
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token) // Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/") baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" { if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL baseURL = geminicli.AIStudioBaseURL
} }
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", baseURL, mappedModel, upstreamAction) fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, upstreamAction)
if useUpstreamStream { if useUpstreamStream {
fullURL += "?alt=sse" fullURL += "?alt=sse"
} }
...@@ -1652,6 +1751,8 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co ...@@ -1652,6 +1751,8 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
_ = json.Unmarshal(respBody, &parsed) _ = json.Unmarshal(respBody, &parsed)
} }
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
contentType := resp.Header.Get("Content-Type") contentType := resp.Header.Get("Content-Type")
if contentType == "" { if contentType == "" {
contentType = "application/json" contentType = "application/json"
...@@ -1676,6 +1777,10 @@ func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Conte ...@@ -1676,6 +1777,10 @@ func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Conte
} }
log.Printf("[GeminiAPI] ====================================================") log.Printf("[GeminiAPI] ====================================================")
if s.cfg != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
}
c.Status(resp.StatusCode) c.Status(resp.StatusCode)
c.Header("Cache-Control", "no-cache") c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive") c.Header("Connection", "keep-alive")
...@@ -1773,11 +1878,15 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac ...@@ -1773,11 +1878,15 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
return nil, errors.New("invalid path") return nil, errors.New("invalid path")
} }
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/") baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" { if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL baseURL = geminicli.AIStudioBaseURL
} }
fullURL := strings.TrimRight(baseURL, "/") + path normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
fullURL := strings.TrimRight(normalizedBaseURL, "/") + path
var proxyURL string var proxyURL string
if account.ProxyID != nil && account.Proxy != nil { if account.ProxyID != nil && account.Proxy != nil {
...@@ -1816,9 +1925,14 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac ...@@ -1816,9 +1925,14 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
defer func() { _ = resp.Body.Close() }() defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<20)) body, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
wwwAuthenticate := resp.Header.Get("Www-Authenticate")
filteredHeaders := responseheaders.FilterHeaders(resp.Header, s.cfg.Security.ResponseHeaders)
if wwwAuthenticate != "" {
filteredHeaders.Set("Www-Authenticate", wwwAuthenticate)
}
return &UpstreamHTTPResult{ return &UpstreamHTTPResult{
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
Headers: resp.Header.Clone(), Headers: filteredHeaders,
Body: body, Body: body,
}, nil }, nil
} }
......
...@@ -1000,8 +1000,9 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR ...@@ -1000,8 +1000,9 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR
req.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent) req.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
client, err := httpclient.GetClient(httpclient.Options{ client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: strings.TrimSpace(proxyURL), ProxyURL: strings.TrimSpace(proxyURL),
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
ValidateResolvedIP: true,
}) })
if err != nil { if err != nil {
client = &http.Client{Timeout: 30 * time.Second} client = &http.Client{Timeout: 30 * time.Second}
......
...@@ -16,9 +16,12 @@ import ( ...@@ -16,9 +16,12 @@ import (
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
"sync/atomic"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
...@@ -630,10 +633,14 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. ...@@ -630,10 +633,14 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
case AccountTypeAPIKey: case AccountTypeAPIKey:
// API Key accounts use Platform API or custom base URL // API Key accounts use Platform API or custom base URL
baseURL := account.GetOpenAIBaseURL() baseURL := account.GetOpenAIBaseURL()
if baseURL != "" { if baseURL == "" {
targetURL = baseURL + "/responses"
} else {
targetURL = openaiPlatformAPIURL targetURL = openaiPlatformAPIURL
} else {
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
targetURL = validatedURL + "/responses"
} }
default: default:
targetURL = openaiPlatformAPIURL targetURL = openaiPlatformAPIURL
...@@ -755,6 +762,10 @@ type openaiStreamingResult struct { ...@@ -755,6 +762,10 @@ type openaiStreamingResult struct {
} }
func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) { func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) {
if s.cfg != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
}
// Set SSE response headers // Set SSE response headers
c.Header("Content-Type", "text/event-stream") c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache") c.Header("Cache-Control", "no-cache")
...@@ -775,48 +786,158 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp ...@@ -775,48 +786,158 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
usage := &OpenAIUsage{} usage := &OpenAIUsage{}
var firstTokenMs *int var firstTokenMs *int
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 64*1024), 1024*1024) maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
type scanEvent struct {
line string
err error
}
// 独立 goroutine 读取上游,避免读取阻塞影响 keepalive/超时处理
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
streamInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
// 仅监控上游数据间隔超时,不被下游写入阻塞影响
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
keepaliveInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
}
// 下游 keepalive 仅用于防止代理空闲断开
var keepaliveTicker *time.Ticker
if keepaliveInterval > 0 {
keepaliveTicker = time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
}
var keepaliveCh <-chan time.Time
if keepaliveTicker != nil {
keepaliveCh = keepaliveTicker.C
}
// 记录上次收到上游数据的时间,用于控制 keepalive 发送频率
lastDataAt := time.Now()
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
errorEventSent := false
sendErrorEvent := func(reason string) {
if errorEventSent {
return
}
errorEventSent = true
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
flusher.Flush()
}
needModelReplace := originalModel != mappedModel needModelReplace := originalModel != mappedModel
for scanner.Scan() { for {
line := scanner.Text() select {
case ev, ok := <-events:
if !ok {
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
if ev.err != nil {
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
sendErrorEvent("response_too_large")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
}
sendErrorEvent("stream_read_error")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
}
// Extract data from SSE line (supports both "data: " and "data:" formats) line := ev.line
if openaiSSEDataRe.MatchString(line) { lastDataAt = time.Now()
data := openaiSSEDataRe.ReplaceAllString(line, "")
// Extract data from SSE line (supports both "data: " and "data:" formats)
if openaiSSEDataRe.MatchString(line) {
data := openaiSSEDataRe.ReplaceAllString(line, "")
// Replace model in response if needed
if needModelReplace {
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
}
// Replace model in response if needed // Forward line
if needModelReplace { if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
line = s.replaceModelInSSELine(line, mappedModel, originalModel) sendErrorEvent("write_failed")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
// Record first token time
if firstTokenMs == nil && data != "" && data != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
s.parseSSEUsage(data, usage)
} else {
// Forward non-data lines as-is
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
} }
// Forward line case <-intervalCh:
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err if time.Since(lastRead) < streamInterval {
continue
} }
flusher.Flush() log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
sendErrorEvent("stream_timeout")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
// Record first token time case <-keepaliveCh:
if firstTokenMs == nil && data != "" && data != "[DONE]" { if time.Since(lastDataAt) < keepaliveInterval {
ms := int(time.Since(startTime).Milliseconds()) continue
firstTokenMs = &ms
} }
s.parseSSEUsage(data, usage) if _, err := fmt.Fprint(w, ":\n\n"); err != nil {
} else {
// Forward non-data lines as-is
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
} }
flusher.Flush() flusher.Flush()
} }
} }
if err := scanner.Err(); err != nil {
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
}
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
} }
func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel string) string { func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel string) string {
...@@ -911,18 +1032,39 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r ...@@ -911,18 +1032,39 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
body = s.replaceModelInResponseBody(body, mappedModel, originalModel) body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
} }
// Pass through headers responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
for key, values := range resp.Header {
for _, value := range values { contentType := "application/json"
c.Header(key, value) if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled {
if upstreamType := resp.Header.Get("Content-Type"); upstreamType != "" {
contentType = upstreamType
} }
} }
c.Data(resp.StatusCode, "application/json", body) c.Data(resp.StatusCode, contentType, body)
return usage, nil return usage, nil
} }
func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, error) {
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
}
return normalized, nil
}
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
RequireAllowlist: true,
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
}
return normalized, nil
}
func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte { func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
var resp map[string]any var resp map[string]any
if err := json.Unmarshal(body, &resp); err != nil { if err := json.Unmarshal(body, &resp); err != nil {
......
package service
import (
"bufio"
"bytes"
"errors"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
)
func TestOpenAIStreamingTimeout(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 1,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{
StatusCode: http.StatusOK,
Body: pr,
Header: http.Header{},
}
start := time.Now()
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, start, "model", "model")
_ = pw.Close()
_ = pr.Close()
if err == nil || !strings.Contains(err.Error(), "stream data interval timeout") {
t.Fatalf("expected stream timeout error, got %v", err)
}
if !strings.Contains(rec.Body.String(), "stream_timeout") {
t.Fatalf("expected stream_timeout SSE error, got %q", rec.Body.String())
}
}
func TestOpenAIStreamingTooLong(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: 64 * 1024,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{
StatusCode: http.StatusOK,
Body: pr,
Header: http.Header{},
}
go func() {
defer func() { _ = pw.Close() }()
// 写入超过 MaxLineSize 的单行数据,触发 ErrTooLong
payload := "data: " + strings.Repeat("a", 128*1024) + "\n"
_, _ = pw.Write([]byte(payload))
}()
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 2}, time.Now(), "model", "model")
_ = pr.Close()
if !errors.Is(err, bufio.ErrTooLong) {
t.Fatalf("expected ErrTooLong, got %v", err)
}
if !strings.Contains(rec.Body.String(), "response_too_large") {
t.Fatalf("expected response_too_large SSE error, got %q", rec.Body.String())
}
}
func TestOpenAINonStreamingContentTypePassThrough(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Security: config.SecurityConfig{
ResponseHeaders: config.ResponseHeaderConfig{Enabled: false},
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
body := []byte(`{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`)
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewReader(body)),
Header: http.Header{"Content-Type": []string{"application/vnd.test+json"}},
}
_, err := svc.handleNonStreamingResponse(c.Request.Context(), resp, c, &Account{}, "model", "model")
if err != nil {
t.Fatalf("handleNonStreamingResponse error: %v", err)
}
if !strings.Contains(rec.Header().Get("Content-Type"), "application/vnd.test+json") {
t.Fatalf("expected Content-Type passthrough, got %q", rec.Header().Get("Content-Type"))
}
}
func TestOpenAINonStreamingContentTypeDefault(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Security: config.SecurityConfig{
ResponseHeaders: config.ResponseHeaderConfig{Enabled: false},
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
body := []byte(`{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`)
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewReader(body)),
Header: http.Header{},
}
_, err := svc.handleNonStreamingResponse(c.Request.Context(), resp, c, &Account{}, "model", "model")
if err != nil {
t.Fatalf("handleNonStreamingResponse error: %v", err)
}
if !strings.Contains(rec.Header().Get("Content-Type"), "application/json") {
t.Fatalf("expected default Content-Type, got %q", rec.Header().Get("Content-Type"))
}
}
func TestOpenAIStreamingHeadersOverride(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Security: config.SecurityConfig{
ResponseHeaders: config.ResponseHeaderConfig{Enabled: false},
},
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{
StatusCode: http.StatusOK,
Body: pr,
Header: http.Header{
"Cache-Control": []string{"upstream"},
"X-Request-Id": []string{"req-123"},
"Content-Type": []string{"application/custom"},
},
}
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {}\n\n"))
}()
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
_ = pr.Close()
if err != nil {
t.Fatalf("handleStreamingResponse error: %v", err)
}
if rec.Header().Get("Cache-Control") != "no-cache" {
t.Fatalf("expected Cache-Control override, got %q", rec.Header().Get("Cache-Control"))
}
if rec.Header().Get("Content-Type") != "text/event-stream" {
t.Fatalf("expected Content-Type override, got %q", rec.Header().Get("Content-Type"))
}
if rec.Header().Get("X-Request-Id") != "req-123" {
t.Fatalf("expected X-Request-Id passthrough, got %q", rec.Header().Get("X-Request-Id"))
}
}
func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Security: config.SecurityConfig{
URLAllowlist: config.URLAllowlistConfig{Enabled: false},
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Credentials: map[string]any{"base_url": "://invalid-url"},
}
_, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte("{}"), "token", false)
if err == nil {
t.Fatalf("expected error for invalid base_url when allowlist disabled")
}
}
func TestOpenAIValidateUpstreamBaseURLDisabledRequiresHTTPS(t *testing.T) {
cfg := &config.Config{
Security: config.SecurityConfig{
URLAllowlist: config.URLAllowlistConfig{Enabled: false},
},
}
svc := &OpenAIGatewayService{cfg: cfg}
if _, err := svc.validateUpstreamBaseURL("http://not-https.example.com"); err == nil {
t.Fatalf("expected http to be rejected when allow_insecure_http is false")
}
normalized, err := svc.validateUpstreamBaseURL("https://example.com")
if err != nil {
t.Fatalf("expected https to be allowed when allowlist disabled, got %v", err)
}
if normalized != "https://example.com" {
t.Fatalf("expected raw url passthrough, got %q", normalized)
}
}
func TestOpenAIValidateUpstreamBaseURLDisabledAllowsHTTP(t *testing.T) {
cfg := &config.Config{
Security: config.SecurityConfig{
URLAllowlist: config.URLAllowlistConfig{
Enabled: false,
AllowInsecureHTTP: true,
},
},
}
svc := &OpenAIGatewayService{cfg: cfg}
normalized, err := svc.validateUpstreamBaseURL("http://not-https.example.com")
if err != nil {
t.Fatalf("expected http allowed when allow_insecure_http is true, got %v", err)
}
if normalized != "http://not-https.example.com" {
t.Fatalf("expected raw url passthrough, got %q", normalized)
}
}
func TestOpenAIValidateUpstreamBaseURLEnabledEnforcesAllowlist(t *testing.T) {
cfg := &config.Config{
Security: config.SecurityConfig{
URLAllowlist: config.URLAllowlistConfig{
Enabled: true,
UpstreamHosts: []string{"example.com"},
},
},
}
svc := &OpenAIGatewayService{cfg: cfg}
if _, err := svc.validateUpstreamBaseURL("https://example.com"); err != nil {
t.Fatalf("expected allowlisted host to pass, got %v", err)
}
if _, err := svc.validateUpstreamBaseURL("https://evil.com"); err == nil {
t.Fatalf("expected non-allowlisted host to fail")
}
}
...@@ -16,6 +16,7 @@ import ( ...@@ -16,6 +16,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
) )
var ( var (
...@@ -213,16 +214,35 @@ func (s *PricingService) syncWithRemote() error { ...@@ -213,16 +214,35 @@ func (s *PricingService) syncWithRemote() error {
// downloadPricingData 从远程下载价格数据 // downloadPricingData 从远程下载价格数据
func (s *PricingService) downloadPricingData() error { func (s *PricingService) downloadPricingData() error {
log.Printf("[Pricing] Downloading from %s", s.cfg.Pricing.RemoteURL) remoteURL, err := s.validatePricingURL(s.cfg.Pricing.RemoteURL)
if err != nil {
return err
}
log.Printf("[Pricing] Downloading from %s", remoteURL)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel() defer cancel()
body, err := s.remoteClient.FetchPricingJSON(ctx, s.cfg.Pricing.RemoteURL) var expectedHash string
if strings.TrimSpace(s.cfg.Pricing.HashURL) != "" {
expectedHash, err = s.fetchRemoteHash()
if err != nil {
return fmt.Errorf("fetch remote hash: %w", err)
}
}
body, err := s.remoteClient.FetchPricingJSON(ctx, remoteURL)
if err != nil { if err != nil {
return fmt.Errorf("download failed: %w", err) return fmt.Errorf("download failed: %w", err)
} }
if expectedHash != "" {
actualHash := sha256.Sum256(body)
if !strings.EqualFold(expectedHash, hex.EncodeToString(actualHash[:])) {
return fmt.Errorf("pricing hash mismatch")
}
}
// 解析JSON数据(使用灵活的解析方式) // 解析JSON数据(使用灵活的解析方式)
data, err := s.parsePricingData(body) data, err := s.parsePricingData(body)
if err != nil { if err != nil {
...@@ -378,10 +398,38 @@ func (s *PricingService) useFallbackPricing() error { ...@@ -378,10 +398,38 @@ func (s *PricingService) useFallbackPricing() error {
// fetchRemoteHash 从远程获取哈希值 // fetchRemoteHash 从远程获取哈希值
func (s *PricingService) fetchRemoteHash() (string, error) { func (s *PricingService) fetchRemoteHash() (string, error) {
hashURL, err := s.validatePricingURL(s.cfg.Pricing.HashURL)
if err != nil {
return "", err
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
return s.remoteClient.FetchHashText(ctx, s.cfg.Pricing.HashURL) hash, err := s.remoteClient.FetchHashText(ctx, hashURL)
if err != nil {
return "", err
}
return strings.TrimSpace(hash), nil
}
func (s *PricingService) validatePricingURL(raw string) (string, error) {
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
if err != nil {
return "", fmt.Errorf("invalid pricing url: %w", err)
}
return normalized, nil
}
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.PricingHosts,
RequireAllowlist: true,
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
return "", fmt.Errorf("invalid pricing url: %w", err)
}
return normalized, nil
} }
// computeFileHash 计算文件哈希 // computeFileHash 计算文件哈希
......
...@@ -130,6 +130,10 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet ...@@ -130,6 +130,10 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyFallbackModelGemini] = settings.FallbackModelGemini updates[SettingKeyFallbackModelGemini] = settings.FallbackModelGemini
updates[SettingKeyFallbackModelAntigravity] = settings.FallbackModelAntigravity updates[SettingKeyFallbackModelAntigravity] = settings.FallbackModelAntigravity
// Identity patch configuration (Claude -> Gemini)
updates[SettingKeyEnableIdentityPatch] = strconv.FormatBool(settings.EnableIdentityPatch)
updates[SettingKeyIdentityPatchPrompt] = settings.IdentityPatchPrompt
return s.settingRepo.SetMultiple(ctx, updates) return s.settingRepo.SetMultiple(ctx, updates)
} }
...@@ -213,6 +217,9 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { ...@@ -213,6 +217,9 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyFallbackModelOpenAI: "gpt-4o", SettingKeyFallbackModelOpenAI: "gpt-4o",
SettingKeyFallbackModelGemini: "gemini-2.5-pro", SettingKeyFallbackModelGemini: "gemini-2.5-pro",
SettingKeyFallbackModelAntigravity: "gemini-2.5-pro", SettingKeyFallbackModelAntigravity: "gemini-2.5-pro",
// Identity patch defaults
SettingKeyEnableIdentityPatch: "true",
SettingKeyIdentityPatchPrompt: "",
} }
return s.settingRepo.SetMultiple(ctx, defaults) return s.settingRepo.SetMultiple(ctx, defaults)
...@@ -221,21 +228,23 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { ...@@ -221,21 +228,23 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// parseSettings 解析设置到结构体 // parseSettings 解析设置到结构体
func (s *SettingService) parseSettings(settings map[string]string) *SystemSettings { func (s *SettingService) parseSettings(settings map[string]string) *SystemSettings {
result := &SystemSettings{ result := &SystemSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true", EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
SMTPHost: settings[SettingKeySMTPHost], SMTPHost: settings[SettingKeySMTPHost],
SMTPUsername: settings[SettingKeySMTPUsername], SMTPUsername: settings[SettingKeySMTPUsername],
SMTPFrom: settings[SettingKeySMTPFrom], SMTPFrom: settings[SettingKeySMTPFrom],
SMTPFromName: settings[SettingKeySMTPFromName], SMTPFromName: settings[SettingKeySMTPFromName],
SMTPUseTLS: settings[SettingKeySMTPUseTLS] == "true", SMTPUseTLS: settings[SettingKeySMTPUseTLS] == "true",
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true", SMTPPasswordConfigured: settings[SettingKeySMTPPassword] != "",
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey], TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"), TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
SiteLogo: settings[SettingKeySiteLogo], TurnstileSecretKeyConfigured: settings[SettingKeyTurnstileSecretKey] != "",
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"), SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
APIBaseURL: settings[SettingKeyAPIBaseURL], SiteLogo: settings[SettingKeySiteLogo],
ContactInfo: settings[SettingKeyContactInfo], SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
DocURL: settings[SettingKeyDocURL], APIBaseURL: settings[SettingKeyAPIBaseURL],
ContactInfo: settings[SettingKeyContactInfo],
DocURL: settings[SettingKeyDocURL],
} }
// 解析整数类型 // 解析整数类型
...@@ -269,6 +278,14 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin ...@@ -269,6 +278,14 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result.FallbackModelGemini = s.getStringOrDefault(settings, SettingKeyFallbackModelGemini, "gemini-2.5-pro") result.FallbackModelGemini = s.getStringOrDefault(settings, SettingKeyFallbackModelGemini, "gemini-2.5-pro")
result.FallbackModelAntigravity = s.getStringOrDefault(settings, SettingKeyFallbackModelAntigravity, "gemini-2.5-pro") result.FallbackModelAntigravity = s.getStringOrDefault(settings, SettingKeyFallbackModelAntigravity, "gemini-2.5-pro")
// Identity patch settings (default: enabled, to preserve existing behavior)
if v, ok := settings[SettingKeyEnableIdentityPatch]; ok && v != "" {
result.EnableIdentityPatch = v == "true"
} else {
result.EnableIdentityPatch = true
}
result.IdentityPatchPrompt = settings[SettingKeyIdentityPatchPrompt]
return result return result
} }
...@@ -298,6 +315,25 @@ func (s *SettingService) GetTurnstileSecretKey(ctx context.Context) string { ...@@ -298,6 +315,25 @@ func (s *SettingService) GetTurnstileSecretKey(ctx context.Context) string {
return value return value
} }
// IsIdentityPatchEnabled 检查是否启用身份补丁(Claude -> Gemini systemInstruction 注入)
func (s *SettingService) IsIdentityPatchEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, SettingKeyEnableIdentityPatch)
if err != nil {
// 默认开启,保持兼容
return true
}
return value == "true"
}
// GetIdentityPatchPrompt 获取自定义身份补丁提示词(为空表示使用内置默认模板)
func (s *SettingService) GetIdentityPatchPrompt(ctx context.Context) string {
value, err := s.settingRepo.GetValue(ctx, SettingKeyIdentityPatchPrompt)
if err != nil {
return ""
}
return value
}
// GenerateAdminAPIKey 生成新的管理员 API Key // GenerateAdminAPIKey 生成新的管理员 API Key
func (s *SettingService) GenerateAdminAPIKey(ctx context.Context) (string, error) { func (s *SettingService) GenerateAdminAPIKey(ctx context.Context) (string, error) {
// 生成 32 字节随机数 = 64 位十六进制字符 // 生成 32 字节随机数 = 64 位十六进制字符
......
...@@ -4,17 +4,19 @@ type SystemSettings struct { ...@@ -4,17 +4,19 @@ type SystemSettings struct {
RegistrationEnabled bool RegistrationEnabled bool
EmailVerifyEnabled bool EmailVerifyEnabled bool
SMTPHost string SMTPHost string
SMTPPort int SMTPPort int
SMTPUsername string SMTPUsername string
SMTPPassword string SMTPPassword string
SMTPFrom string SMTPPasswordConfigured bool
SMTPFromName string SMTPFrom string
SMTPUseTLS bool SMTPFromName string
SMTPUseTLS bool
TurnstileEnabled bool
TurnstileSiteKey string TurnstileEnabled bool
TurnstileSecretKey string TurnstileSiteKey string
TurnstileSecretKey string
TurnstileSecretKeyConfigured bool
SiteName string SiteName string
SiteLogo string SiteLogo string
...@@ -32,6 +34,10 @@ type SystemSettings struct { ...@@ -32,6 +34,10 @@ type SystemSettings struct {
FallbackModelOpenAI string `json:"fallback_model_openai"` FallbackModelOpenAI string `json:"fallback_model_openai"`
FallbackModelGemini string `json:"fallback_model_gemini"` FallbackModelGemini string `json:"fallback_model_gemini"`
FallbackModelAntigravity string `json:"fallback_model_antigravity"` FallbackModelAntigravity string `json:"fallback_model_antigravity"`
// Identity patch configuration (Claude -> Gemini)
EnableIdentityPatch bool `json:"enable_identity_patch"`
IdentityPatchPrompt string `json:"identity_patch_prompt"`
} }
type PublicSettings struct { type PublicSettings struct {
......
...@@ -21,10 +21,44 @@ import ( ...@@ -21,10 +21,44 @@ import (
// Config paths // Config paths
const ( const (
ConfigFile = "config.yaml" ConfigFileName = "config.yaml"
EnvFile = ".env" InstallLockFile = ".installed"
) )
// GetDataDir returns the data directory for storing config and lock files.
// Priority: DATA_DIR env > /app/data (if exists and writable) > current directory
func GetDataDir() string {
// Check DATA_DIR environment variable first
if dir := os.Getenv("DATA_DIR"); dir != "" {
return dir
}
// Check if /app/data exists and is writable (Docker environment)
dockerDataDir := "/app/data"
if info, err := os.Stat(dockerDataDir); err == nil && info.IsDir() {
// Try to check if writable by creating a temp file
testFile := dockerDataDir + "/.write_test"
if f, err := os.Create(testFile); err == nil {
_ = f.Close()
_ = os.Remove(testFile)
return dockerDataDir
}
}
// Default to current directory
return "."
}
// GetConfigFilePath returns the full path to config.yaml
func GetConfigFilePath() string {
return GetDataDir() + "/" + ConfigFileName
}
// GetInstallLockPath returns the full path to .installed lock file
func GetInstallLockPath() string {
return GetDataDir() + "/" + InstallLockFile
}
// SetupConfig holds the setup configuration // SetupConfig holds the setup configuration
type SetupConfig struct { type SetupConfig struct {
Database DatabaseConfig `json:"database" yaml:"database"` Database DatabaseConfig `json:"database" yaml:"database"`
...@@ -71,13 +105,12 @@ type JWTConfig struct { ...@@ -71,13 +105,12 @@ type JWTConfig struct {
// Uses multiple checks to prevent attackers from forcing re-setup by deleting config // Uses multiple checks to prevent attackers from forcing re-setup by deleting config
func NeedsSetup() bool { func NeedsSetup() bool {
// Check 1: Config file must not exist // Check 1: Config file must not exist
if _, err := os.Stat(ConfigFile); !os.IsNotExist(err) { if _, err := os.Stat(GetConfigFilePath()); !os.IsNotExist(err) {
return false // Config exists, no setup needed return false // Config exists, no setup needed
} }
// Check 2: Installation lock file (harder to bypass) // Check 2: Installation lock file (harder to bypass)
lockFile := ".installed" if _, err := os.Stat(GetInstallLockPath()); !os.IsNotExist(err) {
if _, err := os.Stat(lockFile); !os.IsNotExist(err) {
return false // Lock file exists, already installed return false // Lock file exists, already installed
} }
...@@ -201,6 +234,7 @@ func Install(cfg *SetupConfig) error { ...@@ -201,6 +234,7 @@ func Install(cfg *SetupConfig) error {
return fmt.Errorf("failed to generate jwt secret: %w", err) return fmt.Errorf("failed to generate jwt secret: %w", err)
} }
cfg.JWT.Secret = secret cfg.JWT.Secret = secret
log.Println("Warning: JWT secret auto-generated. Consider setting a fixed secret for production.")
} }
// Test connections // Test connections
...@@ -237,9 +271,8 @@ func Install(cfg *SetupConfig) error { ...@@ -237,9 +271,8 @@ func Install(cfg *SetupConfig) error {
// createInstallLock creates a lock file to prevent re-installation attacks // createInstallLock creates a lock file to prevent re-installation attacks
func createInstallLock() error { func createInstallLock() error {
lockFile := ".installed"
content := fmt.Sprintf("installed_at=%s\n", time.Now().UTC().Format(time.RFC3339)) content := fmt.Sprintf("installed_at=%s\n", time.Now().UTC().Format(time.RFC3339))
return os.WriteFile(lockFile, []byte(content), 0400) // Read-only for owner return os.WriteFile(GetInstallLockPath(), []byte(content), 0400) // Read-only for owner
} }
func initializeDatabase(cfg *SetupConfig) error { func initializeDatabase(cfg *SetupConfig) error {
...@@ -390,7 +423,7 @@ func writeConfigFile(cfg *SetupConfig) error { ...@@ -390,7 +423,7 @@ func writeConfigFile(cfg *SetupConfig) error {
return err return err
} }
return os.WriteFile(ConfigFile, data, 0600) return os.WriteFile(GetConfigFilePath(), data, 0600)
} }
func generateSecret(length int) (string, error) { func generateSecret(length int) (string, error) {
...@@ -433,6 +466,7 @@ func getEnvIntOrDefault(key string, defaultValue int) int { ...@@ -433,6 +466,7 @@ func getEnvIntOrDefault(key string, defaultValue int) int {
// This is designed for Docker deployment where all config is passed via env vars // This is designed for Docker deployment where all config is passed via env vars
func AutoSetupFromEnv() error { func AutoSetupFromEnv() error {
log.Println("Auto setup enabled, configuring from environment variables...") log.Println("Auto setup enabled, configuring from environment variables...")
log.Printf("Data directory: %s", GetDataDir())
// Get timezone from TZ or TIMEZONE env var (TZ is standard for Docker) // Get timezone from TZ or TIMEZONE env var (TZ is standard for Docker)
tz := getEnvOrDefault("TZ", "") tz := getEnvOrDefault("TZ", "")
...@@ -479,7 +513,7 @@ func AutoSetupFromEnv() error { ...@@ -479,7 +513,7 @@ func AutoSetupFromEnv() error {
return fmt.Errorf("failed to generate jwt secret: %w", err) return fmt.Errorf("failed to generate jwt secret: %w", err)
} }
cfg.JWT.Secret = secret cfg.JWT.Secret = secret
log.Println("Generated JWT secret automatically") log.Println("Warning: JWT secret auto-generated. Consider setting a fixed secret for production.")
} }
// Generate admin password if not provided // Generate admin password if not provided
...@@ -489,8 +523,8 @@ func AutoSetupFromEnv() error { ...@@ -489,8 +523,8 @@ func AutoSetupFromEnv() error {
return fmt.Errorf("failed to generate admin password: %w", err) return fmt.Errorf("failed to generate admin password: %w", err)
} }
cfg.Admin.Password = password cfg.Admin.Password = password
log.Printf("Generated admin password: %s", cfg.Admin.Password) fmt.Printf("Generated admin password (one-time): %s\n", cfg.Admin.Password)
log.Println("IMPORTANT: Save this password! It will not be shown again.") fmt.Println("IMPORTANT: Save this password! It will not be shown again.")
} }
// Test database connection // Test database connection
......
package logredact
import (
"encoding/json"
"strings"
)
// maxRedactDepth 限制递归深度以防止栈溢出
const maxRedactDepth = 32
var defaultSensitiveKeys = map[string]struct{}{
"authorization_code": {},
"code": {},
"code_verifier": {},
"access_token": {},
"refresh_token": {},
"id_token": {},
"client_secret": {},
"password": {},
}
func RedactMap(input map[string]any, extraKeys ...string) map[string]any {
if input == nil {
return map[string]any{}
}
keys := buildKeySet(extraKeys)
redacted, ok := redactValueWithDepth(input, keys, 0).(map[string]any)
if !ok {
return map[string]any{}
}
return redacted
}
func RedactJSON(raw []byte, extraKeys ...string) string {
if len(raw) == 0 {
return ""
}
var value any
if err := json.Unmarshal(raw, &value); err != nil {
return "<non-json payload redacted>"
}
keys := buildKeySet(extraKeys)
redacted := redactValueWithDepth(value, keys, 0)
encoded, err := json.Marshal(redacted)
if err != nil {
return "<redacted>"
}
return string(encoded)
}
func buildKeySet(extraKeys []string) map[string]struct{} {
keys := make(map[string]struct{}, len(defaultSensitiveKeys)+len(extraKeys))
for k := range defaultSensitiveKeys {
keys[k] = struct{}{}
}
for _, key := range extraKeys {
normalized := normalizeKey(key)
if normalized == "" {
continue
}
keys[normalized] = struct{}{}
}
return keys
}
func redactValueWithDepth(value any, keys map[string]struct{}, depth int) any {
if depth > maxRedactDepth {
return "<depth limit exceeded>"
}
switch v := value.(type) {
case map[string]any:
out := make(map[string]any, len(v))
for k, val := range v {
if isSensitiveKey(k, keys) {
out[k] = "***"
continue
}
out[k] = redactValueWithDepth(val, keys, depth+1)
}
return out
case []any:
out := make([]any, len(v))
for i, item := range v {
out[i] = redactValueWithDepth(item, keys, depth+1)
}
return out
default:
return value
}
}
func isSensitiveKey(key string, keys map[string]struct{}) bool {
_, ok := keys[normalizeKey(key)]
return ok
}
func normalizeKey(key string) string {
return strings.ToLower(strings.TrimSpace(key))
}
package responseheaders
import (
"net/http"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
)
// defaultAllowed 定义允许透传的响应头白名单
// 注意:以下头部由 Go HTTP 包自动处理,不应手动设置:
// - content-length: 由 ResponseWriter 根据实际写入数据自动设置
// - transfer-encoding: 由 HTTP 库根据需要自动添加/移除
// - connection: 由 HTTP 库管理连接复用
var defaultAllowed = map[string]struct{}{
"content-type": {},
"content-encoding": {},
"content-language": {},
"cache-control": {},
"etag": {},
"last-modified": {},
"expires": {},
"vary": {},
"date": {},
"x-request-id": {},
"x-ratelimit-limit-requests": {},
"x-ratelimit-limit-tokens": {},
"x-ratelimit-remaining-requests": {},
"x-ratelimit-remaining-tokens": {},
"x-ratelimit-reset-requests": {},
"x-ratelimit-reset-tokens": {},
"retry-after": {},
"location": {},
"www-authenticate": {},
}
// hopByHopHeaders 是跳过的 hop-by-hop 头部,这些头部由 HTTP 库自动处理
var hopByHopHeaders = map[string]struct{}{
"content-length": {},
"transfer-encoding": {},
"connection": {},
}
func FilterHeaders(src http.Header, cfg config.ResponseHeaderConfig) http.Header {
allowed := make(map[string]struct{}, len(defaultAllowed)+len(cfg.AdditionalAllowed))
for key := range defaultAllowed {
allowed[key] = struct{}{}
}
// 关闭时只使用默认白名单,additional/force_remove 不生效
if cfg.Enabled {
for _, key := range cfg.AdditionalAllowed {
normalized := strings.ToLower(strings.TrimSpace(key))
if normalized == "" {
continue
}
allowed[normalized] = struct{}{}
}
}
forceRemove := map[string]struct{}{}
if cfg.Enabled {
forceRemove = make(map[string]struct{}, len(cfg.ForceRemove))
for _, key := range cfg.ForceRemove {
normalized := strings.ToLower(strings.TrimSpace(key))
if normalized == "" {
continue
}
forceRemove[normalized] = struct{}{}
}
}
filtered := make(http.Header, len(src))
for key, values := range src {
lower := strings.ToLower(key)
if _, blocked := forceRemove[lower]; blocked {
continue
}
if _, ok := allowed[lower]; !ok {
continue
}
// 跳过 hop-by-hop 头部,这些由 HTTP 库自动处理
if _, isHopByHop := hopByHopHeaders[lower]; isHopByHop {
continue
}
for _, value := range values {
filtered.Add(key, value)
}
}
return filtered
}
func WriteFilteredHeaders(dst http.Header, src http.Header, cfg config.ResponseHeaderConfig) {
filtered := FilterHeaders(src, cfg)
for key, values := range filtered {
for _, value := range values {
dst.Add(key, value)
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment