Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
297f08c6
Commit
297f08c6
authored
Jan 10, 2026
by
yangjianbo
Browse files
Merge branch 'test' into dev
parents
61f55674
435f6938
Changes
18
Hide whitespace changes
Inline
Side-by-side
backend/internal/pkg/ctxkey/ctxkey.go
View file @
297f08c6
...
@@ -9,4 +9,6 @@ const (
...
@@ -9,4 +9,6 @@ const (
ForcePlatform
Key
=
"ctx_force_platform"
ForcePlatform
Key
=
"ctx_force_platform"
// IsClaudeCodeClient 是否为 Claude Code 客户端,由中间件设置
// IsClaudeCodeClient 是否为 Claude Code 客户端,由中间件设置
IsClaudeCodeClient
Key
=
"ctx_is_claude_code_client"
IsClaudeCodeClient
Key
=
"ctx_is_claude_code_client"
// Group 认证后的分组信息,由 API Key 认证中间件设置
Group
Key
=
"ctx_group"
)
)
backend/internal/repository/api_key_repo.go
View file @
297f08c6
...
@@ -339,6 +339,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
...
@@ -339,6 +339,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
RateMultiplier
:
g
.
RateMultiplier
,
RateMultiplier
:
g
.
RateMultiplier
,
IsExclusive
:
g
.
IsExclusive
,
IsExclusive
:
g
.
IsExclusive
,
Status
:
g
.
Status
,
Status
:
g
.
Status
,
Hydrated
:
true
,
SubscriptionType
:
g
.
SubscriptionType
,
SubscriptionType
:
g
.
SubscriptionType
,
DailyLimitUSD
:
g
.
DailyLimitUsd
,
DailyLimitUSD
:
g
.
DailyLimitUsd
,
WeeklyLimitUSD
:
g
.
WeeklyLimitUsd
,
WeeklyLimitUSD
:
g
.
WeeklyLimitUsd
,
...
...
backend/internal/repository/group_repo.go
View file @
297f08c6
...
@@ -60,6 +60,17 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
...
@@ -60,6 +60,17 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
}
}
func
(
r
*
groupRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
Group
,
error
)
{
func
(
r
*
groupRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
Group
,
error
)
{
out
,
err
:=
r
.
GetByIDLite
(
ctx
,
id
)
if
err
!=
nil
{
return
nil
,
err
}
count
,
_
:=
r
.
GetAccountCount
(
ctx
,
out
.
ID
)
out
.
AccountCount
=
count
return
out
,
nil
}
func
(
r
*
groupRepository
)
GetByIDLite
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
Group
,
error
)
{
// AccountCount is intentionally not loaded here; use GetByID when needed.
m
,
err
:=
r
.
client
.
Group
.
Query
()
.
m
,
err
:=
r
.
client
.
Group
.
Query
()
.
Where
(
group
.
IDEQ
(
id
))
.
Where
(
group
.
IDEQ
(
id
))
.
Only
(
ctx
)
Only
(
ctx
)
...
@@ -67,10 +78,7 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group
...
@@ -67,10 +78,7 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group
return
nil
,
translatePersistenceError
(
err
,
service
.
ErrGroupNotFound
,
nil
)
return
nil
,
translatePersistenceError
(
err
,
service
.
ErrGroupNotFound
,
nil
)
}
}
out
:=
groupEntityToService
(
m
)
return
groupEntityToService
(
m
),
nil
count
,
_
:=
r
.
GetAccountCount
(
ctx
,
out
.
ID
)
out
.
AccountCount
=
count
return
out
,
nil
}
}
func
(
r
*
groupRepository
)
Update
(
ctx
context
.
Context
,
groupIn
*
service
.
Group
)
error
{
func
(
r
*
groupRepository
)
Update
(
ctx
context
.
Context
,
groupIn
*
service
.
Group
)
error
{
...
...
backend/internal/repository/group_repo_integration_test.go
View file @
297f08c6
...
@@ -4,6 +4,8 @@ package repository
...
@@ -4,6 +4,8 @@ package repository
import
(
import
(
"context"
"context"
"database/sql"
"errors"
"testing"
"testing"
dbent
"github.com/Wei-Shaw/sub2api/ent"
dbent
"github.com/Wei-Shaw/sub2api/ent"
...
@@ -19,6 +21,20 @@ type GroupRepoSuite struct {
...
@@ -19,6 +21,20 @@ type GroupRepoSuite struct {
repo
*
groupRepository
repo
*
groupRepository
}
}
type
forbidSQLExecutor
struct
{
called
bool
}
func
(
s
*
forbidSQLExecutor
)
ExecContext
(
ctx
context
.
Context
,
query
string
,
args
...
any
)
(
sql
.
Result
,
error
)
{
s
.
called
=
true
return
nil
,
errors
.
New
(
"unexpected sql exec"
)
}
func
(
s
*
forbidSQLExecutor
)
QueryContext
(
ctx
context
.
Context
,
query
string
,
args
...
any
)
(
*
sql
.
Rows
,
error
)
{
s
.
called
=
true
return
nil
,
errors
.
New
(
"unexpected sql query"
)
}
func
(
s
*
GroupRepoSuite
)
SetupTest
()
{
func
(
s
*
GroupRepoSuite
)
SetupTest
()
{
s
.
ctx
=
context
.
Background
()
s
.
ctx
=
context
.
Background
()
tx
:=
testEntTx
(
s
.
T
())
tx
:=
testEntTx
(
s
.
T
())
...
@@ -57,6 +73,26 @@ func (s *GroupRepoSuite) TestGetByID_NotFound() {
...
@@ -57,6 +73,26 @@ func (s *GroupRepoSuite) TestGetByID_NotFound() {
s
.
Require
()
.
ErrorIs
(
err
,
service
.
ErrGroupNotFound
)
s
.
Require
()
.
ErrorIs
(
err
,
service
.
ErrGroupNotFound
)
}
}
func
(
s
*
GroupRepoSuite
)
TestGetByIDLite_DoesNotUseAccountCount
()
{
group
:=
&
service
.
Group
{
Name
:
"lite-group"
,
Platform
:
service
.
PlatformAnthropic
,
RateMultiplier
:
1.0
,
IsExclusive
:
false
,
Status
:
service
.
StatusActive
,
SubscriptionType
:
service
.
SubscriptionTypeStandard
,
}
s
.
Require
()
.
NoError
(
s
.
repo
.
Create
(
s
.
ctx
,
group
))
spy
:=
&
forbidSQLExecutor
{}
repo
:=
newGroupRepositoryWithSQL
(
s
.
tx
.
Client
(),
spy
)
got
,
err
:=
repo
.
GetByIDLite
(
s
.
ctx
,
group
.
ID
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Equal
(
group
.
ID
,
got
.
ID
)
s
.
Require
()
.
False
(
spy
.
called
,
"expected no direct sql executor usage"
)
}
func
(
s
*
GroupRepoSuite
)
TestUpdate
()
{
func
(
s
*
GroupRepoSuite
)
TestUpdate
()
{
group
:=
&
service
.
Group
{
group
:=
&
service
.
Group
{
Name
:
"original"
,
Name
:
"original"
,
...
...
backend/internal/server/api_contract_test.go
View file @
297f08c6
...
@@ -575,6 +575,10 @@ func (stubGroupRepo) GetByID(ctx context.Context, id int64) (*service.Group, err
...
@@ -575,6 +575,10 @@ func (stubGroupRepo) GetByID(ctx context.Context, id int64) (*service.Group, err
return
nil
,
service
.
ErrGroupNotFound
return
nil
,
service
.
ErrGroupNotFound
}
}
func
(
stubGroupRepo
)
GetByIDLite
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
Group
,
error
)
{
return
nil
,
service
.
ErrGroupNotFound
}
func
(
stubGroupRepo
)
Update
(
ctx
context
.
Context
,
group
*
service
.
Group
)
error
{
func
(
stubGroupRepo
)
Update
(
ctx
context
.
Context
,
group
*
service
.
Group
)
error
{
return
errors
.
New
(
"not implemented"
)
return
errors
.
New
(
"not implemented"
)
}
}
...
...
backend/internal/server/middleware/api_key_auth.go
View file @
297f08c6
package
middleware
package
middleware
import
(
import
(
"context"
"errors"
"errors"
"log"
"log"
"strings"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/service"
...
@@ -103,6 +105,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
...
@@ -103,6 +105,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
Concurrency
:
apiKey
.
User
.
Concurrency
,
Concurrency
:
apiKey
.
User
.
Concurrency
,
})
})
c
.
Set
(
string
(
ContextKeyUserRole
),
apiKey
.
User
.
Role
)
c
.
Set
(
string
(
ContextKeyUserRole
),
apiKey
.
User
.
Role
)
setGroupContext
(
c
,
apiKey
.
Group
)
c
.
Next
()
c
.
Next
()
return
return
}
}
...
@@ -161,6 +164,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
...
@@ -161,6 +164,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
Concurrency
:
apiKey
.
User
.
Concurrency
,
Concurrency
:
apiKey
.
User
.
Concurrency
,
})
})
c
.
Set
(
string
(
ContextKeyUserRole
),
apiKey
.
User
.
Role
)
c
.
Set
(
string
(
ContextKeyUserRole
),
apiKey
.
User
.
Role
)
setGroupContext
(
c
,
apiKey
.
Group
)
c
.
Next
()
c
.
Next
()
}
}
...
@@ -185,3 +189,14 @@ func GetSubscriptionFromContext(c *gin.Context) (*service.UserSubscription, bool
...
@@ -185,3 +189,14 @@ func GetSubscriptionFromContext(c *gin.Context) (*service.UserSubscription, bool
subscription
,
ok
:=
value
.
(
*
service
.
UserSubscription
)
subscription
,
ok
:=
value
.
(
*
service
.
UserSubscription
)
return
subscription
,
ok
return
subscription
,
ok
}
}
func
setGroupContext
(
c
*
gin
.
Context
,
group
*
service
.
Group
)
{
if
!
service
.
IsGroupContextValid
(
group
)
{
return
}
if
existing
,
ok
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
Group
)
.
(
*
service
.
Group
);
ok
&&
existing
!=
nil
&&
existing
.
ID
==
group
.
ID
&&
service
.
IsGroupContextValid
(
existing
)
{
return
}
ctx
:=
context
.
WithValue
(
c
.
Request
.
Context
(),
ctxkey
.
Group
,
group
)
c
.
Request
=
c
.
Request
.
WithContext
(
ctx
)
}
backend/internal/server/middleware/api_key_auth_google.go
View file @
297f08c6
...
@@ -63,6 +63,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
...
@@ -63,6 +63,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
Concurrency
:
apiKey
.
User
.
Concurrency
,
Concurrency
:
apiKey
.
User
.
Concurrency
,
})
})
c
.
Set
(
string
(
ContextKeyUserRole
),
apiKey
.
User
.
Role
)
c
.
Set
(
string
(
ContextKeyUserRole
),
apiKey
.
User
.
Role
)
setGroupContext
(
c
,
apiKey
.
Group
)
c
.
Next
()
c
.
Next
()
return
return
}
}
...
@@ -102,6 +103,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
...
@@ -102,6 +103,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
Concurrency
:
apiKey
.
User
.
Concurrency
,
Concurrency
:
apiKey
.
User
.
Concurrency
,
})
})
c
.
Set
(
string
(
ContextKeyUserRole
),
apiKey
.
User
.
Role
)
c
.
Set
(
string
(
ContextKeyUserRole
),
apiKey
.
User
.
Role
)
setGroupContext
(
c
,
apiKey
.
Group
)
c
.
Next
()
c
.
Next
()
}
}
}
}
...
...
backend/internal/server/middleware/api_key_auth_google_test.go
View file @
297f08c6
...
@@ -9,6 +9,7 @@ import (
...
@@ -9,6 +9,7 @@ import (
"testing"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/service"
...
@@ -133,6 +134,70 @@ func TestApiKeyAuthWithSubscriptionGoogle_QueryApiKeyRejected(t *testing.T) {
...
@@ -133,6 +134,70 @@ func TestApiKeyAuthWithSubscriptionGoogle_QueryApiKeyRejected(t *testing.T) {
require
.
Equal
(
t
,
"INVALID_ARGUMENT"
,
resp
.
Error
.
Status
)
require
.
Equal
(
t
,
"INVALID_ARGUMENT"
,
resp
.
Error
.
Status
)
}
}
func
TestApiKeyAuthWithSubscriptionGoogleSetsGroupContext
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
group
:=
&
service
.
Group
{
ID
:
99
,
Name
:
"g1"
,
Status
:
service
.
StatusActive
,
Platform
:
service
.
PlatformGemini
,
Hydrated
:
true
,
}
user
:=
&
service
.
User
{
ID
:
7
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
Balance
:
10
,
Concurrency
:
3
,
}
apiKey
:=
&
service
.
APIKey
{
ID
:
100
,
UserID
:
user
.
ID
,
Key
:
"test-key"
,
Status
:
service
.
StatusActive
,
User
:
user
,
Group
:
group
,
}
apiKey
.
GroupID
=
&
group
.
ID
apiKeyService
:=
service
.
NewAPIKeyService
(
fakeAPIKeyRepo
{
getByKey
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
if
key
!=
apiKey
.
Key
{
return
nil
,
service
.
ErrAPIKeyNotFound
}
clone
:=
*
apiKey
return
&
clone
,
nil
},
},
nil
,
nil
,
nil
,
nil
,
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
},
)
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
r
:=
gin
.
New
()
r
.
Use
(
APIKeyAuthWithSubscriptionGoogle
(
apiKeyService
,
nil
,
cfg
))
r
.
GET
(
"/v1beta/test"
,
func
(
c
*
gin
.
Context
)
{
groupFromCtx
,
ok
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
Group
)
.
(
*
service
.
Group
)
if
!
ok
||
groupFromCtx
==
nil
||
groupFromCtx
.
ID
!=
group
.
ID
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"ok"
:
false
})
return
}
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"ok"
:
true
})
})
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/v1beta/test"
,
nil
)
req
.
Header
.
Set
(
"x-api-key"
,
apiKey
.
Key
)
rec
:=
httptest
.
NewRecorder
()
r
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
}
func
TestApiKeyAuthWithSubscriptionGoogle_QueryKeyAllowedOnV1Beta
(
t
*
testing
.
T
)
{
func
TestApiKeyAuthWithSubscriptionGoogle_QueryKeyAllowedOnV1Beta
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
gin
.
SetMode
(
gin
.
TestMode
)
...
...
backend/internal/server/middleware/api_key_auth_test.go
View file @
297f08c6
...
@@ -11,6 +11,7 @@ import (
...
@@ -11,6 +11,7 @@ import (
"time"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin"
...
@@ -25,6 +26,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
...
@@ -25,6 +26,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
ID
:
42
,
ID
:
42
,
Name
:
"sub"
,
Name
:
"sub"
,
Status
:
service
.
StatusActive
,
Status
:
service
.
StatusActive
,
Hydrated
:
true
,
SubscriptionType
:
service
.
SubscriptionTypeSubscription
,
SubscriptionType
:
service
.
SubscriptionTypeSubscription
,
DailyLimitUSD
:
&
limit
,
DailyLimitUSD
:
&
limit
,
}
}
...
@@ -110,6 +112,129 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
...
@@ -110,6 +112,129 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
})
})
}
}
func
TestAPIKeyAuthSetsGroupContext
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
group
:=
&
service
.
Group
{
ID
:
101
,
Name
:
"g1"
,
Status
:
service
.
StatusActive
,
Platform
:
service
.
PlatformAnthropic
,
Hydrated
:
true
,
}
user
:=
&
service
.
User
{
ID
:
7
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
Balance
:
10
,
Concurrency
:
3
,
}
apiKey
:=
&
service
.
APIKey
{
ID
:
100
,
UserID
:
user
.
ID
,
Key
:
"test-key"
,
Status
:
service
.
StatusActive
,
User
:
user
,
Group
:
group
,
}
apiKey
.
GroupID
=
&
group
.
ID
apiKeyRepo
:=
&
stubApiKeyRepo
{
getByKey
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
if
key
!=
apiKey
.
Key
{
return
nil
,
service
.
ErrAPIKeyNotFound
}
clone
:=
*
apiKey
return
&
clone
,
nil
},
}
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
apiKeyService
:=
service
.
NewAPIKeyService
(
apiKeyRepo
,
nil
,
nil
,
nil
,
nil
,
cfg
)
router
:=
gin
.
New
()
router
.
Use
(
gin
.
HandlerFunc
(
NewAPIKeyAuthMiddleware
(
apiKeyService
,
nil
,
cfg
)))
router
.
GET
(
"/t"
,
func
(
c
*
gin
.
Context
)
{
groupFromCtx
,
ok
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
Group
)
.
(
*
service
.
Group
)
if
!
ok
||
groupFromCtx
==
nil
||
groupFromCtx
.
ID
!=
group
.
ID
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"ok"
:
false
})
return
}
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"ok"
:
true
})
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
req
.
Header
.
Set
(
"x-api-key"
,
apiKey
.
Key
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
}
func
TestAPIKeyAuthOverwritesInvalidContextGroup
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
group
:=
&
service
.
Group
{
ID
:
101
,
Name
:
"g1"
,
Status
:
service
.
StatusActive
,
Platform
:
service
.
PlatformAnthropic
,
Hydrated
:
true
,
}
user
:=
&
service
.
User
{
ID
:
7
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
Balance
:
10
,
Concurrency
:
3
,
}
apiKey
:=
&
service
.
APIKey
{
ID
:
100
,
UserID
:
user
.
ID
,
Key
:
"test-key"
,
Status
:
service
.
StatusActive
,
User
:
user
,
Group
:
group
,
}
apiKey
.
GroupID
=
&
group
.
ID
apiKeyRepo
:=
&
stubApiKeyRepo
{
getByKey
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
if
key
!=
apiKey
.
Key
{
return
nil
,
service
.
ErrAPIKeyNotFound
}
clone
:=
*
apiKey
return
&
clone
,
nil
},
}
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
apiKeyService
:=
service
.
NewAPIKeyService
(
apiKeyRepo
,
nil
,
nil
,
nil
,
nil
,
cfg
)
router
:=
gin
.
New
()
router
.
Use
(
gin
.
HandlerFunc
(
NewAPIKeyAuthMiddleware
(
apiKeyService
,
nil
,
cfg
)))
invalidGroup
:=
&
service
.
Group
{
ID
:
group
.
ID
,
Platform
:
group
.
Platform
,
Status
:
group
.
Status
,
}
router
.
GET
(
"/t"
,
func
(
c
*
gin
.
Context
)
{
groupFromCtx
,
ok
:=
c
.
Request
.
Context
()
.
Value
(
ctxkey
.
Group
)
.
(
*
service
.
Group
)
if
!
ok
||
groupFromCtx
==
nil
||
groupFromCtx
.
ID
!=
group
.
ID
||
!
groupFromCtx
.
Hydrated
||
groupFromCtx
==
invalidGroup
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"ok"
:
false
})
return
}
c
.
JSON
(
http
.
StatusOK
,
gin
.
H
{
"ok"
:
true
})
})
w
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/t"
,
nil
)
req
.
Header
.
Set
(
"x-api-key"
,
apiKey
.
Key
)
req
=
req
.
WithContext
(
context
.
WithValue
(
req
.
Context
(),
ctxkey
.
Group
,
invalidGroup
))
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
)
}
func
newAuthTestRouter
(
apiKeyService
*
service
.
APIKeyService
,
subscriptionService
*
service
.
SubscriptionService
,
cfg
*
config
.
Config
)
*
gin
.
Engine
{
func
newAuthTestRouter
(
apiKeyService
*
service
.
APIKeyService
,
subscriptionService
*
service
.
SubscriptionService
,
cfg
*
config
.
Config
)
*
gin
.
Engine
{
router
:=
gin
.
New
()
router
:=
gin
.
New
()
router
.
Use
(
gin
.
HandlerFunc
(
NewAPIKeyAuthMiddleware
(
apiKeyService
,
subscriptionService
,
cfg
)))
router
.
Use
(
gin
.
HandlerFunc
(
NewAPIKeyAuthMiddleware
(
apiKeyService
,
subscriptionService
,
cfg
)))
...
...
backend/internal/service/admin_service.go
View file @
297f08c6
...
@@ -576,18 +576,33 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro
...
@@ -576,18 +576,33 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro
return
fmt
.
Errorf
(
"cannot set self as fallback group"
)
return
fmt
.
Errorf
(
"cannot set self as fallback group"
)
}
}
// 检查降级分组是否存在
visited
:=
map
[
int64
]
struct
{}{}
fallbackGroup
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
fallbackGroupID
)
nextID
:=
fallbackGroupID
if
err
!=
nil
{
for
{
return
fmt
.
Errorf
(
"fallback group not found: %w"
,
err
)
if
_
,
seen
:=
visited
[
nextID
];
seen
{
}
return
fmt
.
Errorf
(
"fallback group cycle detected"
)
}
visited
[
nextID
]
=
struct
{}{}
if
currentGroupID
>
0
&&
nextID
==
currentGroupID
{
return
fmt
.
Errorf
(
"fallback group cycle detected"
)
}
// 降级分组不能启用 claude_code_only,否则会造成死循环
// 检查降级分组是否存在
if
fallbackGroup
.
ClaudeCodeOnly
{
fallbackGroup
,
err
:=
s
.
groupRepo
.
GetByIDLite
(
ctx
,
nextID
)
return
fmt
.
Errorf
(
"fallback group cannot have claude_code_only enabled"
)
if
err
!=
nil
{
}
return
fmt
.
Errorf
(
"fallback group not found: %w"
,
err
)
}
return
nil
// 降级分组不能启用 claude_code_only,否则会造成死循环
if
nextID
==
fallbackGroupID
&&
fallbackGroup
.
ClaudeCodeOnly
{
return
fmt
.
Errorf
(
"fallback group cannot have claude_code_only enabled"
)
}
if
fallbackGroup
.
FallbackGroupID
==
nil
{
return
nil
}
nextID
=
*
fallbackGroup
.
FallbackGroupID
}
}
}
func
(
s
*
adminServiceImpl
)
UpdateGroup
(
ctx
context
.
Context
,
id
int64
,
input
*
UpdateGroupInput
)
(
*
Group
,
error
)
{
func
(
s
*
adminServiceImpl
)
UpdateGroup
(
ctx
context
.
Context
,
id
int64
,
input
*
UpdateGroupInput
)
(
*
Group
,
error
)
{
...
...
backend/internal/service/admin_service_delete_test.go
View file @
297f08c6
...
@@ -107,6 +107,10 @@ func (s *groupRepoStub) GetByID(ctx context.Context, id int64) (*Group, error) {
...
@@ -107,6 +107,10 @@ func (s *groupRepoStub) GetByID(ctx context.Context, id int64) (*Group, error) {
panic
(
"unexpected GetByID call"
)
panic
(
"unexpected GetByID call"
)
}
}
func
(
s
*
groupRepoStub
)
GetByIDLite
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
panic
(
"unexpected GetByIDLite call"
)
}
func
(
s
*
groupRepoStub
)
Update
(
ctx
context
.
Context
,
group
*
Group
)
error
{
func
(
s
*
groupRepoStub
)
Update
(
ctx
context
.
Context
,
group
*
Group
)
error
{
panic
(
"unexpected Update call"
)
panic
(
"unexpected Update call"
)
}
}
...
...
backend/internal/service/admin_service_group_test.go
View file @
297f08c6
...
@@ -45,6 +45,13 @@ func (s *groupRepoStubForAdmin) GetByID(_ context.Context, _ int64) (*Group, err
...
@@ -45,6 +45,13 @@ func (s *groupRepoStubForAdmin) GetByID(_ context.Context, _ int64) (*Group, err
return
s
.
getByID
,
nil
return
s
.
getByID
,
nil
}
}
func
(
s
*
groupRepoStubForAdmin
)
GetByIDLite
(
_
context
.
Context
,
_
int64
)
(
*
Group
,
error
)
{
if
s
.
getErr
!=
nil
{
return
nil
,
s
.
getErr
}
return
s
.
getByID
,
nil
}
func
(
s
*
groupRepoStubForAdmin
)
Delete
(
_
context
.
Context
,
_
int64
)
error
{
func
(
s
*
groupRepoStubForAdmin
)
Delete
(
_
context
.
Context
,
_
int64
)
error
{
panic
(
"unexpected Delete call"
)
panic
(
"unexpected Delete call"
)
}
}
...
@@ -290,3 +297,84 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) {
...
@@ -290,3 +297,84 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) {
require
.
True
(
t
,
*
repo
.
listWithFiltersIsExclusive
)
require
.
True
(
t
,
*
repo
.
listWithFiltersIsExclusive
)
})
})
}
}
func
TestAdminService_ValidateFallbackGroup_DetectsCycle
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
1
)
fallbackID
:=
int64
(
2
)
repo
:=
&
groupRepoStubForFallbackCycle
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
{
ID
:
groupID
,
FallbackGroupID
:
&
fallbackID
,
},
fallbackID
:
{
ID
:
fallbackID
,
FallbackGroupID
:
&
groupID
,
},
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
err
:=
svc
.
validateFallbackGroup
(
context
.
Background
(),
groupID
,
fallbackID
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"fallback group cycle"
)
}
type
groupRepoStubForFallbackCycle
struct
{
groups
map
[
int64
]
*
Group
}
func
(
s
*
groupRepoStubForFallbackCycle
)
Create
(
_
context
.
Context
,
_
*
Group
)
error
{
panic
(
"unexpected Create call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
Update
(
_
context
.
Context
,
_
*
Group
)
error
{
panic
(
"unexpected Update call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
return
s
.
GetByIDLite
(
ctx
,
id
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
GetByIDLite
(
_
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
if
g
,
ok
:=
s
.
groups
[
id
];
ok
{
return
g
,
nil
}
return
nil
,
ErrGroupNotFound
}
func
(
s
*
groupRepoStubForFallbackCycle
)
Delete
(
_
context
.
Context
,
_
int64
)
error
{
panic
(
"unexpected Delete call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
DeleteCascade
(
_
context
.
Context
,
_
int64
)
([]
int64
,
error
)
{
panic
(
"unexpected DeleteCascade call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
List
(
_
context
.
Context
,
_
pagination
.
PaginationParams
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
panic
(
"unexpected List call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
ListWithFilters
(
_
context
.
Context
,
_
pagination
.
PaginationParams
,
_
,
_
string
,
_
*
bool
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
panic
(
"unexpected ListWithFilters call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
ListActive
(
_
context
.
Context
)
([]
Group
,
error
)
{
panic
(
"unexpected ListActive call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
ListActiveByPlatform
(
_
context
.
Context
,
_
string
)
([]
Group
,
error
)
{
panic
(
"unexpected ListActiveByPlatform call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
ExistsByName
(
_
context
.
Context
,
_
string
)
(
bool
,
error
)
{
panic
(
"unexpected ExistsByName call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
GetAccountCount
(
_
context
.
Context
,
_
int64
)
(
int64
,
error
)
{
panic
(
"unexpected GetAccountCount call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
DeleteAccountGroupsByGroupID
(
_
context
.
Context
,
_
int64
)
(
int64
,
error
)
{
panic
(
"unexpected DeleteAccountGroupsByGroupID call"
)
}
backend/internal/service/gateway_multiplatform_test.go
View file @
297f08c6
...
@@ -9,6 +9,7 @@ import (
...
@@ -9,6 +9,7 @@ import (
"time"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/require"
)
)
...
@@ -191,6 +192,56 @@ func (m *mockGatewayCacheForPlatform) RefreshSessionTTL(ctx context.Context, gro
...
@@ -191,6 +192,56 @@ func (m *mockGatewayCacheForPlatform) RefreshSessionTTL(ctx context.Context, gro
return
nil
return
nil
}
}
type
mockGroupRepoForGateway
struct
{
groups
map
[
int64
]
*
Group
getByIDCalls
int
getByIDLiteCalls
int
}
func
(
m
*
mockGroupRepoForGateway
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
m
.
getByIDCalls
++
if
g
,
ok
:=
m
.
groups
[
id
];
ok
{
return
g
,
nil
}
return
nil
,
ErrGroupNotFound
}
func
(
m
*
mockGroupRepoForGateway
)
GetByIDLite
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
m
.
getByIDLiteCalls
++
if
g
,
ok
:=
m
.
groups
[
id
];
ok
{
return
g
,
nil
}
return
nil
,
ErrGroupNotFound
}
func
(
m
*
mockGroupRepoForGateway
)
Create
(
ctx
context
.
Context
,
group
*
Group
)
error
{
return
nil
}
func
(
m
*
mockGroupRepoForGateway
)
Update
(
ctx
context
.
Context
,
group
*
Group
)
error
{
return
nil
}
func
(
m
*
mockGroupRepoForGateway
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
m
*
mockGroupRepoForGateway
)
DeleteCascade
(
ctx
context
.
Context
,
id
int64
)
([]
int64
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockGroupRepoForGateway
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
m
*
mockGroupRepoForGateway
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
status
string
,
isExclusive
*
bool
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
m
*
mockGroupRepoForGateway
)
ListActive
(
ctx
context
.
Context
)
([]
Group
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockGroupRepoForGateway
)
ListActiveByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
Group
,
error
)
{
return
nil
,
nil
}
func
(
m
*
mockGroupRepoForGateway
)
ExistsByName
(
ctx
context
.
Context
,
name
string
)
(
bool
,
error
)
{
return
false
,
nil
}
func
(
m
*
mockGroupRepoForGateway
)
GetAccountCount
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
m
*
mockGroupRepoForGateway
)
DeleteAccountGroupsByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
return
0
,
nil
}
func
ptr
[
T
any
](
v
T
)
*
T
{
func
ptr
[
T
any
](
v
T
)
*
T
{
return
&
v
return
&
v
}
}
...
@@ -1019,3 +1070,190 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
...
@@ -1019,3 +1070,190 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
require
.
Contains
(
t
,
err
.
Error
(),
"no available accounts"
)
require
.
Contains
(
t
,
err
.
Error
(),
"no available accounts"
)
})
})
}
}
func
TestGatewayService_GroupResolution_ReusesContextGroup
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
42
)
group
:=
&
Group
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
Hydrated
:
true
,
}
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
Group
,
group
)
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
group
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cfg
:
testConfig
(),
}
account
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
&
groupID
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
account
)
require
.
Equal
(
t
,
0
,
groupRepo
.
getByIDCalls
)
require
.
Equal
(
t
,
0
,
groupRepo
.
getByIDLiteCalls
)
}
func
TestGatewayService_GroupResolution_IgnoresInvalidContextGroup
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
42
)
ctxGroup
:=
&
Group
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
}
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
Group
,
ctxGroup
)
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
group
:=
&
Group
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
}
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
group
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cfg
:
testConfig
(),
}
account
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
&
groupID
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
account
)
require
.
Equal
(
t
,
0
,
groupRepo
.
getByIDCalls
)
require
.
Equal
(
t
,
1
,
groupRepo
.
getByIDLiteCalls
)
}
func
TestGatewayService_GroupContext_OverwritesInvalidContextGroup
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
42
)
invalidGroup
:=
&
Group
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
}
hydratedGroup
:=
&
Group
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
Hydrated
:
true
,
}
ctx
:=
context
.
WithValue
(
context
.
Background
(),
ctxkey
.
Group
,
invalidGroup
)
svc
:=
&
GatewayService
{}
ctx
=
svc
.
withGroupContext
(
ctx
,
hydratedGroup
)
got
,
ok
:=
ctx
.
Value
(
ctxkey
.
Group
)
.
(
*
Group
)
require
.
True
(
t
,
ok
)
require
.
Same
(
t
,
hydratedGroup
,
got
)
}
func
TestGatewayService_GroupResolution_FallbackUsesLiteOnce
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
10
)
fallbackID
:=
int64
(
11
)
group
:=
&
Group
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
ClaudeCodeOnly
:
true
,
FallbackGroupID
:
&
fallbackID
,
Hydrated
:
true
,
}
fallbackGroup
:=
&
Group
{
ID
:
fallbackID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
}
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
Group
,
group
)
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
fallbackID
:
fallbackGroup
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cfg
:
testConfig
(),
}
account
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
&
groupID
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
account
)
require
.
Equal
(
t
,
0
,
groupRepo
.
getByIDCalls
)
require
.
Equal
(
t
,
1
,
groupRepo
.
getByIDLiteCalls
)
}
func
TestGatewayService_ResolveGatewayGroup_DetectsFallbackCycle
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
10
)
fallbackID
:=
int64
(
11
)
group
:=
&
Group
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
ClaudeCodeOnly
:
true
,
FallbackGroupID
:
&
fallbackID
,
}
fallbackGroup
:=
&
Group
{
ID
:
fallbackID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
ClaudeCodeOnly
:
true
,
FallbackGroupID
:
&
groupID
,
}
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
group
,
fallbackID
:
fallbackGroup
,
},
}
svc
:=
&
GatewayService
{
groupRepo
:
groupRepo
,
}
gotGroup
,
gotID
,
err
:=
svc
.
resolveGatewayGroup
(
ctx
,
&
groupID
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
gotGroup
)
require
.
Nil
(
t
,
gotID
)
require
.
Contains
(
t
,
err
.
Error
(),
"fallback group cycle"
)
}
backend/internal/service/gateway_service.go
View file @
297f08c6
...
@@ -361,27 +361,13 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
...
@@ -361,27 +361,13 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
if
hasForcePlatform
&&
forcePlatform
!=
""
{
if
hasForcePlatform
&&
forcePlatform
!=
""
{
platform
=
forcePlatform
platform
=
forcePlatform
}
else
if
groupID
!=
nil
{
}
else
if
groupID
!=
nil
{
// 根据分组 platform 决定查询哪种账号
group
,
resolvedGroupID
,
err
:=
s
.
resolveGatewayGroup
(
ctx
,
groupID
)
group
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
*
groupID
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get group failed: %w"
,
err
)
return
nil
,
err
}
}
groupID
=
resolvedGroupID
ctx
=
s
.
withGroupContext
(
ctx
,
group
)
platform
=
group
.
Platform
platform
=
group
.
Platform
// 检查 Claude Code 客户端限制
if
group
.
ClaudeCodeOnly
{
isClaudeCode
:=
IsClaudeCodeClient
(
ctx
)
if
!
isClaudeCode
{
// 非 Claude Code 客户端,检查是否有降级分组
if
group
.
FallbackGroupID
!=
nil
{
// 使用降级分组重新调度
fallbackGroupID
:=
*
group
.
FallbackGroupID
return
s
.
SelectAccountForModelWithExclusions
(
ctx
,
&
fallbackGroupID
,
sessionHash
,
requestedModel
,
excludedIDs
)
}
// 无降级分组,拒绝访问
return
nil
,
ErrClaudeCodeOnly
}
}
}
else
{
}
else
{
// 无分组时只使用原生 anthropic 平台
// 无分组时只使用原生 anthropic 平台
platform
=
PlatformAnthropic
platform
=
PlatformAnthropic
...
@@ -409,10 +395,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
...
@@ -409,10 +395,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
}
// 检查 Claude Code 客户端限制(可能会替换 groupID 为降级分组)
// 检查 Claude Code 客户端限制(可能会替换 groupID 为降级分组)
groupID
,
err
:=
s
.
checkClaudeCodeRestriction
(
ctx
,
groupID
)
group
,
groupID
,
err
:=
s
.
checkClaudeCodeRestriction
(
ctx
,
groupID
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
ctx
=
s
.
withGroupContext
(
ctx
,
group
)
if
s
.
concurrencyService
==
nil
||
!
cfg
.
LoadBatchEnabled
{
if
s
.
concurrencyService
==
nil
||
!
cfg
.
LoadBatchEnabled
{
account
,
err
:=
s
.
SelectAccountForModelWithExclusions
(
ctx
,
groupID
,
sessionHash
,
requestedModel
,
excludedIDs
)
account
,
err
:=
s
.
SelectAccountForModelWithExclusions
(
ctx
,
groupID
,
sessionHash
,
requestedModel
,
excludedIDs
)
...
@@ -452,7 +439,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
...
@@ -452,7 +439,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
},
nil
},
nil
}
}
platform
,
hasForcePlatform
,
err
:=
s
.
resolvePlatform
(
ctx
,
groupID
)
platform
,
hasForcePlatform
,
err
:=
s
.
resolvePlatform
(
ctx
,
groupID
,
group
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
...
@@ -655,51 +642,97 @@ func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig {
...
@@ -655,51 +642,97 @@ func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig {
}
}
}
}
// checkClaudeCodeRestriction 检查分组的 Claude Code 客户端限制
func
(
s
*
GatewayService
)
withGroupContext
(
ctx
context
.
Context
,
group
*
Group
)
context
.
Context
{
// 如果分组启用了 claude_code_only 且请求不是来自 Claude Code 客户端:
if
!
IsGroupContextValid
(
group
)
{
// - 有降级分组:返回降级分组的 ID
return
ctx
// - 无降级分组:返回 ErrClaudeCodeOnly 错误
}
func
(
s
*
GatewayService
)
checkClaudeCodeRestriction
(
ctx
context
.
Context
,
groupID
*
int64
)
(
*
int64
,
error
)
{
if
existing
,
ok
:=
ctx
.
Value
(
ctxkey
.
Group
)
.
(
*
Group
);
ok
&&
existing
!=
nil
&&
existing
.
ID
==
group
.
ID
&&
IsGroupContextValid
(
existing
)
{
if
groupID
==
nil
{
return
ctx
return
groupID
,
nil
}
}
return
context
.
WithValue
(
ctx
,
ctxkey
.
Group
,
group
)
}
// 强制平台模式不检查 Claude Code 限制
func
(
s
*
GatewayService
)
groupFromContext
(
ctx
context
.
Context
,
groupID
int64
)
*
Group
{
if
_
,
hasForcePlatform
:=
ctx
.
Value
(
ctxkey
.
ForcePlatform
)
.
(
string
);
hasForcePlatform
{
if
group
,
ok
:=
ctx
.
Value
(
ctxkey
.
Group
)
.
(
*
Group
);
ok
&&
IsGroupContextValid
(
group
)
&&
group
.
ID
==
groupID
{
return
group
ID
,
nil
return
group
}
}
return
nil
}
group
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
*
groupID
)
func
(
s
*
GatewayService
)
resolveGroupByID
(
ctx
context
.
Context
,
groupID
int64
)
(
*
Group
,
error
)
{
if
group
:=
s
.
groupFromContext
(
ctx
,
groupID
);
group
!=
nil
{
return
group
,
nil
}
group
,
err
:=
s
.
groupRepo
.
GetByIDLite
(
ctx
,
groupID
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get group failed: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"get group failed: %w"
,
err
)
}
}
return
group
,
nil
}
if
!
group
.
ClaudeCodeOnly
{
func
(
s
*
GatewayService
)
resolveGatewayGroup
(
ctx
context
.
Context
,
groupID
*
int64
)
(
*
Group
,
*
int64
,
error
)
{
return
groupID
,
nil
if
groupID
==
nil
{
return
nil
,
nil
,
nil
}
}
// 分组启用了 Claude Code 限制
currentID
:=
*
groupID
if
IsClaudeCodeClient
(
ctx
)
{
visited
:=
map
[
int64
]
struct
{}{}
return
groupID
,
nil
for
{
if
_
,
seen
:=
visited
[
currentID
];
seen
{
return
nil
,
nil
,
fmt
.
Errorf
(
"fallback group cycle detected"
)
}
visited
[
currentID
]
=
struct
{}{}
group
,
err
:=
s
.
resolveGroupByID
(
ctx
,
currentID
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
if
!
group
.
ClaudeCodeOnly
||
IsClaudeCodeClient
(
ctx
)
{
return
group
,
&
currentID
,
nil
}
if
group
.
FallbackGroupID
==
nil
{
return
nil
,
nil
,
ErrClaudeCodeOnly
}
currentID
=
*
group
.
FallbackGroupID
}
}
}
// 非 Claude Code 客户端,检查降级分组
// checkClaudeCodeRestriction 检查分组的 Claude Code 客户端限制
if
group
.
FallbackGroupID
!=
nil
{
// 如果分组启用了 claude_code_only 且请求不是来自 Claude Code 客户端:
return
group
.
FallbackGroupID
,
nil
// - 有降级分组:返回降级分组的 ID
// - 无降级分组:返回 ErrClaudeCodeOnly 错误
func
(
s
*
GatewayService
)
checkClaudeCodeRestriction
(
ctx
context
.
Context
,
groupID
*
int64
)
(
*
Group
,
*
int64
,
error
)
{
if
groupID
==
nil
{
return
nil
,
groupID
,
nil
}
// 强制平台模式不检查 Claude Code 限制
if
_
,
hasForcePlatform
:=
ctx
.
Value
(
ctxkey
.
ForcePlatform
)
.
(
string
);
hasForcePlatform
{
return
nil
,
groupID
,
nil
}
}
return
nil
,
ErrClaudeCodeOnly
group
,
resolvedID
,
err
:=
s
.
resolveGatewayGroup
(
ctx
,
groupID
)
if
err
!=
nil
{
return
nil
,
nil
,
err
}
return
group
,
resolvedID
,
nil
}
}
func
(
s
*
GatewayService
)
resolvePlatform
(
ctx
context
.
Context
,
groupID
*
int64
)
(
string
,
bool
,
error
)
{
func
(
s
*
GatewayService
)
resolvePlatform
(
ctx
context
.
Context
,
groupID
*
int64
,
group
*
Group
)
(
string
,
bool
,
error
)
{
forcePlatform
,
hasForcePlatform
:=
ctx
.
Value
(
ctxkey
.
ForcePlatform
)
.
(
string
)
forcePlatform
,
hasForcePlatform
:=
ctx
.
Value
(
ctxkey
.
ForcePlatform
)
.
(
string
)
if
hasForcePlatform
&&
forcePlatform
!=
""
{
if
hasForcePlatform
&&
forcePlatform
!=
""
{
return
forcePlatform
,
true
,
nil
return
forcePlatform
,
true
,
nil
}
}
if
group
!=
nil
{
return
group
.
Platform
,
false
,
nil
}
if
groupID
!=
nil
{
if
groupID
!=
nil
{
group
,
err
:=
s
.
groupRepo
.
Get
ByID
(
ctx
,
*
groupID
)
group
,
err
:=
s
.
resolveGroup
ByID
(
ctx
,
*
groupID
)
if
err
!=
nil
{
if
err
!=
nil
{
return
""
,
false
,
fmt
.
Errorf
(
"get group failed: %w"
,
err
)
return
""
,
false
,
err
}
}
return
group
.
Platform
,
false
,
nil
return
group
.
Platform
,
false
,
nil
}
}
...
...
backend/internal/service/gemini_messages_compat_service.go
View file @
297f08c6
...
@@ -86,9 +86,15 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
...
@@ -86,9 +86,15 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
platform
=
forcePlatform
platform
=
forcePlatform
}
else
if
groupID
!=
nil
{
}
else
if
groupID
!=
nil
{
// 根据分组 platform 决定查询哪种账号
// 根据分组 platform 决定查询哪种账号
group
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
*
groupID
)
var
group
*
Group
if
err
!=
nil
{
if
ctxGroup
,
ok
:=
ctx
.
Value
(
ctxkey
.
Group
)
.
(
*
Group
);
ok
&&
IsGroupContextValid
(
ctxGroup
)
&&
ctxGroup
.
ID
==
*
groupID
{
return
nil
,
fmt
.
Errorf
(
"get group failed: %w"
,
err
)
group
=
ctxGroup
}
else
{
var
err
error
group
,
err
=
s
.
groupRepo
.
GetByIDLite
(
ctx
,
*
groupID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get group failed: %w"
,
err
)
}
}
}
platform
=
group
.
Platform
platform
=
group
.
Platform
}
else
{
}
else
{
...
...
backend/internal/service/gemini_multiplatform_test.go
View file @
297f08c6
...
@@ -8,6 +8,7 @@ import (
...
@@ -8,6 +8,7 @@ import (
"testing"
"testing"
"time"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/require"
)
)
...
@@ -152,10 +153,21 @@ var _ AccountRepository = (*mockAccountRepoForGemini)(nil)
...
@@ -152,10 +153,21 @@ var _ AccountRepository = (*mockAccountRepoForGemini)(nil)
// mockGroupRepoForGemini Gemini 测试用的 group repo mock
// mockGroupRepoForGemini Gemini 测试用的 group repo mock
type
mockGroupRepoForGemini
struct
{
type
mockGroupRepoForGemini
struct
{
groups
map
[
int64
]
*
Group
groups
map
[
int64
]
*
Group
getByIDCalls
int
getByIDLiteCalls
int
}
}
func
(
m
*
mockGroupRepoForGemini
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
func
(
m
*
mockGroupRepoForGemini
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
m
.
getByIDCalls
++
if
g
,
ok
:=
m
.
groups
[
id
];
ok
{
return
g
,
nil
}
return
nil
,
errors
.
New
(
"group not found"
)
}
func
(
m
*
mockGroupRepoForGemini
)
GetByIDLite
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
m
.
getByIDLiteCalls
++
if
g
,
ok
:=
m
.
groups
[
id
];
ok
{
if
g
,
ok
:=
m
.
groups
[
id
];
ok
{
return
g
,
nil
return
g
,
nil
}
}
...
@@ -248,6 +260,77 @@ func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiP
...
@@ -248,6 +260,77 @@ func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiP
require
.
Equal
(
t
,
PlatformGemini
,
acc
.
Platform
,
"无分组时应只返回 gemini 平台账户"
)
require
.
Equal
(
t
,
PlatformGemini
,
acc
.
Platform
,
"无分组时应只返回 gemini 平台账户"
)
}
}
func
TestGeminiMessagesCompatService_GroupResolution_ReusesContextGroup
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
7
)
group
:=
&
Group
{
ID
:
groupID
,
Platform
:
PlatformGemini
,
Status
:
StatusActive
,
Hydrated
:
true
,
}
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
Group
,
group
)
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForGemini
{}
groupRepo
:=
&
mockGroupRepoForGemini
{
groups
:
map
[
int64
]
*
Group
{}}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
&
groupID
,
""
,
"gemini-2.5-flash"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
0
,
groupRepo
.
getByIDCalls
)
require
.
Equal
(
t
,
0
,
groupRepo
.
getByIDLiteCalls
)
}
func
TestGeminiMessagesCompatService_GroupResolution_UsesLiteFetch
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
7
)
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForGemini
{}
groupRepo
:=
&
mockGroupRepoForGemini
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
{
ID
:
groupID
,
Platform
:
PlatformGemini
},
},
}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
&
groupID
,
""
,
"gemini-2.5-flash"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
0
,
groupRepo
.
getByIDCalls
)
require
.
Equal
(
t
,
1
,
groupRepo
.
getByIDLiteCalls
)
}
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_AntigravityGroup 测试 antigravity 分组
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_AntigravityGroup 测试 antigravity 分组
func
TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_AntigravityGroup
(
t
*
testing
.
T
)
{
func
TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_AntigravityGroup
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
ctx
:=
context
.
Background
()
...
...
backend/internal/service/group.go
View file @
297f08c6
...
@@ -10,6 +10,7 @@ type Group struct {
...
@@ -10,6 +10,7 @@ type Group struct {
RateMultiplier
float64
RateMultiplier
float64
IsExclusive
bool
IsExclusive
bool
Status
string
Status
string
Hydrated
bool
// indicates the group was loaded from a trusted repository source
SubscriptionType
string
SubscriptionType
string
DailyLimitUSD
*
float64
DailyLimitUSD
*
float64
...
@@ -72,3 +73,20 @@ func (g *Group) GetImagePrice(imageSize string) *float64 {
...
@@ -72,3 +73,20 @@ func (g *Group) GetImagePrice(imageSize string) *float64 {
return
g
.
ImagePrice2K
return
g
.
ImagePrice2K
}
}
}
}
// IsGroupContextValid reports whether a group from context has the fields required for routing decisions.
func
IsGroupContextValid
(
group
*
Group
)
bool
{
if
group
==
nil
{
return
false
}
if
group
.
ID
<=
0
{
return
false
}
if
!
group
.
Hydrated
{
return
false
}
if
group
.
Platform
==
""
||
group
.
Status
==
""
{
return
false
}
return
true
}
backend/internal/service/group_service.go
View file @
297f08c6
...
@@ -16,6 +16,7 @@ var (
...
@@ -16,6 +16,7 @@ var (
type
GroupRepository
interface
{
type
GroupRepository
interface
{
Create
(
ctx
context
.
Context
,
group
*
Group
)
error
Create
(
ctx
context
.
Context
,
group
*
Group
)
error
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
GetByIDLite
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
Update
(
ctx
context
.
Context
,
group
*
Group
)
error
Update
(
ctx
context
.
Context
,
group
*
Group
)
error
Delete
(
ctx
context
.
Context
,
id
int64
)
error
Delete
(
ctx
context
.
Context
,
id
int64
)
error
DeleteCascade
(
ctx
context
.
Context
,
id
int64
)
([]
int64
,
error
)
DeleteCascade
(
ctx
context
.
Context
,
id
int64
)
([]
int64
,
error
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment