Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
a89477dd
Commit
a89477dd
authored
Feb 22, 2026
by
yangjianbo
Browse files
perf(gateway): 优化热点路径并补齐高覆盖测试
parent
2f520c8d
Changes
16
Expand all
Hide whitespace changes
Inline
Side-by-side
.gitignore
View file @
a89477dd
...
@@ -121,7 +121,6 @@ AGENTS.md
...
@@ -121,7 +121,6 @@ AGENTS.md
scripts
scripts
.code-review-state
.code-review-state
openspec/
openspec/
docs/
code-reviews/
code-reviews/
AGENTS.md
AGENTS.md
backend/cmd/server/server
backend/cmd/server/server
...
...
backend/internal/config/config.go
View file @
a89477dd
...
@@ -423,6 +423,11 @@ type GatewayConfig struct {
...
@@ -423,6 +423,11 @@ type GatewayConfig struct {
// UsageRecord: 使用量记录异步队列配置(有界队列 + 固定 worker)
// UsageRecord: 使用量记录异步队列配置(有界队列 + 固定 worker)
UsageRecord
GatewayUsageRecordConfig
`mapstructure:"usage_record"`
UsageRecord
GatewayUsageRecordConfig
`mapstructure:"usage_record"`
// UserGroupRateCacheTTLSeconds: 用户分组倍率热路径缓存 TTL(秒)
UserGroupRateCacheTTLSeconds
int
`mapstructure:"user_group_rate_cache_ttl_seconds"`
// ModelsListCacheTTLSeconds: /v1/models 模型列表短缓存 TTL(秒)
ModelsListCacheTTLSeconds
int
`mapstructure:"models_list_cache_ttl_seconds"`
}
}
// GatewayUsageRecordConfig 使用量记录异步队列配置
// GatewayUsageRecordConfig 使用量记录异步队列配置
...
@@ -1175,6 +1180,8 @@ func setDefaults() {
...
@@ -1175,6 +1180,8 @@ func setDefaults() {
viper
.
SetDefault
(
"gateway.usage_record.auto_scale_down_step"
,
16
)
viper
.
SetDefault
(
"gateway.usage_record.auto_scale_down_step"
,
16
)
viper
.
SetDefault
(
"gateway.usage_record.auto_scale_check_interval_seconds"
,
3
)
viper
.
SetDefault
(
"gateway.usage_record.auto_scale_check_interval_seconds"
,
3
)
viper
.
SetDefault
(
"gateway.usage_record.auto_scale_cooldown_seconds"
,
10
)
viper
.
SetDefault
(
"gateway.usage_record.auto_scale_cooldown_seconds"
,
10
)
viper
.
SetDefault
(
"gateway.user_group_rate_cache_ttl_seconds"
,
30
)
viper
.
SetDefault
(
"gateway.models_list_cache_ttl_seconds"
,
15
)
// TLS指纹伪装配置(默认关闭,需要账号级别单独启用)
// TLS指纹伪装配置(默认关闭,需要账号级别单独启用)
viper
.
SetDefault
(
"gateway.tls_fingerprint.enabled"
,
true
)
viper
.
SetDefault
(
"gateway.tls_fingerprint.enabled"
,
true
)
viper
.
SetDefault
(
"concurrency.ping_interval"
,
10
)
viper
.
SetDefault
(
"concurrency.ping_interval"
,
10
)
...
@@ -1751,6 +1758,12 @@ func (c *Config) Validate() error {
...
@@ -1751,6 +1758,12 @@ func (c *Config) Validate() error {
return
fmt
.
Errorf
(
"gateway.usage_record.auto_scale_cooldown_seconds must be non-negative"
)
return
fmt
.
Errorf
(
"gateway.usage_record.auto_scale_cooldown_seconds must be non-negative"
)
}
}
}
}
if
c
.
Gateway
.
UserGroupRateCacheTTLSeconds
<=
0
{
return
fmt
.
Errorf
(
"gateway.user_group_rate_cache_ttl_seconds must be positive"
)
}
if
c
.
Gateway
.
ModelsListCacheTTLSeconds
<
10
||
c
.
Gateway
.
ModelsListCacheTTLSeconds
>
30
{
return
fmt
.
Errorf
(
"gateway.models_list_cache_ttl_seconds must be between 10-30"
)
}
if
c
.
Gateway
.
Scheduling
.
StickySessionMaxWaiting
<=
0
{
if
c
.
Gateway
.
Scheduling
.
StickySessionMaxWaiting
<=
0
{
return
fmt
.
Errorf
(
"gateway.scheduling.sticky_session_max_waiting must be positive"
)
return
fmt
.
Errorf
(
"gateway.scheduling.sticky_session_max_waiting must be positive"
)
}
}
...
...
backend/internal/config/config_test.go
View file @
a89477dd
...
@@ -1010,6 +1010,16 @@ func TestValidateConfigErrors(t *testing.T) {
...
@@ -1010,6 +1010,16 @@ func TestValidateConfigErrors(t *testing.T) {
mutate
:
func
(
c
*
Config
)
{
c
.
Gateway
.
UsageRecord
.
AutoScaleCheckIntervalSeconds
=
0
},
mutate
:
func
(
c
*
Config
)
{
c
.
Gateway
.
UsageRecord
.
AutoScaleCheckIntervalSeconds
=
0
},
wantErr
:
"gateway.usage_record.auto_scale_check_interval_seconds"
,
wantErr
:
"gateway.usage_record.auto_scale_check_interval_seconds"
,
},
},
{
name
:
"gateway user group rate cache ttl"
,
mutate
:
func
(
c
*
Config
)
{
c
.
Gateway
.
UserGroupRateCacheTTLSeconds
=
0
},
wantErr
:
"gateway.user_group_rate_cache_ttl_seconds"
,
},
{
name
:
"gateway models list cache ttl range"
,
mutate
:
func
(
c
*
Config
)
{
c
.
Gateway
.
ModelsListCacheTTLSeconds
=
31
},
wantErr
:
"gateway.models_list_cache_ttl_seconds"
,
},
{
{
name
:
"gateway scheduling sticky waiting"
,
name
:
"gateway scheduling sticky waiting"
,
mutate
:
func
(
c
*
Config
)
{
c
.
Gateway
.
Scheduling
.
StickySessionMaxWaiting
=
0
},
mutate
:
func
(
c
*
Config
)
{
c
.
Gateway
.
Scheduling
.
StickySessionMaxWaiting
=
0
},
...
...
backend/internal/handler/admin/account_handler_passthrough_test.go
View file @
a89477dd
...
@@ -64,4 +64,3 @@ func TestAccountHandler_Create_AnthropicAPIKeyPassthroughExtraForwarded(t *testi
...
@@ -64,4 +64,3 @@ func TestAccountHandler_Create_AnthropicAPIKeyPassthroughExtraForwarded(t *testi
require
.
NotNil
(
t
,
created
.
Extra
)
require
.
NotNil
(
t
,
created
.
Extra
)
require
.
Equal
(
t
,
true
,
created
.
Extra
[
"anthropic_passthrough"
])
require
.
Equal
(
t
,
true
,
created
.
Extra
[
"anthropic_passthrough"
])
}
}
backend/internal/handler/gateway_handler.go
View file @
a89477dd
...
@@ -243,6 +243,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
...
@@ -243,6 +243,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
var
sessionBoundAccountID
int64
var
sessionBoundAccountID
int64
if
sessionKey
!=
""
{
if
sessionKey
!=
""
{
sessionBoundAccountID
,
_
=
h
.
gatewayService
.
GetCachedSessionAccountID
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
)
sessionBoundAccountID
,
_
=
h
.
gatewayService
.
GetCachedSessionAccountID
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
)
if
sessionBoundAccountID
>
0
{
ctx
:=
context
.
WithValue
(
c
.
Request
.
Context
(),
ctxkey
.
PrefetchedStickyAccountID
,
sessionBoundAccountID
)
c
.
Request
=
c
.
Request
.
WithContext
(
ctx
)
}
}
}
// 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号
// 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号
hasBoundSession
:=
sessionKey
!=
""
&&
sessionBoundAccountID
>
0
hasBoundSession
:=
sessionKey
!=
""
&&
sessionBoundAccountID
>
0
...
...
backend/internal/handler/gateway_helper.go
View file @
a89477dd
...
@@ -6,6 +6,7 @@ import (
...
@@ -6,6 +6,7 @@ import (
"fmt"
"fmt"
"math/rand/v2"
"math/rand/v2"
"net/http"
"net/http"
"strings"
"sync"
"sync"
"time"
"time"
...
@@ -20,14 +21,28 @@ var claudeCodeValidator = service.NewClaudeCodeValidator()
...
@@ -20,14 +21,28 @@ var claudeCodeValidator = service.NewClaudeCodeValidator()
// SetClaudeCodeClientContext 检查请求是否来自 Claude Code 客户端,并设置到 context 中
// SetClaudeCodeClientContext 检查请求是否来自 Claude Code 客户端,并设置到 context 中
// 返回更新后的 context
// 返回更新后的 context
func
SetClaudeCodeClientContext
(
c
*
gin
.
Context
,
body
[]
byte
)
{
func
SetClaudeCodeClientContext
(
c
*
gin
.
Context
,
body
[]
byte
)
{
// 解析请求体为 map
if
c
==
nil
||
c
.
Request
==
nil
{
var
bodyMap
map
[
string
]
any
return
if
len
(
body
)
>
0
{
}
_
=
json
.
Unmarshal
(
body
,
&
bodyMap
)
// Fast path:非 Claude CLI UA 直接判定 false,避免热路径二次 JSON 反序列化。
if
!
claudeCodeValidator
.
ValidateUserAgent
(
c
.
GetHeader
(
"User-Agent"
))
{
ctx
:=
service
.
SetClaudeCodeClient
(
c
.
Request
.
Context
(),
false
)
c
.
Request
=
c
.
Request
.
WithContext
(
ctx
)
return
}
}
// 验证是否为 Claude Code 客户端
isClaudeCode
:=
false
isClaudeCode
:=
claudeCodeValidator
.
Validate
(
c
.
Request
,
bodyMap
)
if
!
strings
.
Contains
(
c
.
Request
.
URL
.
Path
,
"messages"
)
{
// 与 Validate 行为一致:非 messages 路径 UA 命中即可视为 Claude Code 客户端。
isClaudeCode
=
true
}
else
{
// 仅在确认为 Claude CLI 且 messages 路径时再做 body 解析。
var
bodyMap
map
[
string
]
any
if
len
(
body
)
>
0
{
_
=
json
.
Unmarshal
(
body
,
&
bodyMap
)
}
isClaudeCode
=
claudeCodeValidator
.
Validate
(
c
.
Request
,
bodyMap
)
}
// 更新 request context
// 更新 request context
ctx
:=
service
.
SetClaudeCodeClient
(
c
.
Request
.
Context
(),
isClaudeCode
)
ctx
:=
service
.
SetClaudeCodeClient
(
c
.
Request
.
Context
(),
isClaudeCode
)
...
@@ -223,21 +238,6 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
...
@@ -223,21 +238,6 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
ctx
,
cancel
:=
context
.
WithTimeout
(
c
.
Request
.
Context
(),
timeout
)
ctx
,
cancel
:=
context
.
WithTimeout
(
c
.
Request
.
Context
(),
timeout
)
defer
cancel
()
defer
cancel
()
// Try immediate acquire first (avoid unnecessary wait)
var
result
*
service
.
AcquireResult
var
err
error
if
slotType
==
"user"
{
result
,
err
=
h
.
concurrencyService
.
AcquireUserSlot
(
ctx
,
id
,
maxConcurrency
)
}
else
{
result
,
err
=
h
.
concurrencyService
.
AcquireAccountSlot
(
ctx
,
id
,
maxConcurrency
)
}
if
err
!=
nil
{
return
nil
,
err
}
if
result
.
Acquired
{
return
result
.
ReleaseFunc
,
nil
}
// Determine if ping is needed (streaming + ping format defined)
// Determine if ping is needed (streaming + ping format defined)
needPing
:=
isStream
&&
h
.
pingFormat
!=
""
needPing
:=
isStream
&&
h
.
pingFormat
!=
""
...
...
backend/internal/handler/gateway_helper_hotpath_test.go
0 → 100644
View file @
a89477dd
package
handler
import
(
"context"
"errors"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type
helperConcurrencyCacheStub
struct
{
mu
sync
.
Mutex
accountSeq
[]
bool
userSeq
[]
bool
accountAcquireCalls
int
userAcquireCalls
int
accountReleaseCalls
int
userReleaseCalls
int
}
func
(
s
*
helperConcurrencyCacheStub
)
AcquireAccountSlot
(
ctx
context
.
Context
,
accountID
int64
,
maxConcurrency
int
,
requestID
string
)
(
bool
,
error
)
{
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
s
.
accountAcquireCalls
++
if
len
(
s
.
accountSeq
)
==
0
{
return
false
,
nil
}
v
:=
s
.
accountSeq
[
0
]
s
.
accountSeq
=
s
.
accountSeq
[
1
:
]
return
v
,
nil
}
func
(
s
*
helperConcurrencyCacheStub
)
ReleaseAccountSlot
(
ctx
context
.
Context
,
accountID
int64
,
requestID
string
)
error
{
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
s
.
accountReleaseCalls
++
return
nil
}
func
(
s
*
helperConcurrencyCacheStub
)
GetAccountConcurrency
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
{
return
0
,
nil
}
func
(
s
*
helperConcurrencyCacheStub
)
IncrementAccountWaitCount
(
ctx
context
.
Context
,
accountID
int64
,
maxWait
int
)
(
bool
,
error
)
{
return
true
,
nil
}
func
(
s
*
helperConcurrencyCacheStub
)
DecrementAccountWaitCount
(
ctx
context
.
Context
,
accountID
int64
)
error
{
return
nil
}
func
(
s
*
helperConcurrencyCacheStub
)
GetAccountWaitingCount
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
{
return
0
,
nil
}
func
(
s
*
helperConcurrencyCacheStub
)
AcquireUserSlot
(
ctx
context
.
Context
,
userID
int64
,
maxConcurrency
int
,
requestID
string
)
(
bool
,
error
)
{
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
s
.
userAcquireCalls
++
if
len
(
s
.
userSeq
)
==
0
{
return
false
,
nil
}
v
:=
s
.
userSeq
[
0
]
s
.
userSeq
=
s
.
userSeq
[
1
:
]
return
v
,
nil
}
func
(
s
*
helperConcurrencyCacheStub
)
ReleaseUserSlot
(
ctx
context
.
Context
,
userID
int64
,
requestID
string
)
error
{
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
s
.
userReleaseCalls
++
return
nil
}
func
(
s
*
helperConcurrencyCacheStub
)
GetUserConcurrency
(
ctx
context
.
Context
,
userID
int64
)
(
int
,
error
)
{
return
0
,
nil
}
func
(
s
*
helperConcurrencyCacheStub
)
IncrementWaitCount
(
ctx
context
.
Context
,
userID
int64
,
maxWait
int
)
(
bool
,
error
)
{
return
true
,
nil
}
func
(
s
*
helperConcurrencyCacheStub
)
DecrementWaitCount
(
ctx
context
.
Context
,
userID
int64
)
error
{
return
nil
}
func
(
s
*
helperConcurrencyCacheStub
)
GetAccountsLoadBatch
(
ctx
context
.
Context
,
accounts
[]
service
.
AccountWithConcurrency
)
(
map
[
int64
]
*
service
.
AccountLoadInfo
,
error
)
{
out
:=
make
(
map
[
int64
]
*
service
.
AccountLoadInfo
,
len
(
accounts
))
for
_
,
acc
:=
range
accounts
{
out
[
acc
.
ID
]
=
&
service
.
AccountLoadInfo
{
AccountID
:
acc
.
ID
}
}
return
out
,
nil
}
func
(
s
*
helperConcurrencyCacheStub
)
GetUsersLoadBatch
(
ctx
context
.
Context
,
users
[]
service
.
UserWithConcurrency
)
(
map
[
int64
]
*
service
.
UserLoadInfo
,
error
)
{
out
:=
make
(
map
[
int64
]
*
service
.
UserLoadInfo
,
len
(
users
))
for
_
,
user
:=
range
users
{
out
[
user
.
ID
]
=
&
service
.
UserLoadInfo
{
UserID
:
user
.
ID
}
}
return
out
,
nil
}
func
(
s
*
helperConcurrencyCacheStub
)
CleanupExpiredAccountSlots
(
ctx
context
.
Context
,
accountID
int64
)
error
{
return
nil
}
func
newHelperTestContext
(
method
,
path
string
)
(
*
gin
.
Context
,
*
httptest
.
ResponseRecorder
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
method
,
path
,
nil
)
return
c
,
rec
}
func
validClaudeCodeBodyJSON
()
[]
byte
{
return
[]
byte
(
`{
"model":"claude-3-5-sonnet-20241022",
"system":[{"text":"You are Claude Code, Anthropic's official CLI for Claude."}],
"metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"}
}`
)
}
func
TestSetClaudeCodeClientContext_FastPathAndStrictPath
(
t
*
testing
.
T
)
{
t
.
Run
(
"non_cli_user_agent_sets_false"
,
func
(
t
*
testing
.
T
)
{
c
,
_
:=
newHelperTestContext
(
http
.
MethodPost
,
"/v1/messages"
)
c
.
Request
.
Header
.
Set
(
"User-Agent"
,
"curl/8.6.0"
)
SetClaudeCodeClientContext
(
c
,
validClaudeCodeBodyJSON
())
require
.
False
(
t
,
service
.
IsClaudeCodeClient
(
c
.
Request
.
Context
()))
})
t
.
Run
(
"cli_non_messages_path_sets_true"
,
func
(
t
*
testing
.
T
)
{
c
,
_
:=
newHelperTestContext
(
http
.
MethodGet
,
"/v1/models"
)
c
.
Request
.
Header
.
Set
(
"User-Agent"
,
"claude-cli/1.0.1"
)
SetClaudeCodeClientContext
(
c
,
nil
)
require
.
True
(
t
,
service
.
IsClaudeCodeClient
(
c
.
Request
.
Context
()))
})
t
.
Run
(
"cli_messages_path_valid_body_sets_true"
,
func
(
t
*
testing
.
T
)
{
c
,
_
:=
newHelperTestContext
(
http
.
MethodPost
,
"/v1/messages"
)
c
.
Request
.
Header
.
Set
(
"User-Agent"
,
"claude-cli/1.0.1"
)
c
.
Request
.
Header
.
Set
(
"X-App"
,
"claude-code"
)
c
.
Request
.
Header
.
Set
(
"anthropic-beta"
,
"message-batches-2024-09-24"
)
c
.
Request
.
Header
.
Set
(
"anthropic-version"
,
"2023-06-01"
)
SetClaudeCodeClientContext
(
c
,
validClaudeCodeBodyJSON
())
require
.
True
(
t
,
service
.
IsClaudeCodeClient
(
c
.
Request
.
Context
()))
})
t
.
Run
(
"cli_messages_path_invalid_body_sets_false"
,
func
(
t
*
testing
.
T
)
{
c
,
_
:=
newHelperTestContext
(
http
.
MethodPost
,
"/v1/messages"
)
c
.
Request
.
Header
.
Set
(
"User-Agent"
,
"claude-cli/1.0.1"
)
// 缺少严格校验所需 header + body 字段
SetClaudeCodeClientContext
(
c
,
[]
byte
(
`{"model":"x"}`
))
require
.
False
(
t
,
service
.
IsClaudeCodeClient
(
c
.
Request
.
Context
()))
})
}
func
TestWaitForSlotWithPingTimeout_AccountAndUserAcquire
(
t
*
testing
.
T
)
{
cache
:=
&
helperConcurrencyCacheStub
{
accountSeq
:
[]
bool
{
false
,
true
},
userSeq
:
[]
bool
{
false
,
true
},
}
concurrency
:=
service
.
NewConcurrencyService
(
cache
)
helper
:=
NewConcurrencyHelper
(
concurrency
,
SSEPingFormatNone
,
5
*
time
.
Millisecond
)
t
.
Run
(
"account_slot_acquired_after_retry"
,
func
(
t
*
testing
.
T
)
{
c
,
_
:=
newHelperTestContext
(
http
.
MethodPost
,
"/v1/messages"
)
streamStarted
:=
false
release
,
err
:=
helper
.
waitForSlotWithPingTimeout
(
c
,
"account"
,
101
,
2
,
time
.
Second
,
false
,
&
streamStarted
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
release
)
require
.
False
(
t
,
streamStarted
)
release
()
require
.
GreaterOrEqual
(
t
,
cache
.
accountAcquireCalls
,
2
)
require
.
GreaterOrEqual
(
t
,
cache
.
accountReleaseCalls
,
1
)
})
t
.
Run
(
"user_slot_acquired_after_retry"
,
func
(
t
*
testing
.
T
)
{
c
,
_
:=
newHelperTestContext
(
http
.
MethodPost
,
"/v1/messages"
)
streamStarted
:=
false
release
,
err
:=
helper
.
waitForSlotWithPingTimeout
(
c
,
"user"
,
202
,
3
,
time
.
Second
,
false
,
&
streamStarted
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
release
)
release
()
require
.
GreaterOrEqual
(
t
,
cache
.
userAcquireCalls
,
2
)
require
.
GreaterOrEqual
(
t
,
cache
.
userReleaseCalls
,
1
)
})
}
func
TestWaitForSlotWithPingTimeout_TimeoutAndStreamPing
(
t
*
testing
.
T
)
{
cache
:=
&
helperConcurrencyCacheStub
{
accountSeq
:
[]
bool
{
false
,
false
,
false
},
}
concurrency
:=
service
.
NewConcurrencyService
(
cache
)
t
.
Run
(
"timeout_returns_concurrency_error"
,
func
(
t
*
testing
.
T
)
{
helper
:=
NewConcurrencyHelper
(
concurrency
,
SSEPingFormatNone
,
5
*
time
.
Millisecond
)
c
,
_
:=
newHelperTestContext
(
http
.
MethodPost
,
"/v1/messages"
)
streamStarted
:=
false
release
,
err
:=
helper
.
waitForSlotWithPingTimeout
(
c
,
"account"
,
101
,
2
,
130
*
time
.
Millisecond
,
false
,
&
streamStarted
)
require
.
Nil
(
t
,
release
)
var
cErr
*
ConcurrencyError
require
.
ErrorAs
(
t
,
err
,
&
cErr
)
require
.
True
(
t
,
cErr
.
IsTimeout
)
})
t
.
Run
(
"stream_mode_sends_ping_before_timeout"
,
func
(
t
*
testing
.
T
)
{
helper
:=
NewConcurrencyHelper
(
concurrency
,
SSEPingFormatComment
,
10
*
time
.
Millisecond
)
c
,
rec
:=
newHelperTestContext
(
http
.
MethodPost
,
"/v1/messages"
)
streamStarted
:=
false
release
,
err
:=
helper
.
waitForSlotWithPingTimeout
(
c
,
"account"
,
101
,
2
,
70
*
time
.
Millisecond
,
true
,
&
streamStarted
)
require
.
Nil
(
t
,
release
)
var
cErr
*
ConcurrencyError
require
.
ErrorAs
(
t
,
err
,
&
cErr
)
require
.
True
(
t
,
cErr
.
IsTimeout
)
require
.
True
(
t
,
streamStarted
)
require
.
Contains
(
t
,
rec
.
Body
.
String
(),
":
\n\n
"
)
})
}
func
TestWaitForSlotWithPingTimeout_AcquireError
(
t
*
testing
.
T
)
{
errCache
:=
&
helperConcurrencyCacheStubWithError
{
err
:
errors
.
New
(
"redis unavailable"
),
}
concurrency
:=
service
.
NewConcurrencyService
(
errCache
)
helper
:=
NewConcurrencyHelper
(
concurrency
,
SSEPingFormatNone
,
5
*
time
.
Millisecond
)
c
,
_
:=
newHelperTestContext
(
http
.
MethodPost
,
"/v1/messages"
)
streamStarted
:=
false
release
,
err
:=
helper
.
waitForSlotWithPingTimeout
(
c
,
"account"
,
1
,
1
,
200
*
time
.
Millisecond
,
false
,
&
streamStarted
)
require
.
Nil
(
t
,
release
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"redis unavailable"
)
}
type
helperConcurrencyCacheStubWithError
struct
{
helperConcurrencyCacheStub
err
error
}
func
(
s
*
helperConcurrencyCacheStubWithError
)
AcquireAccountSlot
(
ctx
context
.
Context
,
accountID
int64
,
maxConcurrency
int
,
requestID
string
)
(
bool
,
error
)
{
return
false
,
s
.
err
}
backend/internal/handler/gemini_v1beta_handler.go
View file @
a89477dd
...
@@ -263,6 +263,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
...
@@ -263,6 +263,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
var
sessionBoundAccountID
int64
var
sessionBoundAccountID
int64
if
sessionKey
!=
""
{
if
sessionKey
!=
""
{
sessionBoundAccountID
,
_
=
h
.
gatewayService
.
GetCachedSessionAccountID
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
)
sessionBoundAccountID
,
_
=
h
.
gatewayService
.
GetCachedSessionAccountID
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
)
if
sessionBoundAccountID
>
0
{
ctx
:=
context
.
WithValue
(
c
.
Request
.
Context
(),
ctxkey
.
PrefetchedStickyAccountID
,
sessionBoundAccountID
)
c
.
Request
=
c
.
Request
.
WithContext
(
ctx
)
}
}
}
// === Gemini 内容摘要会话 Fallback 逻辑 ===
// === Gemini 内容摘要会话 Fallback 逻辑 ===
...
...
backend/internal/handler/ops_error_logger.go
View file @
a89477dd
...
@@ -41,9 +41,8 @@ const (
...
@@ -41,9 +41,8 @@ const (
)
)
type
opsErrorLogJob
struct
{
type
opsErrorLogJob
struct
{
ops
*
service
.
OpsService
ops
*
service
.
OpsService
entry
*
service
.
OpsInsertErrorLogInput
entry
*
service
.
OpsInsertErrorLogInput
requestBody
[]
byte
}
}
var
(
var
(
...
@@ -58,6 +57,7 @@ var (
...
@@ -58,6 +57,7 @@ var (
opsErrorLogEnqueued
atomic
.
Int64
opsErrorLogEnqueued
atomic
.
Int64
opsErrorLogDropped
atomic
.
Int64
opsErrorLogDropped
atomic
.
Int64
opsErrorLogProcessed
atomic
.
Int64
opsErrorLogProcessed
atomic
.
Int64
opsErrorLogSanitized
atomic
.
Int64
opsErrorLogLastDropLogAt
atomic
.
Int64
opsErrorLogLastDropLogAt
atomic
.
Int64
...
@@ -94,7 +94,7 @@ func startOpsErrorLogWorkers() {
...
@@ -94,7 +94,7 @@ func startOpsErrorLogWorkers() {
}
}
}()
}()
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
opsErrorLogTimeout
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
opsErrorLogTimeout
)
_
=
job
.
ops
.
RecordError
(
ctx
,
job
.
entry
,
job
.
requestBody
)
_
=
job
.
ops
.
RecordError
(
ctx
,
job
.
entry
,
nil
)
cancel
()
cancel
()
opsErrorLogProcessed
.
Add
(
1
)
opsErrorLogProcessed
.
Add
(
1
)
}()
}()
...
@@ -103,7 +103,7 @@ func startOpsErrorLogWorkers() {
...
@@ -103,7 +103,7 @@ func startOpsErrorLogWorkers() {
}
}
}
}
func
enqueueOpsErrorLog
(
ops
*
service
.
OpsService
,
entry
*
service
.
OpsInsertErrorLogInput
,
requestBody
[]
byte
)
{
func
enqueueOpsErrorLog
(
ops
*
service
.
OpsService
,
entry
*
service
.
OpsInsertErrorLogInput
)
{
if
ops
==
nil
||
entry
==
nil
{
if
ops
==
nil
||
entry
==
nil
{
return
return
}
}
...
@@ -129,7 +129,7 @@ func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLo
...
@@ -129,7 +129,7 @@ func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLo
}
}
select
{
select
{
case
opsErrorLogQueue
<-
opsErrorLogJob
{
ops
:
ops
,
entry
:
entry
,
requestBody
:
requestBody
}
:
case
opsErrorLogQueue
<-
opsErrorLogJob
{
ops
:
ops
,
entry
:
entry
}
:
opsErrorLogQueueLen
.
Add
(
1
)
opsErrorLogQueueLen
.
Add
(
1
)
opsErrorLogEnqueued
.
Add
(
1
)
opsErrorLogEnqueued
.
Add
(
1
)
default
:
default
:
...
@@ -205,6 +205,10 @@ func OpsErrorLogProcessedTotal() int64 {
...
@@ -205,6 +205,10 @@ func OpsErrorLogProcessedTotal() int64 {
return
opsErrorLogProcessed
.
Load
()
return
opsErrorLogProcessed
.
Load
()
}
}
func
OpsErrorLogSanitizedTotal
()
int64
{
return
opsErrorLogSanitized
.
Load
()
}
func
maybeLogOpsErrorLogDrop
()
{
func
maybeLogOpsErrorLogDrop
()
{
now
:=
time
.
Now
()
.
Unix
()
now
:=
time
.
Now
()
.
Unix
()
...
@@ -222,12 +226,13 @@ func maybeLogOpsErrorLogDrop() {
...
@@ -222,12 +226,13 @@ func maybeLogOpsErrorLogDrop() {
queueCap
:=
OpsErrorLogQueueCapacity
()
queueCap
:=
OpsErrorLogQueueCapacity
()
log
.
Printf
(
log
.
Printf
(
"[OpsErrorLogger] queue is full; dropping logs (queued=%d cap=%d enqueued_total=%d dropped_total=%d processed_total=%d)"
,
"[OpsErrorLogger] queue is full; dropping logs (queued=%d cap=%d enqueued_total=%d dropped_total=%d processed_total=%d
sanitized_total=%d
)"
,
queued
,
queued
,
queueCap
,
queueCap
,
opsErrorLogEnqueued
.
Load
(),
opsErrorLogEnqueued
.
Load
(),
opsErrorLogDropped
.
Load
(),
opsErrorLogDropped
.
Load
(),
opsErrorLogProcessed
.
Load
(),
opsErrorLogProcessed
.
Load
(),
opsErrorLogSanitized
.
Load
(),
)
)
}
}
...
@@ -267,6 +272,22 @@ func setOpsRequestContext(c *gin.Context, model string, stream bool, requestBody
...
@@ -267,6 +272,22 @@ func setOpsRequestContext(c *gin.Context, model string, stream bool, requestBody
}
}
}
}
func
attachOpsRequestBodyToEntry
(
c
*
gin
.
Context
,
entry
*
service
.
OpsInsertErrorLogInput
)
{
if
c
==
nil
||
entry
==
nil
{
return
}
v
,
ok
:=
c
.
Get
(
opsRequestBodyKey
)
if
!
ok
{
return
}
raw
,
ok
:=
v
.
([]
byte
)
if
!
ok
||
len
(
raw
)
==
0
{
return
}
entry
.
RequestBodyJSON
,
entry
.
RequestBodyTruncated
,
entry
.
RequestBodyBytes
=
service
.
PrepareOpsRequestBodyForQueue
(
raw
)
opsErrorLogSanitized
.
Add
(
1
)
}
func
setOpsSelectedAccount
(
c
*
gin
.
Context
,
accountID
int64
,
platform
...
string
)
{
func
setOpsSelectedAccount
(
c
*
gin
.
Context
,
accountID
int64
,
platform
...
string
)
{
if
c
==
nil
||
accountID
<=
0
{
if
c
==
nil
||
accountID
<=
0
{
return
return
...
@@ -544,14 +565,9 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
...
@@ -544,14 +565,9 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
entry
.
ClientIP
=
&
clientIP
entry
.
ClientIP
=
&
clientIP
}
}
var
requestBody
[]
byte
if
v
,
ok
:=
c
.
Get
(
opsRequestBodyKey
);
ok
{
if
b
,
ok
:=
v
.
([]
byte
);
ok
&&
len
(
b
)
>
0
{
requestBody
=
b
}
}
// Store request headers/body only when an upstream error occurred to keep overhead minimal.
// Store request headers/body only when an upstream error occurred to keep overhead minimal.
entry
.
RequestHeadersJSON
=
extractOpsRetryRequestHeaders
(
c
)
entry
.
RequestHeadersJSON
=
extractOpsRetryRequestHeaders
(
c
)
attachOpsRequestBodyToEntry
(
c
,
entry
)
// Skip logging if a passthrough rule with skip_monitoring=true matched.
// Skip logging if a passthrough rule with skip_monitoring=true matched.
if
v
,
ok
:=
c
.
Get
(
service
.
OpsSkipPassthroughKey
);
ok
{
if
v
,
ok
:=
c
.
Get
(
service
.
OpsSkipPassthroughKey
);
ok
{
...
@@ -560,7 +576,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
...
@@ -560,7 +576,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
}
}
}
}
enqueueOpsErrorLog
(
ops
,
entry
,
requestBody
)
enqueueOpsErrorLog
(
ops
,
entry
)
return
return
}
}
...
@@ -724,17 +740,12 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
...
@@ -724,17 +740,12 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
entry
.
ClientIP
=
&
clientIP
entry
.
ClientIP
=
&
clientIP
}
}
var
requestBody
[]
byte
if
v
,
ok
:=
c
.
Get
(
opsRequestBodyKey
);
ok
{
if
b
,
ok
:=
v
.
([]
byte
);
ok
&&
len
(
b
)
>
0
{
requestBody
=
b
}
}
// Persist only a minimal, whitelisted set of request headers to improve retry fidelity.
// Persist only a minimal, whitelisted set of request headers to improve retry fidelity.
// Do NOT store Authorization/Cookie/etc.
// Do NOT store Authorization/Cookie/etc.
entry
.
RequestHeadersJSON
=
extractOpsRetryRequestHeaders
(
c
)
entry
.
RequestHeadersJSON
=
extractOpsRetryRequestHeaders
(
c
)
attachOpsRequestBodyToEntry
(
c
,
entry
)
enqueueOpsErrorLog
(
ops
,
entry
,
requestBody
)
enqueueOpsErrorLog
(
ops
,
entry
)
}
}
}
}
...
...
backend/internal/handler/ops_error_logger_test.go
0 → 100644
View file @
a89477dd
package
handler
import
(
"net/http"
"net/http/httptest"
"sync"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func
resetOpsErrorLoggerStateForTest
(
t
*
testing
.
T
)
{
t
.
Helper
()
opsErrorLogMu
.
Lock
()
ch
:=
opsErrorLogQueue
opsErrorLogQueue
=
nil
opsErrorLogStopping
=
true
opsErrorLogMu
.
Unlock
()
if
ch
!=
nil
{
close
(
ch
)
}
opsErrorLogWorkersWg
.
Wait
()
opsErrorLogOnce
=
sync
.
Once
{}
opsErrorLogStopOnce
=
sync
.
Once
{}
opsErrorLogWorkersWg
=
sync
.
WaitGroup
{}
opsErrorLogMu
=
sync
.
RWMutex
{}
opsErrorLogStopping
=
false
opsErrorLogQueueLen
.
Store
(
0
)
opsErrorLogEnqueued
.
Store
(
0
)
opsErrorLogDropped
.
Store
(
0
)
opsErrorLogProcessed
.
Store
(
0
)
opsErrorLogSanitized
.
Store
(
0
)
opsErrorLogLastDropLogAt
.
Store
(
0
)
opsErrorLogShutdownCh
=
make
(
chan
struct
{})
opsErrorLogShutdownOnce
=
sync
.
Once
{}
opsErrorLogDrained
.
Store
(
false
)
}
func
TestAttachOpsRequestBodyToEntry_SanitizeAndTrim
(
t
*
testing
.
T
)
{
resetOpsErrorLoggerStateForTest
(
t
)
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/v1/messages"
,
nil
)
raw
:=
[]
byte
(
`{"access_token":"secret-token","messages":[{"role":"user","content":"hello"}]}`
)
setOpsRequestContext
(
c
,
"claude-3"
,
false
,
raw
)
entry
:=
&
service
.
OpsInsertErrorLogInput
{}
attachOpsRequestBodyToEntry
(
c
,
entry
)
require
.
NotNil
(
t
,
entry
.
RequestBodyBytes
)
require
.
Equal
(
t
,
len
(
raw
),
*
entry
.
RequestBodyBytes
)
require
.
NotNil
(
t
,
entry
.
RequestBodyJSON
)
require
.
NotContains
(
t
,
*
entry
.
RequestBodyJSON
,
"secret-token"
)
require
.
Contains
(
t
,
*
entry
.
RequestBodyJSON
,
"[REDACTED]"
)
require
.
Equal
(
t
,
int64
(
1
),
OpsErrorLogSanitizedTotal
())
}
func
TestAttachOpsRequestBodyToEntry_InvalidJSONKeepsSize
(
t
*
testing
.
T
)
{
resetOpsErrorLoggerStateForTest
(
t
)
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/v1/messages"
,
nil
)
raw
:=
[]
byte
(
"not-json"
)
setOpsRequestContext
(
c
,
"claude-3"
,
false
,
raw
)
entry
:=
&
service
.
OpsInsertErrorLogInput
{}
attachOpsRequestBodyToEntry
(
c
,
entry
)
require
.
Nil
(
t
,
entry
.
RequestBodyJSON
)
require
.
NotNil
(
t
,
entry
.
RequestBodyBytes
)
require
.
Equal
(
t
,
len
(
raw
),
*
entry
.
RequestBodyBytes
)
require
.
False
(
t
,
entry
.
RequestBodyTruncated
)
require
.
Equal
(
t
,
int64
(
1
),
OpsErrorLogSanitizedTotal
())
}
func
TestEnqueueOpsErrorLog_QueueFullDrop
(
t
*
testing
.
T
)
{
resetOpsErrorLoggerStateForTest
(
t
)
// 禁止 enqueueOpsErrorLog 触发 workers,使用测试队列验证满队列降级。
opsErrorLogOnce
.
Do
(
func
()
{})
opsErrorLogMu
.
Lock
()
opsErrorLogQueue
=
make
(
chan
opsErrorLogJob
,
1
)
opsErrorLogMu
.
Unlock
()
ops
:=
service
.
NewOpsService
(
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
entry
:=
&
service
.
OpsInsertErrorLogInput
{
ErrorPhase
:
"upstream"
,
ErrorType
:
"upstream_error"
}
enqueueOpsErrorLog
(
ops
,
entry
)
enqueueOpsErrorLog
(
ops
,
entry
)
require
.
Equal
(
t
,
int64
(
1
),
OpsErrorLogEnqueuedTotal
())
require
.
Equal
(
t
,
int64
(
1
),
OpsErrorLogDroppedTotal
())
require
.
Equal
(
t
,
int64
(
1
),
OpsErrorLogQueueLength
())
}
func
TestAttachOpsRequestBodyToEntry_EarlyReturnBranches
(
t
*
testing
.
T
)
{
resetOpsErrorLoggerStateForTest
(
t
)
gin
.
SetMode
(
gin
.
TestMode
)
entry
:=
&
service
.
OpsInsertErrorLogInput
{}
attachOpsRequestBodyToEntry
(
nil
,
entry
)
attachOpsRequestBodyToEntry
(
&
gin
.
Context
{},
nil
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/v1/messages"
,
nil
)
// 无请求体 key
attachOpsRequestBodyToEntry
(
c
,
entry
)
require
.
Nil
(
t
,
entry
.
RequestBodyJSON
)
require
.
Nil
(
t
,
entry
.
RequestBodyBytes
)
require
.
False
(
t
,
entry
.
RequestBodyTruncated
)
// 错误类型
c
.
Set
(
opsRequestBodyKey
,
"not-bytes"
)
attachOpsRequestBodyToEntry
(
c
,
entry
)
require
.
Nil
(
t
,
entry
.
RequestBodyJSON
)
require
.
Nil
(
t
,
entry
.
RequestBodyBytes
)
// 空 bytes
c
.
Set
(
opsRequestBodyKey
,
[]
byte
{})
attachOpsRequestBodyToEntry
(
c
,
entry
)
require
.
Nil
(
t
,
entry
.
RequestBodyJSON
)
require
.
Nil
(
t
,
entry
.
RequestBodyBytes
)
require
.
Equal
(
t
,
int64
(
0
),
OpsErrorLogSanitizedTotal
())
}
func
TestEnqueueOpsErrorLog_EarlyReturnBranches
(
t
*
testing
.
T
)
{
resetOpsErrorLoggerStateForTest
(
t
)
ops
:=
service
.
NewOpsService
(
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
entry
:=
&
service
.
OpsInsertErrorLogInput
{
ErrorPhase
:
"upstream"
,
ErrorType
:
"upstream_error"
}
// nil 入参分支
enqueueOpsErrorLog
(
nil
,
entry
)
enqueueOpsErrorLog
(
ops
,
nil
)
require
.
Equal
(
t
,
int64
(
0
),
OpsErrorLogEnqueuedTotal
())
// shutdown 分支
close
(
opsErrorLogShutdownCh
)
enqueueOpsErrorLog
(
ops
,
entry
)
require
.
Equal
(
t
,
int64
(
0
),
OpsErrorLogEnqueuedTotal
())
// stopping 分支
resetOpsErrorLoggerStateForTest
(
t
)
opsErrorLogMu
.
Lock
()
opsErrorLogStopping
=
true
opsErrorLogMu
.
Unlock
()
enqueueOpsErrorLog
(
ops
,
entry
)
require
.
Equal
(
t
,
int64
(
0
),
OpsErrorLogEnqueuedTotal
())
// queue nil 分支(防止启动 worker 干扰)
resetOpsErrorLoggerStateForTest
(
t
)
opsErrorLogOnce
.
Do
(
func
()
{})
opsErrorLogMu
.
Lock
()
opsErrorLogQueue
=
nil
opsErrorLogMu
.
Unlock
()
enqueueOpsErrorLog
(
ops
,
entry
)
require
.
Equal
(
t
,
int64
(
0
),
OpsErrorLogEnqueuedTotal
())
}
backend/internal/pkg/ctxkey/ctxkey.go
View file @
a89477dd
...
@@ -44,4 +44,8 @@ const (
...
@@ -44,4 +44,8 @@ const (
// SingleAccountRetry 标识当前请求处于单账号 503 退避重试模式。
// SingleAccountRetry 标识当前请求处于单账号 503 退避重试模式。
// 在此模式下,Service 层的模型限流预检查将等待限流过期而非直接切换账号。
// 在此模式下,Service 层的模型限流预检查将等待限流过期而非直接切换账号。
SingleAccountRetry
Key
=
"ctx_single_account_retry"
SingleAccountRetry
Key
=
"ctx_single_account_retry"
// PrefetchedStickyAccountID 标识上游(通常 handler)预取到的 sticky session 账号 ID。
// Service 层可复用该值,避免同请求链路重复读取 Redis。
PrefetchedStickyAccountID
Key
=
"ctx_prefetched_sticky_account_id"
)
)
backend/internal/repository/usage_log_repo.go
View file @
a89477dd
...
@@ -915,6 +915,59 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
...
@@ -915,6 +915,59 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
return
stats
,
nil
return
stats
,
nil
}
}
// GetAccountWindowStatsBatch 批量获取同一窗口起点下多个账号的统计数据。
// 返回 map[accountID]*AccountStats,未命中的账号会返回零值统计,便于上层直接复用。
func
(
r
*
usageLogRepository
)
GetAccountWindowStatsBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
,
startTime
time
.
Time
)
(
map
[
int64
]
*
usagestats
.
AccountStats
,
error
)
{
result
:=
make
(
map
[
int64
]
*
usagestats
.
AccountStats
,
len
(
accountIDs
))
if
len
(
accountIDs
)
==
0
{
return
result
,
nil
}
query
:=
`
SELECT
account_id,
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(total_cost), 0) as standard_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
WHERE account_id = ANY($1) AND created_at >= $2
GROUP BY account_id
`
rows
,
err
:=
r
.
sql
.
QueryContext
(
ctx
,
query
,
pq
.
Array
(
accountIDs
),
startTime
)
if
err
!=
nil
{
return
nil
,
err
}
defer
func
()
{
_
=
rows
.
Close
()
}()
for
rows
.
Next
()
{
var
accountID
int64
stats
:=
&
usagestats
.
AccountStats
{}
if
err
:=
rows
.
Scan
(
&
accountID
,
&
stats
.
Requests
,
&
stats
.
Tokens
,
&
stats
.
Cost
,
&
stats
.
StandardCost
,
&
stats
.
UserCost
,
);
err
!=
nil
{
return
nil
,
err
}
result
[
accountID
]
=
stats
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
nil
,
err
}
for
_
,
accountID
:=
range
accountIDs
{
if
_
,
ok
:=
result
[
accountID
];
!
ok
{
result
[
accountID
]
=
&
usagestats
.
AccountStats
{}
}
}
return
result
,
nil
}
// TrendDataPoint represents a single point in trend data
// TrendDataPoint represents a single point in trend data
type
TrendDataPoint
=
usagestats
.
TrendDataPoint
type
TrendDataPoint
=
usagestats
.
TrendDataPoint
...
...
backend/internal/service/gateway_hotpath_optimization_test.go
0 → 100644
View file @
a89477dd
This diff is collapsed.
Click to expand it.
backend/internal/service/gateway_service.go
View file @
a89477dd
...
@@ -24,12 +24,15 @@ import (
...
@@ -24,12 +24,15 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/cespare/xxhash/v2"
"github.com/cespare/xxhash/v2"
"github.com/google/uuid"
"github.com/google/uuid"
gocache
"github.com/patrickmn/go-cache"
"github.com/tidwall/gjson"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"github.com/tidwall/sjson"
"golang.org/x/sync/singleflight"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin"
)
)
...
@@ -44,6 +47,9 @@ const (
...
@@ -44,6 +47,9 @@ const (
// separator between system blocks, we add "\n\n" at concatenation time.
// separator between system blocks, we add "\n\n" at concatenation time.
claudeCodeSystemPrompt
=
"You are Claude Code, Anthropic's official CLI for Claude."
claudeCodeSystemPrompt
=
"You are Claude Code, Anthropic's official CLI for Claude."
maxCacheControlBlocks
=
4
// Anthropic API 允许的最大 cache_control 块数量
maxCacheControlBlocks
=
4
// Anthropic API 允许的最大 cache_control 块数量
defaultUserGroupRateCacheTTL
=
30
*
time
.
Second
defaultModelsListCacheTTL
=
15
*
time
.
Second
)
)
const
(
const
(
...
@@ -62,6 +68,53 @@ type accountWithLoad struct {
...
@@ -62,6 +68,53 @@ type accountWithLoad struct {
var
ForceCacheBillingContextKey
=
forceCacheBillingKeyType
{}
var
ForceCacheBillingContextKey
=
forceCacheBillingKeyType
{}
var
(
windowCostPrefetchCacheHitTotal
atomic
.
Int64
windowCostPrefetchCacheMissTotal
atomic
.
Int64
windowCostPrefetchBatchSQLTotal
atomic
.
Int64
windowCostPrefetchFallbackTotal
atomic
.
Int64
windowCostPrefetchErrorTotal
atomic
.
Int64
userGroupRateCacheHitTotal
atomic
.
Int64
userGroupRateCacheMissTotal
atomic
.
Int64
userGroupRateCacheLoadTotal
atomic
.
Int64
userGroupRateCacheSFSharedTotal
atomic
.
Int64
userGroupRateCacheFallbackTotal
atomic
.
Int64
modelsListCacheHitTotal
atomic
.
Int64
modelsListCacheMissTotal
atomic
.
Int64
modelsListCacheStoreTotal
atomic
.
Int64
)
func
GatewayWindowCostPrefetchStats
()
(
cacheHit
,
cacheMiss
,
batchSQL
,
fallback
,
errCount
int64
)
{
return
windowCostPrefetchCacheHitTotal
.
Load
(),
windowCostPrefetchCacheMissTotal
.
Load
(),
windowCostPrefetchBatchSQLTotal
.
Load
(),
windowCostPrefetchFallbackTotal
.
Load
(),
windowCostPrefetchErrorTotal
.
Load
()
}
func
GatewayUserGroupRateCacheStats
()
(
cacheHit
,
cacheMiss
,
load
,
singleflightShared
,
fallback
int64
)
{
return
userGroupRateCacheHitTotal
.
Load
(),
userGroupRateCacheMissTotal
.
Load
(),
userGroupRateCacheLoadTotal
.
Load
(),
userGroupRateCacheSFSharedTotal
.
Load
(),
userGroupRateCacheFallbackTotal
.
Load
()
}
func
GatewayModelsListCacheStats
()
(
cacheHit
,
cacheMiss
,
store
int64
)
{
return
modelsListCacheHitTotal
.
Load
(),
modelsListCacheMissTotal
.
Load
(),
modelsListCacheStoreTotal
.
Load
()
}
func
cloneStringSlice
(
src
[]
string
)
[]
string
{
if
len
(
src
)
==
0
{
return
nil
}
dst
:=
make
([]
string
,
len
(
src
))
copy
(
dst
,
src
)
return
dst
}
// IsForceCacheBilling 检查是否启用强制缓存计费
// IsForceCacheBilling 检查是否启用强制缓存计费
func
IsForceCacheBilling
(
ctx
context
.
Context
)
bool
{
func
IsForceCacheBilling
(
ctx
context
.
Context
)
bool
{
v
,
_
:=
ctx
.
Value
(
ForceCacheBillingContextKey
)
.
(
bool
)
v
,
_
:=
ctx
.
Value
(
ForceCacheBillingContextKey
)
.
(
bool
)
...
@@ -302,6 +355,42 @@ func derefGroupID(groupID *int64) int64 {
...
@@ -302,6 +355,42 @@ func derefGroupID(groupID *int64) int64 {
return
*
groupID
return
*
groupID
}
}
func
resolveUserGroupRateCacheTTL
(
cfg
*
config
.
Config
)
time
.
Duration
{
if
cfg
==
nil
||
cfg
.
Gateway
.
UserGroupRateCacheTTLSeconds
<=
0
{
return
defaultUserGroupRateCacheTTL
}
return
time
.
Duration
(
cfg
.
Gateway
.
UserGroupRateCacheTTLSeconds
)
*
time
.
Second
}
func
resolveModelsListCacheTTL
(
cfg
*
config
.
Config
)
time
.
Duration
{
if
cfg
==
nil
||
cfg
.
Gateway
.
ModelsListCacheTTLSeconds
<=
0
{
return
defaultModelsListCacheTTL
}
return
time
.
Duration
(
cfg
.
Gateway
.
ModelsListCacheTTLSeconds
)
*
time
.
Second
}
func
modelsListCacheKey
(
groupID
*
int64
,
platform
string
)
string
{
return
fmt
.
Sprintf
(
"%d|%s"
,
derefGroupID
(
groupID
),
strings
.
TrimSpace
(
platform
))
}
func
prefetchedStickyAccountIDFromContext
(
ctx
context
.
Context
)
int64
{
if
ctx
==
nil
{
return
0
}
v
:=
ctx
.
Value
(
ctxkey
.
PrefetchedStickyAccountID
)
switch
t
:=
v
.
(
type
)
{
case
int64
:
if
t
>
0
{
return
t
}
case
int
:
if
t
>
0
{
return
int64
(
t
)
}
}
return
0
}
// shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。
// shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。
// 当账号状态为错误、禁用、不可调度、处于临时不可调度期间,
// 当账号状态为错误、禁用、不可调度、处于临时不可调度期间,
// 或请求的模型处于限流状态时,返回 true。
// 或请求的模型处于限流状态时,返回 true。
...
@@ -421,6 +510,10 @@ type GatewayService struct {
...
@@ -421,6 +510,10 @@ type GatewayService struct {
concurrencyService
*
ConcurrencyService
concurrencyService
*
ConcurrencyService
claudeTokenProvider
*
ClaudeTokenProvider
claudeTokenProvider
*
ClaudeTokenProvider
sessionLimitCache
SessionLimitCache
// 会话数量限制缓存(仅 Anthropic OAuth/SetupToken)
sessionLimitCache
SessionLimitCache
// 会话数量限制缓存(仅 Anthropic OAuth/SetupToken)
userGroupRateCache
*
gocache
.
Cache
userGroupRateSF
singleflight
.
Group
modelsListCache
*
gocache
.
Cache
modelsListCacheTTL
time
.
Duration
}
}
// NewGatewayService creates a new GatewayService
// NewGatewayService creates a new GatewayService
...
@@ -445,6 +538,9 @@ func NewGatewayService(
...
@@ -445,6 +538,9 @@ func NewGatewayService(
sessionLimitCache
SessionLimitCache
,
sessionLimitCache
SessionLimitCache
,
digestStore
*
DigestSessionStore
,
digestStore
*
DigestSessionStore
,
)
*
GatewayService
{
)
*
GatewayService
{
userGroupRateTTL
:=
resolveUserGroupRateCacheTTL
(
cfg
)
modelsListTTL
:=
resolveModelsListCacheTTL
(
cfg
)
return
&
GatewayService
{
return
&
GatewayService
{
accountRepo
:
accountRepo
,
accountRepo
:
accountRepo
,
groupRepo
:
groupRepo
,
groupRepo
:
groupRepo
,
...
@@ -465,6 +561,9 @@ func NewGatewayService(
...
@@ -465,6 +561,9 @@ func NewGatewayService(
deferredService
:
deferredService
,
deferredService
:
deferredService
,
claudeTokenProvider
:
claudeTokenProvider
,
claudeTokenProvider
:
claudeTokenProvider
,
sessionLimitCache
:
sessionLimitCache
,
sessionLimitCache
:
sessionLimitCache
,
userGroupRateCache
:
gocache
.
New
(
userGroupRateTTL
,
time
.
Minute
),
modelsListCache
:
gocache
.
New
(
modelsListTTL
,
time
.
Minute
),
modelsListCacheTTL
:
modelsListTTL
,
}
}
}
}
...
@@ -937,7 +1036,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
...
@@ -937,7 +1036,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
cfg
:=
s
.
schedulingConfig
()
cfg
:=
s
.
schedulingConfig
()
var
stickyAccountID
int64
var
stickyAccountID
int64
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
if
prefetch
:=
prefetchedStickyAccountIDFromContext
(
ctx
);
prefetch
>
0
{
stickyAccountID
=
prefetch
}
else
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
if
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
);
err
==
nil
{
if
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
);
err
==
nil
{
stickyAccountID
=
accountID
stickyAccountID
=
accountID
}
}
...
@@ -1035,6 +1136,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
...
@@ -1035,6 +1136,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if
len
(
accounts
)
==
0
{
if
len
(
accounts
)
==
0
{
return
nil
,
errors
.
New
(
"no available accounts"
)
return
nil
,
errors
.
New
(
"no available accounts"
)
}
}
ctx
=
s
.
withWindowCostPrefetch
(
ctx
,
accounts
)
isExcluded
:=
func
(
accountID
int64
)
bool
{
isExcluded
:=
func
(
accountID
int64
)
bool
{
if
excludedIDs
==
nil
{
if
excludedIDs
==
nil
{
...
@@ -1125,9 +1227,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
...
@@ -1125,9 +1227,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if
len
(
routingCandidates
)
>
0
{
if
len
(
routingCandidates
)
>
0
{
// 1.5. 在路由账号范围内检查粘性会话
// 1.5. 在路由账号范围内检查粘性会话
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
if
sessionHash
!=
""
&&
stickyAccountID
>
0
{
stickyAccountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
if
containsInt64
(
routingAccountIDs
,
stickyAccountID
)
&&
!
isExcluded
(
stickyAccountID
)
{
if
err
==
nil
&&
stickyAccountID
>
0
&&
containsInt64
(
routingAccountIDs
,
stickyAccountID
)
&&
!
isExcluded
(
stickyAccountID
)
{
// 粘性账号在路由列表中,优先使用
// 粘性账号在路由列表中,优先使用
if
stickyAccount
,
ok
:=
accountByID
[
stickyAccountID
];
ok
{
if
stickyAccount
,
ok
:=
accountByID
[
stickyAccountID
];
ok
{
if
stickyAccount
.
IsSchedulable
()
&&
if
stickyAccount
.
IsSchedulable
()
&&
...
@@ -1273,9 +1374,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
...
@@ -1273,9 +1374,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
}
// ============ Layer 1.5: 粘性会话(仅在无模型路由配置时生效) ============
// ============ Layer 1.5: 粘性会话(仅在无模型路由配置时生效) ============
if
len
(
routingAccountIDs
)
==
0
&&
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
if
len
(
routingAccountIDs
)
==
0
&&
sessionHash
!=
""
&&
s
tickyAccountID
>
0
&&
!
isExcluded
(
stickyAccountID
)
{
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
accountID
:=
stickyAccountID
if
err
==
nil
&&
accountID
>
0
&&
!
isExcluded
(
accountID
)
{
if
accountID
>
0
&&
!
isExcluded
(
accountID
)
{
account
,
ok
:=
accountByID
[
accountID
]
account
,
ok
:=
accountByID
[
accountID
]
if
ok
{
if
ok
{
// 检查账户是否需要清理粘性会话绑定
// 检查账户是否需要清理粘性会话绑定
...
@@ -1760,6 +1861,129 @@ func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID in
...
@@ -1760,6 +1861,129 @@ func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID in
return
s
.
concurrencyService
.
AcquireAccountSlot
(
ctx
,
accountID
,
maxConcurrency
)
return
s
.
concurrencyService
.
AcquireAccountSlot
(
ctx
,
accountID
,
maxConcurrency
)
}
}
type
usageLogWindowStatsBatchProvider
interface
{
GetAccountWindowStatsBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
,
startTime
time
.
Time
)
(
map
[
int64
]
*
usagestats
.
AccountStats
,
error
)
}
type
windowCostPrefetchContextKeyType
struct
{}
var
windowCostPrefetchContextKey
=
windowCostPrefetchContextKeyType
{}
func
windowCostFromPrefetchContext
(
ctx
context
.
Context
,
accountID
int64
)
(
float64
,
bool
)
{
if
ctx
==
nil
||
accountID
<=
0
{
return
0
,
false
}
m
,
ok
:=
ctx
.
Value
(
windowCostPrefetchContextKey
)
.
(
map
[
int64
]
float64
)
if
!
ok
||
len
(
m
)
==
0
{
return
0
,
false
}
v
,
exists
:=
m
[
accountID
]
return
v
,
exists
}
func
(
s
*
GatewayService
)
withWindowCostPrefetch
(
ctx
context
.
Context
,
accounts
[]
Account
)
context
.
Context
{
if
ctx
==
nil
||
len
(
accounts
)
==
0
||
s
.
sessionLimitCache
==
nil
||
s
.
usageLogRepo
==
nil
{
return
ctx
}
accountByID
:=
make
(
map
[
int64
]
*
Account
)
accountIDs
:=
make
([]
int64
,
0
,
len
(
accounts
))
for
i
:=
range
accounts
{
account
:=
&
accounts
[
i
]
if
account
==
nil
||
!
account
.
IsAnthropicOAuthOrSetupToken
()
{
continue
}
if
account
.
GetWindowCostLimit
()
<=
0
{
continue
}
accountByID
[
account
.
ID
]
=
account
accountIDs
=
append
(
accountIDs
,
account
.
ID
)
}
if
len
(
accountIDs
)
==
0
{
return
ctx
}
costs
:=
make
(
map
[
int64
]
float64
,
len
(
accountIDs
))
cacheValues
,
err
:=
s
.
sessionLimitCache
.
GetWindowCostBatch
(
ctx
,
accountIDs
)
if
err
==
nil
{
for
accountID
,
cost
:=
range
cacheValues
{
costs
[
accountID
]
=
cost
}
windowCostPrefetchCacheHitTotal
.
Add
(
int64
(
len
(
cacheValues
)))
}
else
{
windowCostPrefetchErrorTotal
.
Add
(
1
)
logger
.
LegacyPrintf
(
"service.gateway"
,
"window_cost batch cache read failed: %v"
,
err
)
}
cacheMissCount
:=
len
(
accountIDs
)
-
len
(
costs
)
if
cacheMissCount
<
0
{
cacheMissCount
=
0
}
windowCostPrefetchCacheMissTotal
.
Add
(
int64
(
cacheMissCount
))
missingByStart
:=
make
(
map
[
int64
][]
int64
)
startTimes
:=
make
(
map
[
int64
]
time
.
Time
)
for
_
,
accountID
:=
range
accountIDs
{
if
_
,
ok
:=
costs
[
accountID
];
ok
{
continue
}
account
:=
accountByID
[
accountID
]
if
account
==
nil
{
continue
}
startTime
:=
account
.
GetCurrentWindowStartTime
()
startKey
:=
startTime
.
Unix
()
missingByStart
[
startKey
]
=
append
(
missingByStart
[
startKey
],
accountID
)
startTimes
[
startKey
]
=
startTime
}
if
len
(
missingByStart
)
==
0
{
return
context
.
WithValue
(
ctx
,
windowCostPrefetchContextKey
,
costs
)
}
batchReader
,
hasBatch
:=
s
.
usageLogRepo
.
(
usageLogWindowStatsBatchProvider
)
for
startKey
,
ids
:=
range
missingByStart
{
startTime
:=
startTimes
[
startKey
]
if
hasBatch
{
windowCostPrefetchBatchSQLTotal
.
Add
(
1
)
queryStart
:=
time
.
Now
()
statsByAccount
,
err
:=
batchReader
.
GetAccountWindowStatsBatch
(
ctx
,
ids
,
startTime
)
if
err
==
nil
{
slog
.
Debug
(
"window_cost_batch_query_ok"
,
"accounts"
,
len
(
ids
),
"window_start"
,
startTime
.
Format
(
time
.
RFC3339
),
"duration_ms"
,
time
.
Since
(
queryStart
)
.
Milliseconds
())
for
_
,
accountID
:=
range
ids
{
stats
:=
statsByAccount
[
accountID
]
cost
:=
0.0
if
stats
!=
nil
{
cost
=
stats
.
StandardCost
}
costs
[
accountID
]
=
cost
_
=
s
.
sessionLimitCache
.
SetWindowCost
(
ctx
,
accountID
,
cost
)
}
continue
}
windowCostPrefetchErrorTotal
.
Add
(
1
)
logger
.
LegacyPrintf
(
"service.gateway"
,
"window_cost batch db query failed: start=%s err=%v"
,
startTime
.
Format
(
time
.
RFC3339
),
err
)
}
// 回退路径:缺少批量仓储能力或批量查询失败时,按账号单查(失败开放)。
windowCostPrefetchFallbackTotal
.
Add
(
int64
(
len
(
ids
)))
for
_
,
accountID
:=
range
ids
{
stats
,
err
:=
s
.
usageLogRepo
.
GetAccountWindowStats
(
ctx
,
accountID
,
startTime
)
if
err
!=
nil
{
windowCostPrefetchErrorTotal
.
Add
(
1
)
continue
}
cost
:=
stats
.
StandardCost
costs
[
accountID
]
=
cost
_
=
s
.
sessionLimitCache
.
SetWindowCost
(
ctx
,
accountID
,
cost
)
}
}
return
context
.
WithValue
(
ctx
,
windowCostPrefetchContextKey
,
costs
)
}
// isAccountSchedulableForWindowCost 检查账号是否可根据窗口费用进行调度
// isAccountSchedulableForWindowCost 检查账号是否可根据窗口费用进行调度
// 仅适用于 Anthropic OAuth/SetupToken 账号
// 仅适用于 Anthropic OAuth/SetupToken 账号
// 返回 true 表示可调度,false 表示不可调度
// 返回 true 表示可调度,false 表示不可调度
...
@@ -1776,6 +2000,10 @@ func (s *GatewayService) isAccountSchedulableForWindowCost(ctx context.Context,
...
@@ -1776,6 +2000,10 @@ func (s *GatewayService) isAccountSchedulableForWindowCost(ctx context.Context,
// 尝试从缓存获取窗口费用
// 尝试从缓存获取窗口费用
var
currentCost
float64
var
currentCost
float64
if
cost
,
ok
:=
windowCostFromPrefetchContext
(
ctx
,
account
.
ID
);
ok
{
currentCost
=
cost
goto
checkSchedulability
}
if
s
.
sessionLimitCache
!=
nil
{
if
s
.
sessionLimitCache
!=
nil
{
if
cost
,
hit
,
err
:=
s
.
sessionLimitCache
.
GetWindowCost
(
ctx
,
account
.
ID
);
err
==
nil
&&
hit
{
if
cost
,
hit
,
err
:=
s
.
sessionLimitCache
.
GetWindowCost
(
ctx
,
account
.
ID
);
err
==
nil
&&
hit
{
currentCost
=
cost
currentCost
=
cost
...
@@ -5264,6 +5492,66 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
...
@@ -5264,6 +5492,66 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
return
body
return
body
}
}
func
(
s
*
GatewayService
)
getUserGroupRateMultiplier
(
ctx
context
.
Context
,
userID
,
groupID
int64
,
groupDefaultMultiplier
float64
)
float64
{
if
s
==
nil
||
userID
<=
0
||
groupID
<=
0
{
return
groupDefaultMultiplier
}
key
:=
fmt
.
Sprintf
(
"%d:%d"
,
userID
,
groupID
)
if
s
.
userGroupRateCache
!=
nil
{
if
cached
,
ok
:=
s
.
userGroupRateCache
.
Get
(
key
);
ok
{
if
multiplier
,
castOK
:=
cached
.
(
float64
);
castOK
{
userGroupRateCacheHitTotal
.
Add
(
1
)
return
multiplier
}
}
}
if
s
.
userGroupRateRepo
==
nil
{
return
groupDefaultMultiplier
}
userGroupRateCacheMissTotal
.
Add
(
1
)
value
,
err
,
shared
:=
s
.
userGroupRateSF
.
Do
(
key
,
func
()
(
any
,
error
)
{
if
s
.
userGroupRateCache
!=
nil
{
if
cached
,
ok
:=
s
.
userGroupRateCache
.
Get
(
key
);
ok
{
if
multiplier
,
castOK
:=
cached
.
(
float64
);
castOK
{
userGroupRateCacheHitTotal
.
Add
(
1
)
return
multiplier
,
nil
}
}
}
userGroupRateCacheLoadTotal
.
Add
(
1
)
userRate
,
repoErr
:=
s
.
userGroupRateRepo
.
GetByUserAndGroup
(
ctx
,
userID
,
groupID
)
if
repoErr
!=
nil
{
return
nil
,
repoErr
}
multiplier
:=
groupDefaultMultiplier
if
userRate
!=
nil
{
multiplier
=
*
userRate
}
if
s
.
userGroupRateCache
!=
nil
{
s
.
userGroupRateCache
.
Set
(
key
,
multiplier
,
resolveUserGroupRateCacheTTL
(
s
.
cfg
))
}
return
multiplier
,
nil
})
if
shared
{
userGroupRateCacheSFSharedTotal
.
Add
(
1
)
}
if
err
!=
nil
{
userGroupRateCacheFallbackTotal
.
Add
(
1
)
logger
.
LegacyPrintf
(
"service.gateway"
,
"get user group rate failed, fallback to group default: user=%d group=%d err=%v"
,
userID
,
groupID
,
err
)
return
groupDefaultMultiplier
}
multiplier
,
ok
:=
value
.
(
float64
)
if
!
ok
{
userGroupRateCacheFallbackTotal
.
Add
(
1
)
return
groupDefaultMultiplier
}
return
multiplier
}
// RecordUsageInput 记录使用量的输入参数
// RecordUsageInput 记录使用量的输入参数
type
RecordUsageInput
struct
{
type
RecordUsageInput
struct
{
Result
*
ForwardResult
Result
*
ForwardResult
...
@@ -5307,16 +5595,13 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
...
@@ -5307,16 +5595,13 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
}
}
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
multiplier
:=
s
.
cfg
.
Default
.
RateMultiplier
multiplier
:=
1.0
if
s
.
cfg
!=
nil
{
multiplier
=
s
.
cfg
.
Default
.
RateMultiplier
}
if
apiKey
.
GroupID
!=
nil
&&
apiKey
.
Group
!=
nil
{
if
apiKey
.
GroupID
!=
nil
&&
apiKey
.
Group
!=
nil
{
multiplier
=
apiKey
.
Group
.
RateMultiplier
groupDefault
:=
apiKey
.
Group
.
RateMultiplier
multiplier
=
s
.
getUserGroupRateMultiplier
(
ctx
,
user
.
ID
,
*
apiKey
.
GroupID
,
groupDefault
)
// 检查用户专属倍率
if
s
.
userGroupRateRepo
!=
nil
{
if
userRate
,
err
:=
s
.
userGroupRateRepo
.
GetByUserAndGroup
(
ctx
,
user
.
ID
,
*
apiKey
.
GroupID
);
err
==
nil
&&
userRate
!=
nil
{
multiplier
=
*
userRate
}
}
}
}
var
cost
*
CostBreakdown
var
cost
*
CostBreakdown
...
@@ -5522,16 +5807,13 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
...
@@ -5522,16 +5807,13 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
}
}
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
multiplier
:=
s
.
cfg
.
Default
.
RateMultiplier
multiplier
:=
1.0
if
s
.
cfg
!=
nil
{
multiplier
=
s
.
cfg
.
Default
.
RateMultiplier
}
if
apiKey
.
GroupID
!=
nil
&&
apiKey
.
Group
!=
nil
{
if
apiKey
.
GroupID
!=
nil
&&
apiKey
.
Group
!=
nil
{
multiplier
=
apiKey
.
Group
.
RateMultiplier
groupDefault
:=
apiKey
.
Group
.
RateMultiplier
multiplier
=
s
.
getUserGroupRateMultiplier
(
ctx
,
user
.
ID
,
*
apiKey
.
GroupID
,
groupDefault
)
// 检查用户专属倍率
if
s
.
userGroupRateRepo
!=
nil
{
if
userRate
,
err
:=
s
.
userGroupRateRepo
.
GetByUserAndGroup
(
ctx
,
user
.
ID
,
*
apiKey
.
GroupID
);
err
==
nil
&&
userRate
!=
nil
{
multiplier
=
*
userRate
}
}
}
}
var
cost
*
CostBreakdown
var
cost
*
CostBreakdown
...
@@ -6145,6 +6427,17 @@ func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
...
@@ -6145,6 +6427,17 @@ func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
// GetAvailableModels returns the list of models available for a group
// GetAvailableModels returns the list of models available for a group
// It aggregates model_mapping keys from all schedulable accounts in the group
// It aggregates model_mapping keys from all schedulable accounts in the group
func
(
s
*
GatewayService
)
GetAvailableModels
(
ctx
context
.
Context
,
groupID
*
int64
,
platform
string
)
[]
string
{
func
(
s
*
GatewayService
)
GetAvailableModels
(
ctx
context
.
Context
,
groupID
*
int64
,
platform
string
)
[]
string
{
cacheKey
:=
modelsListCacheKey
(
groupID
,
platform
)
if
s
.
modelsListCache
!=
nil
{
if
cached
,
found
:=
s
.
modelsListCache
.
Get
(
cacheKey
);
found
{
if
models
,
ok
:=
cached
.
([]
string
);
ok
{
modelsListCacheHitTotal
.
Add
(
1
)
return
cloneStringSlice
(
models
)
}
}
}
modelsListCacheMissTotal
.
Add
(
1
)
var
accounts
[]
Account
var
accounts
[]
Account
var
err
error
var
err
error
...
@@ -6185,6 +6478,10 @@ func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64,
...
@@ -6185,6 +6478,10 @@ func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64,
// If no account has model_mapping, return nil (use default)
// If no account has model_mapping, return nil (use default)
if
!
hasAnyMapping
{
if
!
hasAnyMapping
{
if
s
.
modelsListCache
!=
nil
{
s
.
modelsListCache
.
Set
(
cacheKey
,
[]
string
(
nil
),
s
.
modelsListCacheTTL
)
modelsListCacheStoreTotal
.
Add
(
1
)
}
return
nil
return
nil
}
}
...
@@ -6193,8 +6490,45 @@ func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64,
...
@@ -6193,8 +6490,45 @@ func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64,
for
model
:=
range
modelSet
{
for
model
:=
range
modelSet
{
models
=
append
(
models
,
model
)
models
=
append
(
models
,
model
)
}
}
sort
.
Strings
(
models
)
return
models
if
s
.
modelsListCache
!=
nil
{
s
.
modelsListCache
.
Set
(
cacheKey
,
cloneStringSlice
(
models
),
s
.
modelsListCacheTTL
)
modelsListCacheStoreTotal
.
Add
(
1
)
}
return
cloneStringSlice
(
models
)
}
func
(
s
*
GatewayService
)
InvalidateAvailableModelsCache
(
groupID
*
int64
,
platform
string
)
{
if
s
==
nil
||
s
.
modelsListCache
==
nil
{
return
}
normalizedPlatform
:=
strings
.
TrimSpace
(
platform
)
// 完整匹配时精准失效;否则按维度批量失效。
if
groupID
!=
nil
&&
normalizedPlatform
!=
""
{
s
.
modelsListCache
.
Delete
(
modelsListCacheKey
(
groupID
,
normalizedPlatform
))
return
}
targetGroup
:=
derefGroupID
(
groupID
)
for
key
:=
range
s
.
modelsListCache
.
Items
()
{
parts
:=
strings
.
SplitN
(
key
,
"|"
,
2
)
if
len
(
parts
)
!=
2
{
continue
}
groupPart
,
parseErr
:=
strconv
.
ParseInt
(
parts
[
0
],
10
,
64
)
if
parseErr
!=
nil
{
continue
}
if
groupID
!=
nil
&&
groupPart
!=
targetGroup
{
continue
}
if
normalizedPlatform
!=
""
&&
parts
[
1
]
!=
normalizedPlatform
{
continue
}
s
.
modelsListCache
.
Delete
(
key
)
}
}
}
// reconcileCachedTokens 兼容 Kimi 等上游:
// reconcileCachedTokens 兼容 Kimi 等上游:
...
...
backend/internal/service/ops_service.go
View file @
a89477dd
...
@@ -20,6 +20,22 @@ const (
...
@@ -20,6 +20,22 @@ const (
opsMaxStoredErrorBodyBytes
=
20
*
1024
opsMaxStoredErrorBodyBytes
=
20
*
1024
)
)
// PrepareOpsRequestBodyForQueue 在入队前对请求体执行脱敏与裁剪,返回可直接写入 OpsInsertErrorLogInput 的字段。
// 该方法用于避免异步队列持有大块原始请求体,减少错误风暴下的内存放大风险。
func
PrepareOpsRequestBodyForQueue
(
raw
[]
byte
)
(
requestBodyJSON
*
string
,
truncated
bool
,
requestBodyBytes
*
int
)
{
if
len
(
raw
)
==
0
{
return
nil
,
false
,
nil
}
sanitized
,
truncated
,
bytesLen
:=
sanitizeAndTrimRequestBody
(
raw
,
opsMaxStoredRequestBodyBytes
)
if
sanitized
!=
""
{
out
:=
sanitized
requestBodyJSON
=
&
out
}
n
:=
bytesLen
requestBodyBytes
=
&
n
return
requestBodyJSON
,
truncated
,
requestBodyBytes
}
// OpsService provides ingestion and query APIs for the Ops monitoring module.
// OpsService provides ingestion and query APIs for the Ops monitoring module.
type
OpsService
struct
{
type
OpsService
struct
{
opsRepo
OpsRepository
opsRepo
OpsRepository
...
@@ -132,12 +148,7 @@ func (s *OpsService) RecordError(ctx context.Context, entry *OpsInsertErrorLogIn
...
@@ -132,12 +148,7 @@ func (s *OpsService) RecordError(ctx context.Context, entry *OpsInsertErrorLogIn
// Sanitize + trim request body (errors only).
// Sanitize + trim request body (errors only).
if
len
(
rawRequestBody
)
>
0
{
if
len
(
rawRequestBody
)
>
0
{
sanitized
,
truncated
,
bytesLen
:=
sanitizeAndTrimRequestBody
(
rawRequestBody
,
opsMaxStoredRequestBodyBytes
)
entry
.
RequestBodyJSON
,
entry
.
RequestBodyTruncated
,
entry
.
RequestBodyBytes
=
PrepareOpsRequestBodyForQueue
(
rawRequestBody
)
if
sanitized
!=
""
{
entry
.
RequestBodyJSON
=
&
sanitized
}
entry
.
RequestBodyTruncated
=
truncated
entry
.
RequestBodyBytes
=
&
bytesLen
}
}
// Sanitize + truncate error_body to avoid storing sensitive data.
// Sanitize + truncate error_body to avoid storing sensitive data.
...
...
backend/internal/service/ops_service_prepare_queue_test.go
0 → 100644
View file @
a89477dd
package
service
import
(
"encoding/json"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func
TestPrepareOpsRequestBodyForQueue_EmptyBody
(
t
*
testing
.
T
)
{
requestBodyJSON
,
truncated
,
requestBodyBytes
:=
PrepareOpsRequestBodyForQueue
(
nil
)
require
.
Nil
(
t
,
requestBodyJSON
)
require
.
False
(
t
,
truncated
)
require
.
Nil
(
t
,
requestBodyBytes
)
}
func
TestPrepareOpsRequestBodyForQueue_InvalidJSON
(
t
*
testing
.
T
)
{
raw
:=
[]
byte
(
"{invalid-json"
)
requestBodyJSON
,
truncated
,
requestBodyBytes
:=
PrepareOpsRequestBodyForQueue
(
raw
)
require
.
Nil
(
t
,
requestBodyJSON
)
require
.
False
(
t
,
truncated
)
require
.
NotNil
(
t
,
requestBodyBytes
)
require
.
Equal
(
t
,
len
(
raw
),
*
requestBodyBytes
)
}
func
TestPrepareOpsRequestBodyForQueue_RedactSensitiveFields
(
t
*
testing
.
T
)
{
raw
:=
[]
byte
(
`{
"model":"claude-3-5-sonnet-20241022",
"api_key":"sk-test-123",
"headers":{"authorization":"Bearer secret-token"},
"messages":[{"role":"user","content":"hello"}]
}`
)
requestBodyJSON
,
truncated
,
requestBodyBytes
:=
PrepareOpsRequestBodyForQueue
(
raw
)
require
.
NotNil
(
t
,
requestBodyJSON
)
require
.
NotNil
(
t
,
requestBodyBytes
)
require
.
False
(
t
,
truncated
)
require
.
Equal
(
t
,
len
(
raw
),
*
requestBodyBytes
)
var
body
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
([]
byte
(
*
requestBodyJSON
),
&
body
))
require
.
Equal
(
t
,
"[REDACTED]"
,
body
[
"api_key"
])
headers
,
ok
:=
body
[
"headers"
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"[REDACTED]"
,
headers
[
"authorization"
])
}
func
TestPrepareOpsRequestBodyForQueue_LargeBodyTruncated
(
t
*
testing
.
T
)
{
largeMsg
:=
strings
.
Repeat
(
"x"
,
opsMaxStoredRequestBodyBytes
*
2
)
raw
:=
[]
byte
(
`{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":"`
+
largeMsg
+
`"}]}`
)
requestBodyJSON
,
truncated
,
requestBodyBytes
:=
PrepareOpsRequestBodyForQueue
(
raw
)
require
.
NotNil
(
t
,
requestBodyJSON
)
require
.
NotNil
(
t
,
requestBodyBytes
)
require
.
True
(
t
,
truncated
)
require
.
Equal
(
t
,
len
(
raw
),
*
requestBodyBytes
)
require
.
LessOrEqual
(
t
,
len
(
*
requestBodyJSON
),
opsMaxStoredRequestBodyBytes
)
require
.
Contains
(
t
,
*
requestBodyJSON
,
"request_body_truncated"
)
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment