Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
d1a6d6b1
Commit
d1a6d6b1
authored
Jan 10, 2026
by
shaw
Browse files
Merge branch 'mt21625457/main'
parents
15884f36
37308198
Changes
18
Show whitespace changes
Inline
Side-by-side
backend/internal/pkg/ctxkey/ctxkey.go
View file @
d1a6d6b1
...
...
@@ -9,4 +9,6 @@ const (
ForcePlatform
Key
=
"ctx_force_platform"
// IsClaudeCodeClient 是否为 Claude Code 客户端,由中间件设置
IsClaudeCodeClient
Key
=
"ctx_is_claude_code_client"
// Group 认证后的分组信息,由 API Key 认证中间件设置
Group
Key
=
"ctx_group"
)
backend/internal/repository/api_key_repo.go
View file @
d1a6d6b1
...
...
@@ -339,6 +339,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
RateMultiplier
:
g
.
RateMultiplier
,
IsExclusive
:
g
.
IsExclusive
,
Status
:
g
.
Status
,
Hydrated
:
true
,
SubscriptionType
:
g
.
SubscriptionType
,
DailyLimitUSD
:
g
.
DailyLimitUsd
,
WeeklyLimitUSD
:
g
.
WeeklyLimitUsd
,
...
...
backend/internal/repository/group_repo.go
View file @
d1a6d6b1
...
...
@@ -60,6 +60,17 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
}
func
(
r
*
groupRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
Group
,
error
)
{
out
,
err
:=
r
.
GetByIDLite
(
ctx
,
id
)
if
err
!=
nil
{
return
nil
,
err
}
count
,
_
:=
r
.
GetAccountCount
(
ctx
,
out
.
ID
)
out
.
AccountCount
=
count
return
out
,
nil
}
func
(
r
*
groupRepository
)
GetByIDLite
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
Group
,
error
)
{
// AccountCount is intentionally not loaded here; use GetByID when needed.
m
,
err
:=
r
.
client
.
Group
.
Query
()
.
Where
(
group
.
IDEQ
(
id
))
.
Only
(
ctx
)
...
...
@@ -67,10 +78,7 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group
return
nil
,
translatePersistenceError
(
err
,
service
.
ErrGroupNotFound
,
nil
)
}
out
:=
groupEntityToService
(
m
)
count
,
_
:=
r
.
GetAccountCount
(
ctx
,
out
.
ID
)
out
.
AccountCount
=
count
return
out
,
nil
return
groupEntityToService
(
m
),
nil
}
func
(
r
*
groupRepository
)
Update
(
ctx
context
.
Context
,
groupIn
*
service
.
Group
)
error
{
...
...
backend/internal/repository/group_repo_integration_test.go
View file @
d1a6d6b1
...
...
@@ -4,6 +4,8 @@ package repository
import
(
"context"
"database/sql"
"errors"
"testing"
dbent
"github.com/Wei-Shaw/sub2api/ent"
...
...
@@ -19,6 +21,20 @@ type GroupRepoSuite struct {
repo
*
groupRepository
}
type
forbidSQLExecutor
struct
{
called
bool
}
func
(
s
*
forbidSQLExecutor
)
ExecContext
(
ctx
context
.
Context
,
query
string
,
args
...
any
)
(
sql
.
Result
,
error
)
{
s
.
called
=
true
return
nil
,
errors
.
New
(
"unexpected sql exec"
)
}
func
(
s
*
forbidSQLExecutor
)
QueryContext
(
ctx
context
.
Context
,
query
string
,
args
...
any
)
(
*
sql
.
Rows
,
error
)
{
s
.
called
=
true
return
nil
,
errors
.
New
(
"unexpected sql query"
)
}
func
(
s
*
GroupRepoSuite
)
SetupTest
()
{
s
.
ctx
=
context
.
Background
()
tx
:=
testEntTx
(
s
.
T
())
...
...
@@ -57,6 +73,26 @@ func (s *GroupRepoSuite) TestGetByID_NotFound() {
s
.
Require
()
.
ErrorIs
(
err
,
service
.
ErrGroupNotFound
)
}
func
(
s
*
GroupRepoSuite
)
TestGetByIDLite_DoesNotUseAccountCount
()
{
group
:=
&
service
.
Group
{
Name
:
"lite-group"
,
Platform
:
service
.
PlatformAnthropic
,
RateMultiplier
:
1.0
,
IsExclusive
:
false
,
Status
:
service
.
StatusActive
,
SubscriptionType
:
service
.
SubscriptionTypeStandard
,
}
s
.
Require
()
.
NoError
(
s
.
repo
.
Create
(
s
.
ctx
,
group
))
spy
:=
&
forbidSQLExecutor
{}
repo
:=
newGroupRepositoryWithSQL
(
s
.
tx
.
Client
(),
spy
)
got
,
err
:=
repo
.
GetByIDLite
(
s
.
ctx
,
group
.
ID
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Equal
(
group
.
ID
,
got
.
ID
)
s
.
Require
()
.
False
(
spy
.
called
,
"expected no direct sql executor usage"
)
}
func
(
s
*
GroupRepoSuite
)
TestUpdate
()
{
group
:=
&
service
.
Group
{
Name
:
"original"
,
...
...
backend/internal/server/api_contract_test.go
View file @
d1a6d6b1
...
...
@@ -575,6 +575,10 @@ func (stubGroupRepo) GetByID(ctx context.Context, id int64) (*service.Group, err
return
nil
,
service
.
ErrGroupNotFound
}
func
(
stubGroupRepo
)
GetByIDLite
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
Group
,
error
)
{
return
nil
,
service
.
ErrGroupNotFound
}
func
(
stubGroupRepo
)
Update
(
ctx
context
.
Context
,
group
*
service
.
Group
)
error
{
return
errors
.
New
(
"not implemented"
)
}
...
...
backend/internal/server/middleware/api_key_auth.go
View file @
d1a6d6b1
package
middleware
import
(
"context"
"errors"
"log"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/service"
...
...
@@ -103,6 +105,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
Concurrency
:
apiKey
.
User
.
Concurrency
,
})
c
.
Set
(
string
(
ContextKeyUserRole
),
apiKey
.
User
.
Role
)
setGroupContext
(
c
,
apiKey
.
Group
)
c
.
Next
()
return
}
...
...
@@ -161,6 +164,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
Concurrency
:
apiKey
.
User
.
Concurrency
,
})
c
.
Set
(
string
(
ContextKeyUserRole
),
apiKey
.
User
.
Role
)
setGroupContext
(
c
,
apiKey
.
Group
)
c
.
Next
()
}
...
...
@@ -185,3 +189,14 @@ func GetSubscriptionFromContext(c *gin.Context) (*service.UserSubscription, bool
subscription
,
ok
:=
value
.
(
*
service
.
UserSubscription
)
return
subscription
,
ok
}
func
setGroupContext
(
c
*
gin
.
Context
,
group
*
service
.
Group
)
{
if
!
service
.
IsGroupContextValid
(
group
)
{
return
}
if
existing
,
ok
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
Group
)
.
(
*
service
.
Group
);
ok
&&
existing
!=
nil
&&
existing
.
ID
==
group
.
ID
&&
service
.
IsGroupContextValid
(
existing
)
{
return
}
ctx
:=
context
.
WithValue
(
c
.
Request
.
Context
(),
ctxkey
.
Group
,
group
)
c
.
Request
=
c
.
Request
.
WithContext
(
ctx
)
}
backend/internal/server/middleware/api_key_auth_google.go
View file @
d1a6d6b1
...
...
@@ -63,6 +63,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
Concurrency
:
apiKey
.
User
.
Concurrency
,
})
c
.
Set
(
string
(
ContextKeyUserRole
),
apiKey
.
User
.
Role
)
setGroupContext
(
c
,
apiKey
.
Group
)
c
.
Next
()
return
}
...
...
@@ -102,6 +103,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
Concurrency
:
apiKey
.
User
.
Concurrency
,
})
c
.
Set
(
string
(
ContextKeyUserRole
),
apiKey
.
User
.
Role
)
setGroupContext
(
c
,
apiKey
.
Group
)
c
.
Next
()
}
}
...
...
backend/internal/server/middleware/api_key_auth_google_test.go
View file @
d1a6d6b1
...
...
@@ -9,6 +9,7 @@ import (
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
...
...
@@ -133,6 +134,70 @@ func TestApiKeyAuthWithSubscriptionGoogle_QueryApiKeyRejected(t *testing.T) {
require
.
Equal
(
t
,
"INVALID_ARGUMENT"
,
resp
.
Error
.
Status
)
}
func
TestApiKeyAuthWithSubscriptionGoogleSetsGroupContext
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
group
:=
&
service
.
Group
{
ID
:
99
,
Name
:
"g1"
,
Status
:
service
.
StatusActive
,
Platform
:
service
.
PlatformGemini
,
Hydrated
:
true
,
}
user
:=
&
service
.
User
{
ID
:
7
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
Balance
:
10
,
Concurrency
:
3
,
}
apiKey
:=
&
service
.
APIKey
{
ID
:
100
,
UserID
:
user
.
ID
,
Key
:
"test-key"
,
Status
:
service
.
StatusActive
,
User
:
user
,
Group
:
group
,
}
apiKey
.
GroupID
=
&
group
.
ID
apiKeyService
:=
service
.
NewAPIKeyService
(
fakeAPIKeyRepo
{
getByKey
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
if
key
!=
apiKey
.
Key
{
return
nil
,
service
.
ErrAPIKeyNotFound
}
clone
:=
*
apiKey
return
&
clone
,
nil
},
},
nil
,
nil
,
nil
,
nil
,
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
},
)
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
r
:=
gin
.
New
()
r
.
Use
(
APIKeyAuthWithSubscriptionGoogle
(
apiKeyService
,
nil
,
cfg
))
r
.
GET
(
"/v1beta/test"
,
func
(
c
*
gin
.
Context
)
{
groupFromCtx
,
ok
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
Group
)
.
(
*
service
.
Group
)
if
!
ok
||
groupFromCtx
==
nil
||
groupFromCtx
.
ID
!=
group
.
ID
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"ok"
:
false
})
return
}
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"ok"
:
true
})
})
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/v1beta/test"
,
nil
)
req
.
Header
.
Set
(
"x-api-key"
,
apiKey
.
Key
)
rec
:=
httptest
.
NewRecorder
()
r
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
}
func
TestApiKeyAuthWithSubscriptionGoogle_QueryKeyAllowedOnV1Beta
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
...
...
backend/internal/server/middleware/api_key_auth_test.go
View file @
d1a6d6b1
...
...
@@ -11,6 +11,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
...
...
@@ -25,6 +26,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
ID
:
42
,
Name
:
"sub"
,
Status
:
service
.
StatusActive
,
Hydrated
:
true
,
SubscriptionType
:
service
.
SubscriptionTypeSubscription
,
DailyLimitUSD
:
&
limit
,
}
...
...
@@ -110,6 +112,129 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
})
}
func
TestAPIKeyAuthSetsGroupContext
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
group
:=
&
service
.
Group
{
ID
:
101
,
Name
:
"g1"
,
Status
:
service
.
StatusActive
,
Platform
:
service
.
PlatformAnthropic
,
Hydrated
:
true
,
}
user
:=
&
service
.
User
{
ID
:
7
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
Balance
:
10
,
Concurrency
:
3
,
}
apiKey
:=
&
service
.
APIKey
{
ID
:
100
,
UserID
:
user
.
ID
,
Key
:
"test-key"
,
Status
:
service
.
StatusActive
,
User
:
user
,
Group
:
group
,
}
apiKey
.
GroupID
=
&
group
.
ID
apiKeyRepo
:=
&
stubApiKeyRepo
{
getByKey
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
if
key
!=
apiKey
.
Key
{
return
nil
,
service
.
ErrAPIKeyNotFound
}
clone
:=
*
apiKey
return
&
clone
,
nil
},
}
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
apiKeyService
:=
service
.
NewAPIKeyService
(
apiKeyRepo
,
nil
,
nil
,
nil
,
nil
,
cfg
)
router
:=
gin
.
New
()
router
.
Use
(
gin
.
HandlerFunc
(
NewAPIKeyAuthMiddleware
(
apiKeyService
,
nil
,
cfg
)))
router
.
GET
(
"/t"
,
func
(
c
*
gin
.
Context
)
{
groupFromCtx
,
ok
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
Group
)
.
(
*
service
.
Group
)
if
!
ok
||
groupFromCtx
==
nil
||
groupFromCtx
.
ID
!=
group
.
ID
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"ok"
:
false
})
return
}
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"ok"
:
true
})
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
req
.
Header
.
Set
(
"x-api-key"
,
apiKey
.
Key
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
}
func
TestAPIKeyAuthOverwritesInvalidContextGroup
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
group
:=
&
service
.
Group
{
ID
:
101
,
Name
:
"g1"
,
Status
:
service
.
StatusActive
,
Platform
:
service
.
PlatformAnthropic
,
Hydrated
:
true
,
}
user
:=
&
service
.
User
{
ID
:
7
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
Balance
:
10
,
Concurrency
:
3
,
}
apiKey
:=
&
service
.
APIKey
{
ID
:
100
,
UserID
:
user
.
ID
,
Key
:
"test-key"
,
Status
:
service
.
StatusActive
,
User
:
user
,
Group
:
group
,
}
apiKey
.
GroupID
=
&
group
.
ID
apiKeyRepo
:=
&
stubApiKeyRepo
{
getByKey
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
if
key
!=
apiKey
.
Key
{
return
nil
,
service
.
ErrAPIKeyNotFound
}
clone
:=
*
apiKey
return
&
clone
,
nil
},
}
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
apiKeyService
:=
service
.
NewAPIKeyService
(
apiKeyRepo
,
nil
,
nil
,
nil
,
nil
,
cfg
)
router
:=
gin
.
New
()
router
.
Use
(
gin
.
HandlerFunc
(
NewAPIKeyAuthMiddleware
(
apiKeyService
,
nil
,
cfg
)))
invalidGroup
:=
&
service
.
Group
{
ID
:
group
.
ID
,
Platform
:
group
.
Platform
,
Status
:
group
.
Status
,
}
router
.
GET
(
"/t"
,
func
(
c
*
gin
.
Context
)
{
groupFromCtx
,
ok
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
Group
)
.
(
*
service
.
Group
)
if
!
ok
||
groupFromCtx
==
nil
||
groupFromCtx
.
ID
!=
group
.
ID
||
!
groupFromCtx
.
Hydrated
||
groupFromCtx
==
invalidGroup
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"ok"
:
false
})
return
}
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"ok"
:
true
})
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
req
.
Header
.
Set
(
"x-api-key"
,
apiKey
.
Key
)
req
=
req
.
WithContext
(
context
.
WithValue
(
req
.
Context
(),
ctxkey
.
Group
,
invalidGroup
))
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
}
func
newAuthTestRouter
(
apiKeyService
*
service
.
APIKeyService
,
subscriptionService
*
service
.
SubscriptionService
,
cfg
*
config
.
Config
)
*
gin
.
Engine
{
router
:=
gin
.
New
()
router
.
Use
(
gin
.
HandlerFunc
(
NewAPIKeyAuthMiddleware
(
apiKeyService
,
subscriptionService
,
cfg
)))
...
...
backend/internal/service/admin_service.go
View file @
d1a6d6b1
...
...
@@ -576,18 +576,33 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro
return
fmt
.
Errorf
(
"cannot set self as fallback group"
)
}
visited
:=
map
[
int64
]
struct
{}{}
nextID
:=
fallbackGroupID
for
{
if
_
,
seen
:=
visited
[
nextID
];
seen
{
return
fmt
.
Errorf
(
"fallback group cycle detected"
)
}
visited
[
nextID
]
=
struct
{}{}
if
currentGroupID
>
0
&&
nextID
==
currentGroupID
{
return
fmt
.
Errorf
(
"fallback group cycle detected"
)
}
// 检查降级分组是否存在
fallbackGroup
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
fallbackGroup
ID
)
fallbackGroup
,
err
:=
s
.
groupRepo
.
GetByID
Lite
(
ctx
,
next
ID
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"fallback group not found: %w"
,
err
)
}
// 降级分组不能启用 claude_code_only,否则会造成死循环
if
fallbackGroup
.
ClaudeCodeOnly
{
if
nextID
==
fallbackGroupID
&&
fallbackGroup
.
ClaudeCodeOnly
{
return
fmt
.
Errorf
(
"fallback group cannot have claude_code_only enabled"
)
}
if
fallbackGroup
.
FallbackGroupID
==
nil
{
return
nil
}
nextID
=
*
fallbackGroup
.
FallbackGroupID
}
}
func
(
s
*
adminServiceImpl
)
UpdateGroup
(
ctx
context
.
Context
,
id
int64
,
input
*
UpdateGroupInput
)
(
*
Group
,
error
)
{
...
...
backend/internal/service/admin_service_delete_test.go
View file @
d1a6d6b1
...
...
@@ -107,6 +107,10 @@ func (s *groupRepoStub) GetByID(ctx context.Context, id int64) (*Group, error) {
panic
(
"unexpected GetByID call"
)
}
func
(
s
*
groupRepoStub
)
GetByIDLite
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
panic
(
"unexpected GetByIDLite call"
)
}
func
(
s
*
groupRepoStub
)
Update
(
ctx
context
.
Context
,
group
*
Group
)
error
{
panic
(
"unexpected Update call"
)
}
...
...
backend/internal/service/admin_service_group_test.go
View file @
d1a6d6b1
...
...
@@ -45,6 +45,13 @@ func (s *groupRepoStubForAdmin) GetByID(_ context.Context, _ int64) (*Group, err
return
s
.
getByID
,
nil
}
func
(
s
*
groupRepoStubForAdmin
)
GetByIDLite
(
_
context
.
Context
,
_
int64
)
(
*
Group
,
error
)
{
if
s
.
getErr
!=
nil
{
return
nil
,
s
.
getErr
}
return
s
.
getByID
,
nil
}
func
(
s
*
groupRepoStubForAdmin
)
Delete
(
_
context
.
Context
,
_
int64
)
error
{
panic
(
"unexpected Delete call"
)
}
...
...
@@ -290,3 +297,84 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) {
require
.
True
(
t
,
*
repo
.
listWithFiltersIsExclusive
)
})
}
func
TestAdminService_ValidateFallbackGroup_DetectsCycle
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
1
)
fallbackID
:=
int64
(
2
)
repo
:=
&
groupRepoStubForFallbackCycle
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
{
ID
:
groupID
,
FallbackGroupID
:
&
fallbackID
,
},
fallbackID
:
{
ID
:
fallbackID
,
FallbackGroupID
:
&
groupID
,
},
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
err
:=
svc
.
validateFallbackGroup
(
context
.
Background
(),
groupID
,
fallbackID
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"fallback group cycle"
)
}
type
groupRepoStubForFallbackCycle
struct
{
groups
map
[
int64
]
*
Group
}
func
(
s
*
groupRepoStubForFallbackCycle
)
Create
(
_
context
.
Context
,
_
*
Group
)
error
{
panic
(
"unexpected Create call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
Update
(
_
context
.
Context
,
_
*
Group
)
error
{
panic
(
"unexpected Update call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
return
s
.
GetByIDLite
(
ctx
,
id
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
GetByIDLite
(
_
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
if
g
,
ok
:=
s
.
groups
[
id
];
ok
{
return
g
,
nil
}
return
nil
,
ErrGroupNotFound
}
func
(
s
*
groupRepoStubForFallbackCycle
)
Delete
(
_
context
.
Context
,
_
int64
)
error
{
panic
(
"unexpected Delete call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
DeleteCascade
(
_
context
.
Context
,
_
int64
)
([]
int64
,
error
)
{
panic
(
"unexpected DeleteCascade call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
List
(
_
context
.
Context
,
_
pagination
.
PaginationParams
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
panic
(
"unexpected List call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
ListWithFilters
(
_
context
.
Context
,
_
pagination
.
PaginationParams
,
_
,
_
,
_
string
,
_
*
bool
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
panic
(
"unexpected ListWithFilters call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
ListActive
(
_
context
.
Context
)
([]
Group
,
error
)
{
panic
(
"unexpected ListActive call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
ListActiveByPlatform
(
_
context
.
Context
,
_
string
)
([]
Group
,
error
)
{
panic
(
"unexpected ListActiveByPlatform call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
ExistsByName
(
_
context
.
Context
,
_
string
)
(
bool
,
error
)
{
panic
(
"unexpected ExistsByName call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
GetAccountCount
(
_
context
.
Context
,
_
int64
)
(
int64
,
error
)
{
panic
(
"unexpected GetAccountCount call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
DeleteAccountGroupsByGroupID
(
_
context
.
Context
,
_
int64
)
(
int64
,
error
)
{
panic
(
"unexpected DeleteAccountGroupsByGroupID call"
)
}
backend/internal/service/gateway_multiplatform_test.go
View file @
d1a6d6b1
...
...
@@ -9,6 +9,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
...
...
@@ -191,6 +192,56 @@ func (m *mockGatewayCacheForPlatform) RefreshSessionTTL(ctx context.Context, gro
return
nil
}
type
mockGroupRepoForGateway
struct
{
groups
map
[
int64
]
*
Group
getByIDCalls
int
getByIDLiteCalls
int
}
func
(
m
*
mockGroupRepoForGateway
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
m
.
getByIDCalls
++
if
g
,
ok
:=
m
.
groups
[
id
];
ok
{
return
g
,
nil
}
return
nil
,
ErrGroupNotFound
}
func
(
m
*
mockGroupRepoForGateway
)
GetByIDLite
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
m
.
getByIDLiteCalls
++
if
g
,
ok
:=
m
.
groups
[
id
];
ok
{
return
g
,
nil
}
return
nil
,
ErrGroupNotFound
}
func
(
m
*
mockGroupRepoForGateway
)
Create
(
ctx
context
.
Context
,
group
*
Group
)
error
{
return
nil
}
func
(
m
*
mockGroupRepoForGateway
)
Update
(
ctx
context
.
Context
,
group
*
Group
)
error
{
return
nil
}
func
(
m
*
mockGroupRepoForGateway
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
m
*
mockGroupRepoForGateway
)
DeleteCascade
(
ctx
context
.
Context
,
id
int64
)
([]
int64
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockGroupRepoForGateway
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
m
*
mockGroupRepoForGateway
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
status
,
search
string
,
isExclusive
*
bool
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
m
*
mockGroupRepoForGateway
)
ListActive
(
ctx
context
.
Context
)
([]
Group
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockGroupRepoForGateway
)
ListActiveByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
Group
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockGroupRepoForGateway
)
ExistsByName
(
ctx
context
.
Context
,
name
string
)
(
bool
,
error
)
{
return
false
,
nil
}
func
(
m
*
mockGroupRepoForGateway
)
GetAccountCount
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
m
*
mockGroupRepoForGateway
)
DeleteAccountGroupsByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
return
0
,
nil
}
func
ptr
[
T
any
](
v
T
)
*
T
{
return
&
v
}
...
...
@@ -1019,3 +1070,190 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
require
.
Contains
(
t
,
err
.
Error
(),
"no available accounts"
)
})
}
func
TestGatewayService_GroupResolution_ReusesContextGroup
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
42
)
group
:=
&
Group
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
Hydrated
:
true
,
}
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
Group
,
group
)
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
group
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cfg
:
testConfig
(),
}
account
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
&
groupID
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
account
)
require
.
Equal
(
t
,
0
,
groupRepo
.
getByIDCalls
)
require
.
Equal
(
t
,
0
,
groupRepo
.
getByIDLiteCalls
)
}
func
TestGatewayService_GroupResolution_IgnoresInvalidContextGroup
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
42
)
ctxGroup
:=
&
Group
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
}
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
Group
,
ctxGroup
)
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
group
:=
&
Group
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
}
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
group
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cfg
:
testConfig
(),
}
account
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
&
groupID
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
account
)
require
.
Equal
(
t
,
0
,
groupRepo
.
getByIDCalls
)
require
.
Equal
(
t
,
1
,
groupRepo
.
getByIDLiteCalls
)
}
func
TestGatewayService_GroupContext_OverwritesInvalidContextGroup
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
42
)
invalidGroup
:=
&
Group
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
}
hydratedGroup
:=
&
Group
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
Hydrated
:
true
,
}
ctx
:=
context
.
WithValue
(
context
.
Background
(),
ctxkey
.
Group
,
invalidGroup
)
svc
:=
&
GatewayService
{}
ctx
=
svc
.
withGroupContext
(
ctx
,
hydratedGroup
)
got
,
ok
:=
ctx
.
Value
(
ctxkey
.
Group
)
.
(
*
Group
)
require
.
True
(
t
,
ok
)
require
.
Same
(
t
,
hydratedGroup
,
got
)
}
func
TestGatewayService_GroupResolution_FallbackUsesLiteOnce
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
10
)
fallbackID
:=
int64
(
11
)
group
:=
&
Group
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
ClaudeCodeOnly
:
true
,
FallbackGroupID
:
&
fallbackID
,
Hydrated
:
true
,
}
fallbackGroup
:=
&
Group
{
ID
:
fallbackID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
}
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
Group
,
group
)
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
fallbackID
:
fallbackGroup
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cfg
:
testConfig
(),
}
account
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
&
groupID
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
account
)
require
.
Equal
(
t
,
0
,
groupRepo
.
getByIDCalls
)
require
.
Equal
(
t
,
1
,
groupRepo
.
getByIDLiteCalls
)
}
func
TestGatewayService_ResolveGatewayGroup_DetectsFallbackCycle
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
10
)
fallbackID
:=
int64
(
11
)
group
:=
&
Group
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
ClaudeCodeOnly
:
true
,
FallbackGroupID
:
&
fallbackID
,
}
fallbackGroup
:=
&
Group
{
ID
:
fallbackID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
ClaudeCodeOnly
:
true
,
FallbackGroupID
:
&
groupID
,
}
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
group
,
fallbackID
:
fallbackGroup
,
},
}
svc
:=
&
GatewayService
{
groupRepo
:
groupRepo
,
}
gotGroup
,
gotID
,
err
:=
svc
.
resolveGatewayGroup
(
ctx
,
&
groupID
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
gotGroup
)
require
.
Nil
(
t
,
gotID
)
require
.
Contains
(
t
,
err
.
Error
(),
"fallback group cycle"
)
}
backend/internal/service/gateway_service.go
View file @
d1a6d6b1
...
...
@@ -361,27 +361,13 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
if
hasForcePlatform
&&
forcePlatform
!=
""
{
platform
=
forcePlatform
}
else
if
groupID
!=
nil
{
// 根据分组 platform 决定查询哪种账号
group
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
*
groupID
)
group
,
resolvedGroupID
,
err
:=
s
.
resolveGatewayGroup
(
ctx
,
groupID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get group failed: %w"
,
err
)
return
nil
,
err
}
groupID
=
resolvedGroupID
ctx
=
s
.
withGroupContext
(
ctx
,
group
)
platform
=
group
.
Platform
// 检查 Claude Code 客户端限制
if
group
.
ClaudeCodeOnly
{
isClaudeCode
:=
IsClaudeCodeClient
(
ctx
)
if
!
isClaudeCode
{
// 非 Claude Code 客户端,检查是否有降级分组
if
group
.
FallbackGroupID
!=
nil
{
// 使用降级分组重新调度
fallbackGroupID
:=
*
group
.
FallbackGroupID
return
s
.
SelectAccountForModelWithExclusions
(
ctx
,
&
fallbackGroupID
,
sessionHash
,
requestedModel
,
excludedIDs
)
}
// 无降级分组,拒绝访问
return
nil
,
ErrClaudeCodeOnly
}
}
}
else
{
// 无分组时只使用原生 anthropic 平台
platform
=
PlatformAnthropic
...
...
@@ -409,10 +395,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
// 检查 Claude Code 客户端限制(可能会替换 groupID 为降级分组)
groupID
,
err
:=
s
.
checkClaudeCodeRestriction
(
ctx
,
groupID
)
group
,
groupID
,
err
:=
s
.
checkClaudeCodeRestriction
(
ctx
,
groupID
)
if
err
!=
nil
{
return
nil
,
err
}
ctx
=
s
.
withGroupContext
(
ctx
,
group
)
if
s
.
concurrencyService
==
nil
||
!
cfg
.
LoadBatchEnabled
{
account
,
err
:=
s
.
SelectAccountForModelWithExclusions
(
ctx
,
groupID
,
sessionHash
,
requestedModel
,
excludedIDs
)
...
...
@@ -452,7 +439,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
},
nil
}
platform
,
hasForcePlatform
,
err
:=
s
.
resolvePlatform
(
ctx
,
groupID
)
platform
,
hasForcePlatform
,
err
:=
s
.
resolvePlatform
(
ctx
,
groupID
,
group
)
if
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -655,51 +642,97 @@ func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig {
}
}
// checkClaudeCodeRestriction 检查分组的 Claude Code 客户端限制
// 如果分组启用了 claude_code_only 且请求不是来自 Claude Code 客户端:
// - 有降级分组:返回降级分组的 ID
// - 无降级分组:返回 ErrClaudeCodeOnly 错误
func
(
s
*
GatewayService
)
checkClaudeCodeRestriction
(
ctx
context
.
Context
,
groupID
*
int64
)
(
*
int64
,
error
)
{
if
groupID
==
nil
{
return
groupID
,
nil
func
(
s
*
GatewayService
)
withGroupContext
(
ctx
context
.
Context
,
group
*
Group
)
context
.
Context
{
if
!
IsGroupContextValid
(
group
)
{
return
ctx
}
if
existing
,
ok
:=
ctx
.
Value
(
ctxkey
.
Group
)
.
(
*
Group
);
ok
&&
existing
!=
nil
&&
existing
.
ID
==
group
.
ID
&&
IsGroupContextValid
(
existing
)
{
return
ctx
}
return
context
.
WithValue
(
ctx
,
ctxkey
.
Group
,
group
)
}
// 强制平台模式不检查 Claude Code 限制
if
_
,
hasForcePlatform
:=
ctx
.
Value
(
ctxkey
.
ForcePlatform
)
.
(
string
);
hasForcePlatform
{
return
group
ID
,
nil
func
(
s
*
GatewayService
)
groupFromContext
(
ctx
context
.
Context
,
groupID
int64
)
*
Group
{
if
group
,
ok
:=
ctx
.
Value
(
ctxkey
.
Group
)
.
(
*
Group
);
ok
&&
IsGroupContextValid
(
group
)
&&
group
.
ID
==
groupID
{
return
group
}
return
nil
}
group
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
*
groupID
)
func
(
s
*
GatewayService
)
resolveGroupByID
(
ctx
context
.
Context
,
groupID
int64
)
(
*
Group
,
error
)
{
if
group
:=
s
.
groupFromContext
(
ctx
,
groupID
);
group
!=
nil
{
return
group
,
nil
}
group
,
err
:=
s
.
groupRepo
.
GetByIDLite
(
ctx
,
groupID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get group failed: %w"
,
err
)
}
return
group
,
nil
}
func
(
s
*
GatewayService
)
resolveGatewayGroup
(
ctx
context
.
Context
,
groupID
*
int64
)
(
*
Group
,
*
int64
,
error
)
{
if
groupID
==
nil
{
return
nil
,
nil
,
nil
}
if
!
group
.
ClaudeCodeOnly
{
return
groupID
,
nil
currentID
:=
*
groupID
visited
:=
map
[
int64
]
struct
{}{}
for
{
if
_
,
seen
:=
visited
[
currentID
];
seen
{
return
nil
,
nil
,
fmt
.
Errorf
(
"fallback group cycle detected"
)
}
visited
[
currentID
]
=
struct
{}{}
group
,
err
:=
s
.
resolveGroupByID
(
ctx
,
currentID
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
if
!
group
.
ClaudeCodeOnly
||
IsClaudeCodeClient
(
ctx
)
{
return
group
,
&
currentID
,
nil
}
if
group
.
FallbackGroupID
==
nil
{
return
nil
,
nil
,
ErrClaudeCodeOnly
}
currentID
=
*
group
.
FallbackGroupID
}
}
// checkClaudeCodeRestriction 检查分组的 Claude Code 客户端限制
// 如果分组启用了 claude_code_only 且请求不是来自 Claude Code 客户端:
// - 有降级分组:返回降级分组的 ID
// - 无降级分组:返回 ErrClaudeCodeOnly 错误
func
(
s
*
GatewayService
)
checkClaudeCodeRestriction
(
ctx
context
.
Context
,
groupID
*
int64
)
(
*
Group
,
*
int64
,
error
)
{
if
groupID
==
nil
{
return
nil
,
groupID
,
nil
}
//
分组启用了
Claude Code 限制
if
IsClaudeCodeClient
(
ctx
)
{
return
groupID
,
nil
//
强制平台模式不检查
Claude Code 限制
if
_
,
hasForcePlatform
:=
ctx
.
Value
(
ctxkey
.
ForcePlatform
)
.
(
string
);
hasForcePlatform
{
return
nil
,
groupID
,
nil
}
// 非 Claude Code 客户端,检查降级分组
if
group
.
FallbackGroupID
!=
nil
{
return
group
.
FallbackGroupID
,
nil
group
,
resolvedID
,
err
:=
s
.
resolveGatewayGroup
(
ctx
,
groupID
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
return
nil
,
ErrClaudeCodeOnly
return
group
,
resolvedID
,
nil
}
func
(
s
*
GatewayService
)
resolvePlatform
(
ctx
context
.
Context
,
groupID
*
int64
)
(
string
,
bool
,
error
)
{
func
(
s
*
GatewayService
)
resolvePlatform
(
ctx
context
.
Context
,
groupID
*
int64
,
group
*
Group
)
(
string
,
bool
,
error
)
{
forcePlatform
,
hasForcePlatform
:=
ctx
.
Value
(
ctxkey
.
ForcePlatform
)
.
(
string
)
if
hasForcePlatform
&&
forcePlatform
!=
""
{
return
forcePlatform
,
true
,
nil
}
if
group
!=
nil
{
return
group
.
Platform
,
false
,
nil
}
if
groupID
!=
nil
{
group
,
err
:=
s
.
groupRepo
.
Get
ByID
(
ctx
,
*
groupID
)
group
,
err
:=
s
.
resolveGroup
ByID
(
ctx
,
*
groupID
)
if
err
!=
nil
{
return
""
,
false
,
fmt
.
Errorf
(
"get group failed: %w"
,
err
)
return
""
,
false
,
err
}
return
group
.
Platform
,
false
,
nil
}
...
...
backend/internal/service/gemini_messages_compat_service.go
View file @
d1a6d6b1
...
...
@@ -86,10 +86,16 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
platform
=
forcePlatform
}
else
if
groupID
!=
nil
{
// 根据分组 platform 决定查询哪种账号
group
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
*
groupID
)
var
group
*
Group
if
ctxGroup
,
ok
:=
ctx
.
Value
(
ctxkey
.
Group
)
.
(
*
Group
);
ok
&&
IsGroupContextValid
(
ctxGroup
)
&&
ctxGroup
.
ID
==
*
groupID
{
group
=
ctxGroup
}
else
{
var
err
error
group
,
err
=
s
.
groupRepo
.
GetByIDLite
(
ctx
,
*
groupID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get group failed: %w"
,
err
)
}
}
platform
=
group
.
Platform
}
else
{
// 无分组时只使用原生 gemini 平台
...
...
backend/internal/service/gemini_multiplatform_test.go
View file @
d1a6d6b1
...
...
@@ -8,6 +8,7 @@ import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
...
...
@@ -153,9 +154,20 @@ var _ AccountRepository = (*mockAccountRepoForGemini)(nil)
// mockGroupRepoForGemini Gemini 测试用的 group repo mock
type
mockGroupRepoForGemini
struct
{
groups
map
[
int64
]
*
Group
getByIDCalls
int
getByIDLiteCalls
int
}
func
(
m
*
mockGroupRepoForGemini
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
m
.
getByIDCalls
++
if
g
,
ok
:=
m
.
groups
[
id
];
ok
{
return
g
,
nil
}
return
nil
,
errors
.
New
(
"group not found"
)
}
func
(
m
*
mockGroupRepoForGemini
)
GetByIDLite
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
m
.
getByIDLiteCalls
++
if
g
,
ok
:=
m
.
groups
[
id
];
ok
{
return
g
,
nil
}
...
...
@@ -248,6 +260,77 @@ func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiP
require
.
Equal
(
t
,
PlatformGemini
,
acc
.
Platform
,
"无分组时应只返回 gemini 平台账户"
)
}
func
TestGeminiMessagesCompatService_GroupResolution_ReusesContextGroup
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
7
)
group
:=
&
Group
{
ID
:
groupID
,
Platform
:
PlatformGemini
,
Status
:
StatusActive
,
Hydrated
:
true
,
}
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
Group
,
group
)
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForGemini
{}
groupRepo
:=
&
mockGroupRepoForGemini
{
groups
:
map
[
int64
]
*
Group
{}}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
&
groupID
,
""
,
"gemini-2.5-flash"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
0
,
groupRepo
.
getByIDCalls
)
require
.
Equal
(
t
,
0
,
groupRepo
.
getByIDLiteCalls
)
}
func
TestGeminiMessagesCompatService_GroupResolution_UsesLiteFetch
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
7
)
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForGemini
{}
groupRepo
:=
&
mockGroupRepoForGemini
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
{
ID
:
groupID
,
Platform
:
PlatformGemini
},
},
}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
&
groupID
,
""
,
"gemini-2.5-flash"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
0
,
groupRepo
.
getByIDCalls
)
require
.
Equal
(
t
,
1
,
groupRepo
.
getByIDLiteCalls
)
}
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_AntigravityGroup 测试 antigravity 分组
func
TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_AntigravityGroup
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
...
...
backend/internal/service/group.go
View file @
d1a6d6b1
...
...
@@ -10,6 +10,7 @@ type Group struct {
RateMultiplier
float64
IsExclusive
bool
Status
string
Hydrated
bool
// indicates the group was loaded from a trusted repository source
SubscriptionType
string
DailyLimitUSD
*
float64
...
...
@@ -72,3 +73,20 @@ func (g *Group) GetImagePrice(imageSize string) *float64 {
return
g
.
ImagePrice2K
}
}
// IsGroupContextValid reports whether a group from context has the fields required for routing decisions.
func
IsGroupContextValid
(
group
*
Group
)
bool
{
if
group
==
nil
{
return
false
}
if
group
.
ID
<=
0
{
return
false
}
if
!
group
.
Hydrated
{
return
false
}
if
group
.
Platform
==
""
||
group
.
Status
==
""
{
return
false
}
return
true
}
backend/internal/service/group_service.go
View file @
d1a6d6b1
...
...
@@ -16,6 +16,7 @@ var (
type
GroupRepository
interface
{
Create
(
ctx
context
.
Context
,
group
*
Group
)
error
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
GetByIDLite
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
Update
(
ctx
context
.
Context
,
group
*
Group
)
error
Delete
(
ctx
context
.
Context
,
id
int64
)
error
DeleteCascade
(
ctx
context
.
Context
,
id
int64
)
([]
int64
,
error
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment