Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
2597fe78
Commit
2597fe78
authored
Jan 10, 2026
by
yangjianbo
Browse files
fix(分组): 防止降级环并校验上下文分组
- 增加降级链路环检测并拦截配置 - 仅复用合法分组上下文并必要时回退查询 - 标注 GetByIDLite 轻量语义并补充测试
parent
67554324
Changes
8
Hide whitespace changes
Inline
Side-by-side
backend/internal/repository/group_repo.go
View file @
2597fe78
...
@@ -70,6 +70,7 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group
...
@@ -70,6 +70,7 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group
}
}
func
(
r
*
groupRepository
)
GetByIDLite
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
Group
,
error
)
{
func
(
r
*
groupRepository
)
GetByIDLite
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
Group
,
error
)
{
// AccountCount is intentionally not loaded here; use GetByID when needed.
m
,
err
:=
r
.
client
.
Group
.
Query
()
.
m
,
err
:=
r
.
client
.
Group
.
Query
()
.
Where
(
group
.
IDEQ
(
id
))
.
Where
(
group
.
IDEQ
(
id
))
.
Only
(
ctx
)
Only
(
ctx
)
...
...
backend/internal/server/middleware/api_key_auth.go
View file @
2597fe78
...
@@ -179,7 +179,7 @@ func GetSubscriptionFromContext(c *gin.Context) (*service.UserSubscription, bool
...
@@ -179,7 +179,7 @@ func GetSubscriptionFromContext(c *gin.Context) (*service.UserSubscription, bool
}
}
func
setGroupContext
(
c
*
gin
.
Context
,
group
*
service
.
Group
)
{
func
setGroupContext
(
c
*
gin
.
Context
,
group
*
service
.
Group
)
{
if
group
==
nil
{
if
!
service
.
IsGroupContextValid
(
group
)
{
return
return
}
}
if
existing
,
ok
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
Group
)
.
(
*
service
.
Group
);
ok
&&
existing
!=
nil
&&
existing
.
ID
==
group
.
ID
{
if
existing
,
ok
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
Group
)
.
(
*
service
.
Group
);
ok
&&
existing
!=
nil
&&
existing
.
ID
==
group
.
ID
{
...
...
backend/internal/service/admin_service.go
View file @
2597fe78
...
@@ -575,18 +575,33 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro
...
@@ -575,18 +575,33 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro
return
fmt
.
Errorf
(
"cannot set self as fallback group"
)
return
fmt
.
Errorf
(
"cannot set self as fallback group"
)
}
}
// 检查降级分组是否存在
visited
:=
map
[
int64
]
struct
{}{}
fallbackGroup
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
fallbackGroupID
)
nextID
:=
fallbackGroupID
if
err
!=
nil
{
for
{
return
fmt
.
Errorf
(
"fallback group not found: %w"
,
err
)
if
_
,
seen
:=
visited
[
nextID
];
seen
{
}
return
fmt
.
Errorf
(
"fallback group cycle detected"
)
}
visited
[
nextID
]
=
struct
{}{}
if
currentGroupID
>
0
&&
nextID
==
currentGroupID
{
return
fmt
.
Errorf
(
"fallback group cycle detected"
)
}
// 降级分组不能启用 claude_code_only,否则会造成死循环
// 检查降级分组是否存在
if
fallbackGroup
.
ClaudeCodeOnly
{
fallbackGroup
,
err
:=
s
.
groupRepo
.
GetByIDLite
(
ctx
,
nextID
)
return
fmt
.
Errorf
(
"fallback group cannot have claude_code_only enabled"
)
if
err
!=
nil
{
}
return
fmt
.
Errorf
(
"fallback group not found: %w"
,
err
)
}
return
nil
// 降级分组不能启用 claude_code_only,否则会造成死循环
if
nextID
==
fallbackGroupID
&&
fallbackGroup
.
ClaudeCodeOnly
{
return
fmt
.
Errorf
(
"fallback group cannot have claude_code_only enabled"
)
}
if
fallbackGroup
.
FallbackGroupID
==
nil
{
return
nil
}
nextID
=
*
fallbackGroup
.
FallbackGroupID
}
}
}
func
(
s
*
adminServiceImpl
)
UpdateGroup
(
ctx
context
.
Context
,
id
int64
,
input
*
UpdateGroupInput
)
(
*
Group
,
error
)
{
func
(
s
*
adminServiceImpl
)
UpdateGroup
(
ctx
context
.
Context
,
id
int64
,
input
*
UpdateGroupInput
)
(
*
Group
,
error
)
{
...
...
backend/internal/service/admin_service_group_test.go
View file @
2597fe78
...
@@ -202,3 +202,84 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
...
@@ -202,3 +202,84 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
require
.
InDelta
(
t
,
0.15
,
*
repo
.
updated
.
ImagePrice2K
,
0.0001
)
// 原值保持
require
.
InDelta
(
t
,
0.15
,
*
repo
.
updated
.
ImagePrice2K
,
0.0001
)
// 原值保持
require
.
Nil
(
t
,
repo
.
updated
.
ImagePrice4K
)
require
.
Nil
(
t
,
repo
.
updated
.
ImagePrice4K
)
}
}
func
TestAdminService_ValidateFallbackGroup_DetectsCycle
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
1
)
fallbackID
:=
int64
(
2
)
repo
:=
&
groupRepoStubForFallbackCycle
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
{
ID
:
groupID
,
FallbackGroupID
:
&
fallbackID
,
},
fallbackID
:
{
ID
:
fallbackID
,
FallbackGroupID
:
&
groupID
,
},
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
err
:=
svc
.
validateFallbackGroup
(
context
.
Background
(),
groupID
,
fallbackID
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"fallback group cycle"
)
}
type
groupRepoStubForFallbackCycle
struct
{
groups
map
[
int64
]
*
Group
}
func
(
s
*
groupRepoStubForFallbackCycle
)
Create
(
_
context
.
Context
,
_
*
Group
)
error
{
panic
(
"unexpected Create call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
Update
(
_
context
.
Context
,
_
*
Group
)
error
{
panic
(
"unexpected Update call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
return
s
.
GetByIDLite
(
ctx
,
id
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
GetByIDLite
(
_
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
if
g
,
ok
:=
s
.
groups
[
id
];
ok
{
return
g
,
nil
}
return
nil
,
ErrGroupNotFound
}
func
(
s
*
groupRepoStubForFallbackCycle
)
Delete
(
_
context
.
Context
,
_
int64
)
error
{
panic
(
"unexpected Delete call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
DeleteCascade
(
_
context
.
Context
,
_
int64
)
([]
int64
,
error
)
{
panic
(
"unexpected DeleteCascade call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
List
(
_
context
.
Context
,
_
pagination
.
PaginationParams
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
panic
(
"unexpected List call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
ListWithFilters
(
_
context
.
Context
,
_
pagination
.
PaginationParams
,
_
,
_
string
,
_
*
bool
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
panic
(
"unexpected ListWithFilters call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
ListActive
(
_
context
.
Context
)
([]
Group
,
error
)
{
panic
(
"unexpected ListActive call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
ListActiveByPlatform
(
_
context
.
Context
,
_
string
)
([]
Group
,
error
)
{
panic
(
"unexpected ListActiveByPlatform call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
ExistsByName
(
_
context
.
Context
,
_
string
)
(
bool
,
error
)
{
panic
(
"unexpected ExistsByName call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
GetAccountCount
(
_
context
.
Context
,
_
int64
)
(
int64
,
error
)
{
panic
(
"unexpected GetAccountCount call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
DeleteAccountGroupsByGroupID
(
_
context
.
Context
,
_
int64
)
(
int64
,
error
)
{
panic
(
"unexpected DeleteAccountGroupsByGroupID call"
)
}
backend/internal/service/gateway_multiplatform_test.go
View file @
2597fe78
...
@@ -1102,6 +1102,47 @@ func TestGatewayService_GroupResolution_ReusesContextGroup(t *testing.T) {
...
@@ -1102,6 +1102,47 @@ func TestGatewayService_GroupResolution_ReusesContextGroup(t *testing.T) {
require
.
Equal
(
t
,
0
,
groupRepo
.
getByIDLiteCalls
)
require
.
Equal
(
t
,
0
,
groupRepo
.
getByIDLiteCalls
)
}
}
func
TestGatewayService_GroupResolution_IgnoresInvalidContextGroup
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
42
)
ctxGroup
:=
&
Group
{
ID
:
groupID
,
Status
:
StatusActive
,
}
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
Group
,
ctxGroup
)
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
group
:=
&
Group
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
}
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
group
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cfg
:
testConfig
(),
}
account
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
&
groupID
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
account
)
require
.
Equal
(
t
,
0
,
groupRepo
.
getByIDCalls
)
require
.
Equal
(
t
,
1
,
groupRepo
.
getByIDLiteCalls
)
}
func
TestGatewayService_GroupResolution_FallbackUsesLiteOnce
(
t
*
testing
.
T
)
{
func
TestGatewayService_GroupResolution_FallbackUsesLiteOnce
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
ctx
:=
context
.
Background
()
groupID
:=
int64
(
10
)
groupID
:=
int64
(
10
)
...
@@ -1146,3 +1187,41 @@ func TestGatewayService_GroupResolution_FallbackUsesLiteOnce(t *testing.T) {
...
@@ -1146,3 +1187,41 @@ func TestGatewayService_GroupResolution_FallbackUsesLiteOnce(t *testing.T) {
require
.
Equal
(
t
,
0
,
groupRepo
.
getByIDCalls
)
require
.
Equal
(
t
,
0
,
groupRepo
.
getByIDCalls
)
require
.
Equal
(
t
,
1
,
groupRepo
.
getByIDLiteCalls
)
require
.
Equal
(
t
,
1
,
groupRepo
.
getByIDLiteCalls
)
}
}
func
TestGatewayService_ResolveGatewayGroup_DetectsFallbackCycle
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
10
)
fallbackID
:=
int64
(
11
)
group
:=
&
Group
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
ClaudeCodeOnly
:
true
,
FallbackGroupID
:
&
fallbackID
,
}
fallbackGroup
:=
&
Group
{
ID
:
fallbackID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
ClaudeCodeOnly
:
true
,
FallbackGroupID
:
&
groupID
,
}
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
group
,
fallbackID
:
fallbackGroup
,
},
}
svc
:=
&
GatewayService
{
groupRepo
:
groupRepo
,
}
gotGroup
,
gotID
,
err
:=
svc
.
resolveGatewayGroup
(
ctx
,
&
groupID
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
gotGroup
)
require
.
Nil
(
t
,
gotID
)
require
.
Contains
(
t
,
err
.
Error
(),
"fallback group cycle"
)
}
backend/internal/service/gateway_service.go
View file @
2597fe78
...
@@ -640,7 +640,7 @@ func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig {
...
@@ -640,7 +640,7 @@ func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig {
}
}
func
(
s
*
GatewayService
)
withGroupContext
(
ctx
context
.
Context
,
group
*
Group
)
context
.
Context
{
func
(
s
*
GatewayService
)
withGroupContext
(
ctx
context
.
Context
,
group
*
Group
)
context
.
Context
{
if
group
==
nil
{
if
!
IsGroupContextValid
(
group
)
{
return
ctx
return
ctx
}
}
if
existing
,
ok
:=
ctx
.
Value
(
ctxkey
.
Group
)
.
(
*
Group
);
ok
&&
existing
!=
nil
&&
existing
.
ID
==
group
.
ID
{
if
existing
,
ok
:=
ctx
.
Value
(
ctxkey
.
Group
)
.
(
*
Group
);
ok
&&
existing
!=
nil
&&
existing
.
ID
==
group
.
ID
{
...
@@ -650,7 +650,7 @@ func (s *GatewayService) withGroupContext(ctx context.Context, group *Group) con
...
@@ -650,7 +650,7 @@ func (s *GatewayService) withGroupContext(ctx context.Context, group *Group) con
}
}
func
(
s
*
GatewayService
)
groupFromContext
(
ctx
context
.
Context
,
groupID
int64
)
*
Group
{
func
(
s
*
GatewayService
)
groupFromContext
(
ctx
context
.
Context
,
groupID
int64
)
*
Group
{
if
group
,
ok
:=
ctx
.
Value
(
ctxkey
.
Group
)
.
(
*
Group
);
ok
&&
g
roup
!=
nil
&&
group
.
ID
==
groupID
{
if
group
,
ok
:=
ctx
.
Value
(
ctxkey
.
Group
)
.
(
*
Group
);
ok
&&
IsG
roup
ContextValid
(
group
)
&&
group
.
ID
==
groupID
{
return
group
return
group
}
}
return
nil
return
nil
...
@@ -673,7 +673,13 @@ func (s *GatewayService) resolveGatewayGroup(ctx context.Context, groupID *int64
...
@@ -673,7 +673,13 @@ func (s *GatewayService) resolveGatewayGroup(ctx context.Context, groupID *int64
}
}
currentID
:=
*
groupID
currentID
:=
*
groupID
visited
:=
map
[
int64
]
struct
{}{}
for
{
for
{
if
_
,
seen
:=
visited
[
currentID
];
seen
{
return
nil
,
nil
,
fmt
.
Errorf
(
"fallback group cycle detected"
)
}
visited
[
currentID
]
=
struct
{}{}
group
,
err
:=
s
.
resolveGroupByID
(
ctx
,
currentID
)
group
,
err
:=
s
.
resolveGroupByID
(
ctx
,
currentID
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
nil
,
err
return
nil
,
nil
,
err
...
...
backend/internal/service/gemini_messages_compat_service.go
View file @
2597fe78
...
@@ -87,7 +87,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
...
@@ -87,7 +87,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
}
else
if
groupID
!=
nil
{
}
else
if
groupID
!=
nil
{
// 根据分组 platform 决定查询哪种账号
// 根据分组 platform 决定查询哪种账号
var
group
*
Group
var
group
*
Group
if
ctxGroup
,
ok
:=
ctx
.
Value
(
ctxkey
.
Group
)
.
(
*
Group
);
ok
&&
ctxGroup
!=
nil
&&
ctxGroup
.
ID
==
*
groupID
{
if
ctxGroup
,
ok
:=
ctx
.
Value
(
ctxkey
.
Group
)
.
(
*
Group
);
ok
&&
IsGroupContextValid
(
ctxGroup
)
&&
ctxGroup
.
ID
==
*
groupID
{
group
=
ctxGroup
group
=
ctxGroup
}
else
{
}
else
{
var
err
error
var
err
error
...
...
backend/internal/service/group.go
View file @
2597fe78
...
@@ -72,3 +72,17 @@ func (g *Group) GetImagePrice(imageSize string) *float64 {
...
@@ -72,3 +72,17 @@ func (g *Group) GetImagePrice(imageSize string) *float64 {
return
g
.
ImagePrice2K
return
g
.
ImagePrice2K
}
}
}
}
// IsGroupContextValid reports whether a group from context has the fields required for routing decisions.
func
IsGroupContextValid
(
group
*
Group
)
bool
{
if
group
==
nil
{
return
false
}
if
group
.
ID
<=
0
{
return
false
}
if
group
.
Platform
==
""
||
group
.
Status
==
""
{
return
false
}
return
true
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment