Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
1cd033e5
Commit
1cd033e5
authored
Mar 09, 2026
by
erio
Browse files
style: apply gofmt formatting
Co-Authored-By:
Claude Opus 4.6
<
noreply@anthropic.com
>
parent
e534e9ba
Changes
3
Show whitespace changes
Inline
Side-by-side
backend/internal/repository/gateway_cache.go
View file @
1cd033e5
...
@@ -2,14 +2,42 @@ package repository
...
@@ -2,14 +2,42 @@ package repository
import
(
import
(
"context"
"context"
_
"embed"
"fmt"
"fmt"
"strconv"
"time"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
"github.com/redis/go-redis/v9"
)
)
const
stickySessionPrefix
=
"sticky_session:"
const
(
stickySessionPrefix
=
"sticky_session:"
clientAffinityPrefix
=
"client_affinity:"
clientAffinityReversePrefix
=
"client_affinity_rev:"
)
var
(
//go:embed lua/get_affinity.lua
getAffinityLua
string
//go:embed lua/update_affinity.lua
updateAffinityLua
string
//go:embed lua/get_affinity_count.lua
getAffinityCountLua
string
//go:embed lua/get_affinity_clients.lua
getAffinityClientsLua
string
//go:embed lua/get_affinity_clients_with_scores.lua
getAffinityClientsWithScoresLua
string
//go:embed lua/clear_account_affinity.lua
clearAccountAffinityLua
string
getAffinityScript
=
redis
.
NewScript
(
getAffinityLua
)
updateAffinityScript
=
redis
.
NewScript
(
updateAffinityLua
)
getAffinityCountScript
=
redis
.
NewScript
(
getAffinityCountLua
)
getAffinityClientsScript
=
redis
.
NewScript
(
getAffinityClientsLua
)
getAffinityClientsWithScoresScript
=
redis
.
NewScript
(
getAffinityClientsWithScoresLua
)
clearAccountAffinityScript
=
redis
.
NewScript
(
clearAccountAffinityLua
)
)
type
gatewayCache
struct
{
type
gatewayCache
struct
{
rdb
*
redis
.
Client
rdb
*
redis
.
Client
...
@@ -19,6 +47,16 @@ func NewGatewayCache(rdb *redis.Client) service.GatewayCache {
...
@@ -19,6 +47,16 @@ func NewGatewayCache(rdb *redis.Client) service.GatewayCache {
return
&
gatewayCache
{
rdb
:
rdb
}
return
&
gatewayCache
{
rdb
:
rdb
}
}
}
// ensureScriptLoaded 确保 Lua 脚本已加载到 Redis 服务器的脚本缓存中。
// Pipeline 中的 Script.Run 只发送 EVALSHA,如果 Redis 重启过导致脚本缓存丢失,
// EVALSHA 会返回 NOSCRIPT 错误。此方法提前加载脚本以避免该问题。
func
ensureScriptLoaded
(
ctx
context
.
Context
,
rdb
*
redis
.
Client
,
script
*
redis
.
Script
)
{
exists
,
err
:=
script
.
Exists
(
ctx
,
rdb
)
.
Result
()
if
err
!=
nil
||
len
(
exists
)
==
0
||
!
exists
[
0
]
{
_
=
script
.
Load
(
ctx
,
rdb
)
.
Err
()
}
}
// buildSessionKey 构建 session key,包含 groupID 实现分组隔离
// buildSessionKey 构建 session key,包含 groupID 实现分组隔离
// 格式: sticky_session:{groupID}:{sessionHash}
// 格式: sticky_session:{groupID}:{sessionHash}
func
buildSessionKey
(
groupID
int64
,
sessionHash
string
)
string
{
func
buildSessionKey
(
groupID
int64
,
sessionHash
string
)
string
{
...
@@ -41,13 +79,218 @@ func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, ses
...
@@ -41,13 +79,218 @@ func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, ses
}
}
// DeleteSessionAccountID 删除粘性会话与账号的绑定关系。
// DeleteSessionAccountID 删除粘性会话与账号的绑定关系。
// 当检测到绑定的账号不可用(如状态错误、禁用、不可调度等)时调用,
// 以便下次请求能够重新选择可用账号。
//
// DeleteSessionAccountID removes the sticky session binding for the given session.
// Called when the bound account becomes unavailable (e.g., error status, disabled,
// or unschedulable), allowing subsequent requests to select a new available account.
func
(
c
*
gatewayCache
)
DeleteSessionAccountID
(
ctx
context
.
Context
,
groupID
int64
,
sessionHash
string
)
error
{
func
(
c
*
gatewayCache
)
DeleteSessionAccountID
(
ctx
context
.
Context
,
groupID
int64
,
sessionHash
string
)
error
{
key
:=
buildSessionKey
(
groupID
,
sessionHash
)
key
:=
buildSessionKey
(
groupID
,
sessionHash
)
return
c
.
rdb
.
Del
(
ctx
,
key
)
.
Err
()
return
c
.
rdb
.
Del
(
ctx
,
key
)
.
Err
()
}
}
// buildAffinityKey 构建正向亲和 key(client → accounts)
// 格式: client_affinity:{groupID}:{clientID}
func
buildAffinityKey
(
groupID
int64
,
clientID
string
)
string
{
return
fmt
.
Sprintf
(
"%s%d:%s"
,
clientAffinityPrefix
,
groupID
,
clientID
)
}
// buildAffinityReverseKey 构建反向亲和 key(account → clients)
// 格式: client_affinity_rev:{groupID}:{accountID}
func
buildAffinityReverseKey
(
groupID
int64
,
accountID
int64
)
string
{
return
fmt
.
Sprintf
(
"%s%d:%d"
,
clientAffinityReversePrefix
,
groupID
,
accountID
)
}
func
(
c
*
gatewayCache
)
GetClientAffinityAccounts
(
ctx
context
.
Context
,
groupID
int64
,
clientID
string
,
ttl
time
.
Duration
)
([]
int64
,
error
)
{
key
:=
buildAffinityKey
(
groupID
,
clientID
)
now
:=
time
.
Now
()
.
Unix
()
expireThreshold
:=
now
-
int64
(
ttl
.
Seconds
())
result
,
err
:=
getAffinityScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
key
},
expireThreshold
)
.
StringSlice
()
if
err
!=
nil
{
if
err
==
redis
.
Nil
{
return
nil
,
nil
}
return
nil
,
err
}
accountIDs
:=
make
([]
int64
,
0
,
len
(
result
))
for
_
,
s
:=
range
result
{
id
,
err
:=
strconv
.
ParseInt
(
s
,
10
,
64
)
if
err
!=
nil
{
continue
}
accountIDs
=
append
(
accountIDs
,
id
)
}
return
accountIDs
,
nil
}
func
(
c
*
gatewayCache
)
UpdateClientAffinity
(
ctx
context
.
Context
,
groupID
int64
,
clientID
string
,
accountID
int64
,
ttl
time
.
Duration
)
error
{
fwdKey
:=
buildAffinityKey
(
groupID
,
clientID
)
revKey
:=
buildAffinityReverseKey
(
groupID
,
accountID
)
now
:=
time
.
Now
()
.
Unix
()
ttlSeconds
:=
int64
(
ttl
.
Seconds
())
expireThreshold
:=
now
-
ttlSeconds
return
updateAffinityScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
fwdKey
,
revKey
},
now
,
ttlSeconds
,
accountID
,
expireThreshold
,
clientID
,
)
.
Err
()
}
// GetAccountAffinityCountBatch 批量获取账号的亲和客户端数量(惰性清理过期成员)
func
(
c
*
gatewayCache
)
GetAccountAffinityCountBatch
(
ctx
context
.
Context
,
groupID
int64
,
accountIDs
[]
int64
,
ttl
time
.
Duration
)
(
map
[
int64
]
int64
,
error
)
{
if
len
(
accountIDs
)
==
0
{
return
map
[
int64
]
int64
{},
nil
}
now
:=
time
.
Now
()
.
Unix
()
expireThreshold
:=
now
-
int64
(
ttl
.
Seconds
())
ensureScriptLoaded
(
ctx
,
c
.
rdb
,
getAffinityCountScript
)
pipe
:=
c
.
rdb
.
Pipeline
()
cmds
:=
make
([]
*
redis
.
Cmd
,
len
(
accountIDs
))
for
i
,
accID
:=
range
accountIDs
{
key
:=
buildAffinityReverseKey
(
groupID
,
accID
)
cmds
[
i
]
=
getAffinityCountScript
.
Run
(
ctx
,
pipe
,
[]
string
{
key
},
expireThreshold
)
}
_
,
err
:=
pipe
.
Exec
(
ctx
)
if
err
!=
nil
&&
err
!=
redis
.
Nil
{
return
nil
,
err
}
result
:=
make
(
map
[
int64
]
int64
,
len
(
accountIDs
))
for
i
,
accID
:=
range
accountIDs
{
count
,
_
:=
cmds
[
i
]
.
Int64
()
result
[
accID
]
=
count
}
return
result
,
nil
}
// GetAccountAffinityClientsBatch 批量获取每个账号跨所有分组的亲和客户端列表(去重)。
// accountGroups: map[accountID][]groupID,对每个 (groupID, accountID) 组合查询反向索引。
func
(
c
*
gatewayCache
)
GetAccountAffinityClientsBatch
(
ctx
context
.
Context
,
accountGroups
map
[
int64
][]
int64
,
ttl
time
.
Duration
)
(
map
[
int64
][]
string
,
error
)
{
if
len
(
accountGroups
)
==
0
{
return
map
[
int64
][]
string
{},
nil
}
now
:=
time
.
Now
()
.
Unix
()
expireThreshold
:=
now
-
int64
(
ttl
.
Seconds
())
// 构建所有 (accountID, groupID) 组合的查询
type
queryItem
struct
{
accountID
int64
groupID
int64
}
var
queries
[]
queryItem
for
accID
,
groupIDs
:=
range
accountGroups
{
for
_
,
gID
:=
range
groupIDs
{
queries
=
append
(
queries
,
queryItem
{
accountID
:
accID
,
groupID
:
gID
})
}
}
ensureScriptLoaded
(
ctx
,
c
.
rdb
,
getAffinityClientsScript
)
pipe
:=
c
.
rdb
.
Pipeline
()
cmds
:=
make
([]
*
redis
.
Cmd
,
len
(
queries
))
for
i
,
q
:=
range
queries
{
key
:=
buildAffinityReverseKey
(
q
.
groupID
,
q
.
accountID
)
cmds
[
i
]
=
getAffinityClientsScript
.
Run
(
ctx
,
pipe
,
[]
string
{
key
},
expireThreshold
)
}
_
,
err
:=
pipe
.
Exec
(
ctx
)
if
err
!=
nil
&&
err
!=
redis
.
Nil
{
return
nil
,
err
}
// 合并结果:同一个 accountID 跨多个 group 的 clientID 去重
result
:=
make
(
map
[
int64
][]
string
,
len
(
accountGroups
))
seen
:=
make
(
map
[
int64
]
map
[
string
]
struct
{},
len
(
accountGroups
))
for
i
,
q
:=
range
queries
{
clients
,
_
:=
cmds
[
i
]
.
StringSlice
()
if
len
(
clients
)
==
0
{
continue
}
if
seen
[
q
.
accountID
]
==
nil
{
seen
[
q
.
accountID
]
=
make
(
map
[
string
]
struct
{})
}
for
_
,
clientID
:=
range
clients
{
if
_
,
exists
:=
seen
[
q
.
accountID
][
clientID
];
!
exists
{
seen
[
q
.
accountID
][
clientID
]
=
struct
{}{}
result
[
q
.
accountID
]
=
append
(
result
[
q
.
accountID
],
clientID
)
}
}
}
return
result
,
nil
}
// GetAccountAffinityClientsWithScores 获取单个账号跨所有分组的亲和客户端列表(含最后活跃时间戳,去重取最近)。
func
(
c
*
gatewayCache
)
GetAccountAffinityClientsWithScores
(
ctx
context
.
Context
,
accountID
int64
,
groupIDs
[]
int64
,
ttl
time
.
Duration
,
)
([]
service
.
AffinityClient
,
error
)
{
if
len
(
groupIDs
)
==
0
{
return
nil
,
nil
}
now
:=
time
.
Now
()
.
Unix
()
expireThreshold
:=
now
-
int64
(
ttl
.
Seconds
())
ensureScriptLoaded
(
ctx
,
c
.
rdb
,
getAffinityClientsWithScoresScript
)
pipe
:=
c
.
rdb
.
Pipeline
()
cmds
:=
make
([]
*
redis
.
Cmd
,
len
(
groupIDs
))
for
i
,
gID
:=
range
groupIDs
{
key
:=
buildAffinityReverseKey
(
gID
,
accountID
)
cmds
[
i
]
=
getAffinityClientsWithScoresScript
.
Run
(
ctx
,
pipe
,
[]
string
{
key
},
expireThreshold
)
}
_
,
err
:=
pipe
.
Exec
(
ctx
)
if
err
!=
nil
&&
err
!=
redis
.
Nil
{
return
nil
,
err
}
// 合并跨组结果,同一 clientID 取最近的 lastActive
seen
:=
make
(
map
[
string
]
int64
)
// clientID → max timestamp
for
_
,
cmd
:=
range
cmds
{
vals
,
_
:=
cmd
.
StringSlice
()
// vals 格式: [clientID1, score1, clientID2, score2, ...]
for
j
:=
0
;
j
+
1
<
len
(
vals
);
j
+=
2
{
clientID
:=
vals
[
j
]
ts
,
_
:=
strconv
.
ParseInt
(
vals
[
j
+
1
],
10
,
64
)
if
existing
,
ok
:=
seen
[
clientID
];
!
ok
||
ts
>
existing
{
seen
[
clientID
]
=
ts
}
}
}
result
:=
make
([]
service
.
AffinityClient
,
0
,
len
(
seen
))
for
clientID
,
ts
:=
range
seen
{
result
=
append
(
result
,
service
.
AffinityClient
{
ClientID
:
clientID
,
LastActive
:
time
.
Unix
(
ts
,
0
),
})
}
// 按最后活跃时间降序排序
service
.
SortAffinityClients
(
result
)
return
result
,
nil
}
// ClearAccountAffinity 清除指定账号在所有分组的亲和记录(正向+反向索引)。
// 对每个 groupID 执行 Lua 脚本:读取反向索引获取所有客户端,
// 从每个客户端的正向索引中移除该账号,然后删除反向索引。
func
(
c
*
gatewayCache
)
ClearAccountAffinity
(
ctx
context
.
Context
,
accountID
int64
,
groupIDs
[]
int64
)
error
{
if
len
(
groupIDs
)
==
0
{
return
nil
}
ensureScriptLoaded
(
ctx
,
c
.
rdb
,
clearAccountAffinityScript
)
pipe
:=
c
.
rdb
.
Pipeline
()
for
_
,
gID
:=
range
groupIDs
{
revKey
:=
buildAffinityReverseKey
(
gID
,
accountID
)
clearAccountAffinityScript
.
Run
(
ctx
,
pipe
,
[]
string
{
revKey
},
gID
,
accountID
)
}
_
,
err
:=
pipe
.
Exec
(
ctx
)
if
err
!=
nil
&&
err
!=
redis
.
Nil
{
return
err
}
return
nil
}
backend/internal/service/admin_service_apikey_test.go
View file @
1cd033e5
...
@@ -65,9 +65,6 @@ func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (boo
...
@@ -65,9 +65,6 @@ func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (boo
func
(
s
*
userRepoStubForGroupUpdate
)
RemoveGroupFromAllowedGroups
(
context
.
Context
,
int64
)
(
int64
,
error
)
{
func
(
s
*
userRepoStubForGroupUpdate
)
RemoveGroupFromAllowedGroups
(
context
.
Context
,
int64
)
(
int64
,
error
)
{
panic
(
"unexpected"
)
panic
(
"unexpected"
)
}
}
func
(
s
*
userRepoStubForGroupUpdate
)
RemoveGroupFromUserAllowedGroups
(
context
.
Context
,
int64
,
int64
)
error
{
panic
(
"unexpected"
)
}
func
(
s
*
userRepoStubForGroupUpdate
)
UpdateTotpSecret
(
context
.
Context
,
int64
,
*
string
)
error
{
func
(
s
*
userRepoStubForGroupUpdate
)
UpdateTotpSecret
(
context
.
Context
,
int64
,
*
string
)
error
{
panic
(
"unexpected"
)
panic
(
"unexpected"
)
}
}
...
@@ -131,9 +128,6 @@ func (s *apiKeyRepoStubForGroupUpdate) SearchAPIKeys(context.Context, int64, str
...
@@ -131,9 +128,6 @@ func (s *apiKeyRepoStubForGroupUpdate) SearchAPIKeys(context.Context, int64, str
func
(
s
*
apiKeyRepoStubForGroupUpdate
)
ClearGroupIDByGroupID
(
context
.
Context
,
int64
)
(
int64
,
error
)
{
func
(
s
*
apiKeyRepoStubForGroupUpdate
)
ClearGroupIDByGroupID
(
context
.
Context
,
int64
)
(
int64
,
error
)
{
panic
(
"unexpected"
)
panic
(
"unexpected"
)
}
}
func
(
s
*
apiKeyRepoStubForGroupUpdate
)
UpdateGroupIDByUserAndGroup
(
context
.
Context
,
int64
,
int64
,
int64
)
(
int64
,
error
)
{
panic
(
"unexpected"
)
}
func
(
s
*
apiKeyRepoStubForGroupUpdate
)
CountByGroupID
(
context
.
Context
,
int64
)
(
int64
,
error
)
{
func
(
s
*
apiKeyRepoStubForGroupUpdate
)
CountByGroupID
(
context
.
Context
,
int64
)
(
int64
,
error
)
{
panic
(
"unexpected"
)
panic
(
"unexpected"
)
}
}
...
@@ -200,7 +194,7 @@ func (s *groupRepoStubForGroupUpdate) ListActiveByPlatform(context.Context, stri
...
@@ -200,7 +194,7 @@ func (s *groupRepoStubForGroupUpdate) ListActiveByPlatform(context.Context, stri
func
(
s
*
groupRepoStubForGroupUpdate
)
ExistsByName
(
context
.
Context
,
string
)
(
bool
,
error
)
{
func
(
s
*
groupRepoStubForGroupUpdate
)
ExistsByName
(
context
.
Context
,
string
)
(
bool
,
error
)
{
panic
(
"unexpected"
)
panic
(
"unexpected"
)
}
}
func
(
s
*
groupRepoStubForGroupUpdate
)
GetAccountCount
(
context
.
Context
,
int64
)
(
int64
,
int64
,
error
)
{
func
(
s
*
groupRepoStubForGroupUpdate
)
GetAccountCount
(
context
.
Context
,
int64
)
(
int64
,
error
)
{
panic
(
"unexpected"
)
panic
(
"unexpected"
)
}
}
func
(
s
*
groupRepoStubForGroupUpdate
)
DeleteAccountGroupsByGroupID
(
context
.
Context
,
int64
)
(
int64
,
error
)
{
func
(
s
*
groupRepoStubForGroupUpdate
)
DeleteAccountGroupsByGroupID
(
context
.
Context
,
int64
)
(
int64
,
error
)
{
...
@@ -216,29 +210,6 @@ func (s *groupRepoStubForGroupUpdate) UpdateSortOrders(context.Context, []GroupS
...
@@ -216,29 +210,6 @@ func (s *groupRepoStubForGroupUpdate) UpdateSortOrders(context.Context, []GroupS
panic
(
"unexpected"
)
panic
(
"unexpected"
)
}
}
type
userSubRepoStubForGroupUpdate
struct
{
userSubRepoNoop
getActiveSub
*
UserSubscription
getActiveErr
error
called
bool
calledUserID
int64
calledGroupID
int64
}
func
(
s
*
userSubRepoStubForGroupUpdate
)
GetActiveByUserIDAndGroupID
(
_
context
.
Context
,
userID
,
groupID
int64
)
(
*
UserSubscription
,
error
)
{
s
.
called
=
true
s
.
calledUserID
=
userID
s
.
calledGroupID
=
groupID
if
s
.
getActiveErr
!=
nil
{
return
nil
,
s
.
getActiveErr
}
if
s
.
getActiveSub
==
nil
{
return
nil
,
ErrSubscriptionNotFound
}
clone
:=
*
s
.
getActiveSub
return
&
clone
,
nil
}
// ---------------------------------------------------------------------------
// ---------------------------------------------------------------------------
// Tests
// Tests
// ---------------------------------------------------------------------------
// ---------------------------------------------------------------------------
...
@@ -431,49 +402,14 @@ func TestAdminService_AdminUpdateAPIKeyGroupID_NonExclusiveGroup_NoAllowedGroupU
...
@@ -431,49 +402,14 @@ func TestAdminService_AdminUpdateAPIKeyGroupID_NonExclusiveGroup_NoAllowedGroupU
func
TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_Blocked
(
t
*
testing
.
T
)
{
func
TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_Blocked
(
t
*
testing
.
T
)
{
existing
:=
&
APIKey
{
ID
:
1
,
UserID
:
42
,
Key
:
"sk-test"
,
GroupID
:
nil
}
existing
:=
&
APIKey
{
ID
:
1
,
UserID
:
42
,
Key
:
"sk-test"
,
GroupID
:
nil
}
apiKeyRepo
:=
&
apiKeyRepoStubForGroupUpdate
{
key
:
existing
}
apiKeyRepo
:=
&
apiKeyRepoStubForGroupUpdate
{
key
:
existing
}
groupRepo
:=
&
groupRepoStubForGroupUpdate
{
group
:
&
Group
{
ID
:
10
,
Name
:
"Sub"
,
Status
:
StatusActive
,
IsExclusive
:
false
,
SubscriptionType
:
SubscriptionTypeSubscription
}}
groupRepo
:=
&
groupRepoStubForGroupUpdate
{
group
:
&
Group
{
ID
:
10
,
Name
:
"Sub"
,
Status
:
StatusActive
,
IsExclusive
:
true
,
SubscriptionType
:
SubscriptionTypeSubscription
}}
userRepo
:=
&
userRepoStubForGroupUpdate
{}
userSubRepo
:=
&
userSubRepoStubForGroupUpdate
{
getActiveErr
:
ErrSubscriptionNotFound
}
svc
:=
&
adminServiceImpl
{
apiKeyRepo
:
apiKeyRepo
,
groupRepo
:
groupRepo
,
userRepo
:
userRepo
,
userSubRepo
:
userSubRepo
}
// 无有效订阅时应拒绝绑定
_
,
err
:=
svc
.
AdminUpdateAPIKeyGroupID
(
context
.
Background
(),
1
,
int64Ptr
(
10
))
require
.
Error
(
t
,
err
)
require
.
Equal
(
t
,
"SUBSCRIPTION_REQUIRED"
,
infraerrors
.
Reason
(
err
))
require
.
True
(
t
,
userSubRepo
.
called
)
require
.
Equal
(
t
,
int64
(
42
),
userSubRepo
.
calledUserID
)
require
.
Equal
(
t
,
int64
(
10
),
userSubRepo
.
calledGroupID
)
require
.
False
(
t
,
userRepo
.
addGroupCalled
)
}
func
TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_RequiresRepo
(
t
*
testing
.
T
)
{
existing
:=
&
APIKey
{
ID
:
1
,
UserID
:
42
,
Key
:
"sk-test"
,
GroupID
:
nil
}
apiKeyRepo
:=
&
apiKeyRepoStubForGroupUpdate
{
key
:
existing
}
groupRepo
:=
&
groupRepoStubForGroupUpdate
{
group
:
&
Group
{
ID
:
10
,
Name
:
"Sub"
,
Status
:
StatusActive
,
IsExclusive
:
false
,
SubscriptionType
:
SubscriptionTypeSubscription
}}
userRepo
:=
&
userRepoStubForGroupUpdate
{}
userRepo
:=
&
userRepoStubForGroupUpdate
{}
svc
:=
&
adminServiceImpl
{
apiKeyRepo
:
apiKeyRepo
,
groupRepo
:
groupRepo
,
userRepo
:
userRepo
}
svc
:=
&
adminServiceImpl
{
apiKeyRepo
:
apiKeyRepo
,
groupRepo
:
groupRepo
,
userRepo
:
userRepo
}
// 订阅类型分组应被阻止绑定
_
,
err
:=
svc
.
AdminUpdateAPIKeyGroupID
(
context
.
Background
(),
1
,
int64Ptr
(
10
))
_
,
err
:=
svc
.
AdminUpdateAPIKeyGroupID
(
context
.
Background
(),
1
,
int64Ptr
(
10
))
require
.
Error
(
t
,
err
)
require
.
Error
(
t
,
err
)
require
.
Equal
(
t
,
"SUBSCRIPTION_REPOSITORY_UNAVAILABLE"
,
infraerrors
.
Reason
(
err
))
require
.
Equal
(
t
,
"SUBSCRIPTION_GROUP_NOT_ALLOWED"
,
infraerrors
.
Reason
(
err
))
require
.
False
(
t
,
userRepo
.
addGroupCalled
)
}
func
TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_AllowsActiveSubscription
(
t
*
testing
.
T
)
{
existing
:=
&
APIKey
{
ID
:
1
,
UserID
:
42
,
Key
:
"sk-test"
,
GroupID
:
nil
}
apiKeyRepo
:=
&
apiKeyRepoStubForGroupUpdate
{
key
:
existing
}
groupRepo
:=
&
groupRepoStubForGroupUpdate
{
group
:
&
Group
{
ID
:
10
,
Name
:
"Sub"
,
Status
:
StatusActive
,
IsExclusive
:
true
,
SubscriptionType
:
SubscriptionTypeSubscription
}}
userRepo
:=
&
userRepoStubForGroupUpdate
{}
userSubRepo
:=
&
userSubRepoStubForGroupUpdate
{
getActiveSub
:
&
UserSubscription
{
ID
:
99
,
UserID
:
42
,
GroupID
:
10
},
}
svc
:=
&
adminServiceImpl
{
apiKeyRepo
:
apiKeyRepo
,
groupRepo
:
groupRepo
,
userRepo
:
userRepo
,
userSubRepo
:
userSubRepo
}
got
,
err
:=
svc
.
AdminUpdateAPIKeyGroupID
(
context
.
Background
(),
1
,
int64Ptr
(
10
))
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
userSubRepo
.
called
)
require
.
NotNil
(
t
,
got
.
APIKey
.
GroupID
)
require
.
Equal
(
t
,
int64
(
10
),
*
got
.
APIKey
.
GroupID
)
require
.
False
(
t
,
userRepo
.
addGroupCalled
)
require
.
False
(
t
,
userRepo
.
addGroupCalled
)
}
}
...
...
backend/internal/service/user_service_test.go
View file @
1cd033e5
...
@@ -46,9 +46,6 @@ func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int
...
@@ -46,9 +46,6 @@ func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int
return
0
,
nil
return
0
,
nil
}
}
func
(
m
*
mockUserRepo
)
AddGroupToAllowedGroups
(
context
.
Context
,
int64
,
int64
)
error
{
return
nil
}
func
(
m
*
mockUserRepo
)
AddGroupToAllowedGroups
(
context
.
Context
,
int64
,
int64
)
error
{
return
nil
}
func
(
m
*
mockUserRepo
)
RemoveGroupFromUserAllowedGroups
(
context
.
Context
,
int64
,
int64
)
error
{
return
nil
}
func
(
m
*
mockUserRepo
)
UpdateTotpSecret
(
context
.
Context
,
int64
,
*
string
)
error
{
return
nil
}
func
(
m
*
mockUserRepo
)
UpdateTotpSecret
(
context
.
Context
,
int64
,
*
string
)
error
{
return
nil
}
func
(
m
*
mockUserRepo
)
EnableTotp
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
m
*
mockUserRepo
)
EnableTotp
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
m
*
mockUserRepo
)
DisableTotp
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
m
*
mockUserRepo
)
DisableTotp
(
context
.
Context
,
int64
)
error
{
return
nil
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment