Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
a14dfb76
Commit
a14dfb76
authored
Feb 07, 2026
by
yangjianbo
Browse files
Merge branch 'dev-release'
parents
f3605ddc
2588fa6a
Changes
62
Hide whitespace changes
Inline
Side-by-side
backend/internal/service/antigravity_gateway_service_test.go
View file @
a14dfb76
...
...
@@ -7,6 +7,7 @@ import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
...
...
@@ -391,3 +392,37 @@ func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(
require
.
Equal
(
t
,
http
.
StatusServiceUnavailable
,
failoverErr
.
StatusCode
)
require
.
True
(
t
,
failoverErr
.
ForceCacheBilling
,
"ForceCacheBilling should be true for sticky session switch"
)
}
func
TestAntigravityStreamUpstreamResponse_UsageAndFirstToken
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
writer
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
writer
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/v1/messages"
,
nil
)
pr
,
pw
:=
io
.
Pipe
()
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{},
Body
:
pr
}
go
func
()
{
defer
func
()
{
_
=
pw
.
Close
()
}()
_
,
_
=
pw
.
Write
([]
byte
(
"data: {
\"
usage
\"
:{
\"
input_tokens
\"
:1,
\"
output_tokens
\"
:2,
\"
cache_read_input_tokens
\"
:3,
\"
cache_creation_input_tokens
\"
:4}}
\n
"
))
_
,
_
=
pw
.
Write
([]
byte
(
"data: {
\"
usage
\"
:{
\"
output_tokens
\"
:5}}
\n
"
))
}()
svc
:=
&
AntigravityGatewayService
{}
start
:=
time
.
Now
()
.
Add
(
-
10
*
time
.
Millisecond
)
usage
,
firstTokenMs
:=
svc
.
streamUpstreamResponse
(
c
,
resp
,
start
)
_
=
pr
.
Close
()
require
.
NotNil
(
t
,
usage
)
require
.
Equal
(
t
,
1
,
usage
.
InputTokens
)
// 第二次事件覆盖 output_tokens
require
.
Equal
(
t
,
5
,
usage
.
OutputTokens
)
require
.
Equal
(
t
,
3
,
usage
.
CacheReadInputTokens
)
require
.
Equal
(
t
,
4
,
usage
.
CacheCreationInputTokens
)
if
firstTokenMs
==
nil
{
t
.
Fatalf
(
"expected firstTokenMs to be set"
)
}
// 确保有透传输出
require
.
True
(
t
,
strings
.
Contains
(
writer
.
Body
.
String
(),
"data:"
))
}
backend/internal/service/api_key_auth_cache_impl.go
View file @
a14dfb76
...
...
@@ -6,8 +6,7 @@ import (
"encoding/hex"
"errors"
"fmt"
"math/rand"
"sync"
"math/rand/v2"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
...
...
@@ -23,12 +22,6 @@ type apiKeyAuthCacheConfig struct {
singleflight
bool
}
var
(
jitterRandMu
sync
.
Mutex
// 认证缓存抖动使用独立随机源,避免全局 Seed
jitterRand
=
rand
.
New
(
rand
.
NewSource
(
time
.
Now
()
.
UnixNano
()))
)
func
newAPIKeyAuthCacheConfig
(
cfg
*
config
.
Config
)
apiKeyAuthCacheConfig
{
if
cfg
==
nil
{
return
apiKeyAuthCacheConfig
{}
...
...
@@ -56,6 +49,8 @@ func (c apiKeyAuthCacheConfig) negativeEnabled() bool {
return
c
.
negativeTTL
>
0
}
// jitterTTL 为缓存 TTL 添加抖动,避免多个请求在同一时刻同时过期触发集中回源。
// 这里直接使用 rand/v2 的顶层函数:并发安全,无需全局互斥锁。
func
(
c
apiKeyAuthCacheConfig
)
jitterTTL
(
ttl
time
.
Duration
)
time
.
Duration
{
if
ttl
<=
0
{
return
ttl
...
...
@@ -68,9 +63,7 @@ func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration {
percent
=
100
}
delta
:=
float64
(
percent
)
/
100
jitterRandMu
.
Lock
()
randVal
:=
jitterRand
.
Float64
()
jitterRandMu
.
Unlock
()
randVal
:=
rand
.
Float64
()
factor
:=
1
-
delta
+
randVal
*
(
2
*
delta
)
if
factor
<=
0
{
return
ttl
...
...
backend/internal/service/dashboard_service.go
View file @
a14dfb76
...
...
@@ -319,16 +319,16 @@ func (s *DashboardService) GetUserUsageTrend(ctx context.Context, startTime, end
return
trend
,
nil
}
func
(
s
*
DashboardService
)
GetBatchUserUsageStats
(
ctx
context
.
Context
,
userIDs
[]
int64
)
(
map
[
int64
]
*
usagestats
.
BatchUserUsageStats
,
error
)
{
stats
,
err
:=
s
.
usageRepo
.
GetBatchUserUsageStats
(
ctx
,
userIDs
)
func
(
s
*
DashboardService
)
GetBatchUserUsageStats
(
ctx
context
.
Context
,
userIDs
[]
int64
,
startTime
,
endTime
time
.
Time
)
(
map
[
int64
]
*
usagestats
.
BatchUserUsageStats
,
error
)
{
stats
,
err
:=
s
.
usageRepo
.
GetBatchUserUsageStats
(
ctx
,
userIDs
,
startTime
,
endTime
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get batch user usage stats: %w"
,
err
)
}
return
stats
,
nil
}
func
(
s
*
DashboardService
)
GetBatchAPIKeyUsageStats
(
ctx
context
.
Context
,
apiKeyIDs
[]
int64
)
(
map
[
int64
]
*
usagestats
.
BatchAPIKeyUsageStats
,
error
)
{
stats
,
err
:=
s
.
usageRepo
.
GetBatchAPIKeyUsageStats
(
ctx
,
apiKeyIDs
)
func
(
s
*
DashboardService
)
GetBatchAPIKeyUsageStats
(
ctx
context
.
Context
,
apiKeyIDs
[]
int64
,
startTime
,
endTime
time
.
Time
)
(
map
[
int64
]
*
usagestats
.
BatchAPIKeyUsageStats
,
error
)
{
stats
,
err
:=
s
.
usageRepo
.
GetBatchAPIKeyUsageStats
(
ctx
,
apiKeyIDs
,
startTime
,
endTime
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get batch api key usage stats: %w"
,
err
)
}
...
...
backend/internal/service/gateway_service.go
View file @
a14dfb76
...
...
@@ -4145,7 +4145,8 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner
.
Buffer
(
make
([]
byte
,
64
*
1024
),
maxLineSize
)
scanBuf := getSSEScannerBuf64K()
scanner.Buffer(scanBuf[:0], maxLineSize)
type scanEvent struct {
line string
...
...
@@ -4164,7 +4165,8 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go
func
()
{
go func(scanBuf *sseScannerBuf64K) {
defer putSSEScannerBuf64K(scanBuf)
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
...
...
@@ -4175,7 +4177,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
}(
scanBuf
)
defer close(done)
streamInterval := time.Duration(0)
...
...
@@ -4481,24 +4483,16 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
}
// replaceModelInResponseBody 替换响应体中的model字段
// 使用 gjson/sjson 精确替换,避免全量 JSON 反序列化
func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
var
resp
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
resp
);
err
!=
nil
{
return
body
}
model
,
ok
:=
resp
[
"model"
]
.
(
string
)
if
!
ok
||
model
!=
fromModel
{
return
body
}
resp
[
"model"
]
=
toModel
newBody
,
err
:=
json
.
Marshal
(
resp
)
if
err
!=
nil
{
return
body
if m := gjson.GetBytes(body, "model"); m.Exists() && m.Str == fromModel {
newBody, err := sjson.SetBytes(body, "model", toModel)
if err != nil {
return body
}
return newBody
}
return
newBody
return body
}
// RecordUsageInput 记录使用量的输入参数
...
...
backend/internal/service/gateway_service_streaming_test.go
0 → 100644
View file @
a14dfb76
package
service
import
(
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func
TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
cfg
:=
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
StreamDataIntervalTimeout
:
0
,
MaxLineSize
:
defaultMaxLineSize
,
},
}
svc
:=
&
GatewayService
{
cfg
:
cfg
,
rateLimitService
:
&
RateLimitService
{},
}
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/v1/messages"
,
nil
)
pr
,
pw
:=
io
.
Pipe
()
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Header
:
http
.
Header
{},
Body
:
pr
}
go
func
()
{
defer
func
()
{
_
=
pw
.
Close
()
}()
// Minimal SSE event to trigger parseSSEUsage
_
,
_
=
pw
.
Write
([]
byte
(
"data: {
\"
type
\"
:
\"
message_start
\"
,
\"
message
\"
:{
\"
usage
\"
:{
\"
input_tokens
\"
:3}}}
\n\n
"
))
_
,
_
=
pw
.
Write
([]
byte
(
"data: {
\"
type
\"
:
\"
message_delta
\"
,
\"
usage
\"
:{
\"
output_tokens
\"
:7}}
\n\n
"
))
_
,
_
=
pw
.
Write
([]
byte
(
"data: [DONE]
\n\n
"
))
}()
result
,
err
:=
svc
.
handleStreamingResponse
(
context
.
Background
(),
resp
,
c
,
&
Account
{
ID
:
1
},
time
.
Now
(),
"model"
,
"model"
,
nil
,
false
)
_
=
pr
.
Close
()
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
usage
)
require
.
Equal
(
t
,
3
,
result
.
usage
.
InputTokens
)
require
.
Equal
(
t
,
7
,
result
.
usage
.
OutputTokens
)
}
backend/internal/service/openai_codex_transform.go
View file @
a14dfb76
...
...
@@ -2,19 +2,7 @@ package service
import
(
_
"embed"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
)
const
(
opencodeCodexHeaderURL
=
"https://raw.githubusercontent.com/anomalyco/opencode/dev/packages/opencode/src/session/prompt/codex_header.txt"
codexCacheTTL
=
15
*
time
.
Minute
)
//go:embed prompts/codex_cli_instructions.md
...
...
@@ -77,12 +65,6 @@ type codexTransformResult struct {
PromptCacheKey
string
}
type
opencodeCacheMetadata
struct
{
ETag
string
`json:"etag"`
LastFetch
string
`json:"lastFetch,omitempty"`
LastChecked
int64
`json:"lastChecked"`
}
func
applyCodexOAuthTransform
(
reqBody
map
[
string
]
any
,
isCodexCLI
bool
)
codexTransformResult
{
result
:=
codexTransformResult
{}
// 工具续链需求会影响存储策略与 input 过滤逻辑。
...
...
@@ -216,54 +198,9 @@ func getNormalizedCodexModel(modelID string) string {
return
""
}
func
getOpenCodeCachedPrompt
(
url
,
cacheFileName
,
metaFileName
string
)
string
{
cacheDir
:=
codexCachePath
(
""
)
if
cacheDir
==
""
{
return
""
}
cacheFile
:=
filepath
.
Join
(
cacheDir
,
cacheFileName
)
metaFile
:=
filepath
.
Join
(
cacheDir
,
metaFileName
)
var
cachedContent
string
if
content
,
ok
:=
readFile
(
cacheFile
);
ok
{
cachedContent
=
content
}
var
meta
opencodeCacheMetadata
if
loadJSON
(
metaFile
,
&
meta
)
&&
meta
.
LastChecked
>
0
&&
cachedContent
!=
""
{
if
time
.
Since
(
time
.
UnixMilli
(
meta
.
LastChecked
))
<
codexCacheTTL
{
return
cachedContent
}
}
content
,
etag
,
status
,
err
:=
fetchWithETag
(
url
,
meta
.
ETag
)
if
err
==
nil
&&
status
==
http
.
StatusNotModified
&&
cachedContent
!=
""
{
return
cachedContent
}
if
err
==
nil
&&
status
>=
200
&&
status
<
300
&&
content
!=
""
{
_
=
writeFile
(
cacheFile
,
content
)
meta
=
opencodeCacheMetadata
{
ETag
:
etag
,
LastFetch
:
time
.
Now
()
.
UTC
()
.
Format
(
time
.
RFC3339
),
LastChecked
:
time
.
Now
()
.
UnixMilli
(),
}
_
=
writeJSON
(
metaFile
,
meta
)
return
content
}
return
cachedContent
}
func
getOpenCodeCodexHeader
()
string
{
// 优先从 opencode 仓库缓存获取指令。
opencodeInstructions
:=
getOpenCodeCachedPrompt
(
opencodeCodexHeaderURL
,
"opencode-codex-header.txt"
,
"opencode-codex-header-meta.json"
)
// 若 opencode 指令可用,直接返回。
if
opencodeInstructions
!=
""
{
return
opencodeInstructions
}
// 否则回退使用本地 Codex CLI 指令。
// 兼容保留:历史上这里会从 opencode 仓库拉取 codex_header.txt。
// 现在我们与 Codex CLI 一致,直接使用仓库内置的 instructions,避免读写缓存与外网依赖。
return
getCodexCLIInstructions
()
}
...
...
@@ -281,8 +218,8 @@ func GetCodexCLIInstructions() string {
}
// applyInstructions 处理 instructions 字段
// isCodexCLI=true: 仅补充缺失的 instructions(使用
opencode
指令)
// isCodexCLI=false: 优先使用
opencode
指令覆盖
// isCodexCLI=true: 仅补充缺失的 instructions(使用
内置 Codex CLI
指令)
// isCodexCLI=false: 优先使用
内置 Codex CLI
指令覆盖
func
applyInstructions
(
reqBody
map
[
string
]
any
,
isCodexCLI
bool
)
bool
{
if
isCodexCLI
{
return
applyCodexCLIInstructions
(
reqBody
)
...
...
@@ -291,13 +228,13 @@ func applyInstructions(reqBody map[string]any, isCodexCLI bool) bool {
}
// applyCodexCLIInstructions 为 Codex CLI 请求补充缺失的 instructions
// 仅在 instructions 为空时添加 opencode
指令
// 仅在 instructions 为空时添加
内置 Codex CLI 指令(不依赖
opencode
缓存/回源)
func
applyCodexCLIInstructions
(
reqBody
map
[
string
]
any
)
bool
{
if
!
isInstructionsEmpty
(
reqBody
)
{
return
false
// 已有有效 instructions,不修改
}
instructions
:=
strings
.
TrimSpace
(
get
OpenCodeCodexHeader
())
instructions
:=
strings
.
TrimSpace
(
get
CodexCLIInstructions
())
if
instructions
!=
""
{
reqBody
[
"instructions"
]
=
instructions
return
true
...
...
@@ -306,8 +243,8 @@ func applyCodexCLIInstructions(reqBody map[string]any) bool {
return
false
}
// applyOpenCodeInstructions 为非 Codex CLI 请求应用
opencode 指令
// 优先使用
opencode
指令覆盖
// applyOpenCodeInstructions 为非 Codex CLI 请求应用
内置 Codex CLI 指令(兼容历史函数名)
// 优先使用
内置 Codex CLI
指令覆盖
func
applyOpenCodeInstructions
(
reqBody
map
[
string
]
any
)
bool
{
instructions
:=
strings
.
TrimSpace
(
getOpenCodeCodexHeader
())
existingInstructions
,
_
:=
reqBody
[
"instructions"
]
.
(
string
)
...
...
@@ -489,85 +426,3 @@ func normalizeCodexTools(reqBody map[string]any) bool {
return
modified
}
func
codexCachePath
(
filename
string
)
string
{
home
,
err
:=
os
.
UserHomeDir
()
if
err
!=
nil
{
return
""
}
cacheDir
:=
filepath
.
Join
(
home
,
".opencode"
,
"cache"
)
if
filename
==
""
{
return
cacheDir
}
return
filepath
.
Join
(
cacheDir
,
filename
)
}
func
readFile
(
path
string
)
(
string
,
bool
)
{
if
path
==
""
{
return
""
,
false
}
data
,
err
:=
os
.
ReadFile
(
path
)
if
err
!=
nil
{
return
""
,
false
}
return
string
(
data
),
true
}
func
writeFile
(
path
,
content
string
)
error
{
if
path
==
""
{
return
fmt
.
Errorf
(
"empty cache path"
)
}
if
err
:=
os
.
MkdirAll
(
filepath
.
Dir
(
path
),
0
o755
);
err
!=
nil
{
return
err
}
return
os
.
WriteFile
(
path
,
[]
byte
(
content
),
0
o644
)
}
func
loadJSON
(
path
string
,
target
any
)
bool
{
data
,
err
:=
os
.
ReadFile
(
path
)
if
err
!=
nil
{
return
false
}
if
err
:=
json
.
Unmarshal
(
data
,
target
);
err
!=
nil
{
return
false
}
return
true
}
func
writeJSON
(
path
string
,
value
any
)
error
{
if
path
==
""
{
return
fmt
.
Errorf
(
"empty json path"
)
}
if
err
:=
os
.
MkdirAll
(
filepath
.
Dir
(
path
),
0
o755
);
err
!=
nil
{
return
err
}
data
,
err
:=
json
.
Marshal
(
value
)
if
err
!=
nil
{
return
err
}
return
os
.
WriteFile
(
path
,
data
,
0
o644
)
}
func
fetchWithETag
(
url
,
etag
string
)
(
string
,
string
,
int
,
error
)
{
req
,
err
:=
http
.
NewRequest
(
http
.
MethodGet
,
url
,
nil
)
if
err
!=
nil
{
return
""
,
""
,
0
,
err
}
req
.
Header
.
Set
(
"User-Agent"
,
"sub2api-codex"
)
if
etag
!=
""
{
req
.
Header
.
Set
(
"If-None-Match"
,
etag
)
}
resp
,
err
:=
http
.
DefaultClient
.
Do
(
req
)
if
err
!=
nil
{
return
""
,
""
,
0
,
err
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
body
,
err
:=
io
.
ReadAll
(
resp
.
Body
)
if
err
!=
nil
{
return
""
,
""
,
resp
.
StatusCode
,
err
}
return
string
(
body
),
resp
.
Header
.
Get
(
"etag"
),
resp
.
StatusCode
,
nil
}
backend/internal/service/openai_codex_transform_test.go
View file @
a14dfb76
package
service
import
(
"encoding/json"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func
TestApplyCodexOAuthTransform_ToolContinuationPreservesInput
(
t
*
testing
.
T
)
{
// 续链场景:保留 item_reference 与 id,但不再强制 store=true。
setupCodexCache
(
t
)
reqBody
:=
map
[
string
]
any
{
"model"
:
"gpt-5.2"
,
...
...
@@ -48,7 +43,6 @@ func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) {
func
TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved
(
t
*
testing
.
T
)
{
// 续链场景:显式 store=false 不再强制为 true,保持 false。
setupCodexCache
(
t
)
reqBody
:=
map
[
string
]
any
{
"model"
:
"gpt-5.1"
,
...
...
@@ -68,7 +62,6 @@ func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) {
func
TestApplyCodexOAuthTransform_ExplicitStoreTrueForcedFalse
(
t
*
testing
.
T
)
{
// 显式 store=true 也会强制为 false。
setupCodexCache
(
t
)
reqBody
:=
map
[
string
]
any
{
"model"
:
"gpt-5.1"
,
...
...
@@ -88,7 +81,6 @@ func TestApplyCodexOAuthTransform_ExplicitStoreTrueForcedFalse(t *testing.T) {
func
TestApplyCodexOAuthTransform_NonContinuationDefaultsStoreFalseAndStripsIDs
(
t
*
testing
.
T
)
{
// 非续链场景:未设置 store 时默认 false,并移除 input 中的 id。
setupCodexCache
(
t
)
reqBody
:=
map
[
string
]
any
{
"model"
:
"gpt-5.1"
,
...
...
@@ -130,8 +122,6 @@ func TestFilterCodexInput_RemovesItemReferenceWhenNotPreserved(t *testing.T) {
}
func
TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunctionTools
(
t
*
testing
.
T
)
{
setupCodexCache
(
t
)
reqBody
:=
map
[
string
]
any
{
"model"
:
"gpt-5.1"
,
"tools"
:
[]
any
{
...
...
@@ -162,7 +152,6 @@ func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunction
func
TestApplyCodexOAuthTransform_EmptyInput
(
t
*
testing
.
T
)
{
// 空 input 应保持为空且不触发异常。
setupCodexCache
(
t
)
reqBody
:=
map
[
string
]
any
{
"model"
:
"gpt-5.1"
,
...
...
@@ -187,88 +176,27 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) {
for
input
,
expected
:=
range
cases
{
require
.
Equal
(
t
,
expected
,
normalizeCodexModel
(
input
))
}
}
func
TestApplyCodexOAuthTransform_CodexCLI_PreservesExistingInstructions
(
t
*
testing
.
T
)
{
// Codex CLI 场景:已有 instructions 时保持不变
setupCodexCache
(
t
)
// Codex CLI 场景:已有 instructions 时不修改
reqBody
:=
map
[
string
]
any
{
"model"
:
"gpt-5.1"
,
"instructions"
:
"user custom instructions"
,
"input"
:
[]
any
{},
}
result
:=
applyCodexOAuthTransform
(
reqBody
,
true
)
instructions
,
ok
:=
reqBody
[
"instructions"
]
.
(
string
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"user custom instructions"
,
instructions
)
// instructions 未变,但其他字段(如 store、stream)可能被修改
require
.
True
(
t
,
result
.
Modified
)
}
func
TestApplyCodexOAuthTransform_CodexCLI_AddsInstructionsWhenEmpty
(
t
*
testing
.
T
)
{
// Codex CLI 场景:无 instructions 时补充内置指令
setupCodexCache
(
t
)
reqBody
:=
map
[
string
]
any
{
"model"
:
"gpt-5.1"
,
"input"
:
[]
any
{},
"instructions"
:
"existing instructions"
,
}
result
:=
applyCodexOAuthTransform
(
reqBody
,
true
)
instructions
,
ok
:=
reqBody
[
"instructions"
]
.
(
string
)
require
.
True
(
t
,
ok
)
require
.
NotEmpty
(
t
,
instructions
)
require
.
True
(
t
,
result
.
Modified
)
}
func
TestApplyCodexOAuthTransform_NonCodexCLI_UsesOpenCodeInstructions
(
t
*
testing
.
T
)
{
// 非 Codex CLI 场景:使用 opencode 指令(缓存中有 header)
setupCodexCache
(
t
)
reqBody
:=
map
[
string
]
any
{
"model"
:
"gpt-5.1"
,
"input"
:
[]
any
{},
}
result
:=
applyCodexOAuthTransform
(
reqBody
,
false
)
result
:=
applyCodexOAuthTransform
(
reqBody
,
true
)
// isCodexCLI=true
instructions
,
ok
:=
reqBody
[
"instructions"
]
.
(
string
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"header"
,
instructions
)
// setupCodexCache 设置的缓存内容
require
.
True
(
t
,
result
.
Modified
)
}
func
setupCodexCache
(
t
*
testing
.
T
)
{
t
.
Helper
()
// 使用临时 HOME 避免触发网络拉取 header。
// Windows 使用 USERPROFILE,Unix 使用 HOME。
tempDir
:=
t
.
TempDir
()
t
.
Setenv
(
"HOME"
,
tempDir
)
t
.
Setenv
(
"USERPROFILE"
,
tempDir
)
cacheDir
:=
filepath
.
Join
(
tempDir
,
".opencode"
,
"cache"
)
require
.
NoError
(
t
,
os
.
MkdirAll
(
cacheDir
,
0
o755
))
require
.
NoError
(
t
,
os
.
WriteFile
(
filepath
.
Join
(
cacheDir
,
"opencode-codex-header.txt"
),
[]
byte
(
"header"
),
0
o644
))
meta
:=
map
[
string
]
any
{
"etag"
:
""
,
"lastFetch"
:
time
.
Now
()
.
UTC
()
.
Format
(
time
.
RFC3339
),
"lastChecked"
:
time
.
Now
()
.
UnixMilli
(),
}
data
,
err
:=
json
.
Marshal
(
meta
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
os
.
WriteFile
(
filepath
.
Join
(
cacheDir
,
"opencode-codex-header-meta.json"
),
data
,
0
o644
))
require
.
Equal
(
t
,
"existing instructions"
,
instructions
)
// Modified 仍可能为 true(因为其他字段被修改),但 instructions 应保持不变
_
=
result
}
func
TestApplyCodexOAuthTransform_CodexCLI_SuppliesDefaultWhenEmpty
(
t
*
testing
.
T
)
{
// Codex CLI 场景:无 instructions 时补充默认值
setupCodexCache
(
t
)
reqBody
:=
map
[
string
]
any
{
"model"
:
"gpt-5.1"
,
...
...
@@ -284,8 +212,7 @@ func TestApplyCodexOAuthTransform_CodexCLI_SuppliesDefaultWhenEmpty(t *testing.T
}
func
TestApplyCodexOAuthTransform_NonCodexCLI_OverridesInstructions
(
t
*
testing
.
T
)
{
// 非 Codex CLI 场景:使用 opencode 指令覆盖
setupCodexCache
(
t
)
// 非 Codex CLI 场景:使用内置 Codex CLI 指令覆盖
reqBody
:=
map
[
string
]
any
{
"model"
:
"gpt-5.1"
,
...
...
backend/internal/service/openai_gateway_service.go
View file @
a14dfb76
...
...
@@ -24,6 +24,8 @@ import (
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const
(
...
...
@@ -765,7 +767,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
bodyModified
:=
false
originalModel
:=
reqModel
isCodexCLI
:=
openai
.
IsCodexCLIRequest
(
c
.
GetHeader
(
"User-Agent"
))
isCodexCLI
:=
openai
.
IsCodexCLIRequest
(
c
.
GetHeader
(
"User-Agent"
))
||
(
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
ForceCodexCLI
)
// 对所有请求执行模型映射(包含 Codex CLI)。
mappedModel
:=
account
.
GetMappedModel
(
reqModel
)
...
...
@@ -969,6 +971,10 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}
}
if
usage
==
nil
{
usage
=
&
OpenAIUsage
{}
}
reasoningEffort
:=
extractOpenAIReasoningEffort
(
reqBody
,
originalModel
)
return
&
OpenAIForwardResult
{
...
...
@@ -1053,6 +1059,12 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
req
.
Header
.
Set
(
"user-agent"
,
customUA
)
}
// 若开启 ForceCodexCLI,则强制将上游 User-Agent 伪装为 Codex CLI。
// 用于网关未透传/改写 User-Agent 时,仍能命中 Codex 侧识别逻辑。
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
ForceCodexCLI
{
req
.
Header
.
Set
(
"user-agent"
,
"codex_cli_rs/0.98.0"
)
}
// Ensure required headers exist
if
req
.
Header
.
Get
(
"content-type"
)
==
""
{
req
.
Header
.
Set
(
"content-type"
,
"application/json"
)
...
...
@@ -1233,7 +1245,8 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
MaxLineSize
>
0
{
maxLineSize
=
s
.
cfg
.
Gateway
.
MaxLineSize
}
scanner
.
Buffer
(
make
([]
byte
,
64
*
1024
),
maxLineSize
)
scanBuf
:=
getSSEScannerBuf64K
()
scanner
.
Buffer
(
scanBuf
[
:
0
],
maxLineSize
)
type
scanEvent
struct
{
line
string
...
...
@@ -1252,7 +1265,8 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
}
var
lastReadAt
int64
atomic
.
StoreInt64
(
&
lastReadAt
,
time
.
Now
()
.
UnixNano
())
go
func
()
{
go
func
(
scanBuf
*
sseScannerBuf64K
)
{
defer
putSSEScannerBuf64K
(
scanBuf
)
defer
close
(
events
)
for
scanner
.
Scan
()
{
atomic
.
StoreInt64
(
&
lastReadAt
,
time
.
Now
()
.
UnixNano
())
...
...
@@ -1263,7 +1277,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if
err
:=
scanner
.
Err
();
err
!=
nil
{
_
=
sendEvent
(
scanEvent
{
err
:
err
})
}
}()
}(
scanBuf
)
defer
close
(
done
)
streamInterval
:=
time
.
Duration
(
0
)
...
...
@@ -1442,31 +1456,22 @@ func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel st
return
line
}
var
event
map
[
string
]
any
if
err
:=
json
.
Unmarshal
([]
byte
(
data
),
&
event
);
err
!=
nil
{
return
line
}
// Replace model in response
if
m
,
ok
:=
event
[
"model"
]
.
(
string
);
ok
&&
m
==
fromModel
{
event
[
"model"
]
=
toModel
newData
,
err
:=
json
.
Marshal
(
event
)
// 使用 gjson 精确检查 model 字段,避免全量 JSON 反序列化
if
m
:=
gjson
.
Get
(
data
,
"model"
);
m
.
Exists
()
&&
m
.
Str
==
fromModel
{
newData
,
err
:=
sjson
.
Set
(
data
,
"model"
,
toModel
)
if
err
!=
nil
{
return
line
}
return
"data: "
+
string
(
newData
)
return
"data: "
+
newData
}
// Check nested response
if
response
,
ok
:=
event
[
"response"
]
.
(
map
[
string
]
any
);
ok
{
if
m
,
ok
:=
response
[
"model"
]
.
(
string
);
ok
&&
m
==
fromModel
{
response
[
"model"
]
=
toModel
newData
,
err
:=
json
.
Marshal
(
event
)
if
err
!=
nil
{
return
line
}
return
"data: "
+
string
(
newData
)
// 检查嵌套的 response.model 字段
if
m
:=
gjson
.
Get
(
data
,
"response.model"
);
m
.
Exists
()
&&
m
.
Str
==
fromModel
{
newData
,
err
:=
sjson
.
Set
(
data
,
"response.model"
,
toModel
)
if
err
!=
nil
{
return
line
}
return
"data: "
+
newData
}
return
line
...
...
@@ -1686,23 +1691,15 @@ func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, erro
}
func
(
s
*
OpenAIGatewayService
)
replaceModelInResponseBody
(
body
[]
byte
,
fromModel
,
toModel
string
)
[]
byte
{
var
resp
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
resp
);
err
!=
nil
{
return
body
}
model
,
ok
:=
resp
[
"model"
]
.
(
string
)
if
!
ok
||
model
!=
fromModel
{
return
body
}
resp
[
"model"
]
=
toModel
newBody
,
err
:=
json
.
Marshal
(
resp
)
if
err
!=
nil
{
return
body
// 使用 gjson/sjson 精确替换 model 字段,避免全量 JSON 反序列化
if
m
:=
gjson
.
GetBytes
(
body
,
"model"
);
m
.
Exists
()
&&
m
.
Str
==
fromModel
{
newBody
,
err
:=
sjson
.
SetBytes
(
body
,
"model"
,
toModel
)
if
err
!=
nil
{
return
body
}
return
newBody
}
return
newBody
return
body
}
// OpenAIRecordUsageInput input for recording usage
...
...
backend/internal/service/openai_gateway_service_test.go
View file @
a14dfb76
...
...
@@ -14,6 +14,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type
stubOpenAIAccountRepo
struct
{
...
...
@@ -1082,6 +1083,43 @@ func TestOpenAIStreamingHeadersOverride(t *testing.T) {
}
}
func
TestOpenAIStreamingReuseScannerBufferAndStillWorks
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
cfg
:=
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
StreamDataIntervalTimeout
:
0
,
StreamKeepaliveInterval
:
0
,
MaxLineSize
:
defaultMaxLineSize
,
},
}
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
}
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
pr
,
pw
:=
io
.
Pipe
()
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
pr
,
Header
:
http
.
Header
{},
}
go
func
()
{
defer
func
()
{
_
=
pw
.
Close
()
}()
_
,
_
=
pw
.
Write
([]
byte
(
"data: {
\"
type
\"
:
\"
response.completed
\"
,
\"
response
\"
:{
\"
usage
\"
:{
\"
input_tokens
\"
:1,
\"
output_tokens
\"
:2,
\"
input_tokens_details
\"
:{
\"
cached_tokens
\"
:3}}}}
\n\n
"
))
}()
result
,
err
:=
svc
.
handleStreamingResponse
(
c
.
Request
.
Context
(),
resp
,
c
,
&
Account
{
ID
:
1
},
time
.
Now
(),
"model"
,
"model"
)
_
=
pr
.
Close
()
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
usage
)
require
.
Equal
(
t
,
1
,
result
.
usage
.
InputTokens
)
require
.
Equal
(
t
,
2
,
result
.
usage
.
OutputTokens
)
require
.
Equal
(
t
,
3
,
result
.
usage
.
CacheReadInputTokens
)
}
func
TestOpenAIInvalidBaseURLWhenAllowlistDisabled
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
cfg
:=
&
config
.
Config
{
...
...
@@ -1165,3 +1203,226 @@ func TestOpenAIValidateUpstreamBaseURLEnabledEnforcesAllowlist(t *testing.T) {
t
.
Fatalf
(
"expected non-allowlisted host to fail"
)
}
}
// ==================== P1-08 修复:model 替换性能优化测试 ====================
func
TestReplaceModelInSSELine
(
t
*
testing
.
T
)
{
svc
:=
&
OpenAIGatewayService
{}
tests
:=
[]
struct
{
name
string
line
string
from
string
to
string
expected
string
}{
{
name
:
"顶层 model 字段替换"
,
line
:
`data: {"id":"chatcmpl-123","model":"gpt-4o","choices":[]}`
,
from
:
"gpt-4o"
,
to
:
"my-custom-model"
,
expected
:
`data: {"id":"chatcmpl-123","model":"my-custom-model","choices":[]}`
,
},
{
name
:
"嵌套 response.model 替换"
,
line
:
`data: {"type":"response","response":{"id":"resp-1","model":"gpt-4o","output":[]}}`
,
from
:
"gpt-4o"
,
to
:
"my-model"
,
expected
:
`data: {"type":"response","response":{"id":"resp-1","model":"my-model","output":[]}}`
,
},
{
name
:
"model 不匹配时不替换"
,
line
:
`data: {"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`
,
from
:
"gpt-4o"
,
to
:
"my-model"
,
expected
:
`data: {"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`
,
},
{
name
:
"无 model 字段时不替换"
,
line
:
`data: {"id":"chatcmpl-123","choices":[]}`
,
from
:
"gpt-4o"
,
to
:
"my-model"
,
expected
:
`data: {"id":"chatcmpl-123","choices":[]}`
,
},
{
name
:
"空 data 行"
,
line
:
`data: `
,
from
:
"gpt-4o"
,
to
:
"my-model"
,
expected
:
`data: `
,
},
{
name
:
"[DONE] 行"
,
line
:
`data: [DONE]`
,
from
:
"gpt-4o"
,
to
:
"my-model"
,
expected
:
`data: [DONE]`
,
},
{
name
:
"非 data: 前缀行"
,
line
:
`event: message`
,
from
:
"gpt-4o"
,
to
:
"my-model"
,
expected
:
`event: message`
,
},
{
name
:
"非法 JSON 不替换"
,
line
:
`data: {invalid json}`
,
from
:
"gpt-4o"
,
to
:
"my-model"
,
expected
:
`data: {invalid json}`
,
},
{
name
:
"无空格 data: 格式"
,
line
:
`data:{"id":"x","model":"gpt-4o"}`
,
from
:
"gpt-4o"
,
to
:
"my-model"
,
expected
:
`data: {"id":"x","model":"my-model"}`
,
},
{
name
:
"model 名含特殊字符"
,
line
:
`data: {"model":"org/model-v2.1-beta"}`
,
from
:
"org/model-v2.1-beta"
,
to
:
"custom/alias"
,
expected
:
`data: {"model":"custom/alias"}`
,
},
{
name
:
"空行"
,
line
:
""
,
from
:
"gpt-4o"
,
to
:
"my-model"
,
expected
:
""
,
},
{
name
:
"保持其他字段不变"
,
line
:
`data: {"id":"abc","object":"chat.completion.chunk","model":"gpt-4o","created":1234567890,"choices":[{"index":0,"delta":{"content":"hi"}}]}`
,
from
:
"gpt-4o"
,
to
:
"alias"
,
expected
:
`data: {"id":"abc","object":"chat.completion.chunk","model":"alias","created":1234567890,"choices":[{"index":0,"delta":{"content":"hi"}}]}`
,
},
{
name
:
"顶层优先于嵌套:同时存在两个 model"
,
line
:
`data: {"model":"gpt-4o","response":{"model":"gpt-4o"}}`
,
from
:
"gpt-4o"
,
to
:
"replaced"
,
expected
:
`data: {"model":"replaced","response":{"model":"gpt-4o"}}`
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
got
:=
svc
.
replaceModelInSSELine
(
tt
.
line
,
tt
.
from
,
tt
.
to
)
require
.
Equal
(
t
,
tt
.
expected
,
got
)
})
}
}
func
TestReplaceModelInSSEBody
(
t
*
testing
.
T
)
{
svc
:=
&
OpenAIGatewayService
{}
tests
:=
[]
struct
{
name
string
body
string
from
string
to
string
expected
string
}{
{
name
:
"多行 SSE body 替换"
,
body
:
"data: {
\"
model
\"
:
\"
gpt-4o
\"
,
\"
choices
\"
:[]}
\n\n
data: {
\"
model
\"
:
\"
gpt-4o
\"
,
\"
choices
\"
:[{
\"
delta
\"
:{
\"
content
\"
:
\"
hi
\"
}}]}
\n\n
data: [DONE]
\n
"
,
from
:
"gpt-4o"
,
to
:
"alias"
,
expected
:
"data: {
\"
model
\"
:
\"
alias
\"
,
\"
choices
\"
:[]}
\n\n
data: {
\"
model
\"
:
\"
alias
\"
,
\"
choices
\"
:[{
\"
delta
\"
:{
\"
content
\"
:
\"
hi
\"
}}]}
\n\n
data: [DONE]
\n
"
,
},
{
name
:
"无需替换的 body"
,
body
:
"data: {
\"
model
\"
:
\"
gpt-3.5-turbo
\"
}
\n\n
data: [DONE]
\n
"
,
from
:
"gpt-4o"
,
to
:
"alias"
,
expected
:
"data: {
\"
model
\"
:
\"
gpt-3.5-turbo
\"
}
\n\n
data: [DONE]
\n
"
,
},
{
name
:
"混合 event 和 data 行"
,
body
:
"event: message
\n
data: {
\"
model
\"
:
\"
gpt-4o
\"
}
\n\n
"
,
from
:
"gpt-4o"
,
to
:
"alias"
,
expected
:
"event: message
\n
data: {
\"
model
\"
:
\"
alias
\"
}
\n\n
"
,
},
{
name
:
"空 body"
,
body
:
""
,
from
:
"gpt-4o"
,
to
:
"alias"
,
expected
:
""
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
got
:=
svc
.
replaceModelInSSEBody
(
tt
.
body
,
tt
.
from
,
tt
.
to
)
require
.
Equal
(
t
,
tt
.
expected
,
got
)
})
}
}
func
TestReplaceModelInResponseBody
(
t
*
testing
.
T
)
{
svc
:=
&
OpenAIGatewayService
{}
tests
:=
[]
struct
{
name
string
body
string
from
string
to
string
expected
string
}{
{
name
:
"替换顶层 model"
,
body
:
`{"id":"chatcmpl-123","model":"gpt-4o","choices":[]}`
,
from
:
"gpt-4o"
,
to
:
"alias"
,
expected
:
`{"id":"chatcmpl-123","model":"alias","choices":[]}`
,
},
{
name
:
"model 不匹配不替换"
,
body
:
`{"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`
,
from
:
"gpt-4o"
,
to
:
"alias"
,
expected
:
`{"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`
,
},
{
name
:
"无 model 字段不替换"
,
body
:
`{"id":"chatcmpl-123","choices":[]}`
,
from
:
"gpt-4o"
,
to
:
"alias"
,
expected
:
`{"id":"chatcmpl-123","choices":[]}`
,
},
{
name
:
"非法 JSON 返回原值"
,
body
:
`not json`
,
from
:
"gpt-4o"
,
to
:
"alias"
,
expected
:
`not json`
,
},
{
name
:
"空 body 返回原值"
,
body
:
``
,
from
:
"gpt-4o"
,
to
:
"alias"
,
expected
:
``
,
},
{
name
:
"保持嵌套结构不变"
,
body
:
`{"model":"gpt-4o","usage":{"prompt_tokens":10,"completion_tokens":20},"choices":[{"message":{"role":"assistant","content":"hello"}}]}`
,
from
:
"gpt-4o"
,
to
:
"alias"
,
expected
:
`{"model":"alias","usage":{"prompt_tokens":10,"completion_tokens":20},"choices":[{"message":{"role":"assistant","content":"hello"}}]}`
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
got
:=
svc
.
replaceModelInResponseBody
([]
byte
(
tt
.
body
),
tt
.
from
,
tt
.
to
)
require
.
Equal
(
t
,
tt
.
expected
,
string
(
got
))
})
}
}
backend/internal/service/sse_scanner_buffer_pool.go
0 → 100644
View file @
a14dfb76
package
service
import
"sync"
const
sseScannerBuf64KSize
=
64
*
1024
type
sseScannerBuf64K
[
sseScannerBuf64KSize
]
byte
var
sseScannerBuf64KPool
=
sync
.
Pool
{
New
:
func
()
any
{
return
new
(
sseScannerBuf64K
)
},
}
func
getSSEScannerBuf64K
()
*
sseScannerBuf64K
{
return
sseScannerBuf64KPool
.
Get
()
.
(
*
sseScannerBuf64K
)
}
func
putSSEScannerBuf64K
(
buf
*
sseScannerBuf64K
)
{
if
buf
==
nil
{
return
}
sseScannerBuf64KPool
.
Put
(
buf
)
}
backend/internal/service/sse_scanner_buffer_pool_test.go
0 → 100644
View file @
a14dfb76
package
service
import
(
"testing"
"github.com/stretchr/testify/require"
)
func
TestSSEScannerBuf64KPool_GetPutDoesNotPanic
(
t
*
testing
.
T
)
{
buf
:=
getSSEScannerBuf64K
()
require
.
NotNil
(
t
,
buf
)
require
.
Equal
(
t
,
sseScannerBuf64KSize
,
len
(
buf
[
:
]))
buf
[
0
]
=
1
putSSEScannerBuf64K
(
buf
)
// 允许传入 nil,确保不会 panic
putSSEScannerBuf64K
(
nil
)
}
backend/internal/service/subscription_service.go
View file @
a14dfb76
...
...
@@ -4,10 +4,15 @@ import (
"context"
"fmt"
"log"
"math/rand/v2"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/dgraph-io/ristretto"
"golang.org/x/sync/singleflight"
)
// MaxExpiresAt is the maximum allowed expiration date (year 2099)
...
...
@@ -35,15 +40,76 @@ type SubscriptionService struct {
groupRepo
GroupRepository
userSubRepo
UserSubscriptionRepository
billingCacheService
*
BillingCacheService
// L1 缓存:加速中间件热路径的订阅查询
subCacheL1
*
ristretto
.
Cache
subCacheGroup
singleflight
.
Group
subCacheTTL
time
.
Duration
subCacheJitter
int
// 抖动百分比
}
// NewSubscriptionService 创建订阅服务
func
NewSubscriptionService
(
groupRepo
GroupRepository
,
userSubRepo
UserSubscriptionRepository
,
billingCacheService
*
BillingCacheService
)
*
SubscriptionService
{
return
&
SubscriptionService
{
func
NewSubscriptionService
(
groupRepo
GroupRepository
,
userSubRepo
UserSubscriptionRepository
,
billingCacheService
*
BillingCacheService
,
cfg
*
config
.
Config
)
*
SubscriptionService
{
svc
:=
&
SubscriptionService
{
groupRepo
:
groupRepo
,
userSubRepo
:
userSubRepo
,
billingCacheService
:
billingCacheService
,
}
svc
.
initSubCache
(
cfg
)
return
svc
}
// initSubCache 初始化订阅 L1 缓存
func
(
s
*
SubscriptionService
)
initSubCache
(
cfg
*
config
.
Config
)
{
if
cfg
==
nil
{
return
}
sc
:=
cfg
.
SubscriptionCache
if
sc
.
L1Size
<=
0
||
sc
.
L1TTLSeconds
<=
0
{
return
}
cache
,
err
:=
ristretto
.
NewCache
(
&
ristretto
.
Config
{
NumCounters
:
int64
(
sc
.
L1Size
)
*
10
,
MaxCost
:
int64
(
sc
.
L1Size
),
BufferItems
:
64
,
})
if
err
!=
nil
{
log
.
Printf
(
"Warning: failed to init subscription L1 cache: %v"
,
err
)
return
}
s
.
subCacheL1
=
cache
s
.
subCacheTTL
=
time
.
Duration
(
sc
.
L1TTLSeconds
)
*
time
.
Second
s
.
subCacheJitter
=
sc
.
JitterPercent
}
// subCacheKey 生成订阅缓存 key(热路径,避免 fmt.Sprintf 开销)
func
subCacheKey
(
userID
,
groupID
int64
)
string
{
return
"sub:"
+
strconv
.
FormatInt
(
userID
,
10
)
+
":"
+
strconv
.
FormatInt
(
groupID
,
10
)
}
// jitteredTTL 为 TTL 添加抖动,避免集中过期
func
(
s
*
SubscriptionService
)
jitteredTTL
(
ttl
time
.
Duration
)
time
.
Duration
{
if
ttl
<=
0
||
s
.
subCacheJitter
<=
0
{
return
ttl
}
pct
:=
s
.
subCacheJitter
if
pct
>
100
{
pct
=
100
}
delta
:=
float64
(
pct
)
/
100
factor
:=
1
-
delta
+
rand
.
Float64
()
*
(
2
*
delta
)
if
factor
<=
0
{
return
ttl
}
return
time
.
Duration
(
float64
(
ttl
)
*
factor
)
}
// InvalidateSubCache 失效指定用户+分组的订阅 L1 缓存
func
(
s
*
SubscriptionService
)
InvalidateSubCache
(
userID
,
groupID
int64
)
{
if
s
.
subCacheL1
==
nil
{
return
}
s
.
subCacheL1
.
Del
(
subCacheKey
(
userID
,
groupID
))
}
// AssignSubscriptionInput 分配订阅输入
...
...
@@ -81,6 +147,7 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass
}
// 失效订阅缓存
s
.
InvalidateSubCache
(
input
.
UserID
,
input
.
GroupID
)
if
s
.
billingCacheService
!=
nil
{
userID
,
groupID
:=
input
.
UserID
,
input
.
GroupID
go
func
()
{
...
...
@@ -167,6 +234,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
}
// 失效订阅缓存
s
.
InvalidateSubCache
(
input
.
UserID
,
input
.
GroupID
)
if
s
.
billingCacheService
!=
nil
{
userID
,
groupID
:=
input
.
UserID
,
input
.
GroupID
go
func
()
{
...
...
@@ -188,6 +256,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
}
// 失效订阅缓存
s
.
InvalidateSubCache
(
input
.
UserID
,
input
.
GroupID
)
if
s
.
billingCacheService
!=
nil
{
userID
,
groupID
:=
input
.
UserID
,
input
.
GroupID
go
func
()
{
...
...
@@ -297,6 +366,7 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti
}
// 失效订阅缓存
s
.
InvalidateSubCache
(
sub
.
UserID
,
sub
.
GroupID
)
if
s
.
billingCacheService
!=
nil
{
userID
,
groupID
:=
sub
.
UserID
,
sub
.
GroupID
go
func
()
{
...
...
@@ -363,6 +433,7 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti
}
// 失效订阅缓存
s
.
InvalidateSubCache
(
sub
.
UserID
,
sub
.
GroupID
)
if
s
.
billingCacheService
!=
nil
{
userID
,
groupID
:=
sub
.
UserID
,
sub
.
GroupID
go
func
()
{
...
...
@@ -381,12 +452,39 @@ func (s *SubscriptionService) GetByID(ctx context.Context, id int64) (*UserSubsc
}
// GetActiveSubscription 获取用户对特定分组的有效订阅
// 使用 L1 缓存 + singleflight 加速中间件热路径。
// 返回缓存对象的浅拷贝,调用方可安全修改字段而不会污染缓存或触发 data race。
func
(
s
*
SubscriptionService
)
GetActiveSubscription
(
ctx
context
.
Context
,
userID
,
groupID
int64
)
(
*
UserSubscription
,
error
)
{
sub
,
err
:=
s
.
userSubRepo
.
GetActiveByUserIDAndGroupID
(
ctx
,
userID
,
groupID
)
key
:=
subCacheKey
(
userID
,
groupID
)
// L1 缓存命中:返回浅拷贝
if
s
.
subCacheL1
!=
nil
{
if
v
,
ok
:=
s
.
subCacheL1
.
Get
(
key
);
ok
{
if
sub
,
ok
:=
v
.
(
*
UserSubscription
);
ok
{
cp
:=
*
sub
return
&
cp
,
nil
}
}
}
// singleflight 防止并发击穿
value
,
err
,
_
:=
s
.
subCacheGroup
.
Do
(
key
,
func
()
(
any
,
error
)
{
sub
,
err
:=
s
.
userSubRepo
.
GetActiveByUserIDAndGroupID
(
ctx
,
userID
,
groupID
)
if
err
!=
nil
{
return
nil
,
ErrSubscriptionNotFound
}
// 写入 L1 缓存
if
s
.
subCacheL1
!=
nil
{
_
=
s
.
subCacheL1
.
SetWithTTL
(
key
,
sub
,
1
,
s
.
jitteredTTL
(
s
.
subCacheTTL
))
}
return
sub
,
nil
})
if
err
!=
nil
{
return
nil
,
E
rr
SubscriptionNotFound
return
nil
,
e
rr
}
return
sub
,
nil
// singleflight 返回的也是缓存指针,需要浅拷贝
cp
:=
*
value
.
(
*
UserSubscription
)
return
&
cp
,
nil
}
// ListUserSubscriptions 获取用户的所有订阅
...
...
@@ -521,9 +619,12 @@ func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *Use
needsInvalidateCache
=
true
}
// 如果有窗口被重置,失效 Redis 缓存以保持一致性
if
needsInvalidateCache
&&
s
.
billingCacheService
!=
nil
{
_
=
s
.
billingCacheService
.
InvalidateSubscription
(
ctx
,
sub
.
UserID
,
sub
.
GroupID
)
// 如果有窗口被重置,失效缓存以保持一致性
if
needsInvalidateCache
{
s
.
InvalidateSubCache
(
sub
.
UserID
,
sub
.
GroupID
)
if
s
.
billingCacheService
!=
nil
{
_
=
s
.
billingCacheService
.
InvalidateSubscription
(
ctx
,
sub
.
UserID
,
sub
.
GroupID
)
}
}
return
nil
...
...
@@ -544,6 +645,78 @@ func (s *SubscriptionService) CheckUsageLimits(ctx context.Context, sub *UserSub
return
nil
}
// ValidateAndCheckLimits 合并验证+限额检查(中间件热路径专用)
// 仅做内存检查,不触发 DB 写入。窗口重置的 DB 写入由 DoWindowMaintenance 异步完成。
// 返回 needsMaintenance 表示是否需要异步执行窗口维护。
func
(
s
*
SubscriptionService
)
ValidateAndCheckLimits
(
sub
*
UserSubscription
,
group
*
Group
)
(
needsMaintenance
bool
,
err
error
)
{
// 1. 验证订阅状态
if
sub
.
Status
==
SubscriptionStatusExpired
{
return
false
,
ErrSubscriptionExpired
}
if
sub
.
Status
==
SubscriptionStatusSuspended
{
return
false
,
ErrSubscriptionSuspended
}
if
sub
.
IsExpired
()
{
return
false
,
ErrSubscriptionExpired
}
// 2. 内存中修正过期窗口的用量,确保 CheckUsageLimits 不会误拒绝用户
// 实际的 DB 窗口重置由 DoWindowMaintenance 异步完成
if
sub
.
NeedsDailyReset
()
{
sub
.
DailyUsageUSD
=
0
needsMaintenance
=
true
}
if
sub
.
NeedsWeeklyReset
()
{
sub
.
WeeklyUsageUSD
=
0
needsMaintenance
=
true
}
if
sub
.
NeedsMonthlyReset
()
{
sub
.
MonthlyUsageUSD
=
0
needsMaintenance
=
true
}
if
!
sub
.
IsWindowActivated
()
{
needsMaintenance
=
true
}
// 3. 检查用量限额
if
!
sub
.
CheckDailyLimit
(
group
,
0
)
{
return
needsMaintenance
,
ErrDailyLimitExceeded
}
if
!
sub
.
CheckWeeklyLimit
(
group
,
0
)
{
return
needsMaintenance
,
ErrWeeklyLimitExceeded
}
if
!
sub
.
CheckMonthlyLimit
(
group
,
0
)
{
return
needsMaintenance
,
ErrMonthlyLimitExceeded
}
return
needsMaintenance
,
nil
}
// DoWindowMaintenance 异步执行窗口维护(激活+重置)
// 使用独立 context,不受请求取消影响。
// 注意:此方法仅在 ValidateAndCheckLimits 返回 needsMaintenance=true 时调用,
// 而 IsExpired()=true 的订阅在 ValidateAndCheckLimits 中已被拦截返回错误,
// 因此进入此方法的订阅一定未过期,无需处理过期状态同步。
func
(
s
*
SubscriptionService
)
DoWindowMaintenance
(
sub
*
UserSubscription
)
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
5
*
time
.
Second
)
defer
cancel
()
// 激活窗口(首次使用时)
if
!
sub
.
IsWindowActivated
()
{
if
err
:=
s
.
CheckAndActivateWindow
(
ctx
,
sub
);
err
!=
nil
{
log
.
Printf
(
"Failed to activate subscription windows: %v"
,
err
)
}
}
// 重置过期窗口
if
err
:=
s
.
CheckAndResetWindows
(
ctx
,
sub
);
err
!=
nil
{
log
.
Printf
(
"Failed to reset subscription windows: %v"
,
err
)
}
// 失效 L1 缓存,确保后续请求拿到更新后的数据
s
.
InvalidateSubCache
(
sub
.
UserID
,
sub
.
GroupID
)
}
// RecordUsage 记录使用量到订阅
func
(
s
*
SubscriptionService
)
RecordUsage
(
ctx
context
.
Context
,
subscriptionID
int64
,
costUSD
float64
)
error
{
return
s
.
userSubRepo
.
IncrementUsage
(
ctx
,
subscriptionID
,
costUSD
)
...
...
backend/internal/service/usage_service.go
View file @
a14dfb76
...
...
@@ -316,8 +316,8 @@ func (s *UsageService) GetUserModelStats(ctx context.Context, userID int64, star
}
// GetBatchAPIKeyUsageStats returns today/total actual_cost for given api keys.
func
(
s
*
UsageService
)
GetBatchAPIKeyUsageStats
(
ctx
context
.
Context
,
apiKeyIDs
[]
int64
)
(
map
[
int64
]
*
usagestats
.
BatchAPIKeyUsageStats
,
error
)
{
stats
,
err
:=
s
.
usageRepo
.
GetBatchAPIKeyUsageStats
(
ctx
,
apiKeyIDs
)
func
(
s
*
UsageService
)
GetBatchAPIKeyUsageStats
(
ctx
context
.
Context
,
apiKeyIDs
[]
int64
,
startTime
,
endTime
time
.
Time
)
(
map
[
int64
]
*
usagestats
.
BatchAPIKeyUsageStats
,
error
)
{
stats
,
err
:=
s
.
usageRepo
.
GetBatchAPIKeyUsageStats
(
ctx
,
apiKeyIDs
,
startTime
,
endTime
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get batch api key usage stats: %w"
,
err
)
}
...
...
backend/internal/service/user_service.go
View file @
a14dfb76
...
...
@@ -3,6 +3,8 @@ package service
import
(
"context"
"fmt"
"log"
"time"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
...
...
@@ -62,13 +64,15 @@ type ChangePasswordRequest struct {
type
UserService
struct
{
userRepo
UserRepository
authCacheInvalidator
APIKeyAuthCacheInvalidator
billingCache
BillingCache
}
// NewUserService 创建用户服务实例
func
NewUserService
(
userRepo
UserRepository
,
authCacheInvalidator
APIKeyAuthCacheInvalidator
)
*
UserService
{
func
NewUserService
(
userRepo
UserRepository
,
authCacheInvalidator
APIKeyAuthCacheInvalidator
,
billingCache
BillingCache
)
*
UserService
{
return
&
UserService
{
userRepo
:
userRepo
,
authCacheInvalidator
:
authCacheInvalidator
,
billingCache
:
billingCache
,
}
}
...
...
@@ -183,6 +187,15 @@ func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount fl
if
s
.
authCacheInvalidator
!=
nil
{
s
.
authCacheInvalidator
.
InvalidateAuthCacheByUserID
(
ctx
,
userID
)
}
if
s
.
billingCache
!=
nil
{
go
func
()
{
cacheCtx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
5
*
time
.
Second
)
defer
cancel
()
if
err
:=
s
.
billingCache
.
InvalidateUserBalance
(
cacheCtx
,
userID
);
err
!=
nil
{
log
.
Printf
(
"invalidate user balance cache failed: user_id=%d err=%v"
,
userID
,
err
)
}
}()
}
return
nil
}
...
...
backend/internal/service/user_service_test.go
0 → 100644
View file @
a14dfb76
//go:build unit
package
service
import
(
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
// --- mock: UserRepository ---
type
mockUserRepo
struct
{
updateBalanceErr
error
updateBalanceFn
func
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
error
}
func
(
m
*
mockUserRepo
)
Create
(
context
.
Context
,
*
User
)
error
{
return
nil
}
func
(
m
*
mockUserRepo
)
GetByID
(
context
.
Context
,
int64
)
(
*
User
,
error
)
{
return
&
User
{},
nil
}
func
(
m
*
mockUserRepo
)
GetByEmail
(
context
.
Context
,
string
)
(
*
User
,
error
)
{
return
&
User
{},
nil
}
func
(
m
*
mockUserRepo
)
GetFirstAdmin
(
context
.
Context
)
(
*
User
,
error
)
{
return
&
User
{},
nil
}
func
(
m
*
mockUserRepo
)
Update
(
context
.
Context
,
*
User
)
error
{
return
nil
}
func
(
m
*
mockUserRepo
)
Delete
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
m
*
mockUserRepo
)
List
(
context
.
Context
,
pagination
.
PaginationParams
)
([]
User
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
m
*
mockUserRepo
)
ListWithFilters
(
context
.
Context
,
pagination
.
PaginationParams
,
UserListFilters
)
([]
User
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
m
*
mockUserRepo
)
UpdateBalance
(
ctx
context
.
Context
,
id
int64
,
amount
float64
)
error
{
if
m
.
updateBalanceFn
!=
nil
{
return
m
.
updateBalanceFn
(
ctx
,
id
,
amount
)
}
return
m
.
updateBalanceErr
}
func
(
m
*
mockUserRepo
)
DeductBalance
(
context
.
Context
,
int64
,
float64
)
error
{
return
nil
}
func
(
m
*
mockUserRepo
)
UpdateConcurrency
(
context
.
Context
,
int64
,
int
)
error
{
return
nil
}
func
(
m
*
mockUserRepo
)
ExistsByEmail
(
context
.
Context
,
string
)
(
bool
,
error
)
{
return
false
,
nil
}
func
(
m
*
mockUserRepo
)
RemoveGroupFromAllowedGroups
(
context
.
Context
,
int64
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
m
*
mockUserRepo
)
UpdateTotpSecret
(
context
.
Context
,
int64
,
*
string
)
error
{
return
nil
}
func
(
m
*
mockUserRepo
)
EnableTotp
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
m
*
mockUserRepo
)
DisableTotp
(
context
.
Context
,
int64
)
error
{
return
nil
}
// --- mock: APIKeyAuthCacheInvalidator ---
type
mockAuthCacheInvalidator
struct
{
invalidatedUserIDs
[]
int64
mu
sync
.
Mutex
}
func
(
m
*
mockAuthCacheInvalidator
)
InvalidateAuthCacheByKey
(
context
.
Context
,
string
)
{}
func
(
m
*
mockAuthCacheInvalidator
)
InvalidateAuthCacheByGroupID
(
context
.
Context
,
int64
)
{}
func
(
m
*
mockAuthCacheInvalidator
)
InvalidateAuthCacheByUserID
(
_
context
.
Context
,
userID
int64
)
{
m
.
mu
.
Lock
()
defer
m
.
mu
.
Unlock
()
m
.
invalidatedUserIDs
=
append
(
m
.
invalidatedUserIDs
,
userID
)
}
// --- mock: BillingCache ---
type
mockBillingCache
struct
{
invalidateErr
error
invalidateCallCount
atomic
.
Int64
invalidatedUserIDs
[]
int64
mu
sync
.
Mutex
}
func
(
m
*
mockBillingCache
)
GetUserBalance
(
context
.
Context
,
int64
)
(
float64
,
error
)
{
return
0
,
nil
}
func
(
m
*
mockBillingCache
)
SetUserBalance
(
context
.
Context
,
int64
,
float64
)
error
{
return
nil
}
func
(
m
*
mockBillingCache
)
DeductUserBalance
(
context
.
Context
,
int64
,
float64
)
error
{
return
nil
}
func
(
m
*
mockBillingCache
)
InvalidateUserBalance
(
_
context
.
Context
,
userID
int64
)
error
{
m
.
invalidateCallCount
.
Add
(
1
)
m
.
mu
.
Lock
()
defer
m
.
mu
.
Unlock
()
m
.
invalidatedUserIDs
=
append
(
m
.
invalidatedUserIDs
,
userID
)
return
m
.
invalidateErr
}
func
(
m
*
mockBillingCache
)
GetSubscriptionCache
(
context
.
Context
,
int64
,
int64
)
(
*
SubscriptionCacheData
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockBillingCache
)
SetSubscriptionCache
(
context
.
Context
,
int64
,
int64
,
*
SubscriptionCacheData
)
error
{
return
nil
}
func
(
m
*
mockBillingCache
)
UpdateSubscriptionUsage
(
context
.
Context
,
int64
,
int64
,
float64
)
error
{
return
nil
}
func
(
m
*
mockBillingCache
)
InvalidateSubscriptionCache
(
context
.
Context
,
int64
,
int64
)
error
{
return
nil
}
// --- 测试 ---
func
TestUpdateBalance_Success
(
t
*
testing
.
T
)
{
repo
:=
&
mockUserRepo
{}
cache
:=
&
mockBillingCache
{}
svc
:=
NewUserService
(
repo
,
nil
,
cache
)
err
:=
svc
.
UpdateBalance
(
context
.
Background
(),
42
,
100.0
)
require
.
NoError
(
t
,
err
)
// 等待异步 goroutine 完成
require
.
Eventually
(
t
,
func
()
bool
{
return
cache
.
invalidateCallCount
.
Load
()
==
1
},
2
*
time
.
Second
,
10
*
time
.
Millisecond
,
"应异步调用 InvalidateUserBalance"
)
cache
.
mu
.
Lock
()
defer
cache
.
mu
.
Unlock
()
require
.
Equal
(
t
,
[]
int64
{
42
},
cache
.
invalidatedUserIDs
,
"应对 userID=42 失效缓存"
)
}
func
TestUpdateBalance_NilBillingCache_NoPanic
(
t
*
testing
.
T
)
{
repo
:=
&
mockUserRepo
{}
svc
:=
NewUserService
(
repo
,
nil
,
nil
)
// billingCache = nil
err
:=
svc
.
UpdateBalance
(
context
.
Background
(),
1
,
50.0
)
require
.
NoError
(
t
,
err
,
"billingCache 为 nil 时不应 panic"
)
}
func
TestUpdateBalance_CacheFailure_DoesNotAffectReturn
(
t
*
testing
.
T
)
{
repo
:=
&
mockUserRepo
{}
cache
:=
&
mockBillingCache
{
invalidateErr
:
errors
.
New
(
"redis connection refused"
)}
svc
:=
NewUserService
(
repo
,
nil
,
cache
)
err
:=
svc
.
UpdateBalance
(
context
.
Background
(),
99
,
200.0
)
require
.
NoError
(
t
,
err
,
"缓存失效失败不应影响主流程返回值"
)
// 等待异步 goroutine 完成(即使失败也应调用)
require
.
Eventually
(
t
,
func
()
bool
{
return
cache
.
invalidateCallCount
.
Load
()
==
1
},
2
*
time
.
Second
,
10
*
time
.
Millisecond
,
"即使失败也应调用 InvalidateUserBalance"
)
}
func
TestUpdateBalance_RepoError_ReturnsError
(
t
*
testing
.
T
)
{
repo
:=
&
mockUserRepo
{
updateBalanceErr
:
errors
.
New
(
"database error"
)}
cache
:=
&
mockBillingCache
{}
svc
:=
NewUserService
(
repo
,
nil
,
cache
)
err
:=
svc
.
UpdateBalance
(
context
.
Background
(),
1
,
100.0
)
require
.
Error
(
t
,
err
,
"repo 失败时应返回错误"
)
require
.
Contains
(
t
,
err
.
Error
(),
"update balance"
)
// repo 失败时不应触发缓存失效
time
.
Sleep
(
100
*
time
.
Millisecond
)
require
.
Equal
(
t
,
int64
(
0
),
cache
.
invalidateCallCount
.
Load
(),
"repo 失败时不应调用 InvalidateUserBalance"
)
}
func
TestUpdateBalance_WithAuthCacheInvalidator
(
t
*
testing
.
T
)
{
repo
:=
&
mockUserRepo
{}
auth
:=
&
mockAuthCacheInvalidator
{}
cache
:=
&
mockBillingCache
{}
svc
:=
NewUserService
(
repo
,
auth
,
cache
)
err
:=
svc
.
UpdateBalance
(
context
.
Background
(),
77
,
300.0
)
require
.
NoError
(
t
,
err
)
// 验证 auth cache 同步失效
auth
.
mu
.
Lock
()
require
.
Equal
(
t
,
[]
int64
{
77
},
auth
.
invalidatedUserIDs
)
auth
.
mu
.
Unlock
()
// 验证 billing cache 异步失效
require
.
Eventually
(
t
,
func
()
bool
{
return
cache
.
invalidateCallCount
.
Load
()
==
1
},
2
*
time
.
Second
,
10
*
time
.
Millisecond
)
}
func
TestNewUserService_FieldsAssignment
(
t
*
testing
.
T
)
{
repo
:=
&
mockUserRepo
{}
auth
:=
&
mockAuthCacheInvalidator
{}
cache
:=
&
mockBillingCache
{}
svc
:=
NewUserService
(
repo
,
auth
,
cache
)
require
.
NotNil
(
t
,
svc
)
require
.
Equal
(
t
,
repo
,
svc
.
userRepo
)
require
.
Equal
(
t
,
auth
,
svc
.
authCacheInvalidator
)
require
.
Equal
(
t
,
cache
,
svc
.
billingCache
)
}
deploy/.env.example
View file @
a14dfb76
...
...
@@ -58,13 +58,67 @@ TZ=Asia/Shanghai
POSTGRES_USER=sub2api
POSTGRES_PASSWORD=change_this_secure_password
POSTGRES_DB=sub2api
# PostgreSQL 监听端口(同时用于 PG 服务端和应用连接,默认 5432)
DATABASE_PORT=5432
# -----------------------------------------------------------------------------
# PostgreSQL 服务端参数(可选;主要用于 deploy/docker-compose-aicodex.yml)
# -----------------------------------------------------------------------------
# POSTGRES_MAX_CONNECTIONS:PostgreSQL 服务端允许的最大连接数。
# 必须 >=(所有 Sub2API 实例的 DATABASE_MAX_OPEN_CONNS 之和)+ 预留余量(例如 20%)。
POSTGRES_MAX_CONNECTIONS=1024
# POSTGRES_SHARED_BUFFERS:PostgreSQL 用于缓存数据页的共享内存。
# 常见建议:物理内存的 10%~25%(容器内存受限时请按实际限制调整)。
# 8GB 内存容器参考:1GB。
POSTGRES_SHARED_BUFFERS=1GB
# POSTGRES_EFFECTIVE_CACHE_SIZE:查询规划器“假设可用的 OS 缓存大小”(不等于实际分配)。
# 常见建议:物理内存的 50%~75%。
# 8GB 内存容器参考:6GB。
POSTGRES_EFFECTIVE_CACHE_SIZE=4GB
# POSTGRES_MAINTENANCE_WORK_MEM:维护操作内存(VACUUM/CREATE INDEX 等)。
# 值越大维护越快,但会占用更多内存。
# 8GB 内存容器参考:128MB。
POSTGRES_MAINTENANCE_WORK_MEM=128MB
# -----------------------------------------------------------------------------
# PostgreSQL 连接池参数(可选,默认与程序内置一致)
# -----------------------------------------------------------------------------
# 说明:
# - 这些参数控制 Sub2API 进程到 PostgreSQL 的连接池大小(不是 PostgreSQL 自身的 max_connections)。
# - 多实例/多副本部署时,总连接上限约等于:实例数 * DATABASE_MAX_OPEN_CONNS。
# - 连接池过大可能导致:数据库连接耗尽、内存占用上升、上下文切换增多,反而变慢。
# - 建议结合 PostgreSQL 的 max_connections 与机器规格逐步调优:
# 通常把应用总连接上限控制在 max_connections 的 50%~80% 更稳妥。
#
# DATABASE_MAX_OPEN_CONNS:最大打开连接数(活跃+空闲),达到后新请求会等待可用连接。
# 典型范围:50~500(取决于 DB 规格、实例数、SQL 复杂度)。
DATABASE_MAX_OPEN_CONNS=256
# DATABASE_MAX_IDLE_CONNS:最大空闲连接数(热连接),建议 <= MAX_OPEN。
# 太小会频繁建连增加延迟;太大会长期占用数据库资源。
DATABASE_MAX_IDLE_CONNS=128
# DATABASE_CONN_MAX_LIFETIME_MINUTES:单个连接最大存活时间(单位:分钟)。
# 用于避免连接长期不重建导致的中间件/LB/NAT 异常或服务端重启后的“僵尸连接”。
# 设置为 0 表示不限制(一般不建议生产环境)。
DATABASE_CONN_MAX_LIFETIME_MINUTES=30
# DATABASE_CONN_MAX_IDLE_TIME_MINUTES:空闲连接最大存活时间(单位:分钟)。
# 超过该时间的空闲连接会被回收,防止长时间闲置占用连接数。
# 设置为 0 表示不限制(一般不建议生产环境)。
DATABASE_CONN_MAX_IDLE_TIME_MINUTES=5
# -----------------------------------------------------------------------------
# Redis Configuration
# -----------------------------------------------------------------------------
# Redis 监听端口(同时用于应用连接和 Redis 服务端,默认 6379)
REDIS_PORT=6379
# Leave empty for no password (default for local development)
REDIS_PASSWORD=
REDIS_DB=0
# Redis 服务端最大客户端连接数(可选;主要用于 deploy/docker-compose-aicodex.yml)
REDIS_MAXCLIENTS=50000
# Redis 连接池大小(默认 1024)
REDIS_POOL_SIZE=4096
# Redis 最小空闲连接数(默认 10)
REDIS_MIN_IDLE_CONNS=256
REDIS_ENABLE_TLS=false
# -----------------------------------------------------------------------------
...
...
@@ -119,6 +173,19 @@ RATE_LIMIT_OVERLOAD_COOLDOWN_MINUTES=10
# Gateway Scheduling (Optional)
# 调度缓存与受控回源配置(缓存就绪且命中时不读 DB)
# -----------------------------------------------------------------------------
# Force Codex CLI mode: treat all /openai/v1/responses requests as Codex CLI.
# 强制按 Codex CLI 处理 /openai/v1/responses 请求(用于网关未透传/改写 User-Agent 的兜底)。
#
# 注意:开启后会影响所有客户端的行为(不仅限于 VS Code / Codex CLI),请谨慎开启。
#
# 默认:false
GATEWAY_FORCE_CODEX_CLI=false
# 上游连接池:每主机最大连接数(默认 1024;流式/HTTP1.1 场景可调大,如 2400/4096)
GATEWAY_MAX_CONNS_PER_HOST=2048
# 上游连接池:最大空闲连接总数(默认 2560;账号/代理隔离 + 高并发场景可调大)
GATEWAY_MAX_IDLE_CONNS=8192
# 上游连接池:每主机最大空闲连接(默认 120)
GATEWAY_MAX_IDLE_CONNS_PER_HOST=4096
# 粘性会话最大排队长度
GATEWAY_SCHEDULING_STICKY_SESSION_MAX_WAITING=3
# 粘性会话等待超时(时间段,例如 45s)
...
...
deploy/config.example.yaml
View file @
a14dfb76
...
...
@@ -20,6 +20,10 @@ server:
# Mode: "debug" for development, "release" for production
# 运行模式:"debug" 用于开发,"release" 用于生产环境
mode
:
"
release"
# Frontend base URL used to generate external links in emails (e.g. password reset)
# 用于生成邮件中的外部链接(例如:重置密码链接)的前端基础地址
# Example: "https://example.com"
frontend_url
:
"
"
# Trusted proxies for X-Forwarded-For parsing (CIDR/IP). Empty disables trusted proxies.
# 信任的代理地址(CIDR/IP 格式),用于解析 X-Forwarded-For 头。留空则禁用代理信任。
trusted_proxies
:
[]
...
...
@@ -108,9 +112,9 @@ security:
# 白名单禁用时是否允许 http:// URL(默认: false,要求 https)
allow_insecure_http
:
true
response_headers
:
# Enable configurable response header filtering (d
isable to use default allowlist
)
# 启用可配置的响应头过滤(
禁用则使用默认白名单
)
enabled
:
fals
e
# Enable configurable response header filtering (d
efault: true
)
# 启用可配置的响应头过滤(
默认启用,过滤上游敏感响应头
)
enabled
:
tru
e
# Extra allowed response headers from upstream
# 额外允许的上游响应头
additional_allowed
:
[]
...
...
@@ -151,17 +155,22 @@ gateway:
# - account_proxy: Isolate by account+proxy combination (default, finest granularity)
# - account_proxy: 按账户+代理组合隔离(默认,最细粒度)
connection_pool_isolation
:
"
account_proxy"
# Force Codex CLI mode: treat all /openai/v1/responses requests as Codex CLI.
# 强制按 Codex CLI 处理 /openai/v1/responses 请求(用于网关未透传/改写 User-Agent 的兜底)。
#
# 注意:开启后会影响所有客户端的行为(不仅限于 VS Code / Codex CLI),请谨慎开启。
force_codex_cli
:
false
# HTTP upstream connection pool settings (HTTP/2 + multi-proxy scenario defaults)
# HTTP 上游连接池配置(HTTP/2 + 多代理场景默认值)
# Max idle connections across all hosts
# 所有主机的最大空闲连接数
max_idle_conns
:
2
4
0
max_idle_conns
:
2
56
0
# Max idle connections per host
# 每个主机的最大空闲连接数
max_idle_conns_per_host
:
120
# Max connections per host
# 每个主机的最大连接数
max_conns_per_host
:
24
0
max_conns_per_host
:
10
24
# Idle connection timeout (seconds)
# 空闲连接超时时间(秒)
idle_conn_timeout_seconds
:
90
...
...
@@ -381,9 +390,22 @@ database:
# Database name
# 数据库名称
dbname
:
"
sub2api"
# SSL mode: disable, require, verify-ca, verify-full
# SSL 模式:disable(禁用), require(要求), verify-ca(验证CA), verify-full(完全验证)
sslmode
:
"
disable"
# SSL mode: disable, prefer, require, verify-ca, verify-full
# SSL 模式:disable(禁用), prefer(优先加密,默认), require(要求), verify-ca(验证CA), verify-full(完全验证)
# 默认值为 "prefer",数据库支持 SSL 时自动使用加密连接,不支持时回退明文
sslmode
:
"
prefer"
# Max open connections (高并发场景建议 256+,需配合 PostgreSQL max_connections 调整)
# 最大打开连接数
max_open_conns
:
256
# Max idle connections (建议为 max_open_conns 的 50%,减少频繁建连开销)
# 最大空闲连接数
max_idle_conns
:
128
# Connection max lifetime (minutes)
# 连接最大存活时间(分钟)
conn_max_lifetime_minutes
:
30
# Connection max idle time (minutes)
# 空闲连接最大存活时间(分钟)
conn_max_idle_time_minutes
:
5
# =============================================================================
# Redis Configuration
...
...
@@ -402,6 +424,12 @@ redis:
# Database number (0-15)
# 数据库编号(0-15)
db
:
0
# Connection pool size (max concurrent connections)
# 连接池大小(最大并发连接数)
pool_size
:
1024
# Minimum number of idle connections (高并发场景建议 128+,保持足够热连接)
# 最小空闲连接数
min_idle_conns
:
128
# Enable TLS/SSL connection
# 是否启用 TLS/SSL 连接
enable_tls
:
false
...
...
deploy/docker-compose-aicodex.yml
0 → 100644
View file @
a14dfb76
# =============================================================================
# Sub2API Docker Compose Host Configuration (Local Build)
# =============================================================================
# Quick Start:
# 1. Copy .env.example to .env and configure
# 2. docker-compose -f docker-compose-host.yml up -d --build
# 3. Check logs: docker-compose -f docker-compose-host.yml logs -f sub2api
# 4. Access: http://localhost:8080
#
# This configuration builds the image from source (Dockerfile in project root).
# All configuration is done via environment variables.
# No Setup Wizard needed - the system auto-initializes on first run.
# =============================================================================
services
:
# ===========================================================================
# Sub2API Application
# ===========================================================================
sub2api
:
#image: weishaw/sub2api:latest
image
:
yangjianbo/aicodex2api:latest
build
:
context
:
..
dockerfile
:
Dockerfile
container_name
:
sub2api
restart
:
unless-stopped
network_mode
:
host
ulimits
:
nofile
:
soft
:
800000
hard
:
800000
volumes
:
# Data persistence (config.yaml will be auto-generated here)
-
sub2api_data:/app/data
# Mount custom config.yaml (optional, overrides auto-generated config)
#- ./config.yaml:/app/data/config.yaml:ro
environment
:
# =======================================================================
# Auto Setup (REQUIRED for Docker deployment)
# =======================================================================
-
AUTO_SETUP=true
# =======================================================================
# Server Configuration
# =======================================================================
-
SERVER_HOST=0.0.0.0
-
SERVER_PORT=8080
-
SERVER_MODE=${SERVER_MODE:-release}
-
RUN_MODE=${RUN_MODE:-standard}
# =======================================================================
# Database Configuration (PostgreSQL)
# =======================================================================
# Using host network: point to host/external DB by DATABASE_HOST/DATABASE_PORT
-
DATABASE_HOST=${DATABASE_HOST:-127.0.0.1}
-
DATABASE_PORT=${DATABASE_PORT:-5432}
-
DATABASE_USER=${POSTGRES_USER:-sub2api}
-
DATABASE_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required}
-
DATABASE_DBNAME=${POSTGRES_DB:-sub2api}
-
DATABASE_SSLMODE=disable
-
DATABASE_MAX_OPEN_CONNS=${DATABASE_MAX_OPEN_CONNS:-50}
-
DATABASE_MAX_IDLE_CONNS=${DATABASE_MAX_IDLE_CONNS:-10}
-
DATABASE_CONN_MAX_LIFETIME_MINUTES=${DATABASE_CONN_MAX_LIFETIME_MINUTES:-30}
-
DATABASE_CONN_MAX_IDLE_TIME_MINUTES=${DATABASE_CONN_MAX_IDLE_TIME_MINUTES:-5}
# =======================================================================
# Gateway Configuration
# =======================================================================
-
GATEWAY_FORCE_CODEX_CLI=${GATEWAY_FORCE_CODEX_CLI:-false}
-
GATEWAY_MAX_IDLE_CONNS=${GATEWAY_MAX_IDLE_CONNS:-2560}
-
GATEWAY_MAX_IDLE_CONNS_PER_HOST=${GATEWAY_MAX_IDLE_CONNS_PER_HOST:-120}
-
GATEWAY_MAX_CONNS_PER_HOST=${GATEWAY_MAX_CONNS_PER_HOST:-8192}
# =======================================================================
# Redis Configuration
# =======================================================================
# Using host network: point to host/external Redis by REDIS_HOST/REDIS_PORT
-
REDIS_HOST=${REDIS_HOST:-127.0.0.1}
-
REDIS_PORT=${REDIS_PORT:-6379}
-
REDIS_PASSWORD=${REDIS_PASSWORD:-}
-
REDIS_DB=${REDIS_DB:-0}
-
REDIS_POOL_SIZE=${REDIS_POOL_SIZE:-1024}
-
REDIS_MIN_IDLE_CONNS=${REDIS_MIN_IDLE_CONNS:-10}
-
REDIS_ENABLE_TLS=${REDIS_ENABLE_TLS:-false}
# =======================================================================
# Admin Account (auto-created on first run)
# =======================================================================
-
ADMIN_EMAIL=${ADMIN_EMAIL:-admin@sub2api.local}
-
ADMIN_PASSWORD=${ADMIN_PASSWORD:-}
# =======================================================================
# JWT Configuration
# =======================================================================
# Leave empty to auto-generate (recommended)
-
JWT_SECRET=${JWT_SECRET:-}
-
JWT_EXPIRE_HOUR=${JWT_EXPIRE_HOUR:-24}
# =======================================================================
# TOTP (2FA) Configuration
# =======================================================================
# IMPORTANT: Set a fixed encryption key for TOTP secrets. If left empty,
# a random key will be generated on each startup, causing all existing
# TOTP configurations to become invalid (users won't be able to login
# with 2FA).
# Generate a secure key: openssl rand -hex 32
-
TOTP_ENCRYPTION_KEY=${TOTP_ENCRYPTION_KEY:-}
# =======================================================================
# Timezone Configuration
# This affects ALL time operations in the application:
# - Database timestamps
# - Usage statistics "today" boundary
# - Subscription expiry times
# - Log timestamps
# Common values: Asia/Shanghai, America/New_York, Europe/London, UTC
# =======================================================================
-
TZ=${TZ:-Asia/Shanghai}
# =======================================================================
# Gemini OAuth Configuration (for Gemini accounts)
# =======================================================================
-
GEMINI_OAUTH_CLIENT_ID=${GEMINI_OAUTH_CLIENT_ID:-}
-
GEMINI_OAUTH_CLIENT_SECRET=${GEMINI_OAUTH_CLIENT_SECRET:-}
-
GEMINI_OAUTH_SCOPES=${GEMINI_OAUTH_SCOPES:-}
-
GEMINI_QUOTA_POLICY=${GEMINI_QUOTA_POLICY:-}
# =======================================================================
# Security Configuration (URL Allowlist)
# =======================================================================
# Allow private IP addresses for CRS sync (for internal deployments)
-
SECURITY_URL_ALLOWLIST_ALLOW_PRIVATE_HOSTS=${SECURITY_URL_ALLOWLIST_ALLOW_PRIVATE_HOSTS:-true}
depends_on
:
postgres
:
condition
:
service_healthy
redis
:
condition
:
service_healthy
healthcheck
:
test
:
[
"
CMD"
,
"
curl"
,
"
-f"
,
"
http://localhost:8080/health"
]
interval
:
30s
timeout
:
10s
retries
:
3
start_period
:
30s
# ===========================================================================
# PostgreSQL Database
# ===========================================================================
postgres
:
image
:
postgres:18-alpine
container_name
:
sub2api-postgres
restart
:
unless-stopped
network_mode
:
host
ulimits
:
nofile
:
soft
:
800000
hard
:
800000
volumes
:
-
postgres_data:/var/lib/postgresql/data
environment
:
-
POSTGRES_USER=${POSTGRES_USER:-sub2api}
-
POSTGRES_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required}
-
POSTGRES_DB=${POSTGRES_DB:-sub2api}
-
TZ=${TZ:-Asia/Shanghai}
command
:
-
"
postgres"
-
"
-c"
-
"
listen_addresses=127.0.0.1"
# 监听端口:与应用侧 DATABASE_PORT 保持一致。
-
"
-c"
-
"
port=${DATABASE_PORT:-5432}"
# 连接数上限:需要结合应用侧 DATABASE_MAX_OPEN_CONNS 调整。
# 注意:max_connections 过大可能导致内存占用与上下文切换开销显著上升。
-
"
-c"
-
"
max_connections=${POSTGRES_MAX_CONNECTIONS:-1024}"
# 典型内存参数(建议结合机器内存调优;不确定就保持默认或小步调大)。
-
"
-c"
-
"
shared_buffers=${POSTGRES_SHARED_BUFFERS:-1GB}"
-
"
-c"
-
"
effective_cache_size=${POSTGRES_EFFECTIVE_CACHE_SIZE:-6GB}"
-
"
-c"
-
"
maintenance_work_mem=${POSTGRES_MAINTENANCE_WORK_MEM:-128MB}"
healthcheck
:
test
:
[
"
CMD-SHELL"
,
"
pg_isready
-U
${POSTGRES_USER:-sub2api}
-d
${POSTGRES_DB:-sub2api}
-p
${DATABASE_PORT:-5432}"
]
interval
:
10s
timeout
:
5s
retries
:
5
start_period
:
10s
# Note: bound to localhost only; not exposed to external network by default.
# ===========================================================================
# Redis Cache
# ===========================================================================
redis
:
image
:
redis:8-alpine
container_name
:
sub2api-redis
restart
:
unless-stopped
network_mode
:
host
ulimits
:
nofile
:
soft
:
100000
hard
:
100000
volumes
:
-
redis_data:/data
command
:
>
redis-server
--bind 127.0.0.1
--port ${REDIS_PORT:-6379}
--maxclients ${REDIS_MAXCLIENTS:-50000}
--save 60 1
--appendonly yes
--appendfsync everysec
${REDIS_PASSWORD:+--requirepass ${REDIS_PASSWORD}}
environment
:
-
TZ=${TZ:-Asia/Shanghai}
# REDISCLI_AUTH is used by redis-cli for authentication (safer than -a flag)
-
REDISCLI_AUTH=${REDIS_PASSWORD:-}
healthcheck
:
test
:
[
"
CMD-SHELL"
,
"
redis-cli
-p
${REDIS_PORT:-6379}
-a
\"
$REDISCLI_AUTH
\"
ping
|
grep
-q
PONG
||
redis-cli
-p
${REDIS_PORT:-6379}
ping
|
grep
-q
PONG"
]
interval
:
10s
timeout
:
5s
retries
:
5
start_period
:
5s
# =============================================================================
# Volumes
# =============================================================================
volumes
:
sub2api_data
:
driver
:
local
postgres_data
:
driver
:
local
redis_data
:
driver
:
local
deploy/docker-compose-test.yml
View file @
a14dfb76
...
...
@@ -57,6 +57,10 @@ services:
-
DATABASE_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required}
-
DATABASE_DBNAME=${POSTGRES_DB:-sub2api}
-
DATABASE_SSLMODE=disable
-
DATABASE_MAX_OPEN_CONNS=${DATABASE_MAX_OPEN_CONNS:-50}
-
DATABASE_MAX_IDLE_CONNS=${DATABASE_MAX_IDLE_CONNS:-10}
-
DATABASE_CONN_MAX_LIFETIME_MINUTES=${DATABASE_CONN_MAX_LIFETIME_MINUTES:-30}
-
DATABASE_CONN_MAX_IDLE_TIME_MINUTES=${DATABASE_CONN_MAX_IDLE_TIME_MINUTES:-5}
# =======================================================================
# Redis Configuration
...
...
@@ -65,6 +69,8 @@ services:
-
REDIS_PORT=6379
-
REDIS_PASSWORD=${REDIS_PASSWORD:-}
-
REDIS_DB=${REDIS_DB:-0}
-
REDIS_POOL_SIZE=${REDIS_POOL_SIZE:-1024}
-
REDIS_MIN_IDLE_CONNS=${REDIS_MIN_IDLE_CONNS:-10}
# =======================================================================
# Admin Account (auto-created on first run)
...
...
deploy/docker-compose.local.yml
View file @
a14dfb76
...
...
@@ -62,6 +62,10 @@ services:
-
DATABASE_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required}
-
DATABASE_DBNAME=${POSTGRES_DB:-sub2api}
-
DATABASE_SSLMODE=disable
-
DATABASE_MAX_OPEN_CONNS=${DATABASE_MAX_OPEN_CONNS:-50}
-
DATABASE_MAX_IDLE_CONNS=${DATABASE_MAX_IDLE_CONNS:-10}
-
DATABASE_CONN_MAX_LIFETIME_MINUTES=${DATABASE_CONN_MAX_LIFETIME_MINUTES:-30}
-
DATABASE_CONN_MAX_IDLE_TIME_MINUTES=${DATABASE_CONN_MAX_IDLE_TIME_MINUTES:-5}
# =======================================================================
# Redis Configuration
...
...
@@ -70,6 +74,8 @@ services:
-
REDIS_PORT=6379
-
REDIS_PASSWORD=${REDIS_PASSWORD:-}
-
REDIS_DB=${REDIS_DB:-0}
-
REDIS_POOL_SIZE=${REDIS_POOL_SIZE:-1024}
-
REDIS_MIN_IDLE_CONNS=${REDIS_MIN_IDLE_CONNS:-10}
-
REDIS_ENABLE_TLS=${REDIS_ENABLE_TLS:-false}
# =======================================================================
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment