Commit a14dfb76 authored by yangjianbo's avatar yangjianbo
Browse files

Merge branch 'dev-release'

parents f3605ddc 2588fa6a
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings"
"testing" "testing"
"time" "time"
...@@ -391,3 +392,37 @@ func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling( ...@@ -391,3 +392,37 @@ func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(
require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode) require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode)
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch") require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
} }
func TestAntigravityStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr}
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {\"usage\":{\"input_tokens\":1,\"output_tokens\":2,\"cache_read_input_tokens\":3,\"cache_creation_input_tokens\":4}}\n"))
_, _ = pw.Write([]byte("data: {\"usage\":{\"output_tokens\":5}}\n"))
}()
svc := &AntigravityGatewayService{}
start := time.Now().Add(-10 * time.Millisecond)
usage, firstTokenMs := svc.streamUpstreamResponse(c, resp, start)
_ = pr.Close()
require.NotNil(t, usage)
require.Equal(t, 1, usage.InputTokens)
// 第二次事件覆盖 output_tokens
require.Equal(t, 5, usage.OutputTokens)
require.Equal(t, 3, usage.CacheReadInputTokens)
require.Equal(t, 4, usage.CacheCreationInputTokens)
if firstTokenMs == nil {
t.Fatalf("expected firstTokenMs to be set")
}
// 确保有透传输出
require.True(t, strings.Contains(writer.Body.String(), "data:"))
}
...@@ -6,8 +6,7 @@ import ( ...@@ -6,8 +6,7 @@ import (
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt" "fmt"
"math/rand" "math/rand/v2"
"sync"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
...@@ -23,12 +22,6 @@ type apiKeyAuthCacheConfig struct { ...@@ -23,12 +22,6 @@ type apiKeyAuthCacheConfig struct {
singleflight bool singleflight bool
} }
var (
jitterRandMu sync.Mutex
// 认证缓存抖动使用独立随机源,避免全局 Seed
jitterRand = rand.New(rand.NewSource(time.Now().UnixNano()))
)
func newAPIKeyAuthCacheConfig(cfg *config.Config) apiKeyAuthCacheConfig { func newAPIKeyAuthCacheConfig(cfg *config.Config) apiKeyAuthCacheConfig {
if cfg == nil { if cfg == nil {
return apiKeyAuthCacheConfig{} return apiKeyAuthCacheConfig{}
...@@ -56,6 +49,8 @@ func (c apiKeyAuthCacheConfig) negativeEnabled() bool { ...@@ -56,6 +49,8 @@ func (c apiKeyAuthCacheConfig) negativeEnabled() bool {
return c.negativeTTL > 0 return c.negativeTTL > 0
} }
// jitterTTL 为缓存 TTL 添加抖动,避免多个请求在同一时刻同时过期触发集中回源。
// 这里直接使用 rand/v2 的顶层函数:并发安全,无需全局互斥锁。
func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration { func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration {
if ttl <= 0 { if ttl <= 0 {
return ttl return ttl
...@@ -68,9 +63,7 @@ func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration { ...@@ -68,9 +63,7 @@ func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration {
percent = 100 percent = 100
} }
delta := float64(percent) / 100 delta := float64(percent) / 100
jitterRandMu.Lock() randVal := rand.Float64()
randVal := jitterRand.Float64()
jitterRandMu.Unlock()
factor := 1 - delta + randVal*(2*delta) factor := 1 - delta + randVal*(2*delta)
if factor <= 0 { if factor <= 0 {
return ttl return ttl
......
...@@ -319,16 +319,16 @@ func (s *DashboardService) GetUserUsageTrend(ctx context.Context, startTime, end ...@@ -319,16 +319,16 @@ func (s *DashboardService) GetUserUsageTrend(ctx context.Context, startTime, end
return trend, nil return trend, nil
} }
func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) { func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs) stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs, startTime, endTime)
if err != nil { if err != nil {
return nil, fmt.Errorf("get batch user usage stats: %w", err) return nil, fmt.Errorf("get batch user usage stats: %w", err)
} }
return stats, nil return stats, nil
} }
func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs) stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs, startTime, endTime)
if err != nil { if err != nil {
return nil, fmt.Errorf("get batch api key usage stats: %w", err) return nil, fmt.Errorf("get batch api key usage stats: %w", err)
} }
......
...@@ -4145,7 +4145,8 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http ...@@ -4145,7 +4145,8 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize maxLineSize = s.cfg.Gateway.MaxLineSize
} }
scanner.Buffer(make([]byte, 64*1024), maxLineSize) scanBuf := getSSEScannerBuf64K()
scanner.Buffer(scanBuf[:0], maxLineSize)
type scanEvent struct { type scanEvent struct {
line string line string
...@@ -4164,7 +4165,8 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http ...@@ -4164,7 +4165,8 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
} }
var lastReadAt int64 var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() { go func(scanBuf *sseScannerBuf64K) {
defer putSSEScannerBuf64K(scanBuf)
defer close(events) defer close(events)
for scanner.Scan() { for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
...@@ -4175,7 +4177,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http ...@@ -4175,7 +4177,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
if err := scanner.Err(); err != nil { if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err}) _ = sendEvent(scanEvent{err: err})
} }
}() }(scanBuf)
defer close(done) defer close(done)
streamInterval := time.Duration(0) streamInterval := time.Duration(0)
...@@ -4481,24 +4483,16 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h ...@@ -4481,24 +4483,16 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
} }
// replaceModelInResponseBody 替换响应体中的model字段 // replaceModelInResponseBody 替换响应体中的model字段
// 使用 gjson/sjson 精确替换,避免全量 JSON 反序列化
func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte { func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
var resp map[string]any if m := gjson.GetBytes(body, "model"); m.Exists() && m.Str == fromModel {
if err := json.Unmarshal(body, &resp); err != nil { newBody, err := sjson.SetBytes(body, "model", toModel)
return body
}
model, ok := resp["model"].(string)
if !ok || model != fromModel {
return body
}
resp["model"] = toModel
newBody, err := json.Marshal(resp)
if err != nil { if err != nil {
return body return body
} }
return newBody return newBody
}
return body
} }
// RecordUsageInput 记录使用量的输入参数 // RecordUsageInput 记录使用量的输入参数
......
package service
import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &GatewayService{
cfg: cfg,
rateLimitService: &RateLimitService{},
}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr}
go func() {
defer func() { _ = pw.Close() }()
// Minimal SSE event to trigger parseSSEUsage
_, _ = pw.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":3}}}\n\n"))
_, _ = pw.Write([]byte("data: {\"type\":\"message_delta\",\"usage\":{\"output_tokens\":7}}\n\n"))
_, _ = pw.Write([]byte("data: [DONE]\n\n"))
}()
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", nil, false)
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.usage)
require.Equal(t, 3, result.usage.InputTokens)
require.Equal(t, 7, result.usage.OutputTokens)
}
...@@ -2,19 +2,7 @@ package service ...@@ -2,19 +2,7 @@ package service
import ( import (
_ "embed" _ "embed"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings" "strings"
"time"
)
const (
opencodeCodexHeaderURL = "https://raw.githubusercontent.com/anomalyco/opencode/dev/packages/opencode/src/session/prompt/codex_header.txt"
codexCacheTTL = 15 * time.Minute
) )
//go:embed prompts/codex_cli_instructions.md //go:embed prompts/codex_cli_instructions.md
...@@ -77,12 +65,6 @@ type codexTransformResult struct { ...@@ -77,12 +65,6 @@ type codexTransformResult struct {
PromptCacheKey string PromptCacheKey string
} }
type opencodeCacheMetadata struct {
ETag string `json:"etag"`
LastFetch string `json:"lastFetch,omitempty"`
LastChecked int64 `json:"lastChecked"`
}
func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool) codexTransformResult { func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool) codexTransformResult {
result := codexTransformResult{} result := codexTransformResult{}
// 工具续链需求会影响存储策略与 input 过滤逻辑。 // 工具续链需求会影响存储策略与 input 过滤逻辑。
...@@ -216,54 +198,9 @@ func getNormalizedCodexModel(modelID string) string { ...@@ -216,54 +198,9 @@ func getNormalizedCodexModel(modelID string) string {
return "" return ""
} }
func getOpenCodeCachedPrompt(url, cacheFileName, metaFileName string) string {
cacheDir := codexCachePath("")
if cacheDir == "" {
return ""
}
cacheFile := filepath.Join(cacheDir, cacheFileName)
metaFile := filepath.Join(cacheDir, metaFileName)
var cachedContent string
if content, ok := readFile(cacheFile); ok {
cachedContent = content
}
var meta opencodeCacheMetadata
if loadJSON(metaFile, &meta) && meta.LastChecked > 0 && cachedContent != "" {
if time.Since(time.UnixMilli(meta.LastChecked)) < codexCacheTTL {
return cachedContent
}
}
content, etag, status, err := fetchWithETag(url, meta.ETag)
if err == nil && status == http.StatusNotModified && cachedContent != "" {
return cachedContent
}
if err == nil && status >= 200 && status < 300 && content != "" {
_ = writeFile(cacheFile, content)
meta = opencodeCacheMetadata{
ETag: etag,
LastFetch: time.Now().UTC().Format(time.RFC3339),
LastChecked: time.Now().UnixMilli(),
}
_ = writeJSON(metaFile, meta)
return content
}
return cachedContent
}
func getOpenCodeCodexHeader() string { func getOpenCodeCodexHeader() string {
// 优先从 opencode 仓库缓存获取指令。 // 兼容保留:历史上这里会从 opencode 仓库拉取 codex_header.txt。
opencodeInstructions := getOpenCodeCachedPrompt(opencodeCodexHeaderURL, "opencode-codex-header.txt", "opencode-codex-header-meta.json") // 现在我们与 Codex CLI 一致,直接使用仓库内置的 instructions,避免读写缓存与外网依赖。
// 若 opencode 指令可用,直接返回。
if opencodeInstructions != "" {
return opencodeInstructions
}
// 否则回退使用本地 Codex CLI 指令。
return getCodexCLIInstructions() return getCodexCLIInstructions()
} }
...@@ -281,8 +218,8 @@ func GetCodexCLIInstructions() string { ...@@ -281,8 +218,8 @@ func GetCodexCLIInstructions() string {
} }
// applyInstructions 处理 instructions 字段 // applyInstructions 处理 instructions 字段
// isCodexCLI=true: 仅补充缺失的 instructions(使用 opencode 指令) // isCodexCLI=true: 仅补充缺失的 instructions(使用内置 Codex CLI 指令)
// isCodexCLI=false: 优先使用 opencode 指令覆盖 // isCodexCLI=false: 优先使用内置 Codex CLI 指令覆盖
func applyInstructions(reqBody map[string]any, isCodexCLI bool) bool { func applyInstructions(reqBody map[string]any, isCodexCLI bool) bool {
if isCodexCLI { if isCodexCLI {
return applyCodexCLIInstructions(reqBody) return applyCodexCLIInstructions(reqBody)
...@@ -291,13 +228,13 @@ func applyInstructions(reqBody map[string]any, isCodexCLI bool) bool { ...@@ -291,13 +228,13 @@ func applyInstructions(reqBody map[string]any, isCodexCLI bool) bool {
} }
// applyCodexCLIInstructions 为 Codex CLI 请求补充缺失的 instructions // applyCodexCLIInstructions 为 Codex CLI 请求补充缺失的 instructions
// 仅在 instructions 为空时添加 opencode 指令 // 仅在 instructions 为空时添加内置 Codex CLI 指令(不依赖 opencode 缓存/回源)
func applyCodexCLIInstructions(reqBody map[string]any) bool { func applyCodexCLIInstructions(reqBody map[string]any) bool {
if !isInstructionsEmpty(reqBody) { if !isInstructionsEmpty(reqBody) {
return false // 已有有效 instructions,不修改 return false // 已有有效 instructions,不修改
} }
instructions := strings.TrimSpace(getOpenCodeCodexHeader()) instructions := strings.TrimSpace(getCodexCLIInstructions())
if instructions != "" { if instructions != "" {
reqBody["instructions"] = instructions reqBody["instructions"] = instructions
return true return true
...@@ -306,8 +243,8 @@ func applyCodexCLIInstructions(reqBody map[string]any) bool { ...@@ -306,8 +243,8 @@ func applyCodexCLIInstructions(reqBody map[string]any) bool {
return false return false
} }
// applyOpenCodeInstructions 为非 Codex CLI 请求应用 opencode 指令 // applyOpenCodeInstructions 为非 Codex CLI 请求应用内置 Codex CLI 指令(兼容历史函数名)
// 优先使用 opencode 指令覆盖 // 优先使用内置 Codex CLI 指令覆盖
func applyOpenCodeInstructions(reqBody map[string]any) bool { func applyOpenCodeInstructions(reqBody map[string]any) bool {
instructions := strings.TrimSpace(getOpenCodeCodexHeader()) instructions := strings.TrimSpace(getOpenCodeCodexHeader())
existingInstructions, _ := reqBody["instructions"].(string) existingInstructions, _ := reqBody["instructions"].(string)
...@@ -489,85 +426,3 @@ func normalizeCodexTools(reqBody map[string]any) bool { ...@@ -489,85 +426,3 @@ func normalizeCodexTools(reqBody map[string]any) bool {
return modified return modified
} }
func codexCachePath(filename string) string {
home, err := os.UserHomeDir()
if err != nil {
return ""
}
cacheDir := filepath.Join(home, ".opencode", "cache")
if filename == "" {
return cacheDir
}
return filepath.Join(cacheDir, filename)
}
func readFile(path string) (string, bool) {
if path == "" {
return "", false
}
data, err := os.ReadFile(path)
if err != nil {
return "", false
}
return string(data), true
}
func writeFile(path, content string) error {
if path == "" {
return fmt.Errorf("empty cache path")
}
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return err
}
return os.WriteFile(path, []byte(content), 0o644)
}
func loadJSON(path string, target any) bool {
data, err := os.ReadFile(path)
if err != nil {
return false
}
if err := json.Unmarshal(data, target); err != nil {
return false
}
return true
}
func writeJSON(path string, value any) error {
if path == "" {
return fmt.Errorf("empty json path")
}
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return err
}
data, err := json.Marshal(value)
if err != nil {
return err
}
return os.WriteFile(path, data, 0o644)
}
func fetchWithETag(url, etag string) (string, string, int, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return "", "", 0, err
}
req.Header.Set("User-Agent", "sub2api-codex")
if etag != "" {
req.Header.Set("If-None-Match", etag)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", "", 0, err
}
defer func() {
_ = resp.Body.Close()
}()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", "", resp.StatusCode, err
}
return string(body), resp.Header.Get("etag"), resp.StatusCode, nil
}
package service package service
import ( import (
"encoding/json"
"os"
"path/filepath"
"testing" "testing"
"time"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) { func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) {
// 续链场景:保留 item_reference 与 id,但不再强制 store=true。 // 续链场景:保留 item_reference 与 id,但不再强制 store=true。
setupCodexCache(t)
reqBody := map[string]any{ reqBody := map[string]any{
"model": "gpt-5.2", "model": "gpt-5.2",
...@@ -48,7 +43,6 @@ func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) { ...@@ -48,7 +43,6 @@ func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) {
func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) { func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) {
// 续链场景:显式 store=false 不再强制为 true,保持 false。 // 续链场景:显式 store=false 不再强制为 true,保持 false。
setupCodexCache(t)
reqBody := map[string]any{ reqBody := map[string]any{
"model": "gpt-5.1", "model": "gpt-5.1",
...@@ -68,7 +62,6 @@ func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) { ...@@ -68,7 +62,6 @@ func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) {
func TestApplyCodexOAuthTransform_ExplicitStoreTrueForcedFalse(t *testing.T) { func TestApplyCodexOAuthTransform_ExplicitStoreTrueForcedFalse(t *testing.T) {
// 显式 store=true 也会强制为 false。 // 显式 store=true 也会强制为 false。
setupCodexCache(t)
reqBody := map[string]any{ reqBody := map[string]any{
"model": "gpt-5.1", "model": "gpt-5.1",
...@@ -88,7 +81,6 @@ func TestApplyCodexOAuthTransform_ExplicitStoreTrueForcedFalse(t *testing.T) { ...@@ -88,7 +81,6 @@ func TestApplyCodexOAuthTransform_ExplicitStoreTrueForcedFalse(t *testing.T) {
func TestApplyCodexOAuthTransform_NonContinuationDefaultsStoreFalseAndStripsIDs(t *testing.T) { func TestApplyCodexOAuthTransform_NonContinuationDefaultsStoreFalseAndStripsIDs(t *testing.T) {
// 非续链场景:未设置 store 时默认 false,并移除 input 中的 id。 // 非续链场景:未设置 store 时默认 false,并移除 input 中的 id。
setupCodexCache(t)
reqBody := map[string]any{ reqBody := map[string]any{
"model": "gpt-5.1", "model": "gpt-5.1",
...@@ -130,8 +122,6 @@ func TestFilterCodexInput_RemovesItemReferenceWhenNotPreserved(t *testing.T) { ...@@ -130,8 +122,6 @@ func TestFilterCodexInput_RemovesItemReferenceWhenNotPreserved(t *testing.T) {
} }
func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunctionTools(t *testing.T) { func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunctionTools(t *testing.T) {
setupCodexCache(t)
reqBody := map[string]any{ reqBody := map[string]any{
"model": "gpt-5.1", "model": "gpt-5.1",
"tools": []any{ "tools": []any{
...@@ -162,7 +152,6 @@ func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunction ...@@ -162,7 +152,6 @@ func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunction
func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) { func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
// 空 input 应保持为空且不触发异常。 // 空 input 应保持为空且不触发异常。
setupCodexCache(t)
reqBody := map[string]any{ reqBody := map[string]any{
"model": "gpt-5.1", "model": "gpt-5.1",
...@@ -187,88 +176,27 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) { ...@@ -187,88 +176,27 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) {
for input, expected := range cases { for input, expected := range cases {
require.Equal(t, expected, normalizeCodexModel(input)) require.Equal(t, expected, normalizeCodexModel(input))
} }
} }
func TestApplyCodexOAuthTransform_CodexCLI_PreservesExistingInstructions(t *testing.T) { func TestApplyCodexOAuthTransform_CodexCLI_PreservesExistingInstructions(t *testing.T) {
// Codex CLI 场景:已有 instructions 时保持不变 // Codex CLI 场景:已有 instructions 时不修改
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
"instructions": "user custom instructions",
"input": []any{},
}
result := applyCodexOAuthTransform(reqBody, true)
instructions, ok := reqBody["instructions"].(string)
require.True(t, ok)
require.Equal(t, "user custom instructions", instructions)
// instructions 未变,但其他字段(如 store、stream)可能被修改
require.True(t, result.Modified)
}
func TestApplyCodexOAuthTransform_CodexCLI_AddsInstructionsWhenEmpty(t *testing.T) {
// Codex CLI 场景:无 instructions 时补充内置指令
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
"input": []any{},
}
result := applyCodexOAuthTransform(reqBody, true)
instructions, ok := reqBody["instructions"].(string)
require.True(t, ok)
require.NotEmpty(t, instructions)
require.True(t, result.Modified)
}
func TestApplyCodexOAuthTransform_NonCodexCLI_UsesOpenCodeInstructions(t *testing.T) {
// 非 Codex CLI 场景:使用 opencode 指令(缓存中有 header)
setupCodexCache(t)
reqBody := map[string]any{ reqBody := map[string]any{
"model": "gpt-5.1", "model": "gpt-5.1",
"input": []any{}, "instructions": "existing instructions",
} }
result := applyCodexOAuthTransform(reqBody, false) result := applyCodexOAuthTransform(reqBody, true) // isCodexCLI=true
instructions, ok := reqBody["instructions"].(string) instructions, ok := reqBody["instructions"].(string)
require.True(t, ok) require.True(t, ok)
require.Equal(t, "header", instructions) // setupCodexCache 设置的缓存内容 require.Equal(t, "existing instructions", instructions)
require.True(t, result.Modified) // Modified 仍可能为 true(因为其他字段被修改),但 instructions 应保持不变
} _ = result
func setupCodexCache(t *testing.T) {
t.Helper()
// 使用临时 HOME 避免触发网络拉取 header。
// Windows 使用 USERPROFILE,Unix 使用 HOME。
tempDir := t.TempDir()
t.Setenv("HOME", tempDir)
t.Setenv("USERPROFILE", tempDir)
cacheDir := filepath.Join(tempDir, ".opencode", "cache")
require.NoError(t, os.MkdirAll(cacheDir, 0o755))
require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header.txt"), []byte("header"), 0o644))
meta := map[string]any{
"etag": "",
"lastFetch": time.Now().UTC().Format(time.RFC3339),
"lastChecked": time.Now().UnixMilli(),
}
data, err := json.Marshal(meta)
require.NoError(t, err)
require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header-meta.json"), data, 0o644))
} }
func TestApplyCodexOAuthTransform_CodexCLI_SuppliesDefaultWhenEmpty(t *testing.T) { func TestApplyCodexOAuthTransform_CodexCLI_SuppliesDefaultWhenEmpty(t *testing.T) {
// Codex CLI 场景:无 instructions 时补充默认值 // Codex CLI 场景:无 instructions 时补充默认值
setupCodexCache(t)
reqBody := map[string]any{ reqBody := map[string]any{
"model": "gpt-5.1", "model": "gpt-5.1",
...@@ -284,8 +212,7 @@ func TestApplyCodexOAuthTransform_CodexCLI_SuppliesDefaultWhenEmpty(t *testing.T ...@@ -284,8 +212,7 @@ func TestApplyCodexOAuthTransform_CodexCLI_SuppliesDefaultWhenEmpty(t *testing.T
} }
func TestApplyCodexOAuthTransform_NonCodexCLI_OverridesInstructions(t *testing.T) { func TestApplyCodexOAuthTransform_NonCodexCLI_OverridesInstructions(t *testing.T) {
// 非 Codex CLI 场景:使用 opencode 指令覆盖 // 非 Codex CLI 场景:使用内置 Codex CLI 指令覆盖
setupCodexCache(t)
reqBody := map[string]any{ reqBody := map[string]any{
"model": "gpt-5.1", "model": "gpt-5.1",
......
...@@ -24,6 +24,8 @@ import ( ...@@ -24,6 +24,8 @@ import (
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
) )
const ( const (
...@@ -765,7 +767,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco ...@@ -765,7 +767,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
bodyModified := false bodyModified := false
originalModel := reqModel originalModel := reqModel
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
// 对所有请求执行模型映射(包含 Codex CLI)。 // 对所有请求执行模型映射(包含 Codex CLI)。
mappedModel := account.GetMappedModel(reqModel) mappedModel := account.GetMappedModel(reqModel)
...@@ -969,6 +971,10 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco ...@@ -969,6 +971,10 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
} }
} }
if usage == nil {
usage = &OpenAIUsage{}
}
reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel) reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel)
return &OpenAIForwardResult{ return &OpenAIForwardResult{
...@@ -1053,6 +1059,12 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. ...@@ -1053,6 +1059,12 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
req.Header.Set("user-agent", customUA) req.Header.Set("user-agent", customUA)
} }
// 若开启 ForceCodexCLI,则强制将上游 User-Agent 伪装为 Codex CLI。
// 用于网关未透传/改写 User-Agent 时,仍能命中 Codex 侧识别逻辑。
if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI {
req.Header.Set("user-agent", "codex_cli_rs/0.98.0")
}
// Ensure required headers exist // Ensure required headers exist
if req.Header.Get("content-type") == "" { if req.Header.Get("content-type") == "" {
req.Header.Set("content-type", "application/json") req.Header.Set("content-type", "application/json")
...@@ -1233,7 +1245,8 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp ...@@ -1233,7 +1245,8 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize maxLineSize = s.cfg.Gateway.MaxLineSize
} }
scanner.Buffer(make([]byte, 64*1024), maxLineSize) scanBuf := getSSEScannerBuf64K()
scanner.Buffer(scanBuf[:0], maxLineSize)
type scanEvent struct { type scanEvent struct {
line string line string
...@@ -1252,7 +1265,8 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp ...@@ -1252,7 +1265,8 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
} }
var lastReadAt int64 var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() { go func(scanBuf *sseScannerBuf64K) {
defer putSSEScannerBuf64K(scanBuf)
defer close(events) defer close(events)
for scanner.Scan() { for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
...@@ -1263,7 +1277,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp ...@@ -1263,7 +1277,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if err := scanner.Err(); err != nil { if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err}) _ = sendEvent(scanEvent{err: err})
} }
}() }(scanBuf)
defer close(done) defer close(done)
streamInterval := time.Duration(0) streamInterval := time.Duration(0)
...@@ -1442,31 +1456,22 @@ func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel st ...@@ -1442,31 +1456,22 @@ func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel st
return line return line
} }
var event map[string]any // 使用 gjson 精确检查 model 字段,避免全量 JSON 反序列化
if err := json.Unmarshal([]byte(data), &event); err != nil { if m := gjson.Get(data, "model"); m.Exists() && m.Str == fromModel {
return line newData, err := sjson.Set(data, "model", toModel)
}
// Replace model in response
if m, ok := event["model"].(string); ok && m == fromModel {
event["model"] = toModel
newData, err := json.Marshal(event)
if err != nil { if err != nil {
return line return line
} }
return "data: " + string(newData) return "data: " + newData
} }
// Check nested response // 检查嵌套的 response.model 字段
if response, ok := event["response"].(map[string]any); ok { if m := gjson.Get(data, "response.model"); m.Exists() && m.Str == fromModel {
if m, ok := response["model"].(string); ok && m == fromModel { newData, err := sjson.Set(data, "response.model", toModel)
response["model"] = toModel
newData, err := json.Marshal(event)
if err != nil { if err != nil {
return line return line
} }
return "data: " + string(newData) return "data: " + newData
}
} }
return line return line
...@@ -1686,23 +1691,15 @@ func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, erro ...@@ -1686,23 +1691,15 @@ func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, erro
} }
func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte { func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
var resp map[string]any // 使用 gjson/sjson 精确替换 model 字段,避免全量 JSON 反序列化
if err := json.Unmarshal(body, &resp); err != nil { if m := gjson.GetBytes(body, "model"); m.Exists() && m.Str == fromModel {
return body newBody, err := sjson.SetBytes(body, "model", toModel)
}
model, ok := resp["model"].(string)
if !ok || model != fromModel {
return body
}
resp["model"] = toModel
newBody, err := json.Marshal(resp)
if err != nil { if err != nil {
return body return body
} }
return newBody return newBody
}
return body
} }
// OpenAIRecordUsageInput input for recording usage // OpenAIRecordUsageInput input for recording usage
......
...@@ -14,6 +14,7 @@ import ( ...@@ -14,6 +14,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
) )
type stubOpenAIAccountRepo struct { type stubOpenAIAccountRepo struct {
...@@ -1082,6 +1083,43 @@ func TestOpenAIStreamingHeadersOverride(t *testing.T) { ...@@ -1082,6 +1083,43 @@ func TestOpenAIStreamingHeadersOverride(t *testing.T) {
} }
} }
func TestOpenAIStreamingReuseScannerBufferAndStillWorks(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{
StatusCode: http.StatusOK,
Body: pr,
Header: http.Header{},
}
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":2,\"input_tokens_details\":{\"cached_tokens\":3}}}}\n\n"))
}()
result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.usage)
require.Equal(t, 1, result.usage.InputTokens)
require.Equal(t, 2, result.usage.OutputTokens)
require.Equal(t, 3, result.usage.CacheReadInputTokens)
}
func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) { func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
cfg := &config.Config{ cfg := &config.Config{
...@@ -1165,3 +1203,226 @@ func TestOpenAIValidateUpstreamBaseURLEnabledEnforcesAllowlist(t *testing.T) { ...@@ -1165,3 +1203,226 @@ func TestOpenAIValidateUpstreamBaseURLEnabledEnforcesAllowlist(t *testing.T) {
t.Fatalf("expected non-allowlisted host to fail") t.Fatalf("expected non-allowlisted host to fail")
} }
} }
// ==================== P1-08 修复:model 替换性能优化测试 ====================
func TestReplaceModelInSSELine(t *testing.T) {
svc := &OpenAIGatewayService{}
tests := []struct {
name string
line string
from string
to string
expected string
}{
{
name: "顶层 model 字段替换",
line: `data: {"id":"chatcmpl-123","model":"gpt-4o","choices":[]}`,
from: "gpt-4o",
to: "my-custom-model",
expected: `data: {"id":"chatcmpl-123","model":"my-custom-model","choices":[]}`,
},
{
name: "嵌套 response.model 替换",
line: `data: {"type":"response","response":{"id":"resp-1","model":"gpt-4o","output":[]}}`,
from: "gpt-4o",
to: "my-model",
expected: `data: {"type":"response","response":{"id":"resp-1","model":"my-model","output":[]}}`,
},
{
name: "model 不匹配时不替换",
line: `data: {"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`,
from: "gpt-4o",
to: "my-model",
expected: `data: {"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`,
},
{
name: "无 model 字段时不替换",
line: `data: {"id":"chatcmpl-123","choices":[]}`,
from: "gpt-4o",
to: "my-model",
expected: `data: {"id":"chatcmpl-123","choices":[]}`,
},
{
name: "空 data 行",
line: `data: `,
from: "gpt-4o",
to: "my-model",
expected: `data: `,
},
{
name: "[DONE] 行",
line: `data: [DONE]`,
from: "gpt-4o",
to: "my-model",
expected: `data: [DONE]`,
},
{
name: "非 data: 前缀行",
line: `event: message`,
from: "gpt-4o",
to: "my-model",
expected: `event: message`,
},
{
name: "非法 JSON 不替换",
line: `data: {invalid json}`,
from: "gpt-4o",
to: "my-model",
expected: `data: {invalid json}`,
},
{
name: "无空格 data: 格式",
line: `data:{"id":"x","model":"gpt-4o"}`,
from: "gpt-4o",
to: "my-model",
expected: `data: {"id":"x","model":"my-model"}`,
},
{
name: "model 名含特殊字符",
line: `data: {"model":"org/model-v2.1-beta"}`,
from: "org/model-v2.1-beta",
to: "custom/alias",
expected: `data: {"model":"custom/alias"}`,
},
{
name: "空行",
line: "",
from: "gpt-4o",
to: "my-model",
expected: "",
},
{
name: "保持其他字段不变",
line: `data: {"id":"abc","object":"chat.completion.chunk","model":"gpt-4o","created":1234567890,"choices":[{"index":0,"delta":{"content":"hi"}}]}`,
from: "gpt-4o",
to: "alias",
expected: `data: {"id":"abc","object":"chat.completion.chunk","model":"alias","created":1234567890,"choices":[{"index":0,"delta":{"content":"hi"}}]}`,
},
{
name: "顶层优先于嵌套:同时存在两个 model",
line: `data: {"model":"gpt-4o","response":{"model":"gpt-4o"}}`,
from: "gpt-4o",
to: "replaced",
expected: `data: {"model":"replaced","response":{"model":"gpt-4o"}}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := svc.replaceModelInSSELine(tt.line, tt.from, tt.to)
require.Equal(t, tt.expected, got)
})
}
}
func TestReplaceModelInSSEBody(t *testing.T) {
svc := &OpenAIGatewayService{}
tests := []struct {
name string
body string
from string
to string
expected string
}{
{
name: "多行 SSE body 替换",
body: "data: {\"model\":\"gpt-4o\",\"choices\":[]}\n\ndata: {\"model\":\"gpt-4o\",\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\ndata: [DONE]\n",
from: "gpt-4o",
to: "alias",
expected: "data: {\"model\":\"alias\",\"choices\":[]}\n\ndata: {\"model\":\"alias\",\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\ndata: [DONE]\n",
},
{
name: "无需替换的 body",
body: "data: {\"model\":\"gpt-3.5-turbo\"}\n\ndata: [DONE]\n",
from: "gpt-4o",
to: "alias",
expected: "data: {\"model\":\"gpt-3.5-turbo\"}\n\ndata: [DONE]\n",
},
{
name: "混合 event 和 data 行",
body: "event: message\ndata: {\"model\":\"gpt-4o\"}\n\n",
from: "gpt-4o",
to: "alias",
expected: "event: message\ndata: {\"model\":\"alias\"}\n\n",
},
{
name: "空 body",
body: "",
from: "gpt-4o",
to: "alias",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := svc.replaceModelInSSEBody(tt.body, tt.from, tt.to)
require.Equal(t, tt.expected, got)
})
}
}
func TestReplaceModelInResponseBody(t *testing.T) {
svc := &OpenAIGatewayService{}
tests := []struct {
name string
body string
from string
to string
expected string
}{
{
name: "替换顶层 model",
body: `{"id":"chatcmpl-123","model":"gpt-4o","choices":[]}`,
from: "gpt-4o",
to: "alias",
expected: `{"id":"chatcmpl-123","model":"alias","choices":[]}`,
},
{
name: "model 不匹配不替换",
body: `{"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`,
from: "gpt-4o",
to: "alias",
expected: `{"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`,
},
{
name: "无 model 字段不替换",
body: `{"id":"chatcmpl-123","choices":[]}`,
from: "gpt-4o",
to: "alias",
expected: `{"id":"chatcmpl-123","choices":[]}`,
},
{
name: "非法 JSON 返回原值",
body: `not json`,
from: "gpt-4o",
to: "alias",
expected: `not json`,
},
{
name: "空 body 返回原值",
body: ``,
from: "gpt-4o",
to: "alias",
expected: ``,
},
{
name: "保持嵌套结构不变",
body: `{"model":"gpt-4o","usage":{"prompt_tokens":10,"completion_tokens":20},"choices":[{"message":{"role":"assistant","content":"hello"}}]}`,
from: "gpt-4o",
to: "alias",
expected: `{"model":"alias","usage":{"prompt_tokens":10,"completion_tokens":20},"choices":[{"message":{"role":"assistant","content":"hello"}}]}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := svc.replaceModelInResponseBody([]byte(tt.body), tt.from, tt.to)
require.Equal(t, tt.expected, string(got))
})
}
}
package service
import "sync"
const sseScannerBuf64KSize = 64 * 1024
type sseScannerBuf64K [sseScannerBuf64KSize]byte
var sseScannerBuf64KPool = sync.Pool{
New: func() any {
return new(sseScannerBuf64K)
},
}
func getSSEScannerBuf64K() *sseScannerBuf64K {
return sseScannerBuf64KPool.Get().(*sseScannerBuf64K)
}
func putSSEScannerBuf64K(buf *sseScannerBuf64K) {
if buf == nil {
return
}
sseScannerBuf64KPool.Put(buf)
}
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestSSEScannerBuf64KPool_GetPutDoesNotPanic(t *testing.T) {
buf := getSSEScannerBuf64K()
require.NotNil(t, buf)
require.Equal(t, sseScannerBuf64KSize, len(buf[:]))
buf[0] = 1
putSSEScannerBuf64K(buf)
// 允许传入 nil,确保不会 panic
putSSEScannerBuf64K(nil)
}
...@@ -4,10 +4,15 @@ import ( ...@@ -4,10 +4,15 @@ import (
"context" "context"
"fmt" "fmt"
"log" "log"
"math/rand/v2"
"strconv"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/dgraph-io/ristretto"
"golang.org/x/sync/singleflight"
) )
// MaxExpiresAt is the maximum allowed expiration date (year 2099) // MaxExpiresAt is the maximum allowed expiration date (year 2099)
...@@ -35,15 +40,76 @@ type SubscriptionService struct { ...@@ -35,15 +40,76 @@ type SubscriptionService struct {
groupRepo GroupRepository groupRepo GroupRepository
userSubRepo UserSubscriptionRepository userSubRepo UserSubscriptionRepository
billingCacheService *BillingCacheService billingCacheService *BillingCacheService
// L1 缓存:加速中间件热路径的订阅查询
subCacheL1 *ristretto.Cache
subCacheGroup singleflight.Group
subCacheTTL time.Duration
subCacheJitter int // 抖动百分比
} }
// NewSubscriptionService 创建订阅服务 // NewSubscriptionService 创建订阅服务
func NewSubscriptionService(groupRepo GroupRepository, userSubRepo UserSubscriptionRepository, billingCacheService *BillingCacheService) *SubscriptionService { func NewSubscriptionService(groupRepo GroupRepository, userSubRepo UserSubscriptionRepository, billingCacheService *BillingCacheService, cfg *config.Config) *SubscriptionService {
return &SubscriptionService{ svc := &SubscriptionService{
groupRepo: groupRepo, groupRepo: groupRepo,
userSubRepo: userSubRepo, userSubRepo: userSubRepo,
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
} }
svc.initSubCache(cfg)
return svc
}
// initSubCache 初始化订阅 L1 缓存
func (s *SubscriptionService) initSubCache(cfg *config.Config) {
if cfg == nil {
return
}
sc := cfg.SubscriptionCache
if sc.L1Size <= 0 || sc.L1TTLSeconds <= 0 {
return
}
cache, err := ristretto.NewCache(&ristretto.Config{
NumCounters: int64(sc.L1Size) * 10,
MaxCost: int64(sc.L1Size),
BufferItems: 64,
})
if err != nil {
log.Printf("Warning: failed to init subscription L1 cache: %v", err)
return
}
s.subCacheL1 = cache
s.subCacheTTL = time.Duration(sc.L1TTLSeconds) * time.Second
s.subCacheJitter = sc.JitterPercent
}
// subCacheKey 生成订阅缓存 key(热路径,避免 fmt.Sprintf 开销)
func subCacheKey(userID, groupID int64) string {
return "sub:" + strconv.FormatInt(userID, 10) + ":" + strconv.FormatInt(groupID, 10)
}
// jitteredTTL 为 TTL 添加抖动,避免集中过期
func (s *SubscriptionService) jitteredTTL(ttl time.Duration) time.Duration {
if ttl <= 0 || s.subCacheJitter <= 0 {
return ttl
}
pct := s.subCacheJitter
if pct > 100 {
pct = 100
}
delta := float64(pct) / 100
factor := 1 - delta + rand.Float64()*(2*delta)
if factor <= 0 {
return ttl
}
return time.Duration(float64(ttl) * factor)
}
// InvalidateSubCache 失效指定用户+分组的订阅 L1 缓存
func (s *SubscriptionService) InvalidateSubCache(userID, groupID int64) {
if s.subCacheL1 == nil {
return
}
s.subCacheL1.Del(subCacheKey(userID, groupID))
} }
// AssignSubscriptionInput 分配订阅输入 // AssignSubscriptionInput 分配订阅输入
...@@ -81,6 +147,7 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass ...@@ -81,6 +147,7 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass
} }
// 失效订阅缓存 // 失效订阅缓存
s.InvalidateSubCache(input.UserID, input.GroupID)
if s.billingCacheService != nil { if s.billingCacheService != nil {
userID, groupID := input.UserID, input.GroupID userID, groupID := input.UserID, input.GroupID
go func() { go func() {
...@@ -167,6 +234,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in ...@@ -167,6 +234,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
} }
// 失效订阅缓存 // 失效订阅缓存
s.InvalidateSubCache(input.UserID, input.GroupID)
if s.billingCacheService != nil { if s.billingCacheService != nil {
userID, groupID := input.UserID, input.GroupID userID, groupID := input.UserID, input.GroupID
go func() { go func() {
...@@ -188,6 +256,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in ...@@ -188,6 +256,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
} }
// 失效订阅缓存 // 失效订阅缓存
s.InvalidateSubCache(input.UserID, input.GroupID)
if s.billingCacheService != nil { if s.billingCacheService != nil {
userID, groupID := input.UserID, input.GroupID userID, groupID := input.UserID, input.GroupID
go func() { go func() {
...@@ -297,6 +366,7 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti ...@@ -297,6 +366,7 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti
} }
// 失效订阅缓存 // 失效订阅缓存
s.InvalidateSubCache(sub.UserID, sub.GroupID)
if s.billingCacheService != nil { if s.billingCacheService != nil {
userID, groupID := sub.UserID, sub.GroupID userID, groupID := sub.UserID, sub.GroupID
go func() { go func() {
...@@ -363,6 +433,7 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti ...@@ -363,6 +433,7 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti
} }
// 失效订阅缓存 // 失效订阅缓存
s.InvalidateSubCache(sub.UserID, sub.GroupID)
if s.billingCacheService != nil { if s.billingCacheService != nil {
userID, groupID := sub.UserID, sub.GroupID userID, groupID := sub.UserID, sub.GroupID
go func() { go func() {
...@@ -381,12 +452,39 @@ func (s *SubscriptionService) GetByID(ctx context.Context, id int64) (*UserSubsc ...@@ -381,12 +452,39 @@ func (s *SubscriptionService) GetByID(ctx context.Context, id int64) (*UserSubsc
} }
// GetActiveSubscription 获取用户对特定分组的有效订阅 // GetActiveSubscription 获取用户对特定分组的有效订阅
// 使用 L1 缓存 + singleflight 加速中间件热路径。
// 返回缓存对象的浅拷贝,调用方可安全修改字段而不会污染缓存或触发 data race。
func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID, groupID int64) (*UserSubscription, error) { func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID, groupID int64) (*UserSubscription, error) {
key := subCacheKey(userID, groupID)
// L1 缓存命中:返回浅拷贝
if s.subCacheL1 != nil {
if v, ok := s.subCacheL1.Get(key); ok {
if sub, ok := v.(*UserSubscription); ok {
cp := *sub
return &cp, nil
}
}
}
// singleflight 防止并发击穿
value, err, _ := s.subCacheGroup.Do(key, func() (any, error) {
sub, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID) sub, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
if err != nil { if err != nil {
return nil, ErrSubscriptionNotFound return nil, ErrSubscriptionNotFound
} }
// 写入 L1 缓存
if s.subCacheL1 != nil {
_ = s.subCacheL1.SetWithTTL(key, sub, 1, s.jitteredTTL(s.subCacheTTL))
}
return sub, nil return sub, nil
})
if err != nil {
return nil, err
}
// singleflight 返回的也是缓存指针,需要浅拷贝
cp := *value.(*UserSubscription)
return &cp, nil
} }
// ListUserSubscriptions 获取用户的所有订阅 // ListUserSubscriptions 获取用户的所有订阅
...@@ -521,10 +619,13 @@ func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *Use ...@@ -521,10 +619,13 @@ func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *Use
needsInvalidateCache = true needsInvalidateCache = true
} }
// 如果有窗口被重置,失效 Redis 缓存以保持一致性 // 如果有窗口被重置,失效缓存以保持一致性
if needsInvalidateCache && s.billingCacheService != nil { if needsInvalidateCache {
s.InvalidateSubCache(sub.UserID, sub.GroupID)
if s.billingCacheService != nil {
_ = s.billingCacheService.InvalidateSubscription(ctx, sub.UserID, sub.GroupID) _ = s.billingCacheService.InvalidateSubscription(ctx, sub.UserID, sub.GroupID)
} }
}
return nil return nil
} }
...@@ -544,6 +645,78 @@ func (s *SubscriptionService) CheckUsageLimits(ctx context.Context, sub *UserSub ...@@ -544,6 +645,78 @@ func (s *SubscriptionService) CheckUsageLimits(ctx context.Context, sub *UserSub
return nil return nil
} }
// ValidateAndCheckLimits 合并验证+限额检查(中间件热路径专用)
// 仅做内存检查,不触发 DB 写入。窗口重置的 DB 写入由 DoWindowMaintenance 异步完成。
// 返回 needsMaintenance 表示是否需要异步执行窗口维护。
func (s *SubscriptionService) ValidateAndCheckLimits(sub *UserSubscription, group *Group) (needsMaintenance bool, err error) {
// 1. 验证订阅状态
if sub.Status == SubscriptionStatusExpired {
return false, ErrSubscriptionExpired
}
if sub.Status == SubscriptionStatusSuspended {
return false, ErrSubscriptionSuspended
}
if sub.IsExpired() {
return false, ErrSubscriptionExpired
}
// 2. 内存中修正过期窗口的用量,确保 CheckUsageLimits 不会误拒绝用户
// 实际的 DB 窗口重置由 DoWindowMaintenance 异步完成
if sub.NeedsDailyReset() {
sub.DailyUsageUSD = 0
needsMaintenance = true
}
if sub.NeedsWeeklyReset() {
sub.WeeklyUsageUSD = 0
needsMaintenance = true
}
if sub.NeedsMonthlyReset() {
sub.MonthlyUsageUSD = 0
needsMaintenance = true
}
if !sub.IsWindowActivated() {
needsMaintenance = true
}
// 3. 检查用量限额
if !sub.CheckDailyLimit(group, 0) {
return needsMaintenance, ErrDailyLimitExceeded
}
if !sub.CheckWeeklyLimit(group, 0) {
return needsMaintenance, ErrWeeklyLimitExceeded
}
if !sub.CheckMonthlyLimit(group, 0) {
return needsMaintenance, ErrMonthlyLimitExceeded
}
return needsMaintenance, nil
}
// DoWindowMaintenance 异步执行窗口维护(激活+重置)
// 使用独立 context,不受请求取消影响。
// 注意:此方法仅在 ValidateAndCheckLimits 返回 needsMaintenance=true 时调用,
// 而 IsExpired()=true 的订阅在 ValidateAndCheckLimits 中已被拦截返回错误,
// 因此进入此方法的订阅一定未过期,无需处理过期状态同步。
func (s *SubscriptionService) DoWindowMaintenance(sub *UserSubscription) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// 激活窗口(首次使用时)
if !sub.IsWindowActivated() {
if err := s.CheckAndActivateWindow(ctx, sub); err != nil {
log.Printf("Failed to activate subscription windows: %v", err)
}
}
// 重置过期窗口
if err := s.CheckAndResetWindows(ctx, sub); err != nil {
log.Printf("Failed to reset subscription windows: %v", err)
}
// 失效 L1 缓存,确保后续请求拿到更新后的数据
s.InvalidateSubCache(sub.UserID, sub.GroupID)
}
// RecordUsage 记录使用量到订阅 // RecordUsage 记录使用量到订阅
func (s *SubscriptionService) RecordUsage(ctx context.Context, subscriptionID int64, costUSD float64) error { func (s *SubscriptionService) RecordUsage(ctx context.Context, subscriptionID int64, costUSD float64) error {
return s.userSubRepo.IncrementUsage(ctx, subscriptionID, costUSD) return s.userSubRepo.IncrementUsage(ctx, subscriptionID, costUSD)
......
...@@ -316,8 +316,8 @@ func (s *UsageService) GetUserModelStats(ctx context.Context, userID int64, star ...@@ -316,8 +316,8 @@ func (s *UsageService) GetUserModelStats(ctx context.Context, userID int64, star
} }
// GetBatchAPIKeyUsageStats returns today/total actual_cost for given api keys. // GetBatchAPIKeyUsageStats returns today/total actual_cost for given api keys.
func (s *UsageService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { func (s *UsageService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs) stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs, startTime, endTime)
if err != nil { if err != nil {
return nil, fmt.Errorf("get batch api key usage stats: %w", err) return nil, fmt.Errorf("get batch api key usage stats: %w", err)
} }
......
...@@ -3,6 +3,8 @@ package service ...@@ -3,6 +3,8 @@ package service
import ( import (
"context" "context"
"fmt" "fmt"
"log"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
...@@ -62,13 +64,15 @@ type ChangePasswordRequest struct { ...@@ -62,13 +64,15 @@ type ChangePasswordRequest struct {
type UserService struct { type UserService struct {
userRepo UserRepository userRepo UserRepository
authCacheInvalidator APIKeyAuthCacheInvalidator authCacheInvalidator APIKeyAuthCacheInvalidator
billingCache BillingCache
} }
// NewUserService 创建用户服务实例 // NewUserService 创建用户服务实例
func NewUserService(userRepo UserRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *UserService { func NewUserService(userRepo UserRepository, authCacheInvalidator APIKeyAuthCacheInvalidator, billingCache BillingCache) *UserService {
return &UserService{ return &UserService{
userRepo: userRepo, userRepo: userRepo,
authCacheInvalidator: authCacheInvalidator, authCacheInvalidator: authCacheInvalidator,
billingCache: billingCache,
} }
} }
...@@ -183,6 +187,15 @@ func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount fl ...@@ -183,6 +187,15 @@ func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount fl
if s.authCacheInvalidator != nil { if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
} }
if s.billingCache != nil {
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.billingCache.InvalidateUserBalance(cacheCtx, userID); err != nil {
log.Printf("invalidate user balance cache failed: user_id=%d err=%v", userID, err)
}
}()
}
return nil return nil
} }
......
//go:build unit
package service
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
// --- mock: UserRepository ---
type mockUserRepo struct {
updateBalanceErr error
updateBalanceFn func(ctx context.Context, id int64, amount float64) error
}
func (m *mockUserRepo) Create(context.Context, *User) error { return nil }
func (m *mockUserRepo) GetByID(context.Context, int64) (*User, error) { return &User{}, nil }
func (m *mockUserRepo) GetByEmail(context.Context, string) (*User, error) { return &User{}, nil }
func (m *mockUserRepo) GetFirstAdmin(context.Context) (*User, error) { return &User{}, nil }
func (m *mockUserRepo) Update(context.Context, *User) error { return nil }
func (m *mockUserRepo) Delete(context.Context, int64) error { return nil }
func (m *mockUserRepo) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockUserRepo) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockUserRepo) UpdateBalance(ctx context.Context, id int64, amount float64) error {
if m.updateBalanceFn != nil {
return m.updateBalanceFn(ctx, id, amount)
}
return m.updateBalanceErr
}
func (m *mockUserRepo) DeductBalance(context.Context, int64, float64) error { return nil }
func (m *mockUserRepo) UpdateConcurrency(context.Context, int64, int) error { return nil }
func (m *mockUserRepo) ExistsByEmail(context.Context, string) (bool, error) { return false, nil }
func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
return 0, nil
}
func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil }
func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil }
// --- mock: APIKeyAuthCacheInvalidator ---
type mockAuthCacheInvalidator struct {
invalidatedUserIDs []int64
mu sync.Mutex
}
func (m *mockAuthCacheInvalidator) InvalidateAuthCacheByKey(context.Context, string) {}
func (m *mockAuthCacheInvalidator) InvalidateAuthCacheByGroupID(context.Context, int64) {}
func (m *mockAuthCacheInvalidator) InvalidateAuthCacheByUserID(_ context.Context, userID int64) {
m.mu.Lock()
defer m.mu.Unlock()
m.invalidatedUserIDs = append(m.invalidatedUserIDs, userID)
}
// --- mock: BillingCache ---
type mockBillingCache struct {
invalidateErr error
invalidateCallCount atomic.Int64
invalidatedUserIDs []int64
mu sync.Mutex
}
func (m *mockBillingCache) GetUserBalance(context.Context, int64) (float64, error) { return 0, nil }
func (m *mockBillingCache) SetUserBalance(context.Context, int64, float64) error { return nil }
func (m *mockBillingCache) DeductUserBalance(context.Context, int64, float64) error { return nil }
func (m *mockBillingCache) InvalidateUserBalance(_ context.Context, userID int64) error {
m.invalidateCallCount.Add(1)
m.mu.Lock()
defer m.mu.Unlock()
m.invalidatedUserIDs = append(m.invalidatedUserIDs, userID)
return m.invalidateErr
}
func (m *mockBillingCache) GetSubscriptionCache(context.Context, int64, int64) (*SubscriptionCacheData, error) {
return nil, nil
}
func (m *mockBillingCache) SetSubscriptionCache(context.Context, int64, int64, *SubscriptionCacheData) error {
return nil
}
func (m *mockBillingCache) UpdateSubscriptionUsage(context.Context, int64, int64, float64) error {
return nil
}
func (m *mockBillingCache) InvalidateSubscriptionCache(context.Context, int64, int64) error {
return nil
}
// --- 测试 ---
func TestUpdateBalance_Success(t *testing.T) {
repo := &mockUserRepo{}
cache := &mockBillingCache{}
svc := NewUserService(repo, nil, cache)
err := svc.UpdateBalance(context.Background(), 42, 100.0)
require.NoError(t, err)
// 等待异步 goroutine 完成
require.Eventually(t, func() bool {
return cache.invalidateCallCount.Load() == 1
}, 2*time.Second, 10*time.Millisecond, "应异步调用 InvalidateUserBalance")
cache.mu.Lock()
defer cache.mu.Unlock()
require.Equal(t, []int64{42}, cache.invalidatedUserIDs, "应对 userID=42 失效缓存")
}
func TestUpdateBalance_NilBillingCache_NoPanic(t *testing.T) {
repo := &mockUserRepo{}
svc := NewUserService(repo, nil, nil) // billingCache = nil
err := svc.UpdateBalance(context.Background(), 1, 50.0)
require.NoError(t, err, "billingCache 为 nil 时不应 panic")
}
func TestUpdateBalance_CacheFailure_DoesNotAffectReturn(t *testing.T) {
repo := &mockUserRepo{}
cache := &mockBillingCache{invalidateErr: errors.New("redis connection refused")}
svc := NewUserService(repo, nil, cache)
err := svc.UpdateBalance(context.Background(), 99, 200.0)
require.NoError(t, err, "缓存失效失败不应影响主流程返回值")
// 等待异步 goroutine 完成(即使失败也应调用)
require.Eventually(t, func() bool {
return cache.invalidateCallCount.Load() == 1
}, 2*time.Second, 10*time.Millisecond, "即使失败也应调用 InvalidateUserBalance")
}
func TestUpdateBalance_RepoError_ReturnsError(t *testing.T) {
repo := &mockUserRepo{updateBalanceErr: errors.New("database error")}
cache := &mockBillingCache{}
svc := NewUserService(repo, nil, cache)
err := svc.UpdateBalance(context.Background(), 1, 100.0)
require.Error(t, err, "repo 失败时应返回错误")
require.Contains(t, err.Error(), "update balance")
// repo 失败时不应触发缓存失效
time.Sleep(100 * time.Millisecond)
require.Equal(t, int64(0), cache.invalidateCallCount.Load(),
"repo 失败时不应调用 InvalidateUserBalance")
}
func TestUpdateBalance_WithAuthCacheInvalidator(t *testing.T) {
repo := &mockUserRepo{}
auth := &mockAuthCacheInvalidator{}
cache := &mockBillingCache{}
svc := NewUserService(repo, auth, cache)
err := svc.UpdateBalance(context.Background(), 77, 300.0)
require.NoError(t, err)
// 验证 auth cache 同步失效
auth.mu.Lock()
require.Equal(t, []int64{77}, auth.invalidatedUserIDs)
auth.mu.Unlock()
// 验证 billing cache 异步失效
require.Eventually(t, func() bool {
return cache.invalidateCallCount.Load() == 1
}, 2*time.Second, 10*time.Millisecond)
}
func TestNewUserService_FieldsAssignment(t *testing.T) {
repo := &mockUserRepo{}
auth := &mockAuthCacheInvalidator{}
cache := &mockBillingCache{}
svc := NewUserService(repo, auth, cache)
require.NotNil(t, svc)
require.Equal(t, repo, svc.userRepo)
require.Equal(t, auth, svc.authCacheInvalidator)
require.Equal(t, cache, svc.billingCache)
}
...@@ -58,13 +58,67 @@ TZ=Asia/Shanghai ...@@ -58,13 +58,67 @@ TZ=Asia/Shanghai
POSTGRES_USER=sub2api POSTGRES_USER=sub2api
POSTGRES_PASSWORD=change_this_secure_password POSTGRES_PASSWORD=change_this_secure_password
POSTGRES_DB=sub2api POSTGRES_DB=sub2api
# PostgreSQL 监听端口(同时用于 PG 服务端和应用连接,默认 5432)
DATABASE_PORT=5432
# -----------------------------------------------------------------------------
# PostgreSQL 服务端参数(可选;主要用于 deploy/docker-compose-aicodex.yml)
# -----------------------------------------------------------------------------
# POSTGRES_MAX_CONNECTIONS:PostgreSQL 服务端允许的最大连接数。
# 必须 >=(所有 Sub2API 实例的 DATABASE_MAX_OPEN_CONNS 之和)+ 预留余量(例如 20%)。
POSTGRES_MAX_CONNECTIONS=1024
# POSTGRES_SHARED_BUFFERS:PostgreSQL 用于缓存数据页的共享内存。
# 常见建议:物理内存的 10%~25%(容器内存受限时请按实际限制调整)。
# 8GB 内存容器参考:1GB。
POSTGRES_SHARED_BUFFERS=1GB
# POSTGRES_EFFECTIVE_CACHE_SIZE:查询规划器“假设可用的 OS 缓存大小”(不等于实际分配)。
# 常见建议:物理内存的 50%~75%。
# 8GB 内存容器参考:6GB。
POSTGRES_EFFECTIVE_CACHE_SIZE=4GB
# POSTGRES_MAINTENANCE_WORK_MEM:维护操作内存(VACUUM/CREATE INDEX 等)。
# 值越大维护越快,但会占用更多内存。
# 8GB 内存容器参考:128MB。
POSTGRES_MAINTENANCE_WORK_MEM=128MB
# -----------------------------------------------------------------------------
# PostgreSQL 连接池参数(可选,默认与程序内置一致)
# -----------------------------------------------------------------------------
# 说明:
# - 这些参数控制 Sub2API 进程到 PostgreSQL 的连接池大小(不是 PostgreSQL 自身的 max_connections)。
# - 多实例/多副本部署时,总连接上限约等于:实例数 * DATABASE_MAX_OPEN_CONNS。
# - 连接池过大可能导致:数据库连接耗尽、内存占用上升、上下文切换增多,反而变慢。
# - 建议结合 PostgreSQL 的 max_connections 与机器规格逐步调优:
# 通常把应用总连接上限控制在 max_connections 的 50%~80% 更稳妥。
#
# DATABASE_MAX_OPEN_CONNS:最大打开连接数(活跃+空闲),达到后新请求会等待可用连接。
# 典型范围:50~500(取决于 DB 规格、实例数、SQL 复杂度)。
DATABASE_MAX_OPEN_CONNS=256
# DATABASE_MAX_IDLE_CONNS:最大空闲连接数(热连接),建议 <= MAX_OPEN。
# 太小会频繁建连增加延迟;太大会长期占用数据库资源。
DATABASE_MAX_IDLE_CONNS=128
# DATABASE_CONN_MAX_LIFETIME_MINUTES:单个连接最大存活时间(单位:分钟)。
# 用于避免连接长期不重建导致的中间件/LB/NAT 异常或服务端重启后的“僵尸连接”。
# 设置为 0 表示不限制(一般不建议生产环境)。
DATABASE_CONN_MAX_LIFETIME_MINUTES=30
# DATABASE_CONN_MAX_IDLE_TIME_MINUTES:空闲连接最大存活时间(单位:分钟)。
# 超过该时间的空闲连接会被回收,防止长时间闲置占用连接数。
# 设置为 0 表示不限制(一般不建议生产环境)。
DATABASE_CONN_MAX_IDLE_TIME_MINUTES=5
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Redis Configuration # Redis Configuration
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Redis 监听端口(同时用于应用连接和 Redis 服务端,默认 6379)
REDIS_PORT=6379
# Leave empty for no password (default for local development) # Leave empty for no password (default for local development)
REDIS_PASSWORD= REDIS_PASSWORD=
REDIS_DB=0 REDIS_DB=0
# Redis 服务端最大客户端连接数(可选;主要用于 deploy/docker-compose-aicodex.yml)
REDIS_MAXCLIENTS=50000
# Redis 连接池大小(默认 1024)
REDIS_POOL_SIZE=4096
# Redis 最小空闲连接数(默认 10)
REDIS_MIN_IDLE_CONNS=256
REDIS_ENABLE_TLS=false REDIS_ENABLE_TLS=false
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
...@@ -119,6 +173,19 @@ RATE_LIMIT_OVERLOAD_COOLDOWN_MINUTES=10 ...@@ -119,6 +173,19 @@ RATE_LIMIT_OVERLOAD_COOLDOWN_MINUTES=10
# Gateway Scheduling (Optional) # Gateway Scheduling (Optional)
# 调度缓存与受控回源配置(缓存就绪且命中时不读 DB) # 调度缓存与受控回源配置(缓存就绪且命中时不读 DB)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Force Codex CLI mode: treat all /openai/v1/responses requests as Codex CLI.
# 强制按 Codex CLI 处理 /openai/v1/responses 请求(用于网关未透传/改写 User-Agent 的兜底)。
#
# 注意:开启后会影响所有客户端的行为(不仅限于 VS Code / Codex CLI),请谨慎开启。
#
# 默认:false
GATEWAY_FORCE_CODEX_CLI=false
# 上游连接池:每主机最大连接数(默认 1024;流式/HTTP1.1 场景可调大,如 2400/4096)
GATEWAY_MAX_CONNS_PER_HOST=2048
# 上游连接池:最大空闲连接总数(默认 2560;账号/代理隔离 + 高并发场景可调大)
GATEWAY_MAX_IDLE_CONNS=8192
# 上游连接池:每主机最大空闲连接(默认 120)
GATEWAY_MAX_IDLE_CONNS_PER_HOST=4096
# 粘性会话最大排队长度 # 粘性会话最大排队长度
GATEWAY_SCHEDULING_STICKY_SESSION_MAX_WAITING=3 GATEWAY_SCHEDULING_STICKY_SESSION_MAX_WAITING=3
# 粘性会话等待超时(时间段,例如 45s) # 粘性会话等待超时(时间段,例如 45s)
......
...@@ -20,6 +20,10 @@ server: ...@@ -20,6 +20,10 @@ server:
# Mode: "debug" for development, "release" for production # Mode: "debug" for development, "release" for production
# 运行模式:"debug" 用于开发,"release" 用于生产环境 # 运行模式:"debug" 用于开发,"release" 用于生产环境
mode: "release" mode: "release"
# Frontend base URL used to generate external links in emails (e.g. password reset)
# 用于生成邮件中的外部链接(例如:重置密码链接)的前端基础地址
# Example: "https://example.com"
frontend_url: ""
# Trusted proxies for X-Forwarded-For parsing (CIDR/IP). Empty disables trusted proxies. # Trusted proxies for X-Forwarded-For parsing (CIDR/IP). Empty disables trusted proxies.
# 信任的代理地址(CIDR/IP 格式),用于解析 X-Forwarded-For 头。留空则禁用代理信任。 # 信任的代理地址(CIDR/IP 格式),用于解析 X-Forwarded-For 头。留空则禁用代理信任。
trusted_proxies: [] trusted_proxies: []
...@@ -108,9 +112,9 @@ security: ...@@ -108,9 +112,9 @@ security:
# 白名单禁用时是否允许 http:// URL(默认: false,要求 https) # 白名单禁用时是否允许 http:// URL(默认: false,要求 https)
allow_insecure_http: true allow_insecure_http: true
response_headers: response_headers:
# Enable configurable response header filtering (disable to use default allowlist) # Enable configurable response header filtering (default: true)
# 启用可配置的响应头过滤(禁用则使用默认白名单 # 启用可配置的响应头过滤(默认启用,过滤上游敏感响应头
enabled: false enabled: true
# Extra allowed response headers from upstream # Extra allowed response headers from upstream
# 额外允许的上游响应头 # 额外允许的上游响应头
additional_allowed: [] additional_allowed: []
...@@ -151,17 +155,22 @@ gateway: ...@@ -151,17 +155,22 @@ gateway:
# - account_proxy: Isolate by account+proxy combination (default, finest granularity) # - account_proxy: Isolate by account+proxy combination (default, finest granularity)
# - account_proxy: 按账户+代理组合隔离(默认,最细粒度) # - account_proxy: 按账户+代理组合隔离(默认,最细粒度)
connection_pool_isolation: "account_proxy" connection_pool_isolation: "account_proxy"
# Force Codex CLI mode: treat all /openai/v1/responses requests as Codex CLI.
# 强制按 Codex CLI 处理 /openai/v1/responses 请求(用于网关未透传/改写 User-Agent 的兜底)。
#
# 注意:开启后会影响所有客户端的行为(不仅限于 VS Code / Codex CLI),请谨慎开启。
force_codex_cli: false
# HTTP upstream connection pool settings (HTTP/2 + multi-proxy scenario defaults) # HTTP upstream connection pool settings (HTTP/2 + multi-proxy scenario defaults)
# HTTP 上游连接池配置(HTTP/2 + 多代理场景默认值) # HTTP 上游连接池配置(HTTP/2 + 多代理场景默认值)
# Max idle connections across all hosts # Max idle connections across all hosts
# 所有主机的最大空闲连接数 # 所有主机的最大空闲连接数
max_idle_conns: 240 max_idle_conns: 2560
# Max idle connections per host # Max idle connections per host
# 每个主机的最大空闲连接数 # 每个主机的最大空闲连接数
max_idle_conns_per_host: 120 max_idle_conns_per_host: 120
# Max connections per host # Max connections per host
# 每个主机的最大连接数 # 每个主机的最大连接数
max_conns_per_host: 240 max_conns_per_host: 1024
# Idle connection timeout (seconds) # Idle connection timeout (seconds)
# 空闲连接超时时间(秒) # 空闲连接超时时间(秒)
idle_conn_timeout_seconds: 90 idle_conn_timeout_seconds: 90
...@@ -381,9 +390,22 @@ database: ...@@ -381,9 +390,22 @@ database:
# Database name # Database name
# 数据库名称 # 数据库名称
dbname: "sub2api" dbname: "sub2api"
# SSL mode: disable, require, verify-ca, verify-full # SSL mode: disable, prefer, require, verify-ca, verify-full
# SSL 模式:disable(禁用), require(要求), verify-ca(验证CA), verify-full(完全验证) # SSL 模式:disable(禁用), prefer(优先加密,默认), require(要求), verify-ca(验证CA), verify-full(完全验证)
sslmode: "disable" # 默认值为 "prefer",数据库支持 SSL 时自动使用加密连接,不支持时回退明文
sslmode: "prefer"
# Max open connections (高并发场景建议 256+,需配合 PostgreSQL max_connections 调整)
# 最大打开连接数
max_open_conns: 256
# Max idle connections (建议为 max_open_conns 的 50%,减少频繁建连开销)
# 最大空闲连接数
max_idle_conns: 128
# Connection max lifetime (minutes)
# 连接最大存活时间(分钟)
conn_max_lifetime_minutes: 30
# Connection max idle time (minutes)
# 空闲连接最大存活时间(分钟)
conn_max_idle_time_minutes: 5
# ============================================================================= # =============================================================================
# Redis Configuration # Redis Configuration
...@@ -402,6 +424,12 @@ redis: ...@@ -402,6 +424,12 @@ redis:
# Database number (0-15) # Database number (0-15)
# 数据库编号(0-15) # 数据库编号(0-15)
db: 0 db: 0
# Connection pool size (max concurrent connections)
# 连接池大小(最大并发连接数)
pool_size: 1024
# Minimum number of idle connections (高并发场景建议 128+,保持足够热连接)
# 最小空闲连接数
min_idle_conns: 128
# Enable TLS/SSL connection # Enable TLS/SSL connection
# 是否启用 TLS/SSL 连接 # 是否启用 TLS/SSL 连接
enable_tls: false enable_tls: false
......
# =============================================================================
# Sub2API Docker Compose Host Configuration (Local Build)
# =============================================================================
# Quick Start:
# 1. Copy .env.example to .env and configure
# 2. docker-compose -f docker-compose-host.yml up -d --build
# 3. Check logs: docker-compose -f docker-compose-host.yml logs -f sub2api
# 4. Access: http://localhost:8080
#
# This configuration builds the image from source (Dockerfile in project root).
# All configuration is done via environment variables.
# No Setup Wizard needed - the system auto-initializes on first run.
# =============================================================================
services:
# ===========================================================================
# Sub2API Application
# ===========================================================================
sub2api:
#image: weishaw/sub2api:latest
image: yangjianbo/aicodex2api:latest
build:
context: ..
dockerfile: Dockerfile
container_name: sub2api
restart: unless-stopped
network_mode: host
ulimits:
nofile:
soft: 800000
hard: 800000
volumes:
# Data persistence (config.yaml will be auto-generated here)
- sub2api_data:/app/data
# Mount custom config.yaml (optional, overrides auto-generated config)
#- ./config.yaml:/app/data/config.yaml:ro
environment:
# =======================================================================
# Auto Setup (REQUIRED for Docker deployment)
# =======================================================================
- AUTO_SETUP=true
# =======================================================================
# Server Configuration
# =======================================================================
- SERVER_HOST=0.0.0.0
- SERVER_PORT=8080
- SERVER_MODE=${SERVER_MODE:-release}
- RUN_MODE=${RUN_MODE:-standard}
# =======================================================================
# Database Configuration (PostgreSQL)
# =======================================================================
# Using host network: point to host/external DB by DATABASE_HOST/DATABASE_PORT
- DATABASE_HOST=${DATABASE_HOST:-127.0.0.1}
- DATABASE_PORT=${DATABASE_PORT:-5432}
- DATABASE_USER=${POSTGRES_USER:-sub2api}
- DATABASE_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required}
- DATABASE_DBNAME=${POSTGRES_DB:-sub2api}
- DATABASE_SSLMODE=disable
- DATABASE_MAX_OPEN_CONNS=${DATABASE_MAX_OPEN_CONNS:-50}
- DATABASE_MAX_IDLE_CONNS=${DATABASE_MAX_IDLE_CONNS:-10}
- DATABASE_CONN_MAX_LIFETIME_MINUTES=${DATABASE_CONN_MAX_LIFETIME_MINUTES:-30}
- DATABASE_CONN_MAX_IDLE_TIME_MINUTES=${DATABASE_CONN_MAX_IDLE_TIME_MINUTES:-5}
# =======================================================================
# Gateway Configuration
# =======================================================================
- GATEWAY_FORCE_CODEX_CLI=${GATEWAY_FORCE_CODEX_CLI:-false}
- GATEWAY_MAX_IDLE_CONNS=${GATEWAY_MAX_IDLE_CONNS:-2560}
- GATEWAY_MAX_IDLE_CONNS_PER_HOST=${GATEWAY_MAX_IDLE_CONNS_PER_HOST:-120}
- GATEWAY_MAX_CONNS_PER_HOST=${GATEWAY_MAX_CONNS_PER_HOST:-8192}
# =======================================================================
# Redis Configuration
# =======================================================================
# Using host network: point to host/external Redis by REDIS_HOST/REDIS_PORT
- REDIS_HOST=${REDIS_HOST:-127.0.0.1}
- REDIS_PORT=${REDIS_PORT:-6379}
- REDIS_PASSWORD=${REDIS_PASSWORD:-}
- REDIS_DB=${REDIS_DB:-0}
- REDIS_POOL_SIZE=${REDIS_POOL_SIZE:-1024}
- REDIS_MIN_IDLE_CONNS=${REDIS_MIN_IDLE_CONNS:-10}
- REDIS_ENABLE_TLS=${REDIS_ENABLE_TLS:-false}
# =======================================================================
# Admin Account (auto-created on first run)
# =======================================================================
- ADMIN_EMAIL=${ADMIN_EMAIL:-admin@sub2api.local}
- ADMIN_PASSWORD=${ADMIN_PASSWORD:-}
# =======================================================================
# JWT Configuration
# =======================================================================
# Leave empty to auto-generate (recommended)
- JWT_SECRET=${JWT_SECRET:-}
- JWT_EXPIRE_HOUR=${JWT_EXPIRE_HOUR:-24}
# =======================================================================
# TOTP (2FA) Configuration
# =======================================================================
# IMPORTANT: Set a fixed encryption key for TOTP secrets. If left empty,
# a random key will be generated on each startup, causing all existing
# TOTP configurations to become invalid (users won't be able to login
# with 2FA).
# Generate a secure key: openssl rand -hex 32
- TOTP_ENCRYPTION_KEY=${TOTP_ENCRYPTION_KEY:-}
# =======================================================================
# Timezone Configuration
# This affects ALL time operations in the application:
# - Database timestamps
# - Usage statistics "today" boundary
# - Subscription expiry times
# - Log timestamps
# Common values: Asia/Shanghai, America/New_York, Europe/London, UTC
# =======================================================================
- TZ=${TZ:-Asia/Shanghai}
# =======================================================================
# Gemini OAuth Configuration (for Gemini accounts)
# =======================================================================
- GEMINI_OAUTH_CLIENT_ID=${GEMINI_OAUTH_CLIENT_ID:-}
- GEMINI_OAUTH_CLIENT_SECRET=${GEMINI_OAUTH_CLIENT_SECRET:-}
- GEMINI_OAUTH_SCOPES=${GEMINI_OAUTH_SCOPES:-}
- GEMINI_QUOTA_POLICY=${GEMINI_QUOTA_POLICY:-}
# =======================================================================
# Security Configuration (URL Allowlist)
# =======================================================================
# Allow private IP addresses for CRS sync (for internal deployments)
- SECURITY_URL_ALLOWLIST_ALLOW_PRIVATE_HOSTS=${SECURITY_URL_ALLOWLIST_ALLOW_PRIVATE_HOSTS:-true}
depends_on:
postgres:
condition: service_healthy
redis:
condition: service_healthy
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8080/health"]
interval: 30s
timeout: 10s
retries: 3
start_period: 30s
# ===========================================================================
# PostgreSQL Database
# ===========================================================================
postgres:
image: postgres:18-alpine
container_name: sub2api-postgres
restart: unless-stopped
network_mode: host
ulimits:
nofile:
soft: 800000
hard: 800000
volumes:
- postgres_data:/var/lib/postgresql/data
environment:
- POSTGRES_USER=${POSTGRES_USER:-sub2api}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required}
- POSTGRES_DB=${POSTGRES_DB:-sub2api}
- TZ=${TZ:-Asia/Shanghai}
command:
- "postgres"
- "-c"
- "listen_addresses=127.0.0.1"
# 监听端口:与应用侧 DATABASE_PORT 保持一致。
- "-c"
- "port=${DATABASE_PORT:-5432}"
# 连接数上限:需要结合应用侧 DATABASE_MAX_OPEN_CONNS 调整。
# 注意:max_connections 过大可能导致内存占用与上下文切换开销显著上升。
- "-c"
- "max_connections=${POSTGRES_MAX_CONNECTIONS:-1024}"
# 典型内存参数(建议结合机器内存调优;不确定就保持默认或小步调大)。
- "-c"
- "shared_buffers=${POSTGRES_SHARED_BUFFERS:-1GB}"
- "-c"
- "effective_cache_size=${POSTGRES_EFFECTIVE_CACHE_SIZE:-6GB}"
- "-c"
- "maintenance_work_mem=${POSTGRES_MAINTENANCE_WORK_MEM:-128MB}"
healthcheck:
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-sub2api} -d ${POSTGRES_DB:-sub2api} -p ${DATABASE_PORT:-5432}"]
interval: 10s
timeout: 5s
retries: 5
start_period: 10s
# Note: bound to localhost only; not exposed to external network by default.
# ===========================================================================
# Redis Cache
# ===========================================================================
redis:
image: redis:8-alpine
container_name: sub2api-redis
restart: unless-stopped
network_mode: host
ulimits:
nofile:
soft: 100000
hard: 100000
volumes:
- redis_data:/data
command: >
redis-server
--bind 127.0.0.1
--port ${REDIS_PORT:-6379}
--maxclients ${REDIS_MAXCLIENTS:-50000}
--save 60 1
--appendonly yes
--appendfsync everysec
${REDIS_PASSWORD:+--requirepass ${REDIS_PASSWORD}}
environment:
- TZ=${TZ:-Asia/Shanghai}
# REDISCLI_AUTH is used by redis-cli for authentication (safer than -a flag)
- REDISCLI_AUTH=${REDIS_PASSWORD:-}
healthcheck:
test: ["CMD-SHELL", "redis-cli -p ${REDIS_PORT:-6379} -a \"$REDISCLI_AUTH\" ping | grep -q PONG || redis-cli -p ${REDIS_PORT:-6379} ping | grep -q PONG"]
interval: 10s
timeout: 5s
retries: 5
start_period: 5s
# =============================================================================
# Volumes
# =============================================================================
volumes:
sub2api_data:
driver: local
postgres_data:
driver: local
redis_data:
driver: local
...@@ -57,6 +57,10 @@ services: ...@@ -57,6 +57,10 @@ services:
- DATABASE_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required} - DATABASE_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required}
- DATABASE_DBNAME=${POSTGRES_DB:-sub2api} - DATABASE_DBNAME=${POSTGRES_DB:-sub2api}
- DATABASE_SSLMODE=disable - DATABASE_SSLMODE=disable
- DATABASE_MAX_OPEN_CONNS=${DATABASE_MAX_OPEN_CONNS:-50}
- DATABASE_MAX_IDLE_CONNS=${DATABASE_MAX_IDLE_CONNS:-10}
- DATABASE_CONN_MAX_LIFETIME_MINUTES=${DATABASE_CONN_MAX_LIFETIME_MINUTES:-30}
- DATABASE_CONN_MAX_IDLE_TIME_MINUTES=${DATABASE_CONN_MAX_IDLE_TIME_MINUTES:-5}
# ======================================================================= # =======================================================================
# Redis Configuration # Redis Configuration
...@@ -65,6 +69,8 @@ services: ...@@ -65,6 +69,8 @@ services:
- REDIS_PORT=6379 - REDIS_PORT=6379
- REDIS_PASSWORD=${REDIS_PASSWORD:-} - REDIS_PASSWORD=${REDIS_PASSWORD:-}
- REDIS_DB=${REDIS_DB:-0} - REDIS_DB=${REDIS_DB:-0}
- REDIS_POOL_SIZE=${REDIS_POOL_SIZE:-1024}
- REDIS_MIN_IDLE_CONNS=${REDIS_MIN_IDLE_CONNS:-10}
# ======================================================================= # =======================================================================
# Admin Account (auto-created on first run) # Admin Account (auto-created on first run)
......
...@@ -62,6 +62,10 @@ services: ...@@ -62,6 +62,10 @@ services:
- DATABASE_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required} - DATABASE_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required}
- DATABASE_DBNAME=${POSTGRES_DB:-sub2api} - DATABASE_DBNAME=${POSTGRES_DB:-sub2api}
- DATABASE_SSLMODE=disable - DATABASE_SSLMODE=disable
- DATABASE_MAX_OPEN_CONNS=${DATABASE_MAX_OPEN_CONNS:-50}
- DATABASE_MAX_IDLE_CONNS=${DATABASE_MAX_IDLE_CONNS:-10}
- DATABASE_CONN_MAX_LIFETIME_MINUTES=${DATABASE_CONN_MAX_LIFETIME_MINUTES:-30}
- DATABASE_CONN_MAX_IDLE_TIME_MINUTES=${DATABASE_CONN_MAX_IDLE_TIME_MINUTES:-5}
# ======================================================================= # =======================================================================
# Redis Configuration # Redis Configuration
...@@ -70,6 +74,8 @@ services: ...@@ -70,6 +74,8 @@ services:
- REDIS_PORT=6379 - REDIS_PORT=6379
- REDIS_PASSWORD=${REDIS_PASSWORD:-} - REDIS_PASSWORD=${REDIS_PASSWORD:-}
- REDIS_DB=${REDIS_DB:-0} - REDIS_DB=${REDIS_DB:-0}
- REDIS_POOL_SIZE=${REDIS_POOL_SIZE:-1024}
- REDIS_MIN_IDLE_CONNS=${REDIS_MIN_IDLE_CONNS:-10}
- REDIS_ENABLE_TLS=${REDIS_ENABLE_TLS:-false} - REDIS_ENABLE_TLS=${REDIS_ENABLE_TLS:-false}
# ======================================================================= # =======================================================================
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment