Commit 7efa8b54 authored by yangjianbo's avatar yangjianbo
Browse files

perf(后端): 完成性能优化与连接池配置

新增 DB/Redis 连接池配置与校验,并补充单测

网关请求体大小限制与 413 处理

HTTP/req 客户端池化并调整上游连接池默认值

并发槽位改为 ZSET+Lua 与指数退避

用量统计改 SQL 聚合并新增索引迁移

计费缓存写入改工作池并补测试/基准

测试: 在 backend/ 下运行 go test ./...
parent 53767866
...@@ -9,22 +9,22 @@ import ( ...@@ -9,22 +9,22 @@ import (
"time" "time"
) )
// ConcurrencyCache defines cache operations for concurrency service // ConcurrencyCache 定义并发控制的缓存接口
// Uses independent keys per request slot with native Redis TTL for automatic cleanup // 使用有序集合存储槽位,按时间戳清理过期条目
type ConcurrencyCache interface { type ConcurrencyCache interface {
// Account slot management - each slot is a separate key with independent TTL // 账号槽位管理
// Key format: concurrency:account:{accountID}:{requestID} // 键格式: concurrency:account:{accountID}(有序集合,成员为 requestID
AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error)
ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error
GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error)
// User slot management - each slot is a separate key with independent TTL // 用户槽位管理
// Key format: concurrency:user:{userID}:{requestID} // 键格式: concurrency:user:{userID}(有序集合,成员为 requestID
AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error)
ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error
GetUserConcurrency(ctx context.Context, userID int64) (int, error) GetUserConcurrency(ctx context.Context, userID int64) (int, error)
// Wait queue - uses counter with TTL set only on creation // 等待队列计数(只在首次创建时设置 TTL)
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
DecrementWaitCount(ctx context.Context, userID int64) error DecrementWaitCount(ctx context.Context, userID int64) error
} }
......
...@@ -12,6 +12,8 @@ import ( ...@@ -12,6 +12,8 @@ import (
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
) )
type CRSSyncService struct { type CRSSyncService struct {
...@@ -193,7 +195,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -193,7 +195,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
return nil, errors.New("username and password are required") return nil, errors.New("username and password are required")
} }
client := &http.Client{Timeout: 20 * time.Second} client, err := httpclient.GetClient(httpclient.Options{
Timeout: 20 * time.Second,
})
if err != nil {
client = &http.Client{Timeout: 20 * time.Second}
}
adminToken, err := crsLogin(ctx, client, baseURL, input.Username, input.Password) adminToken, err := crsLogin(ctx, client, baseURL, input.Username, input.Password)
if err != nil { if err != nil {
......
package service
import (
"encoding/json"
"fmt"
)
// ParsedRequest 保存网关请求的预解析结果
//
// 性能优化说明:
// 原实现在多个位置重复解析请求体(Handler、Service 各解析一次):
// 1. gateway_handler.go 解析获取 model 和 stream
// 2. gateway_service.go 再次解析获取 system、messages、metadata
// 3. GenerateSessionHash 又一次解析获取会话哈希所需字段
//
// 新实现一次解析,多处复用:
// 1. 在 Handler 层统一调用 ParseGatewayRequest 一次性解析
// 2. 将解析结果 ParsedRequest 传递给 Service 层
// 3. 避免重复 json.Unmarshal,减少 CPU 和内存开销
type ParsedRequest struct {
Body []byte // 原始请求体(保留用于转发)
Model string // 请求的模型名称
Stream bool // 是否为流式请求
MetadataUserID string // metadata.user_id(用于会话亲和)
System any // system 字段内容
Messages []any // messages 数组
HasSystem bool // 是否包含 system 字段
}
// ParseGatewayRequest 解析网关请求体并返回结构化结果
// 性能优化:一次解析提取所有需要的字段,避免重复 Unmarshal
func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return nil, err
}
parsed := &ParsedRequest{
Body: body,
}
if rawModel, exists := req["model"]; exists {
model, ok := rawModel.(string)
if !ok {
return nil, fmt.Errorf("invalid model field type")
}
parsed.Model = model
}
if rawStream, exists := req["stream"]; exists {
stream, ok := rawStream.(bool)
if !ok {
return nil, fmt.Errorf("invalid stream field type")
}
parsed.Stream = stream
}
if metadata, ok := req["metadata"].(map[string]any); ok {
if userID, ok := metadata["user_id"].(string); ok {
parsed.MetadataUserID = userID
}
}
if system, ok := req["system"]; ok && system != nil {
parsed.HasSystem = true
parsed.System = system
}
if messages, ok := req["messages"].([]any); ok {
parsed.Messages = messages
}
return parsed, nil
}
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestParseGatewayRequest(t *testing.T) {
body := []byte(`{"model":"claude-3-7-sonnet","stream":true,"metadata":{"user_id":"session_123e4567-e89b-12d3-a456-426614174000"},"system":[{"type":"text","text":"hello","cache_control":{"type":"ephemeral"}}],"messages":[{"content":"hi"}]}`)
parsed, err := ParseGatewayRequest(body)
require.NoError(t, err)
require.Equal(t, "claude-3-7-sonnet", parsed.Model)
require.True(t, parsed.Stream)
require.Equal(t, "session_123e4567-e89b-12d3-a456-426614174000", parsed.MetadataUserID)
require.True(t, parsed.HasSystem)
require.NotNil(t, parsed.System)
require.Len(t, parsed.Messages, 1)
}
func TestParseGatewayRequest_SystemNull(t *testing.T) {
body := []byte(`{"model":"claude-3","system":null}`)
parsed, err := ParseGatewayRequest(body)
require.NoError(t, err)
require.False(t, parsed.HasSystem)
}
func TestParseGatewayRequest_InvalidModelType(t *testing.T) {
body := []byte(`{"model":123}`)
_, err := ParseGatewayRequest(body)
require.Error(t, err)
}
func TestParseGatewayRequest_InvalidStreamType(t *testing.T) {
body := []byte(`{"stream":"true"}`)
_, err := ParseGatewayRequest(body)
require.Error(t, err)
}
...@@ -19,7 +19,6 @@ import ( ...@@ -19,7 +19,6 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
...@@ -33,7 +32,10 @@ const ( ...@@ -33,7 +32,10 @@ const (
// sseDataRe matches SSE data lines with optional whitespace after colon. // sseDataRe matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: "). // Some upstream APIs return non-standard "data:" without space (should be "data: ").
var sseDataRe = regexp.MustCompile(`^data:\s*`) var (
sseDataRe = regexp.MustCompile(`^data:\s*`)
sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`)
)
// allowedHeaders 白名单headers(参考CRS项目) // allowedHeaders 白名单headers(参考CRS项目)
var allowedHeaders = map[string]bool{ var allowedHeaders = map[string]bool{
...@@ -141,40 +143,36 @@ func NewGatewayService( ...@@ -141,40 +143,36 @@ func NewGatewayService(
} }
} }
// GenerateSessionHash 从请求体计算粘性会话hash // GenerateSessionHash 从预解析请求计算粘性会话 hash
func (s *GatewayService) GenerateSessionHash(body []byte) string { func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
var req map[string]any if parsed == nil {
if err := json.Unmarshal(body, &req); err != nil {
return "" return ""
} }
// 1. 最高优先级:从metadata.user_id提取session_xxx // 1. 最高优先级:从 metadata.user_id 提取 session_xxx
if metadata, ok := req["metadata"].(map[string]any); ok { if parsed.MetadataUserID != "" {
if userID, ok := metadata["user_id"].(string); ok { if match := sessionIDRegex.FindStringSubmatch(parsed.MetadataUserID); len(match) > 1 {
re := regexp.MustCompile(`session_([a-f0-9-]{36})`) return match[1]
if match := re.FindStringSubmatch(userID); len(match) > 1 {
return match[1]
}
} }
} }
// 2. 提取带cache_control: {type: "ephemeral"}的内容 // 2. 提取带 cache_control: {type: "ephemeral"} 的内容
cacheableContent := s.extractCacheableContent(req) cacheableContent := s.extractCacheableContent(parsed)
if cacheableContent != "" { if cacheableContent != "" {
return s.hashContent(cacheableContent) return s.hashContent(cacheableContent)
} }
// 3. Fallback: 使用system内容 // 3. Fallback: 使用 system 内容
if system := req["system"]; system != nil { if parsed.System != nil {
systemText := s.extractTextFromSystem(system) systemText := s.extractTextFromSystem(parsed.System)
if systemText != "" { if systemText != "" {
return s.hashContent(systemText) return s.hashContent(systemText)
} }
} }
// 4. 最后fallback: 使用第一条消息 // 4. 最后 fallback: 使用第一条消息
if messages, ok := req["messages"].([]any); ok && len(messages) > 0 { if len(parsed.Messages) > 0 {
if firstMsg, ok := messages[0].(map[string]any); ok { if firstMsg, ok := parsed.Messages[0].(map[string]any); ok {
msgText := s.extractTextFromContent(firstMsg["content"]) msgText := s.extractTextFromContent(firstMsg["content"])
if msgText != "" { if msgText != "" {
return s.hashContent(msgText) return s.hashContent(msgText)
...@@ -185,36 +183,38 @@ func (s *GatewayService) GenerateSessionHash(body []byte) string { ...@@ -185,36 +183,38 @@ func (s *GatewayService) GenerateSessionHash(body []byte) string {
return "" return ""
} }
func (s *GatewayService) extractCacheableContent(req map[string]any) string { func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
var content string if parsed == nil {
return ""
}
// 检查system中的cacheable内容 var builder strings.Builder
if system, ok := req["system"].([]any); ok {
// 检查 system 中的 cacheable 内容
if system, ok := parsed.System.([]any); ok {
for _, part := range system { for _, part := range system {
if partMap, ok := part.(map[string]any); ok { if partMap, ok := part.(map[string]any); ok {
if cc, ok := partMap["cache_control"].(map[string]any); ok { if cc, ok := partMap["cache_control"].(map[string]any); ok {
if cc["type"] == "ephemeral" { if cc["type"] == "ephemeral" {
if text, ok := partMap["text"].(string); ok { if text, ok := partMap["text"].(string); ok {
content += text builder.WriteString(text)
} }
} }
} }
} }
} }
} }
systemText := builder.String()
// 检查messages中的cacheable内容
if messages, ok := req["messages"].([]any); ok { // 检查 messages 中的 cacheable 内容
for _, msg := range messages { for _, msg := range parsed.Messages {
if msgMap, ok := msg.(map[string]any); ok { if msgMap, ok := msg.(map[string]any); ok {
if msgContent, ok := msgMap["content"].([]any); ok { if msgContent, ok := msgMap["content"].([]any); ok {
for _, part := range msgContent { for _, part := range msgContent {
if partMap, ok := part.(map[string]any); ok { if partMap, ok := part.(map[string]any); ok {
if cc, ok := partMap["cache_control"].(map[string]any); ok { if cc, ok := partMap["cache_control"].(map[string]any); ok {
if cc["type"] == "ephemeral" { if cc["type"] == "ephemeral" {
// 找到cacheable内容,提取第一条消息的文本 return s.extractTextFromContent(msgMap["content"])
return s.extractTextFromContent(msgMap["content"])
}
} }
} }
} }
...@@ -223,7 +223,7 @@ func (s *GatewayService) extractCacheableContent(req map[string]any) string { ...@@ -223,7 +223,7 @@ func (s *GatewayService) extractCacheableContent(req map[string]any) string {
} }
} }
return content return systemText
} }
func (s *GatewayService) extractTextFromSystem(system any) string { func (s *GatewayService) extractTextFromSystem(system any) string {
...@@ -588,19 +588,17 @@ func (s *GatewayService) shouldFailoverUpstreamError(statusCode int) bool { ...@@ -588,19 +588,17 @@ func (s *GatewayService) shouldFailoverUpstreamError(statusCode int) bool {
} }
// Forward 转发请求到Claude API // Forward 转发请求到Claude API
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) { func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) {
startTime := time.Now() startTime := time.Now()
if parsed == nil {
// 解析请求获取model和stream return nil, fmt.Errorf("parse request: empty request")
var req struct {
Model string `json:"model"`
Stream bool `json:"stream"`
}
if err := json.Unmarshal(body, &req); err != nil {
return nil, fmt.Errorf("parse request: %w", err)
} }
if !gjson.GetBytes(body, "system").Exists() { body := parsed.Body
reqModel := parsed.Model
reqStream := parsed.Stream
if !parsed.HasSystem {
body, _ = sjson.SetBytes(body, "system", []any{ body, _ = sjson.SetBytes(body, "system", []any{
map[string]any{ map[string]any{
"type": "text", "type": "text",
...@@ -613,13 +611,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -613,13 +611,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
} }
// 应用模型映射(仅对apikey类型账号) // 应用模型映射(仅对apikey类型账号)
originalModel := req.Model originalModel := reqModel
if account.Type == AccountTypeApiKey { if account.Type == AccountTypeApiKey {
mappedModel := account.GetMappedModel(req.Model) mappedModel := account.GetMappedModel(reqModel)
if mappedModel != req.Model { if mappedModel != reqModel {
// 替换请求体中的模型名 // 替换请求体中的模型名
body = s.replaceModelInBody(body, mappedModel) body = s.replaceModelInBody(body, mappedModel)
req.Model = mappedModel reqModel = mappedModel
log.Printf("Model mapping applied: %s -> %s (account: %s)", originalModel, mappedModel, account.Name) log.Printf("Model mapping applied: %s -> %s (account: %s)", originalModel, mappedModel, account.Name)
} }
} }
...@@ -640,7 +638,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -640,7 +638,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
var resp *http.Response var resp *http.Response
for attempt := 1; attempt <= maxRetries; attempt++ { for attempt := 1; attempt <= maxRetries; attempt++ {
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取) // 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType) upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -692,8 +690,8 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -692,8 +690,8 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 处理正常响应 // 处理正常响应
var usage *ClaudeUsage var usage *ClaudeUsage
var firstTokenMs *int var firstTokenMs *int
if req.Stream { if reqStream {
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, req.Model) streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel)
if err != nil { if err != nil {
if err.Error() == "have error in stream" { if err.Error() == "have error in stream" {
return nil, &UpstreamFailoverError{ return nil, &UpstreamFailoverError{
...@@ -705,7 +703,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -705,7 +703,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
usage = streamResult.usage usage = streamResult.usage
firstTokenMs = streamResult.firstTokenMs firstTokenMs = streamResult.firstTokenMs
} else { } else {
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, req.Model) usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -715,13 +713,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -715,13 +713,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
RequestID: resp.Header.Get("x-request-id"), RequestID: resp.Header.Get("x-request-id"),
Usage: *usage, Usage: *usage,
Model: originalModel, // 使用原始模型用于计费和日志 Model: originalModel, // 使用原始模型用于计费和日志
Stream: req.Stream, Stream: reqStream,
Duration: time.Since(startTime), Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs, FirstTokenMs: firstTokenMs,
}, nil }, nil
} }
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType string) (*http.Request, error) { func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
// 确定目标URL // 确定目标URL
targetURL := claudeAPIURL targetURL := claudeAPIURL
if account.Type == AccountTypeApiKey { if account.Type == AccountTypeApiKey {
...@@ -787,7 +785,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex ...@@ -787,7 +785,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// 处理anthropic-beta header(OAuth账号需要特殊处理) // 处理anthropic-beta header(OAuth账号需要特殊处理)
if tokenType == "oauth" { if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(body, c.GetHeader("anthropic-beta"))) req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} }
return req, nil return req, nil
...@@ -795,7 +793,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex ...@@ -795,7 +793,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// getBetaHeader 处理anthropic-beta header // getBetaHeader 处理anthropic-beta header
// 对于OAuth账号,需要确保包含oauth-2025-04-20 // 对于OAuth账号,需要确保包含oauth-2025-04-20
func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) string { func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string) string {
// 如果客户端传了anthropic-beta // 如果客户端传了anthropic-beta
if clientBetaHeader != "" { if clientBetaHeader != "" {
// 已包含oauth beta则直接返回 // 已包含oauth beta则直接返回
...@@ -832,15 +830,7 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str ...@@ -832,15 +830,7 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str
} }
// 客户端没传,根据模型生成 // 客户端没传,根据模型生成
var modelID string // haiku 模型不需要 claude-code beta
var reqMap map[string]any
if json.Unmarshal(body, &reqMap) == nil {
if m, ok := reqMap["model"].(string); ok {
modelID = m
}
}
// haiku模型不需要claude-code beta
if strings.Contains(strings.ToLower(modelID), "haiku") { if strings.Contains(strings.ToLower(modelID), "haiku") {
return claude.HaikuBetaHeader return claude.HaikuBetaHeader
} }
...@@ -1248,13 +1238,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -1248,13 +1238,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
log.Printf("Increment subscription usage failed: %v", err) log.Printf("Increment subscription usage failed: %v", err)
} }
// 异步更新订阅缓存 // 异步更新订阅缓存
go func() { s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.billingCacheService.UpdateSubscriptionUsage(cacheCtx, user.ID, *apiKey.GroupID, cost.TotalCost); err != nil {
log.Printf("Update subscription cache failed: %v", err)
}
}()
} }
} else { } else {
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用) // 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
...@@ -1263,13 +1247,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -1263,13 +1247,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
log.Printf("Deduct balance failed: %v", err) log.Printf("Deduct balance failed: %v", err)
} }
// 异步更新余额缓存 // 异步更新余额缓存
go func() { s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.billingCacheService.DeductBalanceCache(cacheCtx, user.ID, cost.ActualCost); err != nil {
log.Printf("Update balance cache failed: %v", err)
}
}()
} }
} }
...@@ -1281,7 +1259,15 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -1281,7 +1259,15 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
// ForwardCountTokens 转发 count_tokens 请求到上游 API // ForwardCountTokens 转发 count_tokens 请求到上游 API
// 特点:不记录使用量、仅支持非流式响应 // 特点:不记录使用量、仅支持非流式响应
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, body []byte) error { func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error {
if parsed == nil {
s.countTokensError(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return fmt.Errorf("parse request: empty request")
}
body := parsed.Body
reqModel := parsed.Model
// Antigravity 账户不支持 count_tokens 转发,返回估算值 // Antigravity 账户不支持 count_tokens 转发,返回估算值
// 参考 Antigravity-Manager 和 proxycast 实现 // 参考 Antigravity-Manager 和 proxycast 实现
if account.Platform == PlatformAntigravity { if account.Platform == PlatformAntigravity {
...@@ -1291,14 +1277,12 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, ...@@ -1291,14 +1277,12 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
// 应用模型映射(仅对 apikey 类型账号) // 应用模型映射(仅对 apikey 类型账号)
if account.Type == AccountTypeApiKey { if account.Type == AccountTypeApiKey {
var req struct { if reqModel != "" {
Model string `json:"model"` mappedModel := account.GetMappedModel(reqModel)
} if mappedModel != reqModel {
if err := json.Unmarshal(body, &req); err == nil && req.Model != "" {
mappedModel := account.GetMappedModel(req.Model)
if mappedModel != req.Model {
body = s.replaceModelInBody(body, mappedModel) body = s.replaceModelInBody(body, mappedModel)
log.Printf("CountTokens model mapping applied: %s -> %s (account: %s)", req.Model, mappedModel, account.Name) reqModel = mappedModel
log.Printf("CountTokens model mapping applied: %s -> %s (account: %s)", parsed.Model, mappedModel, account.Name)
} }
} }
} }
...@@ -1311,7 +1295,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, ...@@ -1311,7 +1295,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
} }
// 构建上游请求 // 构建上游请求
upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType) upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType, reqModel)
if err != nil { if err != nil {
s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request") s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request")
return err return err
...@@ -1363,7 +1347,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, ...@@ -1363,7 +1347,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
} }
// buildCountTokensRequest 构建 count_tokens 上游请求 // buildCountTokensRequest 构建 count_tokens 上游请求
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType string) (*http.Request, error) { func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
// 确定目标 URL // 确定目标 URL
targetURL := claudeAPICountTokensURL targetURL := claudeAPICountTokensURL
if account.Type == AccountTypeApiKey { if account.Type == AccountTypeApiKey {
...@@ -1424,7 +1408,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con ...@@ -1424,7 +1408,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
// OAuth 账号:处理 anthropic-beta header // OAuth 账号:处理 anthropic-beta header
if tokenType == "oauth" { if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(body, c.GetHeader("anthropic-beta"))) req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} }
return req, nil return req, nil
......
package service
import (
"strconv"
"testing"
)
var benchmarkStringSink string
// BenchmarkGenerateSessionHash_Metadata 关注 JSON 解析与正则匹配开销。
func BenchmarkGenerateSessionHash_Metadata(b *testing.B) {
svc := &GatewayService{}
body := []byte(`{"metadata":{"user_id":"session_123e4567-e89b-12d3-a456-426614174000"},"messages":[{"content":"hello"}]}`)
b.ReportAllocs()
for i := 0; i < b.N; i++ {
parsed, err := ParseGatewayRequest(body)
if err != nil {
b.Fatalf("解析请求失败: %v", err)
}
benchmarkStringSink = svc.GenerateSessionHash(parsed)
}
}
// BenchmarkExtractCacheableContent_System 关注字符串拼接路径的性能。
func BenchmarkExtractCacheableContent_System(b *testing.B) {
svc := &GatewayService{}
req := buildSystemCacheableRequest(12)
b.ReportAllocs()
for i := 0; i < b.N; i++ {
benchmarkStringSink = svc.extractCacheableContent(req)
}
}
func buildSystemCacheableRequest(parts int) *ParsedRequest {
systemParts := make([]any, 0, parts)
for i := 0; i < parts; i++ {
systemParts = append(systemParts, map[string]any{
"text": "system_part_" + strconv.Itoa(i),
"cache_control": map[string]any{
"type": "ephemeral",
},
})
}
return &ParsedRequest{
System: systemParts,
HasSystem: true,
}
}
...@@ -921,7 +921,10 @@ func sleepGeminiBackoff(attempt int) { ...@@ -921,7 +921,10 @@ func sleepGeminiBackoff(attempt int) {
time.Sleep(sleepFor) time.Sleep(sleepFor)
} }
var sensitiveQueryParamRegex = regexp.MustCompile(`(?i)([?&](?:key|client_secret|access_token|refresh_token)=)[^&"\s]+`) var (
sensitiveQueryParamRegex = regexp.MustCompile(`(?i)([?&](?:key|client_secret|access_token|refresh_token)=)[^&"\s]+`)
retryInRegex = regexp.MustCompile(`Please retry in ([0-9.]+)s`)
)
func sanitizeUpstreamErrorMessage(msg string) string { func sanitizeUpstreamErrorMessage(msg string) string {
if msg == "" { if msg == "" {
...@@ -1925,7 +1928,6 @@ func ParseGeminiRateLimitResetTime(body []byte) *int64 { ...@@ -1925,7 +1928,6 @@ func ParseGeminiRateLimitResetTime(body []byte) *int64 {
} }
// Match "Please retry in Xs" // Match "Please retry in Xs"
retryInRegex := regexp.MustCompile(`Please retry in ([0-9.]+)s`)
matches := retryInRegex.FindStringSubmatch(string(body)) matches := retryInRegex.FindStringSubmatch(string(body))
if len(matches) == 2 { if len(matches) == 2 {
if dur, err := time.ParseDuration(matches[1] + "s"); err == nil { if dur, err := time.ParseDuration(matches[1] + "s"); err == nil {
......
...@@ -7,13 +7,13 @@ import ( ...@@ -7,13 +7,13 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/url"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
) )
type GeminiOAuthService struct { type GeminiOAuthService struct {
...@@ -497,11 +497,12 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR ...@@ -497,11 +497,12 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR
req.Header.Set("Authorization", "Bearer "+accessToken) req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent) req.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
client := &http.Client{Timeout: 30 * time.Second} client, err := httpclient.GetClient(httpclient.Options{
if strings.TrimSpace(proxyURL) != "" { ProxyURL: strings.TrimSpace(proxyURL),
if proxyURLParsed, err := url.Parse(strings.TrimSpace(proxyURL)); err == nil { Timeout: 30 * time.Second,
client.Transport = &http.Transport{Proxy: http.ProxyURL(proxyURLParsed)} })
} if err != nil {
client = &http.Client{Timeout: 30 * time.Second}
} }
resp, err := client.Do(req) resp, err := client.Do(req)
......
...@@ -768,20 +768,12 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec ...@@ -768,20 +768,12 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
if isSubscriptionBilling { if isSubscriptionBilling {
if cost.TotalCost > 0 { if cost.TotalCost > 0 {
_ = s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost) _ = s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost)
go func() { s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.billingCacheService.UpdateSubscriptionUsage(cacheCtx, user.ID, *apiKey.GroupID, cost.TotalCost)
}()
} }
} else { } else {
if cost.ActualCost > 0 { if cost.ActualCost > 0 {
_ = s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost) _ = s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost)
go func() { s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.billingCacheService.DeductBalanceCache(cacheCtx, user.ID, cost.ActualCost)
}()
} }
} }
......
...@@ -18,6 +18,11 @@ import ( ...@@ -18,6 +18,11 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
) )
var (
openAIModelDatePattern = regexp.MustCompile(`-\d{8}$`)
openAIModelBasePattern = regexp.MustCompile(`^(gpt-\d+(?:\.\d+)?)(?:-|$)`)
)
// LiteLLMModelPricing LiteLLM价格数据结构 // LiteLLMModelPricing LiteLLM价格数据结构
// 只保留我们需要的字段,使用指针来处理可能缺失的值 // 只保留我们需要的字段,使用指针来处理可能缺失的值
type LiteLLMModelPricing struct { type LiteLLMModelPricing struct {
...@@ -595,11 +600,8 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing { ...@@ -595,11 +600,8 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
// 2. gpt-5.2-20251222 -> gpt-5.2(去掉日期版本号) // 2. gpt-5.2-20251222 -> gpt-5.2(去掉日期版本号)
// 3. 最终回退到 DefaultTestModel (gpt-5.1-codex) // 3. 最终回退到 DefaultTestModel (gpt-5.1-codex)
func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing { func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing {
// 正则匹配日期后缀 (如 -20251222)
datePattern := regexp.MustCompile(`-\d{8}$`)
// 尝试的回退变体 // 尝试的回退变体
variants := s.generateOpenAIModelVariants(model, datePattern) variants := s.generateOpenAIModelVariants(model, openAIModelDatePattern)
for _, variant := range variants { for _, variant := range variants {
if pricing, ok := s.pricingData[variant]; ok { if pricing, ok := s.pricingData[variant]; ok {
...@@ -638,14 +640,13 @@ func (s *PricingService) generateOpenAIModelVariants(model string, datePattern * ...@@ -638,14 +640,13 @@ func (s *PricingService) generateOpenAIModelVariants(model string, datePattern *
// 2. 提取基础版本号: gpt-5.2-codex -> gpt-5.2 // 2. 提取基础版本号: gpt-5.2-codex -> gpt-5.2
// 只匹配纯数字版本号格式 gpt-X 或 gpt-X.Y,不匹配 gpt-4o 这种带字母后缀的 // 只匹配纯数字版本号格式 gpt-X 或 gpt-X.Y,不匹配 gpt-4o 这种带字母后缀的
basePattern := regexp.MustCompile(`^(gpt-\d+(?:\.\d+)?)(?:-|$)`) if matches := openAIModelBasePattern.FindStringSubmatch(model); len(matches) > 1 {
if matches := basePattern.FindStringSubmatch(model); len(matches) > 1 {
addVariant(matches[1]) addVariant(matches[1])
} }
// 3. 同时去掉日期后再提取基础版本号 // 3. 同时去掉日期后再提取基础版本号
if withoutDate != model { if withoutDate != model {
if matches := basePattern.FindStringSubmatch(withoutDate); len(matches) > 1 { if matches := openAIModelBasePattern.FindStringSubmatch(withoutDate); len(matches) > 1 {
addVariant(matches[1]) addVariant(matches[1])
} }
} }
......
...@@ -186,22 +186,40 @@ func (s *UsageService) GetStatsByApiKey(ctx context.Context, apiKeyID int64, sta ...@@ -186,22 +186,40 @@ func (s *UsageService) GetStatsByApiKey(ctx context.Context, apiKeyID int64, sta
// GetStatsByAccount 获取账号的使用统计 // GetStatsByAccount 获取账号的使用统计
func (s *UsageService) GetStatsByAccount(ctx context.Context, accountID int64, startTime, endTime time.Time) (*UsageStats, error) { func (s *UsageService) GetStatsByAccount(ctx context.Context, accountID int64, startTime, endTime time.Time) (*UsageStats, error) {
logs, _, err := s.usageRepo.ListByAccountAndTimeRange(ctx, accountID, startTime, endTime) stats, err := s.usageRepo.GetAccountStatsAggregated(ctx, accountID, startTime, endTime)
if err != nil { if err != nil {
return nil, fmt.Errorf("list usage logs: %w", err) return nil, fmt.Errorf("get account stats: %w", err)
} }
return s.calculateStats(logs), nil return &UsageStats{
TotalRequests: stats.TotalRequests,
TotalInputTokens: stats.TotalInputTokens,
TotalOutputTokens: stats.TotalOutputTokens,
TotalCacheTokens: stats.TotalCacheTokens,
TotalTokens: stats.TotalTokens,
TotalCost: stats.TotalCost,
TotalActualCost: stats.TotalActualCost,
AverageDurationMs: stats.AverageDurationMs,
}, nil
} }
// GetStatsByModel 获取模型的使用统计 // GetStatsByModel 获取模型的使用统计
func (s *UsageService) GetStatsByModel(ctx context.Context, modelName string, startTime, endTime time.Time) (*UsageStats, error) { func (s *UsageService) GetStatsByModel(ctx context.Context, modelName string, startTime, endTime time.Time) (*UsageStats, error) {
logs, _, err := s.usageRepo.ListByModelAndTimeRange(ctx, modelName, startTime, endTime) stats, err := s.usageRepo.GetModelStatsAggregated(ctx, modelName, startTime, endTime)
if err != nil { if err != nil {
return nil, fmt.Errorf("list usage logs: %w", err) return nil, fmt.Errorf("get model stats: %w", err)
} }
return s.calculateStats(logs), nil return &UsageStats{
TotalRequests: stats.TotalRequests,
TotalInputTokens: stats.TotalInputTokens,
TotalOutputTokens: stats.TotalOutputTokens,
TotalCacheTokens: stats.TotalCacheTokens,
TotalTokens: stats.TotalTokens,
TotalCost: stats.TotalCost,
TotalActualCost: stats.TotalActualCost,
AverageDurationMs: stats.AverageDurationMs,
}, nil
} }
// GetDailyStats 获取每日使用统计(最近N天) // GetDailyStats 获取每日使用统计(最近N天)
...@@ -209,54 +227,12 @@ func (s *UsageService) GetDailyStats(ctx context.Context, userID int64, days int ...@@ -209,54 +227,12 @@ func (s *UsageService) GetDailyStats(ctx context.Context, userID int64, days int
endTime := time.Now() endTime := time.Now()
startTime := endTime.AddDate(0, 0, -days) startTime := endTime.AddDate(0, 0, -days)
logs, _, err := s.usageRepo.ListByUserAndTimeRange(ctx, userID, startTime, endTime) stats, err := s.usageRepo.GetDailyStatsAggregated(ctx, userID, startTime, endTime)
if err != nil { if err != nil {
return nil, fmt.Errorf("list usage logs: %w", err) return nil, fmt.Errorf("get daily stats: %w", err)
} }
// 按日期分组统计 return stats, nil
dailyStats := make(map[string]*UsageStats)
for _, log := range logs {
dateKey := log.CreatedAt.Format("2006-01-02")
if _, exists := dailyStats[dateKey]; !exists {
dailyStats[dateKey] = &UsageStats{}
}
stats := dailyStats[dateKey]
stats.TotalRequests++
stats.TotalInputTokens += int64(log.InputTokens)
stats.TotalOutputTokens += int64(log.OutputTokens)
stats.TotalCacheTokens += int64(log.CacheCreationTokens + log.CacheReadTokens)
stats.TotalTokens += int64(log.TotalTokens())
stats.TotalCost += log.TotalCost
stats.TotalActualCost += log.ActualCost
if log.DurationMs != nil {
stats.AverageDurationMs += float64(*log.DurationMs)
}
}
// 计算平均值并转换为数组
result := make([]map[string]any, 0, len(dailyStats))
for date, stats := range dailyStats {
if stats.TotalRequests > 0 {
stats.AverageDurationMs /= float64(stats.TotalRequests)
}
result = append(result, map[string]any{
"date": date,
"total_requests": stats.TotalRequests,
"total_input_tokens": stats.TotalInputTokens,
"total_output_tokens": stats.TotalOutputTokens,
"total_cache_tokens": stats.TotalCacheTokens,
"total_tokens": stats.TotalTokens,
"total_cost": stats.TotalCost,
"total_actual_cost": stats.TotalActualCost,
"average_duration_ms": stats.AverageDurationMs,
})
}
return result, nil
} }
// calculateStats 计算统计数据 // calculateStats 计算统计数据
......
-- 为聚合查询补充复合索引
CREATE INDEX IF NOT EXISTS idx_usage_logs_account_created_at ON usage_logs(account_id, created_at);
CREATE INDEX IF NOT EXISTS idx_usage_logs_api_key_created_at ON usage_logs(api_key_id, created_at);
CREATE INDEX IF NOT EXISTS idx_usage_logs_model_created_at ON usage_logs(model, created_at);
...@@ -21,6 +21,22 @@ server: ...@@ -21,6 +21,22 @@ server:
# - simple: Hides SaaS features and skips billing/balance checks # - simple: Hides SaaS features and skips billing/balance checks
run_mode: "standard" run_mode: "standard"
# =============================================================================
# 网关配置
# =============================================================================
gateway:
# 等待上游响应头超时时间(秒)
response_header_timeout: 300
# 请求体最大字节数(默认 100MB)
max_body_size: 104857600
# HTTP 上游连接池配置(HTTP/2 + 多代理场景默认)
max_idle_conns: 240
max_idle_conns_per_host: 120
max_conns_per_host: 240
idle_conn_timeout_seconds: 300
# 并发槽位过期时间(分钟)
concurrency_slot_ttl_minutes: 15
# ============================================================================= # =============================================================================
# Database Configuration (PostgreSQL) # Database Configuration (PostgreSQL)
# ============================================================================= # =============================================================================
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment