Commit a7386882 authored by 陈曦's avatar 陈曦
Browse files

merge capture requests branch to upstream follow

parents 110702d4 55891dff
Pipeline #82303 passed with stage
in 3 minutes and 44 seconds
......@@ -25,6 +25,8 @@ type APIKeyAuthSnapshot struct {
RateLimit5h float64 `json:"rate_limit_5h"`
RateLimit1d float64 `json:"rate_limit_1d"`
RateLimit7d float64 `json:"rate_limit_7d"`
CaptureRequests bool `json:"capture_requests"`
}
// APIKeyAuthUserSnapshot 用户快照
......
......@@ -14,7 +14,7 @@ import (
"github.com/dgraph-io/ristretto"
)
const apiKeyAuthSnapshotVersion = 7 // v7: added UserGroupRPMOverride on user snapshot
const apiKeyAuthSnapshotVersion = 8 // v8: added CaptureRequests on api key snapshot
type apiKeyAuthCacheConfig struct {
l1Size int
......@@ -216,9 +216,10 @@ func (s *APIKeyService) snapshotFromAPIKey(ctx context.Context, apiKey *APIKey)
Quota: apiKey.Quota,
QuotaUsed: apiKey.QuotaUsed,
ExpiresAt: apiKey.ExpiresAt,
RateLimit5h: apiKey.RateLimit5h,
RateLimit1d: apiKey.RateLimit1d,
RateLimit7d: apiKey.RateLimit7d,
RateLimit5h: apiKey.RateLimit5h,
RateLimit1d: apiKey.RateLimit1d,
RateLimit7d: apiKey.RateLimit7d,
CaptureRequests: apiKey.CaptureRequests,
User: APIKeyAuthUserSnapshot{
ID: apiKey.User.ID,
Status: apiKey.User.Status,
......@@ -289,9 +290,10 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
Quota: snapshot.Quota,
QuotaUsed: snapshot.QuotaUsed,
ExpiresAt: snapshot.ExpiresAt,
RateLimit5h: snapshot.RateLimit5h,
RateLimit1d: snapshot.RateLimit1d,
RateLimit7d: snapshot.RateLimit7d,
RateLimit5h: snapshot.RateLimit5h,
RateLimit1d: snapshot.RateLimit1d,
RateLimit7d: snapshot.RateLimit7d,
CaptureRequests: snapshot.CaptureRequests,
User: &User{
ID: snapshot.User.ID,
Status: snapshot.User.Status,
......
......@@ -184,6 +184,9 @@ type UpdateAPIKeyRequest struct {
RateLimit1d *float64 `json:"rate_limit_1d"`
RateLimit7d *float64 `json:"rate_limit_7d"`
ResetRateLimitUsage *bool `json:"reset_rate_limit_usage"` // Reset all usage counters to 0
// Request capture
CaptureRequests *bool `json:"capture_requests"` // nil = no change
}
// APIKeyService API Key服务
......@@ -601,6 +604,10 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
apiKey.IPWhitelist = req.IPWhitelist
apiKey.IPBlacklist = req.IPBlacklist
if req.CaptureRequests != nil {
apiKey.CaptureRequests = *req.CaptureRequests
}
// Update rate limit configuration
if req.RateLimit5h != nil {
apiKey.RateLimit5h = *req.RateLimit5h
......
......@@ -9,6 +9,7 @@ import (
"hash/crc32"
"io"
"net/http"
"strings"
"sync/atomic"
"time"
......@@ -16,6 +17,7 @@ import (
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
......@@ -48,6 +50,9 @@ func (s *GatewayService) handleBedrockStreamingResponse(
var firstTokenMs *int
clientDisconnected := false
// 响应体捕获:若 context 中注入了 ResponseCaptureBuffer,则收集 text_delta 文本
captureBuilder, _ := ctx.Value(ctxkey.ResponseCaptureBuffer).(*strings.Builder)
// Bedrock EventStream 使用 application/vnd.amazon.eventstream 二进制格式。
// 每个帧结构:total_length(4) + headers_length(4) + prelude_crc(4) + headers + payload + message_crc(4)
// 但更实用的方式是使用行扫描找 JSON chunks,因为 Bedrock 的响应在二进制帧中。
......@@ -141,6 +146,13 @@ func (s *GatewayService) handleBedrockStreamingResponse(
// 解析 SSE 事件数据提取 usage
s.parseSSEUsagePassthrough(string(sseData), usage)
// 收集 assistant text(仅 content_block_delta + text_delta)
if captureBuilder != nil && gjson.GetBytes(sseData, "type").String() == "content_block_delta" {
if gjson.GetBytes(sseData, "delta.type").String() == "text_delta" {
captureBuilder.WriteString(gjson.GetBytes(sseData, "delta.text").String())
}
}
// 确定 SSE event type
eventType := gjson.GetBytes(sseData, "type").String()
......
......@@ -315,8 +315,10 @@ func (s *GatewayService) handleCCBufferedFromAnthropic(
}
// Marshal then bytes-replace so tool name mapping is reversed at byte level
// (parity with Parrot non-stream flow that marshals → restore → emit).
var responseBody string
if respBytes, err := json.Marshal(ccResp); err == nil {
respBytes = reverseToolNamesIfPresent(c, respBytes)
responseBody = string(respBytes)
c.Data(http.StatusOK, "application/json; charset=utf-8", respBytes)
} else {
c.JSON(http.StatusOK, ccResp)
......@@ -330,6 +332,7 @@ func (s *GatewayService) handleCCBufferedFromAnthropic(
ReasoningEffort: reasoningEffort,
Stream: false,
Duration: time.Since(startTime),
ResponseBody: responseBody,
}, nil
}
......@@ -365,6 +368,7 @@ func (s *GatewayService) handleCCStreamingFromAnthropic(
var usage ClaudeUsage
var firstTokenMs *int
firstChunk := true
var textBuilder strings.Builder // 收集 assistant 文本用于响应捕获
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
......@@ -383,10 +387,15 @@ func (s *GatewayService) handleCCStreamingFromAnthropic(
Stream: true,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
ResponseBody: textBuilder.String(),
}
}
writeChunk := func(chunk apicompat.ChatCompletionsChunk) bool {
// 收集 assistant text 用于响应捕获
if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != nil {
textBuilder.WriteString(*chunk.Choices[0].Delta.Content)
}
sse, err := apicompat.ChatChunkToSSE(chunk)
if err != nil {
return false
......
......@@ -493,6 +493,10 @@ type ForwardResult struct {
ClientDisconnect bool // 客户端是否在流式传输过程中断开
ReasoningEffort *string
// ResponseBody 响应内容:非 streaming 为完整 JSON,streaming 为拼接的 assistant text。
// 仅当 API Key 开启了 capture_requests 时才会被填充(通过 context 标记控制)。
ResponseBody string
// 图片生成计费字段(图片生成模型使用)
ImageCount int // 生成的图片数量
ImageSize string // 图片尺寸 "1K", "2K", "4K"
......@@ -4662,6 +4666,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
var usage *ClaudeUsage
var firstTokenMs *int
var clientDisconnect bool
var nonStreamingResponseBody string
if reqStream {
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel, shouldMimicClaudeCode)
if err != nil {
......@@ -4675,11 +4680,17 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
usage = streamResult.usage
firstTokenMs = streamResult.firstTokenMs
clientDisconnect = streamResult.clientDisconnect
// 若注入了 ResponseCaptureBuffer,从 context 中读取已收集的 assistant 文本
if captureBuilder, ok := ctx.Value(ctxkey.ResponseCaptureBuffer).(*strings.Builder); ok && captureBuilder != nil {
nonStreamingResponseBody = captureBuilder.String()
}
} else {
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel)
var nonStreamRespBody []byte
nonStreamRespBody, usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel)
if err != nil {
return nil, err
}
nonStreamingResponseBody = string(nonStreamRespBody)
}
return &ForwardResult{
......@@ -4691,6 +4702,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
ClientDisconnect: clientDisconnect,
ResponseBody: nonStreamingResponseBody,
}, nil
}
......@@ -4916,6 +4928,7 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput(
var usage *ClaudeUsage
var firstTokenMs *int
var clientDisconnect bool
var responseBody string
if input.RequestStream {
streamResult, err := s.handleStreamingResponseAnthropicAPIKeyPassthrough(ctx, resp, c, account, input.StartTime, input.RequestModel)
if err != nil {
......@@ -4924,8 +4937,12 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput(
usage = streamResult.usage
firstTokenMs = streamResult.firstTokenMs
clientDisconnect = streamResult.clientDisconnect
// 从 context buffer 读取已收集的 assistant 文本
if captureBuilder, ok := ctx.Value(ctxkey.ResponseCaptureBuffer).(*strings.Builder); ok && captureBuilder != nil {
responseBody = captureBuilder.String()
}
} else {
usage, err = s.handleNonStreamingResponseAnthropicAPIKeyPassthrough(ctx, resp, c, account)
responseBody, usage, err = s.handleNonStreamingResponseAnthropicAPIKeyPassthrough(ctx, resp, c, account)
if err != nil {
return nil, err
}
......@@ -4943,6 +4960,7 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput(
Duration: time.Since(input.StartTime),
FirstTokenMs: firstTokenMs,
ClientDisconnect: clientDisconnect,
ResponseBody: responseBody,
}, nil
}
......@@ -5039,6 +5057,9 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
clientDisconnected := false
sawTerminalEvent := false
// 响应体捕获:若 context 中注入了 ResponseCaptureBuffer,则收集 text_delta 文本
captureBuilder, _ := ctx.Value(ctxkey.ResponseCaptureBuffer).(*strings.Builder)
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
......@@ -5133,6 +5154,12 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
firstTokenMs = &ms
}
s.parseSSEUsagePassthrough(data, usage)
// 收集 assistant text(仅 content_block_delta + text_delta)
if captureBuilder != nil && gjson.Get(data, "type").String() == "content_block_delta" {
if gjson.Get(data, "delta.type").String() == "text_delta" {
captureBuilder.WriteString(gjson.Get(data, "delta.text").String())
}
}
} else {
trimmed := strings.TrimSpace(line)
if strings.HasPrefix(trimmed, "event:") && anthropicStreamEventIsTerminal(strings.TrimSpace(strings.TrimPrefix(trimmed, "event:")), "") {
......@@ -5295,14 +5322,14 @@ func (s *GatewayService) handleNonStreamingResponseAnthropicAPIKeyPassthrough(
resp *http.Response,
c *gin.Context,
account *Account,
) (*ClaudeUsage, error) {
) (string, *ClaudeUsage, error) {
if s.rateLimitService != nil {
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
}
body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, anthropicTooLargeError)
if err != nil {
return nil, err
return "", nil, err
}
usage := parseClaudeUsageFromResponseBody(body)
......@@ -5314,7 +5341,7 @@ func (s *GatewayService) handleNonStreamingResponseAnthropicAPIKeyPassthrough(
}
body = reverseToolNamesIfPresent(c, body)
c.Data(resp.StatusCode, contentType, body)
return usage, nil
return string(body), usage, nil
}
func writeAnthropicPassthroughResponseHeaders(dst http.Header, src http.Header, filter *responseheaders.CompiledHeaderFilter) {
......@@ -5415,6 +5442,7 @@ func (s *GatewayService) forwardBedrock(
var usage *ClaudeUsage
var firstTokenMs *int
var clientDisconnect bool
var responseBody string
if reqStream {
streamResult, err := s.handleBedrockStreamingResponse(ctx, resp, c, account, startTime, reqModel)
if err != nil {
......@@ -5423,8 +5451,12 @@ func (s *GatewayService) forwardBedrock(
usage = streamResult.usage
firstTokenMs = streamResult.firstTokenMs
clientDisconnect = streamResult.clientDisconnect
// 从 context buffer 读取已收集的 assistant 文本
if captureBuilder, ok := ctx.Value(ctxkey.ResponseCaptureBuffer).(*strings.Builder); ok && captureBuilder != nil {
responseBody = captureBuilder.String()
}
} else {
usage, err = s.handleBedrockNonStreamingResponse(ctx, resp, c, account)
responseBody, usage, err = s.handleBedrockNonStreamingResponse(ctx, resp, c, account)
if err != nil {
return nil, err
}
......@@ -5442,6 +5474,7 @@ func (s *GatewayService) forwardBedrock(
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
ClientDisconnect: clientDisconnect,
ResponseBody: responseBody,
}, nil
}
......@@ -5667,10 +5700,10 @@ func (s *GatewayService) handleBedrockNonStreamingResponse(
resp *http.Response,
c *gin.Context,
account *Account,
) (*ClaudeUsage, error) {
) (string, *ClaudeUsage, error) {
body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, anthropicTooLargeError)
if err != nil {
return nil, err
return "", nil, err
}
// 转换 Bedrock 特有的 amazon-bedrock-invocationMetrics 为标准 Anthropic usage 格式
......@@ -5684,7 +5717,7 @@ func (s *GatewayService) handleBedrockNonStreamingResponse(
c.Header("x-request-id", v)
}
c.Data(resp.StatusCode, "application/json", body)
return usage, nil
return string(body), usage, nil
}
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool, mimicClaudeCode bool) (*http.Request, error) {
......@@ -6886,6 +6919,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage
sawTerminalEvent := false
// 响应体捕获:若 context 中注入了 ResponseCaptureBuffer,则收集 text_delta 文本
captureBuilder, _ := ctx.Value(ctxkey.ResponseCaptureBuffer).(*strings.Builder)
pendingEventLines := make([]string, 0, 4)
processSSEEvent := func(lines []string) ([]string, string, *sseUsagePatch, error) {
......@@ -6941,6 +6977,17 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
eventChanged := false
// 收集 assistant text(仅 content_block_delta + text_delta)
if captureBuilder != nil && eventType == "content_block_delta" {
if delta, ok := event["delta"].(map[string]any); ok {
if dt, _ := delta["type"].(string); dt == "text_delta" {
if text, _ := delta["text"].(string); text != "" {
captureBuilder.WriteString(text)
}
}
}
}
// 兼容 Kimi cached_tokens → cache_read_input_tokens
if eventType == "message_start" {
if msg, ok := event["message"].(map[string]any); ok {
......@@ -7342,13 +7389,13 @@ func rewriteCacheCreationJSON(usageObj map[string]any, target string) bool {
return true
}
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) {
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) ([]byte, *ClaudeUsage, error) {
// 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, anthropicTooLargeError)
if err != nil {
return nil, err
return nil, nil, err
}
// 解析usage
......@@ -7356,7 +7403,7 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
Usage ClaudeUsage `json:"usage"`
}
if err := json.Unmarshal(body, &response); err != nil {
return nil, fmt.Errorf("parse response: %w", err)
return nil, nil, fmt.Errorf("parse response: %w", err)
}
// 解析嵌套的 cache_creation 对象中的 5m/1h 明细
......@@ -7411,7 +7458,7 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
// 写入响应
c.Data(resp.StatusCode, contentType, body)
return &response.Usage, nil
return body, &response.Usage, nil
}
// replaceModelInResponseBody 替换响应体中的model字段
......
......@@ -1008,6 +1008,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
var usage *ClaudeUsage
var firstTokenMs *int
var responseBody string
if req.Stream {
streamRes, err := s.handleStreamingResponse(c, resp, startTime, originalModel)
if err != nil {
......@@ -1015,6 +1016,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}
usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs
responseBody = streamRes.responseBody
} else {
if useUpstreamStream {
collected, usageObj, err := collectGeminiSSE(resp.Body, true)
......@@ -1023,16 +1025,20 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}
collectedBytes, _ := json.Marshal(collected)
claudeResp, usageObj2 := convertGeminiToClaudeMessage(collected, originalModel, collectedBytes)
c.JSON(http.StatusOK, claudeResp)
respBytes, _ := json.Marshal(claudeResp)
c.Data(http.StatusOK, "application/json", respBytes)
responseBody = string(respBytes)
usage = usageObj2
if usageObj != nil && (usageObj.InputTokens > 0 || usageObj.OutputTokens > 0) {
usage = usageObj
}
} else {
usage, err = s.handleNonStreamingResponse(c, resp, originalModel)
var nonStreamBody string
nonStreamBody, usage, err = s.handleNonStreamingResponse(c, resp, originalModel)
if err != nil {
return nil, err
}
responseBody = nonStreamBody
}
}
......@@ -1044,15 +1050,16 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}
return &ForwardResult{
RequestID: requestID,
Usage: *usage,
Model: originalModel,
RequestID: requestID,
Usage: *usage,
Model: originalModel,
UpstreamModel: mappedModel,
Stream: req.Stream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
ImageCount: imageCount,
ImageSize: imageSize,
ResponseBody: responseBody,
}, nil
}
......@@ -1872,28 +1879,30 @@ func mapGeminiStatusToClaudeErrorType(status string) string {
type geminiStreamResult struct {
usage *ClaudeUsage
firstTokenMs *int
responseBody string // 累积的文本内容,用于响应捕获
}
func (s *GeminiMessagesCompatService) handleNonStreamingResponse(c *gin.Context, resp *http.Response, originalModel string) (*ClaudeUsage, error) {
func (s *GeminiMessagesCompatService) handleNonStreamingResponse(c *gin.Context, resp *http.Response, originalModel string) (string, *ClaudeUsage, error) {
body, err := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
if err != nil {
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream response")
return "", nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream response")
}
unwrappedBody, err := unwrapGeminiResponse(body)
if err != nil {
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
return "", nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
}
var geminiResp map[string]any
if err := json.Unmarshal(unwrappedBody, &geminiResp); err != nil {
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
return "", nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
}
claudeResp, usage := convertGeminiToClaudeMessage(geminiResp, originalModel, unwrappedBody)
c.JSON(http.StatusOK, claudeResp)
respBytes, _ := json.Marshal(claudeResp)
c.Data(http.StatusOK, "application/json", respBytes)
return usage, nil
return string(respBytes), usage, nil
}
func (s *GeminiMessagesCompatService) handleStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*geminiStreamResult, error) {
......@@ -2146,7 +2155,7 @@ func (s *GeminiMessagesCompatService) handleStreamingResponse(c *gin.Context, re
})
flusher.Flush()
return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil
return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs, responseBody: seenText}, nil
}
func writeSSE(w io.Writer, event string, data any) {
......
......@@ -406,7 +406,14 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse(
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.JSON(http.StatusOK, chatResp)
var responseBody string
if respBytes, err := json.Marshal(chatResp); err == nil {
responseBody = string(respBytes)
c.Data(http.StatusOK, "application/json; charset=utf-8", respBytes)
} else {
c.JSON(http.StatusOK, chatResp)
}
return &OpenAIForwardResult{
RequestID: requestID,
......@@ -416,6 +423,7 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse(
UpstreamModel: upstreamModel,
Stream: false,
Duration: time.Since(startTime),
ResponseBody: responseBody,
}, nil
}
......@@ -448,6 +456,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
var usage OpenAIUsage
var firstTokenMs *int
firstChunk := true
var textBuilder strings.Builder // 收集 assistant 文本用于响应捕获
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
......@@ -466,6 +475,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
Stream: true,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
ResponseBody: textBuilder.String(),
}
}
......@@ -499,6 +509,10 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
chunks := apicompat.ResponsesEventToChatChunks(&event, state)
for _, chunk := range chunks {
// 收集 assistant text 用于响应捕获
if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != nil {
textBuilder.WriteString(*chunk.Choices[0].Delta.Content)
}
sse, err := apicompat.ChatChunkToSSE(chunk)
if err != nil {
logger.L().Warn("openai chat_completions stream: failed to marshal chunk",
......
......@@ -354,7 +354,13 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.JSON(http.StatusOK, anthropicResp)
var responseBody string
if respBytes, err := json.Marshal(anthropicResp); err == nil {
responseBody = string(respBytes)
c.Data(http.StatusOK, "application/json; charset=utf-8", respBytes)
} else {
c.JSON(http.StatusOK, anthropicResp)
}
return &OpenAIForwardResult{
RequestID: requestID,
......@@ -364,6 +370,7 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
UpstreamModel: upstreamModel,
Stream: false,
Duration: time.Since(startTime),
ResponseBody: responseBody,
}, nil
}
......@@ -396,6 +403,7 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
var usage OpenAIUsage
var firstTokenMs *int
firstChunk := true
var textBuilder strings.Builder // 收集 assistant text 用于响应捕获
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
......@@ -415,6 +423,7 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
Stream: true,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
ResponseBody: textBuilder.String(),
}
}
......@@ -451,6 +460,10 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
// Convert to Anthropic events
events := apicompat.ResponsesEventToAnthropicEvents(&event, state)
for _, evt := range events {
// 采集 text_delta 用于响应捕获
if evt.Type == "content_block_delta" && evt.Delta != nil && evt.Delta.Type == "text_delta" {
textBuilder.WriteString(evt.Delta.Text)
}
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt)
if err != nil {
logger.L().Warn("openai messages stream: failed to marshal event",
......
......@@ -235,6 +235,9 @@ type OpenAIForwardResult struct {
FirstTokenMs *int
ImageCount int
ImageSize string
// ResponseBody 响应内容:非 streaming 为完整 JSON,streaming 为拼接的 assistant text。
// 仅当 API Key 开启了 capture_requests 时才会被使用。
ResponseBody string
}
type OpenAIWSRetryMetricsSnapshot struct {
......
package service
import (
"bytes"
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"go.uber.org/zap"
)
// RequestCaptureLogRepository 定义请求捕获日志的持久化接口。
type RequestCaptureLogRepository interface {
Create(ctx context.Context, params CreateRequestCaptureLogParams) (int64, error)
UpdateResponseBody(ctx context.Context, id int64, responseBody string) error
}
// CreateRequestCaptureLogParams 创建请求捕获日志的参数。
type CreateRequestCaptureLogParams struct {
APIKeyID int64
UserID int64
RequestID string
Path string
Method string
IPAddress string
RequestBody string
NFSFilePath string
}
// RequestCaptureService 异步捕获指定 API Key 的请求体,写入数据库和 NFS。
type RequestCaptureService struct {
repo RequestCaptureLogRepository
nfsPath string
timeout time.Duration
nfsPathMap sync.Map // captureID int64 → nfsFilePath string(短暂存活,CaptureResponse 调用后删除)
}
// nfsFileEnvelope 是写入 NFS 文件的 JSON 结构。
type nfsFileEnvelope struct {
APIKeyID int64 `json:"api_key_id"`
UserID int64 `json:"user_id"`
RequestID string `json:"request_id"`
CreatedAt time.Time `json:"created_at"`
Path string `json:"path"`
Method string `json:"method"`
IPAddress string `json:"ip_address"`
Body json.RawMessage `json:"body"`
}
// NewRequestCaptureService 创建 RequestCaptureService。
func NewRequestCaptureService(repo RequestCaptureLogRepository, cfg *config.Config) *RequestCaptureService {
timeout := 5 * time.Second
if cfg != nil && cfg.RequestCapture.WorkerTimeoutSeconds > 0 {
timeout = time.Duration(cfg.RequestCapture.WorkerTimeoutSeconds) * time.Second
}
nfsPath := ""
if cfg != nil {
nfsPath = cfg.RequestCapture.NFSPath
}
if nfsPath != "" {
logger.L().Info("request_capture: NFS storage enabled", zap.String("nfs_path", nfsPath))
} else {
logger.L().Info("request_capture: NFS storage disabled (nfs_path not configured), DB-only mode")
}
return &RequestCaptureService{
repo: repo,
nfsPath: nfsPath,
timeout: timeout,
}
}
// Capture 异步捕获请求体,立即返回 captureID(DB 行 ID),不阻塞调用方。
// 返回 0 表示捕获未启用或写入失败。
// DB 写入与 NFS 写入各自独立,互不影响。
func (s *RequestCaptureService) Capture(
apiKeyID, userID int64,
requestID, path, method, ipAddr string,
body []byte,
) int64 {
now := time.Now()
// NFS 写入(独立 goroutine)
nfsFilePath := ""
if s.nfsPath != "" {
nfsFilePath = s.buildNFSFilePath(apiKeyID, requestID, now)
bodyCopy := make([]byte, len(body))
copy(bodyCopy, body)
logger.L().Debug("request_capture: launching nfs request write",
zap.Int64("api_key_id", apiKeyID),
zap.String("nfs_file", nfsFilePath),
)
go s.writeToNFS(nfsFilePath, apiKeyID, userID, requestID, path, method, ipAddr, bodyCopy, now)
}
// DB 写入(同步,需要拿到 ID)
ctx, cancel := context.WithTimeout(context.Background(), s.timeout)
defer cancel()
id, err := s.repo.Create(ctx, CreateRequestCaptureLogParams{
APIKeyID: apiKeyID,
UserID: userID,
RequestID: requestID,
Path: path,
Method: method,
IPAddress: ipAddr,
RequestBody: string(body),
NFSFilePath: nfsFilePath,
})
if err != nil {
logger.L().Error("request_capture: db write failed",
zap.Int64("api_key_id", apiKeyID),
zap.String("request_id", requestID),
zap.Error(err),
)
return 0
}
// 记录 captureID → nfsFilePath 映射,供 CaptureResponse 写响应文件用
if nfsFilePath != "" {
s.nfsPathMap.Store(id, nfsFilePath)
}
return id
}
// CaptureResponse 异步将响应体写入已有的捕获记录(数据库 + NFS),不阻塞调用方。
// captureID 为 Capture 返回的 ID,为 0 时直接忽略。
func (s *RequestCaptureService) CaptureResponse(captureID int64, responseBody string) {
if captureID == 0 || responseBody == "" {
return
}
// 取出并删除 NFS 路径映射(一次性消费)
var nfsFilePath string
if v, ok := s.nfsPathMap.LoadAndDelete(captureID); ok {
nfsFilePath, _ = v.(string)
}
go func() {
ctx, cancel := context.WithTimeout(context.Background(), s.timeout)
defer cancel()
if err := s.repo.UpdateResponseBody(ctx, captureID, responseBody); err != nil {
logger.L().Error("request_capture: db update response failed",
zap.Int64("capture_id", captureID),
zap.Error(err),
)
}
// NFS 响应文件:与请求文件同目录,文件名加 _response 后缀
if nfsFilePath != "" {
respPath := nfsResponseFilePath(nfsFilePath)
logger.L().Debug("request_capture: launching nfs response write",
zap.Int64("capture_id", captureID),
zap.String("nfs_file", respPath),
)
s.writeResponseToNFS(respPath, captureID, responseBody)
}
}()
}
// nfsResponseFilePath 将请求文件路径转换为响应文件路径。
// 例如:/nfs/2024-01-01/42/123_reqid.json → /nfs/2024-01-01/42/123_reqid_response.json
func nfsResponseFilePath(requestPath string) string {
ext := filepath.Ext(requestPath)
base := requestPath[:len(requestPath)-len(ext)]
return base + "_response" + ext
}
func (s *RequestCaptureService) buildNFSFilePath(apiKeyID int64, requestID string, t time.Time) string {
date := t.UTC().Format("2006-01-02")
filename := fmt.Sprintf("%d_%s.json", t.UnixNano(), requestID)
return filepath.Join(s.nfsPath, date, fmt.Sprintf("%d", apiKeyID), filename)
}
func (s *RequestCaptureService) writeToNFS(
filePath string,
apiKeyID, userID int64,
requestID, path, method, ipAddr string,
body []byte,
now time.Time,
) {
dir := filepath.Dir(filePath)
if err := os.MkdirAll(dir, 0o755); err != nil {
logger.L().Error("request_capture: mkdir failed",
zap.String("dir", dir),
zap.Error(err),
)
return
}
envelope := nfsFileEnvelope{
APIKeyID: apiKeyID,
UserID: userID,
RequestID: requestID,
CreatedAt: now.UTC(),
Path: path,
Method: method,
IPAddress: ipAddr,
Body: json.RawMessage(body),
}
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
enc.SetEscapeHTML(false)
if err := enc.Encode(envelope); err != nil {
logger.L().Error("request_capture: json marshal failed",
zap.String("request_id", requestID),
zap.Error(err),
)
return
}
if err := os.WriteFile(filePath, buf.Bytes(), 0o644); err != nil {
logger.L().Error("request_capture: nfs write failed",
zap.String("file", filePath),
zap.Error(err),
)
} else {
logger.L().Debug("request_capture: nfs request file written", zap.String("file", filePath))
}
}
// nfsResponseEnvelope 是写入 NFS 响应文件的 JSON 结构。
// Body 使用 any:非流式时为 json.RawMessage(保留原始 JSON 结构),
// 流式时为 string(纯文本,如中文内容),避免将非法 JSON 作为 RawMessage 导致编码失败。
type nfsResponseEnvelope struct {
CaptureID int64 `json:"capture_id"`
CreatedAt time.Time `json:"created_at"`
Body any `json:"body"`
}
func (s *RequestCaptureService) writeResponseToNFS(filePath string, captureID int64, responseBody string) {
dir := filepath.Dir(filePath)
if err := os.MkdirAll(dir, 0o755); err != nil {
logger.L().Error("request_capture: mkdir failed (response)",
zap.String("dir", dir),
zap.Error(err),
)
return
}
// 若 responseBody 是合法 JSON(非流式响应),直接嵌入保留结构;
// 否则(流式纯文本),作为普通字符串存储,避免编码错误。
var body any
if json.Valid([]byte(responseBody)) {
body = json.RawMessage(responseBody)
} else {
body = responseBody
}
envelope := nfsResponseEnvelope{
CaptureID: captureID,
CreatedAt: time.Now().UTC(),
Body: body,
}
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
enc.SetEscapeHTML(false)
if err := enc.Encode(envelope); err != nil {
logger.L().Error("request_capture: json marshal failed (response)",
zap.Int64("capture_id", captureID),
zap.Error(err),
)
return
}
if err := os.WriteFile(filePath, buf.Bytes(), 0o644); err != nil {
logger.L().Error("request_capture: nfs write failed (response)",
zap.String("file", filePath),
zap.Error(err),
)
} else {
logger.L().Debug("request_capture: nfs response file written", zap.String("file", filePath))
}
}
......@@ -485,6 +485,7 @@ var ProviderSet = wire.NewSet(
ProvideScheduledTestRunnerService,
NewGroupCapacityService,
NewChannelService,
NewRequestCaptureService,
NewModelPricingResolver,
NewAffiliateService,
ProvidePaymentConfigService,
......
-- Add capture_requests flag to api_keys
ALTER TABLE api_keys ADD COLUMN IF NOT EXISTS capture_requests boolean NOT NULL DEFAULT false;
-- Create request_capture_logs table (monthly range-partitioned by created_at)
-- PRIMARY KEY must include the partition key, so we use (id, created_at).
CREATE TABLE IF NOT EXISTS request_capture_logs (
id bigserial NOT NULL,
api_key_id bigint NOT NULL,
user_id bigint NOT NULL,
request_id varchar(64),
path varchar(100),
method varchar(10),
ip_address varchar(45),
request_body text,
response_body text,
nfs_file_path varchar(500),
created_at timestamptz NOT NULL DEFAULT now(),
PRIMARY KEY (id, created_at)
) PARTITION BY RANGE (created_at);
CREATE INDEX IF NOT EXISTS idx_rcl_api_key_created ON request_capture_logs (api_key_id, created_at DESC);
CREATE INDEX IF NOT EXISTS idx_rcl_user_id ON request_capture_logs (user_id);
-- Pre-create partitions for previous, current, and next month
DO $$
DECLARE
month_start DATE;
prev_month DATE;
next_month DATE;
BEGIN
month_start := date_trunc('month', now() AT TIME ZONE 'UTC')::date;
prev_month := (month_start - INTERVAL '1 month')::date;
next_month := (month_start + INTERVAL '1 month')::date;
EXECUTE format(
'CREATE TABLE IF NOT EXISTS request_capture_logs_%s PARTITION OF request_capture_logs FOR VALUES FROM (%L) TO (%L)',
to_char(prev_month, 'YYYYMM'), prev_month, month_start
);
EXECUTE format(
'CREATE TABLE IF NOT EXISTS request_capture_logs_%s PARTITION OF request_capture_logs FOR VALUES FROM (%L) TO (%L)',
to_char(month_start, 'YYYYMM'), month_start, next_month
);
EXECUTE format(
'CREATE TABLE IF NOT EXISTS request_capture_logs_%s PARTITION OF request_capture_logs FOR VALUES FROM (%L) TO (%L)',
to_char(next_month, 'YYYYMM'), next_month, (next_month + INTERVAL '1 month')::date
);
END $$;
-- Add capture_requests flag to api_keys
ALTER TABLE api_keys ADD COLUMN IF NOT EXISTS capture_requests boolean NOT NULL DEFAULT false;
-- Create request_capture_logs table (monthly range-partitioned by created_at)
-- PRIMARY KEY must include the partition key, so we use (id, created_at).
CREATE TABLE IF NOT EXISTS request_capture_logs (
id bigserial NOT NULL,
api_key_id bigint NOT NULL,
user_id bigint NOT NULL,
request_id varchar(64),
path varchar(100),
method varchar(10),
ip_address varchar(45),
request_body text,
response_body text,
nfs_file_path varchar(500),
created_at timestamptz NOT NULL DEFAULT now(),
PRIMARY KEY (id, created_at)
) PARTITION BY RANGE (created_at);
CREATE INDEX IF NOT EXISTS idx_rcl_api_key_created ON request_capture_logs (api_key_id, created_at DESC);
CREATE INDEX IF NOT EXISTS idx_rcl_user_id ON request_capture_logs (user_id);
-- Pre-create partitions for previous, current, and next month
DO $$
DECLARE
month_start DATE;
prev_month DATE;
next_month DATE;
BEGIN
month_start := date_trunc('month', now() AT TIME ZONE 'UTC')::date;
prev_month := (month_start - INTERVAL '1 month')::date;
next_month := (month_start + INTERVAL '1 month')::date;
EXECUTE format(
'CREATE TABLE IF NOT EXISTS request_capture_logs_%s PARTITION OF request_capture_logs FOR VALUES FROM (%L) TO (%L)',
to_char(prev_month, 'YYYYMM'), prev_month, month_start
);
EXECUTE format(
'CREATE TABLE IF NOT EXISTS request_capture_logs_%s PARTITION OF request_capture_logs FOR VALUES FROM (%L) TO (%L)',
to_char(month_start, 'YYYYMM'), month_start, next_month
);
EXECUTE format(
'CREATE TABLE IF NOT EXISTS request_capture_logs_%s PARTITION OF request_capture_logs FOR VALUES FROM (%L) TO (%L)',
to_char(next_month, 'YYYYMM'), next_month, (next_month + INTERVAL '1 month')::date
);
END $$;
#!/bin/bash
# =============================================================================
# capture_requests.sh — 控制指定 API Key 的请求体捕获开关
# =============================================================================
# 用法:
# ./capture_requests.sh <key_id> <on|off>
#
# 环境变量(优先级高于脚本内默认值):
# BASE_URL API 服务地址,例如 https://s2a-st.appbym.com
# ADMIN_KEY Admin API Key
#
# 示例:
# BASE_URL=https://example.com ADMIN_KEY=sk-xxx ./capture_requests.sh 123 on
# ./capture_requests.sh 456 off
# =============================================================================
set -euo pipefail
# ---------- 默认配置(可通过环境变量覆盖)----------
DEFAULT_BASE_URL="https://s2a-st.appbym.com"
DEFAULT_ADMIN_KEY="admin-0c19f7fca7f05050a946c7ded419693f6aa3893221e82b5718663c198b002ace"
BASE_URL="${BASE_URL:-$DEFAULT_BASE_URL}"
ADMIN_KEY="${ADMIN_KEY:-$DEFAULT_ADMIN_KEY}"
# ---------- 颜色输出 ----------
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m' # No Color
usage() {
echo "用法: $0 <key_id> <on|off>"
echo ""
echo " key_id API Key 的数字 ID"
echo " on 开启请求体捕获"
echo " off 关闭请求体捕获(同时立即清除认证缓存)"
echo ""
echo "环境变量:"
echo " BASE_URL API 服务地址(默认: $DEFAULT_BASE_URL)"
echo " ADMIN_KEY Admin API Key(必填)"
exit 1
}
# ---------- 参数检查 ----------
if [[ $# -ne 2 ]]; then
echo -e "${RED}错误: 需要 2 个参数${NC}"
usage
fi
KEY_ID="$1"
ACTION="$2"
# 验证 key_id 是正整数
if ! [[ "$KEY_ID" =~ ^[1-9][0-9]*$ ]]; then
echo -e "${RED}错误: key_id 必须是正整数,收到: '$KEY_ID'${NC}"
usage
fi
# 解析 on/off → true/false
case "$ACTION" in
on|true|1|yes)
ENABLED="true"
ACTION_LABEL="开启"
;;
off|false|0|no)
ENABLED="false"
ACTION_LABEL="关闭"
;;
*)
echo -e "${RED}错误: 第二个参数必须是 on 或 off,收到: '$ACTION'${NC}"
usage
;;
esac
# 检查 ADMIN_KEY
if [[ -z "$ADMIN_KEY" ]]; then
echo -e "${RED}错误: 未设置 ADMIN_KEY${NC}"
echo "请通过环境变量传入: ADMIN_KEY=sk-xxx $0 $KEY_ID $ACTION"
exit 1
fi
# ---------- 发送请求 ----------
ENDPOINT="${BASE_URL}/api/v1/admin/api-keys/${KEY_ID}/capture-requests"
echo -e "${YELLOW}${ACTION_LABEL} API Key #${KEY_ID} 的请求体捕获...${NC}"
echo " 接口: PUT $ENDPOINT"
echo " 参数: enabled=$ENABLED"
echo ""
HTTP_RESPONSE=$(curl -s -w "\n%{http_code}" \
-X PUT "$ENDPOINT" \
-H "x-api-key: $ADMIN_KEY" \
-H "Content-Type: application/json" \
-d "{\"enabled\": $ENABLED}")
# 分离响应体和状态码
HTTP_BODY=$(echo "$HTTP_RESPONSE" | sed '$d' | tr -d '\r')
HTTP_CODE=$(echo "$HTTP_RESPONSE" | tail -n 1 | tr -d '\r')
# ---------- 结果输出 ----------
echo "HTTP 状态码: $HTTP_CODE"
echo "响应内容:"
# 尝试格式化 JSON(有 jq 就用,没有就原样输出)
if command -v jq &>/dev/null; then
echo "$HTTP_BODY" | jq .
else
echo "$HTTP_BODY"
fi
echo ""
if [[ "$HTTP_CODE" == "200" ]]; then
echo -e "${GREEN}✓ 操作成功:API Key #${KEY_ID} 请求体捕获已${ACTION_LABEL}${NC}"
if [[ "$ENABLED" == "false" ]]; then
echo -e "${GREEN} 认证缓存已同步清除,下一条请求立即生效${NC}"
fi
else
echo -e "${RED}✗ 操作失败(HTTP $HTTP_CODE${NC}"
exit 1
fi
......@@ -376,6 +376,17 @@ GEMINI_QUOTA_POLICY=
# 设置为 false 可在左侧栏隐藏运维监控菜单并禁用所有运维监控功能
OPS_ENABLED=true
# -----------------------------------------------------------------------------
# Request Capture Configuration (Optional)
# 请求捕获配置(可选,按 API Key 开启,用于审计/调试)
# -----------------------------------------------------------------------------
# Local NFS mount path for writing request/response files (leave empty for DB-only mode)
# 本地挂载的 NFS 根目录,留空则跳过文件写入(仅写数据库)
REQUEST_CAPTURE_NFS_PATH=/app/logs/nfs/
# Async write timeout in seconds (default: 5)
# 单次异步写入超时时间(秒,默认 5)
REQUEST_CAPTURE_WORKER_TIMEOUT_SECONDS=5
# -----------------------------------------------------------------------------
# Update Configuration (在线更新配置)
# -----------------------------------------------------------------------------
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment