Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
2ee6c266
Commit
2ee6c266
authored
Feb 22, 2026
by
yangjianbo
Browse files
fix(gateway): 修复粘性会话预取分组错配并优化并发等待热路径
parent
a89477dd
Changes
7
Hide whitespace changes
Inline
Side-by-side
backend/internal/handler/gateway_handler.go
View file @
2ee6c266
...
@@ -244,7 +244,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
...
@@ -244,7 +244,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if
sessionKey
!=
""
{
if
sessionKey
!=
""
{
sessionBoundAccountID
,
_
=
h
.
gatewayService
.
GetCachedSessionAccountID
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
)
sessionBoundAccountID
,
_
=
h
.
gatewayService
.
GetCachedSessionAccountID
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
)
if
sessionBoundAccountID
>
0
{
if
sessionBoundAccountID
>
0
{
prefetchedGroupID
:=
int64
(
0
)
if
apiKey
.
GroupID
!=
nil
{
prefetchedGroupID
=
*
apiKey
.
GroupID
}
ctx
:=
context
.
WithValue
(
c
.
Request
.
Context
(),
ctxkey
.
PrefetchedStickyAccountID
,
sessionBoundAccountID
)
ctx
:=
context
.
WithValue
(
c
.
Request
.
Context
(),
ctxkey
.
PrefetchedStickyAccountID
,
sessionBoundAccountID
)
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
PrefetchedStickyGroupID
,
prefetchedGroupID
)
c
.
Request
=
c
.
Request
.
WithContext
(
ctx
)
c
.
Request
=
c
.
Request
.
WithContext
(
ctx
)
}
}
}
}
...
...
backend/internal/handler/gateway_helper.go
View file @
2ee6c266
...
@@ -230,14 +230,31 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID
...
@@ -230,14 +230,31 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.
// streamStarted pointer is updated when streaming begins (for proper error handling by caller).
// streamStarted pointer is updated when streaming begins (for proper error handling by caller).
func
(
h
*
ConcurrencyHelper
)
waitForSlotWithPing
(
c
*
gin
.
Context
,
slotType
string
,
id
int64
,
maxConcurrency
int
,
isStream
bool
,
streamStarted
*
bool
)
(
func
(),
error
)
{
func
(
h
*
ConcurrencyHelper
)
waitForSlotWithPing
(
c
*
gin
.
Context
,
slotType
string
,
id
int64
,
maxConcurrency
int
,
isStream
bool
,
streamStarted
*
bool
)
(
func
(),
error
)
{
return
h
.
waitForSlotWithPingTimeout
(
c
,
slotType
,
id
,
maxConcurrency
,
maxConcurrencyWait
,
isStream
,
streamStarted
)
return
h
.
waitForSlotWithPingTimeout
(
c
,
slotType
,
id
,
maxConcurrency
,
maxConcurrencyWait
,
isStream
,
streamStarted
,
false
)
}
}
// waitForSlotWithPingTimeout waits for a concurrency slot with a custom timeout.
// waitForSlotWithPingTimeout waits for a concurrency slot with a custom timeout.
func
(
h
*
ConcurrencyHelper
)
waitForSlotWithPingTimeout
(
c
*
gin
.
Context
,
slotType
string
,
id
int64
,
maxConcurrency
int
,
timeout
time
.
Duration
,
isStream
bool
,
streamStarted
*
bool
)
(
func
(),
error
)
{
func
(
h
*
ConcurrencyHelper
)
waitForSlotWithPingTimeout
(
c
*
gin
.
Context
,
slotType
string
,
id
int64
,
maxConcurrency
int
,
timeout
time
.
Duration
,
isStream
bool
,
streamStarted
*
bool
,
tryImmediate
bool
)
(
func
(),
error
)
{
ctx
,
cancel
:=
context
.
WithTimeout
(
c
.
Request
.
Context
(),
timeout
)
ctx
,
cancel
:=
context
.
WithTimeout
(
c
.
Request
.
Context
(),
timeout
)
defer
cancel
()
defer
cancel
()
acquireSlot
:=
func
()
(
*
service
.
AcquireResult
,
error
)
{
if
slotType
==
"user"
{
return
h
.
concurrencyService
.
AcquireUserSlot
(
ctx
,
id
,
maxConcurrency
)
}
return
h
.
concurrencyService
.
AcquireAccountSlot
(
ctx
,
id
,
maxConcurrency
)
}
if
tryImmediate
{
result
,
err
:=
acquireSlot
()
if
err
!=
nil
{
return
nil
,
err
}
if
result
.
Acquired
{
return
result
.
ReleaseFunc
,
nil
}
}
// Determine if ping is needed (streaming + ping format defined)
// Determine if ping is needed (streaming + ping format defined)
needPing
:=
isStream
&&
h
.
pingFormat
!=
""
needPing
:=
isStream
&&
h
.
pingFormat
!=
""
...
@@ -286,15 +303,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
...
@@ -286,15 +303,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
case
<-
timer
.
C
:
case
<-
timer
.
C
:
// Try to acquire slot
// Try to acquire slot
var
result
*
service
.
AcquireResult
result
,
err
:=
acquireSlot
()
var
err
error
if
slotType
==
"user"
{
result
,
err
=
h
.
concurrencyService
.
AcquireUserSlot
(
ctx
,
id
,
maxConcurrency
)
}
else
{
result
,
err
=
h
.
concurrencyService
.
AcquireAccountSlot
(
ctx
,
id
,
maxConcurrency
)
}
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
...
@@ -310,7 +319,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
...
@@ -310,7 +319,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
// AcquireAccountSlotWithWaitTimeout acquires an account slot with a custom timeout (keeps SSE ping).
// AcquireAccountSlotWithWaitTimeout acquires an account slot with a custom timeout (keeps SSE ping).
func
(
h
*
ConcurrencyHelper
)
AcquireAccountSlotWithWaitTimeout
(
c
*
gin
.
Context
,
accountID
int64
,
maxConcurrency
int
,
timeout
time
.
Duration
,
isStream
bool
,
streamStarted
*
bool
)
(
func
(),
error
)
{
func
(
h
*
ConcurrencyHelper
)
AcquireAccountSlotWithWaitTimeout
(
c
*
gin
.
Context
,
accountID
int64
,
maxConcurrency
int
,
timeout
time
.
Duration
,
isStream
bool
,
streamStarted
*
bool
)
(
func
(),
error
)
{
return
h
.
waitForSlotWithPingTimeout
(
c
,
"account"
,
accountID
,
maxConcurrency
,
timeout
,
isStream
,
streamStarted
)
return
h
.
waitForSlotWithPingTimeout
(
c
,
"account"
,
accountID
,
maxConcurrency
,
timeout
,
isStream
,
streamStarted
,
true
)
}
}
// nextBackoff 计算下一次退避时间
// nextBackoff 计算下一次退避时间
...
...
backend/internal/handler/gateway_helper_hotpath_test.go
View file @
2ee6c266
...
@@ -176,7 +176,7 @@ func TestWaitForSlotWithPingTimeout_AccountAndUserAcquire(t *testing.T) {
...
@@ -176,7 +176,7 @@ func TestWaitForSlotWithPingTimeout_AccountAndUserAcquire(t *testing.T) {
t
.
Run
(
"account_slot_acquired_after_retry"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"account_slot_acquired_after_retry"
,
func
(
t
*
testing
.
T
)
{
c
,
_
:=
newHelperTestContext
(
http
.
MethodPost
,
"/v1/messages"
)
c
,
_
:=
newHelperTestContext
(
http
.
MethodPost
,
"/v1/messages"
)
streamStarted
:=
false
streamStarted
:=
false
release
,
err
:=
helper
.
waitForSlotWithPingTimeout
(
c
,
"account"
,
101
,
2
,
time
.
Second
,
false
,
&
streamStarted
)
release
,
err
:=
helper
.
waitForSlotWithPingTimeout
(
c
,
"account"
,
101
,
2
,
time
.
Second
,
false
,
&
streamStarted
,
true
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
release
)
require
.
NotNil
(
t
,
release
)
require
.
False
(
t
,
streamStarted
)
require
.
False
(
t
,
streamStarted
)
...
@@ -188,7 +188,7 @@ func TestWaitForSlotWithPingTimeout_AccountAndUserAcquire(t *testing.T) {
...
@@ -188,7 +188,7 @@ func TestWaitForSlotWithPingTimeout_AccountAndUserAcquire(t *testing.T) {
t
.
Run
(
"user_slot_acquired_after_retry"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"user_slot_acquired_after_retry"
,
func
(
t
*
testing
.
T
)
{
c
,
_
:=
newHelperTestContext
(
http
.
MethodPost
,
"/v1/messages"
)
c
,
_
:=
newHelperTestContext
(
http
.
MethodPost
,
"/v1/messages"
)
streamStarted
:=
false
streamStarted
:=
false
release
,
err
:=
helper
.
waitForSlotWithPingTimeout
(
c
,
"user"
,
202
,
3
,
time
.
Second
,
false
,
&
streamStarted
)
release
,
err
:=
helper
.
waitForSlotWithPingTimeout
(
c
,
"user"
,
202
,
3
,
time
.
Second
,
false
,
&
streamStarted
,
true
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
release
)
require
.
NotNil
(
t
,
release
)
release
()
release
()
...
@@ -207,7 +207,7 @@ func TestWaitForSlotWithPingTimeout_TimeoutAndStreamPing(t *testing.T) {
...
@@ -207,7 +207,7 @@ func TestWaitForSlotWithPingTimeout_TimeoutAndStreamPing(t *testing.T) {
helper
:=
NewConcurrencyHelper
(
concurrency
,
SSEPingFormatNone
,
5
*
time
.
Millisecond
)
helper
:=
NewConcurrencyHelper
(
concurrency
,
SSEPingFormatNone
,
5
*
time
.
Millisecond
)
c
,
_
:=
newHelperTestContext
(
http
.
MethodPost
,
"/v1/messages"
)
c
,
_
:=
newHelperTestContext
(
http
.
MethodPost
,
"/v1/messages"
)
streamStarted
:=
false
streamStarted
:=
false
release
,
err
:=
helper
.
waitForSlotWithPingTimeout
(
c
,
"account"
,
101
,
2
,
130
*
time
.
Millisecond
,
false
,
&
streamStarted
)
release
,
err
:=
helper
.
waitForSlotWithPingTimeout
(
c
,
"account"
,
101
,
2
,
130
*
time
.
Millisecond
,
false
,
&
streamStarted
,
true
)
require
.
Nil
(
t
,
release
)
require
.
Nil
(
t
,
release
)
var
cErr
*
ConcurrencyError
var
cErr
*
ConcurrencyError
require
.
ErrorAs
(
t
,
err
,
&
cErr
)
require
.
ErrorAs
(
t
,
err
,
&
cErr
)
...
@@ -218,7 +218,7 @@ func TestWaitForSlotWithPingTimeout_TimeoutAndStreamPing(t *testing.T) {
...
@@ -218,7 +218,7 @@ func TestWaitForSlotWithPingTimeout_TimeoutAndStreamPing(t *testing.T) {
helper
:=
NewConcurrencyHelper
(
concurrency
,
SSEPingFormatComment
,
10
*
time
.
Millisecond
)
helper
:=
NewConcurrencyHelper
(
concurrency
,
SSEPingFormatComment
,
10
*
time
.
Millisecond
)
c
,
rec
:=
newHelperTestContext
(
http
.
MethodPost
,
"/v1/messages"
)
c
,
rec
:=
newHelperTestContext
(
http
.
MethodPost
,
"/v1/messages"
)
streamStarted
:=
false
streamStarted
:=
false
release
,
err
:=
helper
.
waitForSlotWithPingTimeout
(
c
,
"account"
,
101
,
2
,
70
*
time
.
Millisecond
,
true
,
&
streamStarted
)
release
,
err
:=
helper
.
waitForSlotWithPingTimeout
(
c
,
"account"
,
101
,
2
,
70
*
time
.
Millisecond
,
true
,
&
streamStarted
,
true
)
require
.
Nil
(
t
,
release
)
require
.
Nil
(
t
,
release
)
var
cErr
*
ConcurrencyError
var
cErr
*
ConcurrencyError
require
.
ErrorAs
(
t
,
err
,
&
cErr
)
require
.
ErrorAs
(
t
,
err
,
&
cErr
)
...
@@ -236,12 +236,29 @@ func TestWaitForSlotWithPingTimeout_AcquireError(t *testing.T) {
...
@@ -236,12 +236,29 @@ func TestWaitForSlotWithPingTimeout_AcquireError(t *testing.T) {
helper
:=
NewConcurrencyHelper
(
concurrency
,
SSEPingFormatNone
,
5
*
time
.
Millisecond
)
helper
:=
NewConcurrencyHelper
(
concurrency
,
SSEPingFormatNone
,
5
*
time
.
Millisecond
)
c
,
_
:=
newHelperTestContext
(
http
.
MethodPost
,
"/v1/messages"
)
c
,
_
:=
newHelperTestContext
(
http
.
MethodPost
,
"/v1/messages"
)
streamStarted
:=
false
streamStarted
:=
false
release
,
err
:=
helper
.
waitForSlotWithPingTimeout
(
c
,
"account"
,
1
,
1
,
200
*
time
.
Millisecond
,
false
,
&
streamStarted
)
release
,
err
:=
helper
.
waitForSlotWithPingTimeout
(
c
,
"account"
,
1
,
1
,
200
*
time
.
Millisecond
,
false
,
&
streamStarted
,
true
)
require
.
Nil
(
t
,
release
)
require
.
Nil
(
t
,
release
)
require
.
Error
(
t
,
err
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"redis unavailable"
)
require
.
Contains
(
t
,
err
.
Error
(),
"redis unavailable"
)
}
}
func
TestAcquireAccountSlotWithWaitTimeout_ImmediateAttemptBeforeBackoff
(
t
*
testing
.
T
)
{
cache
:=
&
helperConcurrencyCacheStub
{
accountSeq
:
[]
bool
{
false
},
}
concurrency
:=
service
.
NewConcurrencyService
(
cache
)
helper
:=
NewConcurrencyHelper
(
concurrency
,
SSEPingFormatNone
,
5
*
time
.
Millisecond
)
c
,
_
:=
newHelperTestContext
(
http
.
MethodPost
,
"/v1/messages"
)
streamStarted
:=
false
release
,
err
:=
helper
.
AcquireAccountSlotWithWaitTimeout
(
c
,
301
,
1
,
30
*
time
.
Millisecond
,
false
,
&
streamStarted
)
require
.
Nil
(
t
,
release
)
var
cErr
*
ConcurrencyError
require
.
ErrorAs
(
t
,
err
,
&
cErr
)
require
.
True
(
t
,
cErr
.
IsTimeout
)
require
.
GreaterOrEqual
(
t
,
cache
.
accountAcquireCalls
,
1
)
}
type
helperConcurrencyCacheStubWithError
struct
{
type
helperConcurrencyCacheStubWithError
struct
{
helperConcurrencyCacheStub
helperConcurrencyCacheStub
err
error
err
error
...
...
backend/internal/handler/gemini_v1beta_handler.go
View file @
2ee6c266
...
@@ -264,7 +264,12 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
...
@@ -264,7 +264,12 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
if
sessionKey
!=
""
{
if
sessionKey
!=
""
{
sessionBoundAccountID
,
_
=
h
.
gatewayService
.
GetCachedSessionAccountID
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
)
sessionBoundAccountID
,
_
=
h
.
gatewayService
.
GetCachedSessionAccountID
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
)
if
sessionBoundAccountID
>
0
{
if
sessionBoundAccountID
>
0
{
prefetchedGroupID
:=
int64
(
0
)
if
apiKey
.
GroupID
!=
nil
{
prefetchedGroupID
=
*
apiKey
.
GroupID
}
ctx
:=
context
.
WithValue
(
c
.
Request
.
Context
(),
ctxkey
.
PrefetchedStickyAccountID
,
sessionBoundAccountID
)
ctx
:=
context
.
WithValue
(
c
.
Request
.
Context
(),
ctxkey
.
PrefetchedStickyAccountID
,
sessionBoundAccountID
)
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
PrefetchedStickyGroupID
,
prefetchedGroupID
)
c
.
Request
=
c
.
Request
.
WithContext
(
ctx
)
c
.
Request
=
c
.
Request
.
WithContext
(
ctx
)
}
}
}
}
...
...
backend/internal/pkg/ctxkey/ctxkey.go
View file @
2ee6c266
...
@@ -48,4 +48,8 @@ const (
...
@@ -48,4 +48,8 @@ const (
// PrefetchedStickyAccountID 标识上游(通常 handler)预取到的 sticky session 账号 ID。
// PrefetchedStickyAccountID 标识上游(通常 handler)预取到的 sticky session 账号 ID。
// Service 层可复用该值,避免同请求链路重复读取 Redis。
// Service 层可复用该值,避免同请求链路重复读取 Redis。
PrefetchedStickyAccountID
Key
=
"ctx_prefetched_sticky_account_id"
PrefetchedStickyAccountID
Key
=
"ctx_prefetched_sticky_account_id"
// PrefetchedStickyGroupID 标识上游预取 sticky session 时所使用的分组 ID。
// Service 层仅在分组匹配时复用 PrefetchedStickyAccountID,避免分组切换重试误用旧 sticky。
PrefetchedStickyGroupID
Key
=
"ctx_prefetched_sticky_group_id"
)
)
backend/internal/service/gateway_hotpath_optimization_test.go
View file @
2ee6c266
...
@@ -604,17 +604,25 @@ func TestGatewayHotpathHelpers_CacheTTLAndStickyContext(t *testing.T) {
...
@@ -604,17 +604,25 @@ func TestGatewayHotpathHelpers_CacheTTLAndStickyContext(t *testing.T) {
})
})
t
.
Run
(
"prefetched_sticky_account_id_from_context"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"prefetched_sticky_account_id_from_context"
,
func
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
int64
(
0
),
prefetchedStickyAccountIDFromContext
(
context
.
TODO
()))
require
.
Equal
(
t
,
int64
(
0
),
prefetchedStickyAccountIDFromContext
(
context
.
TODO
()
,
nil
))
require
.
Equal
(
t
,
int64
(
0
),
prefetchedStickyAccountIDFromContext
(
context
.
Background
()))
require
.
Equal
(
t
,
int64
(
0
),
prefetchedStickyAccountIDFromContext
(
context
.
Background
()
,
nil
))
ctx
:=
context
.
WithValue
(
context
.
Background
(),
ctxkey
.
PrefetchedStickyAccountID
,
int64
(
123
))
ctx
:=
context
.
WithValue
(
context
.
Background
(),
ctxkey
.
PrefetchedStickyAccountID
,
int64
(
123
))
require
.
Equal
(
t
,
int64
(
123
),
prefetchedStickyAccountIDFromContext
(
ctx
))
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
PrefetchedStickyGroupID
,
int64
(
0
))
require
.
Equal
(
t
,
int64
(
123
),
prefetchedStickyAccountIDFromContext
(
ctx
,
nil
))
groupID
:=
int64
(
9
)
ctx2
:=
context
.
WithValue
(
context
.
Background
(),
ctxkey
.
PrefetchedStickyAccountID
,
456
)
ctx2
:=
context
.
WithValue
(
context
.
Background
(),
ctxkey
.
PrefetchedStickyAccountID
,
456
)
require
.
Equal
(
t
,
int64
(
456
),
prefetchedStickyAccountIDFromContext
(
ctx2
))
ctx2
=
context
.
WithValue
(
ctx2
,
ctxkey
.
PrefetchedStickyGroupID
,
groupID
)
require
.
Equal
(
t
,
int64
(
456
),
prefetchedStickyAccountIDFromContext
(
ctx2
,
&
groupID
))
ctx3
:=
context
.
WithValue
(
context
.
Background
(),
ctxkey
.
PrefetchedStickyAccountID
,
"invalid"
)
ctx3
:=
context
.
WithValue
(
context
.
Background
(),
ctxkey
.
PrefetchedStickyAccountID
,
"invalid"
)
require
.
Equal
(
t
,
int64
(
0
),
prefetchedStickyAccountIDFromContext
(
ctx3
))
ctx3
=
context
.
WithValue
(
ctx3
,
ctxkey
.
PrefetchedStickyGroupID
,
groupID
)
require
.
Equal
(
t
,
int64
(
0
),
prefetchedStickyAccountIDFromContext
(
ctx3
,
&
groupID
))
ctx4
:=
context
.
WithValue
(
context
.
Background
(),
ctxkey
.
PrefetchedStickyAccountID
,
int64
(
789
))
ctx4
=
context
.
WithValue
(
ctx4
,
ctxkey
.
PrefetchedStickyGroupID
,
int64
(
10
))
require
.
Equal
(
t
,
int64
(
0
),
prefetchedStickyAccountIDFromContext
(
ctx4
,
&
groupID
))
})
})
t
.
Run
(
"window_cost_from_prefetch_context"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"window_cost_from_prefetch_context"
,
func
(
t
*
testing
.
T
)
{
...
@@ -745,6 +753,7 @@ func TestSelectAccountWithLoadAwareness_StickyReadReuse(t *testing.T) {
...
@@ -745,6 +753,7 @@ func TestSelectAccountWithLoadAwareness_StickyReadReuse(t *testing.T) {
}
}
ctx
:=
context
.
WithValue
(
baseCtx
,
ctxkey
.
PrefetchedStickyAccountID
,
account
.
ID
)
ctx
:=
context
.
WithValue
(
baseCtx
,
ctxkey
.
PrefetchedStickyAccountID
,
account
.
ID
)
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
PrefetchedStickyGroupID
,
int64
(
0
))
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
"sess-hash"
,
""
,
nil
,
""
)
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
"sess-hash"
,
""
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
)
...
@@ -752,4 +761,26 @@ func TestSelectAccountWithLoadAwareness_StickyReadReuse(t *testing.T) {
...
@@ -752,4 +761,26 @@ func TestSelectAccountWithLoadAwareness_StickyReadReuse(t *testing.T) {
require
.
Equal
(
t
,
account
.
ID
,
result
.
Account
.
ID
)
require
.
Equal
(
t
,
account
.
ID
,
result
.
Account
.
ID
)
require
.
Equal
(
t
,
int64
(
0
),
cache
.
getCalls
.
Load
())
require
.
Equal
(
t
,
int64
(
0
),
cache
.
getCalls
.
Load
())
})
})
t
.
Run
(
"with_prefetch_group_mismatch_reads_cache"
,
func
(
t
*
testing
.
T
)
{
cache
:=
&
stickyGatewayCacheHotpathStub
{
stickyID
:
account
.
ID
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
concurrency
,
userGroupRateCache
:
gocache
.
New
(
time
.
Minute
,
time
.
Minute
),
modelsListCache
:
gocache
.
New
(
time
.
Minute
,
time
.
Minute
),
modelsListCacheTTL
:
time
.
Minute
,
}
ctx
:=
context
.
WithValue
(
baseCtx
,
ctxkey
.
PrefetchedStickyAccountID
,
int64
(
999
))
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
PrefetchedStickyGroupID
,
int64
(
77
))
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
"sess-hash"
,
""
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
account
.
ID
,
result
.
Account
.
ID
)
require
.
Equal
(
t
,
int64
(
1
),
cache
.
getCalls
.
Load
())
})
}
}
backend/internal/service/gateway_service.go
View file @
2ee6c266
...
@@ -373,8 +373,26 @@ func modelsListCacheKey(groupID *int64, platform string) string {
...
@@ -373,8 +373,26 @@ func modelsListCacheKey(groupID *int64, platform string) string {
return
fmt
.
Sprintf
(
"%d|%s"
,
derefGroupID
(
groupID
),
strings
.
TrimSpace
(
platform
))
return
fmt
.
Sprintf
(
"%d|%s"
,
derefGroupID
(
groupID
),
strings
.
TrimSpace
(
platform
))
}
}
func
prefetchedSticky
Account
IDFromContext
(
ctx
context
.
Context
)
int64
{
func
prefetchedSticky
Group
IDFromContext
(
ctx
context
.
Context
)
(
int64
,
bool
)
{
if
ctx
==
nil
{
if
ctx
==
nil
{
return
0
,
false
}
v
:=
ctx
.
Value
(
ctxkey
.
PrefetchedStickyGroupID
)
switch
t
:=
v
.
(
type
)
{
case
int64
:
return
t
,
true
case
int
:
return
int64
(
t
),
true
}
return
0
,
false
}
func
prefetchedStickyAccountIDFromContext
(
ctx
context
.
Context
,
groupID
*
int64
)
int64
{
if
ctx
==
nil
{
return
0
}
prefetchedGroupID
,
ok
:=
prefetchedStickyGroupIDFromContext
(
ctx
)
if
!
ok
||
prefetchedGroupID
!=
derefGroupID
(
groupID
)
{
return
0
return
0
}
}
v
:=
ctx
.
Value
(
ctxkey
.
PrefetchedStickyAccountID
)
v
:=
ctx
.
Value
(
ctxkey
.
PrefetchedStickyAccountID
)
...
@@ -1035,8 +1053,15 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
...
@@ -1035,8 +1053,15 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
cfg
:=
s
.
schedulingConfig
()
cfg
:=
s
.
schedulingConfig
()
// 检查 Claude Code 客户端限制(可能会替换 groupID 为降级分组)
group
,
groupID
,
err
:=
s
.
checkClaudeCodeRestriction
(
ctx
,
groupID
)
if
err
!=
nil
{
return
nil
,
err
}
ctx
=
s
.
withGroupContext
(
ctx
,
group
)
var
stickyAccountID
int64
var
stickyAccountID
int64
if
prefetch
:=
prefetchedStickyAccountIDFromContext
(
ctx
);
prefetch
>
0
{
if
prefetch
:=
prefetchedStickyAccountIDFromContext
(
ctx
,
groupID
);
prefetch
>
0
{
stickyAccountID
=
prefetch
stickyAccountID
=
prefetch
}
else
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
}
else
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
if
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
);
err
==
nil
{
if
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
);
err
==
nil
{
...
@@ -1044,13 +1069,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
...
@@ -1044,13 +1069,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
}
}
}
// 检查 Claude Code 客户端限制(可能会替换 groupID 为降级分组)
group
,
groupID
,
err
:=
s
.
checkClaudeCodeRestriction
(
ctx
,
groupID
)
if
err
!=
nil
{
return
nil
,
err
}
ctx
=
s
.
withGroupContext
(
ctx
,
group
)
if
s
.
debugModelRoutingEnabled
()
&&
requestedModel
!=
""
{
if
s
.
debugModelRoutingEnabled
()
&&
requestedModel
!=
""
{
groupPlatform
:=
""
groupPlatform
:=
""
if
group
!=
nil
{
if
group
!=
nil
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment