"frontend/git@web.lueluesay.top:chenxi/sub2api.git" did not exist on "ef2c35dbb1e4af119478d75638e3b9b12b3ebc2a"
Commit c86d445c authored by IanShaw027's avatar IanShaw027
Browse files

fix(frontend): sync with main and finalize i18n & component optimizations

parents 6c036d7b e78c8646
...@@ -135,12 +135,12 @@ func (s *AccountRepoSuite) TestListWithFilters() { ...@@ -135,12 +135,12 @@ func (s *AccountRepoSuite) TestListWithFilters() {
name: "filter_by_type", name: "filter_by_type",
setup: func(client *dbent.Client) { setup: func(client *dbent.Client) {
mustCreateAccount(s.T(), client, &service.Account{Name: "t1", Type: service.AccountTypeOAuth}) mustCreateAccount(s.T(), client, &service.Account{Name: "t1", Type: service.AccountTypeOAuth})
mustCreateAccount(s.T(), client, &service.Account{Name: "t2", Type: service.AccountTypeApiKey}) mustCreateAccount(s.T(), client, &service.Account{Name: "t2", Type: service.AccountTypeAPIKey})
}, },
accType: service.AccountTypeApiKey, accType: service.AccountTypeAPIKey,
wantCount: 1, wantCount: 1,
validate: func(accounts []service.Account) { validate: func(accounts []service.Account) {
s.Require().Equal(service.AccountTypeApiKey, accounts[0].Type) s.Require().Equal(service.AccountTypeAPIKey, accounts[0].Type)
}, },
}, },
{ {
......
...@@ -98,7 +98,7 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *t ...@@ -98,7 +98,7 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *t
userRepo := newUserRepositoryWithSQL(entClient, tx) userRepo := newUserRepositoryWithSQL(entClient, tx)
groupRepo := newGroupRepositoryWithSQL(entClient, tx) groupRepo := newGroupRepositoryWithSQL(entClient, tx)
apiKeyRepo := NewApiKeyRepository(entClient) apiKeyRepo := NewAPIKeyRepository(entClient)
u := &service.User{ u := &service.User{
Email: uniqueTestValue(t, "cascade-user") + "@example.com", Email: uniqueTestValue(t, "cascade-user") + "@example.com",
...@@ -110,7 +110,7 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *t ...@@ -110,7 +110,7 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *t
} }
require.NoError(t, userRepo.Create(ctx, u)) require.NoError(t, userRepo.Create(ctx, u))
key := &service.ApiKey{ key := &service.APIKey{
UserID: u.ID, UserID: u.ID,
Key: uniqueTestValue(t, "sk-test-delete-cascade"), Key: uniqueTestValue(t, "sk-test-delete-cascade"),
Name: "test key", Name: "test key",
......
...@@ -24,7 +24,7 @@ type apiKeyCache struct { ...@@ -24,7 +24,7 @@ type apiKeyCache struct {
rdb *redis.Client rdb *redis.Client
} }
func NewApiKeyCache(rdb *redis.Client) service.ApiKeyCache { func NewAPIKeyCache(rdb *redis.Client) service.APIKeyCache {
return &apiKeyCache{rdb: rdb} return &apiKeyCache{rdb: rdb}
} }
......
...@@ -16,17 +16,17 @@ type apiKeyRepository struct { ...@@ -16,17 +16,17 @@ type apiKeyRepository struct {
client *dbent.Client client *dbent.Client
} }
func NewApiKeyRepository(client *dbent.Client) service.ApiKeyRepository { func NewAPIKeyRepository(client *dbent.Client) service.APIKeyRepository {
return &apiKeyRepository{client: client} return &apiKeyRepository{client: client}
} }
func (r *apiKeyRepository) activeQuery() *dbent.ApiKeyQuery { func (r *apiKeyRepository) activeQuery() *dbent.APIKeyQuery {
// 默认过滤已软删除记录,避免删除后仍被查询到。 // 默认过滤已软删除记录,避免删除后仍被查询到。
return r.client.ApiKey.Query().Where(apikey.DeletedAtIsNil()) return r.client.APIKey.Query().Where(apikey.DeletedAtIsNil())
} }
func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) error { func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) error {
created, err := r.client.ApiKey.Create(). created, err := r.client.APIKey.Create().
SetUserID(key.UserID). SetUserID(key.UserID).
SetKey(key.Key). SetKey(key.Key).
SetName(key.Name). SetName(key.Name).
...@@ -38,10 +38,10 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) erro ...@@ -38,10 +38,10 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) erro
key.CreatedAt = created.CreatedAt key.CreatedAt = created.CreatedAt
key.UpdatedAt = created.UpdatedAt key.UpdatedAt = created.UpdatedAt
} }
return translatePersistenceError(err, nil, service.ErrApiKeyExists) return translatePersistenceError(err, nil, service.ErrAPIKeyExists)
} }
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) { func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
m, err := r.activeQuery(). m, err := r.activeQuery().
Where(apikey.IDEQ(id)). Where(apikey.IDEQ(id)).
WithUser(). WithUser().
...@@ -49,7 +49,7 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiK ...@@ -49,7 +49,7 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiK
Only(ctx) Only(ctx)
if err != nil { if err != nil {
if dbent.IsNotFound(err) { if dbent.IsNotFound(err) {
return nil, service.ErrApiKeyNotFound return nil, service.ErrAPIKeyNotFound
} }
return nil, err return nil, err
} }
...@@ -59,7 +59,7 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiK ...@@ -59,7 +59,7 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiK
// GetOwnerID 根据 API Key ID 获取其所有者(用户)的 ID。 // GetOwnerID 根据 API Key ID 获取其所有者(用户)的 ID。
// 相比 GetByID,此方法性能更优,因为: // 相比 GetByID,此方法性能更优,因为:
// - 使用 Select() 只查询 user_id 字段,减少数据传输量 // - 使用 Select() 只查询 user_id 字段,减少数据传输量
// - 不加载完整的 ApiKey 实体及其关联数据(User、Group 等) // - 不加载完整的 API Key 实体及其关联数据(User、Group 等)
// - 适用于权限验证等只需用户 ID 的场景(如删除前的所有权检查) // - 适用于权限验证等只需用户 ID 的场景(如删除前的所有权检查)
func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, error) { func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, error) {
m, err := r.activeQuery(). m, err := r.activeQuery().
...@@ -68,14 +68,14 @@ func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, err ...@@ -68,14 +68,14 @@ func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, err
Only(ctx) Only(ctx)
if err != nil { if err != nil {
if dbent.IsNotFound(err) { if dbent.IsNotFound(err) {
return 0, service.ErrApiKeyNotFound return 0, service.ErrAPIKeyNotFound
} }
return 0, err return 0, err
} }
return m.UserID, nil return m.UserID, nil
} }
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) { func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
m, err := r.activeQuery(). m, err := r.activeQuery().
Where(apikey.KeyEQ(key)). Where(apikey.KeyEQ(key)).
WithUser(). WithUser().
...@@ -83,21 +83,21 @@ func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.A ...@@ -83,21 +83,21 @@ func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.A
Only(ctx) Only(ctx)
if err != nil { if err != nil {
if dbent.IsNotFound(err) { if dbent.IsNotFound(err) {
return nil, service.ErrApiKeyNotFound return nil, service.ErrAPIKeyNotFound
} }
return nil, err return nil, err
} }
return apiKeyEntityToService(m), nil return apiKeyEntityToService(m), nil
} }
func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) error { func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) error {
// 使用原子操作:将软删除检查与更新合并到同一语句,避免竞态条件。 // 使用原子操作:将软删除检查与更新合并到同一语句,避免竞态条件。
// 之前的实现先检查 Exist 再 UpdateOneID,若在两步之间发生软删除, // 之前的实现先检查 Exist 再 UpdateOneID,若在两步之间发生软删除,
// 则会更新已删除的记录。 // 则会更新已删除的记录。
// 这里选择 Update().Where(),确保只有未软删除记录能被更新。 // 这里选择 Update().Where(),确保只有未软删除记录能被更新。
// 同时显式设置 updated_at,避免二次查询带来的并发可见性问题。 // 同时显式设置 updated_at,避免二次查询带来的并发可见性问题。
now := time.Now() now := time.Now()
builder := r.client.ApiKey.Update(). builder := r.client.APIKey.Update().
Where(apikey.IDEQ(key.ID), apikey.DeletedAtIsNil()). Where(apikey.IDEQ(key.ID), apikey.DeletedAtIsNil()).
SetName(key.Name). SetName(key.Name).
SetStatus(key.Status). SetStatus(key.Status).
...@@ -114,7 +114,7 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) erro ...@@ -114,7 +114,7 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) erro
} }
if affected == 0 { if affected == 0 {
// 更新影响行数为 0,说明记录不存在或已被软删除。 // 更新影响行数为 0,说明记录不存在或已被软删除。
return service.ErrApiKeyNotFound return service.ErrAPIKeyNotFound
} }
// 使用同一时间戳回填,避免并发删除导致二次查询失败。 // 使用同一时间戳回填,避免并发删除导致二次查询失败。
...@@ -124,18 +124,18 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) erro ...@@ -124,18 +124,18 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) erro
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error { func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
// 显式软删除:避免依赖 Hook 行为,确保 deleted_at 一定被设置。 // 显式软删除:避免依赖 Hook 行为,确保 deleted_at 一定被设置。
affected, err := r.client.ApiKey.Update(). affected, err := r.client.APIKey.Update().
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()). Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
SetDeletedAt(time.Now()). SetDeletedAt(time.Now()).
Save(ctx) Save(ctx)
if err != nil { if err != nil {
if dbent.IsNotFound(err) { if dbent.IsNotFound(err) {
return service.ErrApiKeyNotFound return service.ErrAPIKeyNotFound
} }
return err return err
} }
if affected == 0 { if affected == 0 {
exists, err := r.client.ApiKey.Query(). exists, err := r.client.APIKey.Query().
Where(apikey.IDEQ(id)). Where(apikey.IDEQ(id)).
Exist(mixins.SkipSoftDelete(ctx)) Exist(mixins.SkipSoftDelete(ctx))
if err != nil { if err != nil {
...@@ -144,12 +144,12 @@ func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error { ...@@ -144,12 +144,12 @@ func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
if exists { if exists {
return nil return nil
} }
return service.ErrApiKeyNotFound return service.ErrAPIKeyNotFound
} }
return nil return nil
} }
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) { func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
q := r.activeQuery().Where(apikey.UserIDEQ(userID)) q := r.activeQuery().Where(apikey.UserIDEQ(userID))
total, err := q.Count(ctx) total, err := q.Count(ctx)
...@@ -167,7 +167,7 @@ func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, param ...@@ -167,7 +167,7 @@ func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
return nil, nil, err return nil, nil, err
} }
outKeys := make([]service.ApiKey, 0, len(keys)) outKeys := make([]service.APIKey, 0, len(keys))
for i := range keys { for i := range keys {
outKeys = append(outKeys, *apiKeyEntityToService(keys[i])) outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
} }
...@@ -180,7 +180,7 @@ func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, ap ...@@ -180,7 +180,7 @@ func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, ap
return []int64{}, nil return []int64{}, nil
} }
ids, err := r.client.ApiKey.Query(). ids, err := r.client.APIKey.Query().
Where(apikey.UserIDEQ(userID), apikey.IDIn(apiKeyIDs...), apikey.DeletedAtIsNil()). Where(apikey.UserIDEQ(userID), apikey.IDIn(apiKeyIDs...), apikey.DeletedAtIsNil()).
IDs(ctx) IDs(ctx)
if err != nil { if err != nil {
...@@ -199,7 +199,7 @@ func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, e ...@@ -199,7 +199,7 @@ func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, e
return count > 0, err return count > 0, err
} }
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) { func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
q := r.activeQuery().Where(apikey.GroupIDEQ(groupID)) q := r.activeQuery().Where(apikey.GroupIDEQ(groupID))
total, err := q.Count(ctx) total, err := q.Count(ctx)
...@@ -217,7 +217,7 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par ...@@ -217,7 +217,7 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
return nil, nil, err return nil, nil, err
} }
outKeys := make([]service.ApiKey, 0, len(keys)) outKeys := make([]service.APIKey, 0, len(keys))
for i := range keys { for i := range keys {
outKeys = append(outKeys, *apiKeyEntityToService(keys[i])) outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
} }
...@@ -225,8 +225,8 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par ...@@ -225,8 +225,8 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
return outKeys, paginationResultFromTotal(int64(total), params), nil return outKeys, paginationResultFromTotal(int64(total), params), nil
} }
// SearchApiKeys searches API keys by user ID and/or keyword (name) // SearchAPIKeys searches API keys by user ID and/or keyword (name)
func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) { func (r *apiKeyRepository) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
q := r.activeQuery() q := r.activeQuery()
if userID > 0 { if userID > 0 {
q = q.Where(apikey.UserIDEQ(userID)) q = q.Where(apikey.UserIDEQ(userID))
...@@ -241,7 +241,7 @@ func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw ...@@ -241,7 +241,7 @@ func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw
return nil, err return nil, err
} }
outKeys := make([]service.ApiKey, 0, len(keys)) outKeys := make([]service.APIKey, 0, len(keys))
for i := range keys { for i := range keys {
outKeys = append(outKeys, *apiKeyEntityToService(keys[i])) outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
} }
...@@ -250,7 +250,7 @@ func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw ...@@ -250,7 +250,7 @@ func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil // ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
n, err := r.client.ApiKey.Update(). n, err := r.client.APIKey.Update().
Where(apikey.GroupIDEQ(groupID), apikey.DeletedAtIsNil()). Where(apikey.GroupIDEQ(groupID), apikey.DeletedAtIsNil()).
ClearGroupID(). ClearGroupID().
Save(ctx) Save(ctx)
...@@ -263,11 +263,11 @@ func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (i ...@@ -263,11 +263,11 @@ func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (i
return int64(count), err return int64(count), err
} }
func apiKeyEntityToService(m *dbent.ApiKey) *service.ApiKey { func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
if m == nil { if m == nil {
return nil return nil
} }
out := &service.ApiKey{ out := &service.APIKey{
ID: m.ID, ID: m.ID,
UserID: m.UserID, UserID: m.UserID,
Key: m.Key, Key: m.Key,
......
...@@ -12,30 +12,30 @@ import ( ...@@ -12,30 +12,30 @@ import (
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
) )
type ApiKeyRepoSuite struct { type APIKeyRepoSuite struct {
suite.Suite suite.Suite
ctx context.Context ctx context.Context
client *dbent.Client client *dbent.Client
repo *apiKeyRepository repo *apiKeyRepository
} }
func (s *ApiKeyRepoSuite) SetupTest() { func (s *APIKeyRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
tx := testEntTx(s.T()) tx := testEntTx(s.T())
s.client = tx.Client() s.client = tx.Client()
s.repo = NewApiKeyRepository(s.client).(*apiKeyRepository) s.repo = NewAPIKeyRepository(s.client).(*apiKeyRepository)
} }
func TestApiKeyRepoSuite(t *testing.T) { func TestAPIKeyRepoSuite(t *testing.T) {
suite.Run(t, new(ApiKeyRepoSuite)) suite.Run(t, new(APIKeyRepoSuite))
} }
// --- Create / GetByID / GetByKey --- // --- Create / GetByID / GetByKey ---
func (s *ApiKeyRepoSuite) TestCreate() { func (s *APIKeyRepoSuite) TestCreate() {
user := s.mustCreateUser("create@test.com") user := s.mustCreateUser("create@test.com")
key := &service.ApiKey{ key := &service.APIKey{
UserID: user.ID, UserID: user.ID,
Key: "sk-create-test", Key: "sk-create-test",
Name: "Test Key", Name: "Test Key",
...@@ -51,16 +51,16 @@ func (s *ApiKeyRepoSuite) TestCreate() { ...@@ -51,16 +51,16 @@ func (s *ApiKeyRepoSuite) TestCreate() {
s.Require().Equal("sk-create-test", got.Key) s.Require().Equal("sk-create-test", got.Key)
} }
func (s *ApiKeyRepoSuite) TestGetByID_NotFound() { func (s *APIKeyRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999) _, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID") s.Require().Error(err, "expected error for non-existent ID")
} }
func (s *ApiKeyRepoSuite) TestGetByKey() { func (s *APIKeyRepoSuite) TestGetByKey() {
user := s.mustCreateUser("getbykey@test.com") user := s.mustCreateUser("getbykey@test.com")
group := s.mustCreateGroup("g-key") group := s.mustCreateGroup("g-key")
key := &service.ApiKey{ key := &service.APIKey{
UserID: user.ID, UserID: user.ID,
Key: "sk-getbykey", Key: "sk-getbykey",
Name: "My Key", Name: "My Key",
...@@ -78,16 +78,16 @@ func (s *ApiKeyRepoSuite) TestGetByKey() { ...@@ -78,16 +78,16 @@ func (s *ApiKeyRepoSuite) TestGetByKey() {
s.Require().Equal(group.ID, got.Group.ID) s.Require().Equal(group.ID, got.Group.ID)
} }
func (s *ApiKeyRepoSuite) TestGetByKey_NotFound() { func (s *APIKeyRepoSuite) TestGetByKey_NotFound() {
_, err := s.repo.GetByKey(s.ctx, "non-existent-key") _, err := s.repo.GetByKey(s.ctx, "non-existent-key")
s.Require().Error(err, "expected error for non-existent key") s.Require().Error(err, "expected error for non-existent key")
} }
// --- Update --- // --- Update ---
func (s *ApiKeyRepoSuite) TestUpdate() { func (s *APIKeyRepoSuite) TestUpdate() {
user := s.mustCreateUser("update@test.com") user := s.mustCreateUser("update@test.com")
key := &service.ApiKey{ key := &service.APIKey{
UserID: user.ID, UserID: user.ID,
Key: "sk-update", Key: "sk-update",
Name: "Original", Name: "Original",
...@@ -108,10 +108,10 @@ func (s *ApiKeyRepoSuite) TestUpdate() { ...@@ -108,10 +108,10 @@ func (s *ApiKeyRepoSuite) TestUpdate() {
s.Require().Equal(service.StatusDisabled, got.Status) s.Require().Equal(service.StatusDisabled, got.Status)
} }
func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() { func (s *APIKeyRepoSuite) TestUpdate_ClearGroupID() {
user := s.mustCreateUser("cleargroup@test.com") user := s.mustCreateUser("cleargroup@test.com")
group := s.mustCreateGroup("g-clear") group := s.mustCreateGroup("g-clear")
key := &service.ApiKey{ key := &service.APIKey{
UserID: user.ID, UserID: user.ID,
Key: "sk-clear-group", Key: "sk-clear-group",
Name: "Group Key", Name: "Group Key",
...@@ -131,9 +131,9 @@ func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() { ...@@ -131,9 +131,9 @@ func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
// --- Delete --- // --- Delete ---
func (s *ApiKeyRepoSuite) TestDelete() { func (s *APIKeyRepoSuite) TestDelete() {
user := s.mustCreateUser("delete@test.com") user := s.mustCreateUser("delete@test.com")
key := &service.ApiKey{ key := &service.APIKey{
UserID: user.ID, UserID: user.ID,
Key: "sk-delete", Key: "sk-delete",
Name: "Delete Me", Name: "Delete Me",
...@@ -150,7 +150,7 @@ func (s *ApiKeyRepoSuite) TestDelete() { ...@@ -150,7 +150,7 @@ func (s *ApiKeyRepoSuite) TestDelete() {
// --- ListByUserID / CountByUserID --- // --- ListByUserID / CountByUserID ---
func (s *ApiKeyRepoSuite) TestListByUserID() { func (s *APIKeyRepoSuite) TestListByUserID() {
user := s.mustCreateUser("listbyuser@test.com") user := s.mustCreateUser("listbyuser@test.com")
s.mustCreateApiKey(user.ID, "sk-list-1", "Key 1", nil) s.mustCreateApiKey(user.ID, "sk-list-1", "Key 1", nil)
s.mustCreateApiKey(user.ID, "sk-list-2", "Key 2", nil) s.mustCreateApiKey(user.ID, "sk-list-2", "Key 2", nil)
...@@ -161,7 +161,7 @@ func (s *ApiKeyRepoSuite) TestListByUserID() { ...@@ -161,7 +161,7 @@ func (s *ApiKeyRepoSuite) TestListByUserID() {
s.Require().Equal(int64(2), page.Total) s.Require().Equal(int64(2), page.Total)
} }
func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() { func (s *APIKeyRepoSuite) TestListByUserID_Pagination() {
user := s.mustCreateUser("paging@test.com") user := s.mustCreateUser("paging@test.com")
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
s.mustCreateApiKey(user.ID, "sk-page-"+string(rune('a'+i)), "Key", nil) s.mustCreateApiKey(user.ID, "sk-page-"+string(rune('a'+i)), "Key", nil)
...@@ -174,7 +174,7 @@ func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() { ...@@ -174,7 +174,7 @@ func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
s.Require().Equal(3, page.Pages) s.Require().Equal(3, page.Pages)
} }
func (s *ApiKeyRepoSuite) TestCountByUserID() { func (s *APIKeyRepoSuite) TestCountByUserID() {
user := s.mustCreateUser("count@test.com") user := s.mustCreateUser("count@test.com")
s.mustCreateApiKey(user.ID, "sk-count-1", "K1", nil) s.mustCreateApiKey(user.ID, "sk-count-1", "K1", nil)
s.mustCreateApiKey(user.ID, "sk-count-2", "K2", nil) s.mustCreateApiKey(user.ID, "sk-count-2", "K2", nil)
...@@ -186,7 +186,7 @@ func (s *ApiKeyRepoSuite) TestCountByUserID() { ...@@ -186,7 +186,7 @@ func (s *ApiKeyRepoSuite) TestCountByUserID() {
// --- ListByGroupID / CountByGroupID --- // --- ListByGroupID / CountByGroupID ---
func (s *ApiKeyRepoSuite) TestListByGroupID() { func (s *APIKeyRepoSuite) TestListByGroupID() {
user := s.mustCreateUser("listbygroup@test.com") user := s.mustCreateUser("listbygroup@test.com")
group := s.mustCreateGroup("g-list") group := s.mustCreateGroup("g-list")
...@@ -202,7 +202,7 @@ func (s *ApiKeyRepoSuite) TestListByGroupID() { ...@@ -202,7 +202,7 @@ func (s *ApiKeyRepoSuite) TestListByGroupID() {
s.Require().NotNil(keys[0].User) s.Require().NotNil(keys[0].User)
} }
func (s *ApiKeyRepoSuite) TestCountByGroupID() { func (s *APIKeyRepoSuite) TestCountByGroupID() {
user := s.mustCreateUser("countgroup@test.com") user := s.mustCreateUser("countgroup@test.com")
group := s.mustCreateGroup("g-count") group := s.mustCreateGroup("g-count")
s.mustCreateApiKey(user.ID, "sk-gc-1", "K1", &group.ID) s.mustCreateApiKey(user.ID, "sk-gc-1", "K1", &group.ID)
...@@ -214,7 +214,7 @@ func (s *ApiKeyRepoSuite) TestCountByGroupID() { ...@@ -214,7 +214,7 @@ func (s *ApiKeyRepoSuite) TestCountByGroupID() {
// --- ExistsByKey --- // --- ExistsByKey ---
func (s *ApiKeyRepoSuite) TestExistsByKey() { func (s *APIKeyRepoSuite) TestExistsByKey() {
user := s.mustCreateUser("exists@test.com") user := s.mustCreateUser("exists@test.com")
s.mustCreateApiKey(user.ID, "sk-exists", "K", nil) s.mustCreateApiKey(user.ID, "sk-exists", "K", nil)
...@@ -227,41 +227,41 @@ func (s *ApiKeyRepoSuite) TestExistsByKey() { ...@@ -227,41 +227,41 @@ func (s *ApiKeyRepoSuite) TestExistsByKey() {
s.Require().False(notExists) s.Require().False(notExists)
} }
// --- SearchApiKeys --- // --- SearchAPIKeys ---
func (s *ApiKeyRepoSuite) TestSearchApiKeys() { func (s *APIKeyRepoSuite) TestSearchAPIKeys() {
user := s.mustCreateUser("search@test.com") user := s.mustCreateUser("search@test.com")
s.mustCreateApiKey(user.ID, "sk-search-1", "Production Key", nil) s.mustCreateApiKey(user.ID, "sk-search-1", "Production Key", nil)
s.mustCreateApiKey(user.ID, "sk-search-2", "Development Key", nil) s.mustCreateApiKey(user.ID, "sk-search-2", "Development Key", nil)
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "prod", 10) found, err := s.repo.SearchAPIKeys(s.ctx, user.ID, "prod", 10)
s.Require().NoError(err, "SearchApiKeys") s.Require().NoError(err, "SearchAPIKeys")
s.Require().Len(found, 1) s.Require().Len(found, 1)
s.Require().Contains(found[0].Name, "Production") s.Require().Contains(found[0].Name, "Production")
} }
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() { func (s *APIKeyRepoSuite) TestSearchAPIKeys_NoKeyword() {
user := s.mustCreateUser("searchnokw@test.com") user := s.mustCreateUser("searchnokw@test.com")
s.mustCreateApiKey(user.ID, "sk-nk-1", "K1", nil) s.mustCreateApiKey(user.ID, "sk-nk-1", "K1", nil)
s.mustCreateApiKey(user.ID, "sk-nk-2", "K2", nil) s.mustCreateApiKey(user.ID, "sk-nk-2", "K2", nil)
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "", 10) found, err := s.repo.SearchAPIKeys(s.ctx, user.ID, "", 10)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(found, 2) s.Require().Len(found, 2)
} }
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() { func (s *APIKeyRepoSuite) TestSearchAPIKeys_NoUserID() {
user := s.mustCreateUser("searchnouid@test.com") user := s.mustCreateUser("searchnouid@test.com")
s.mustCreateApiKey(user.ID, "sk-nu-1", "TestKey", nil) s.mustCreateApiKey(user.ID, "sk-nu-1", "TestKey", nil)
found, err := s.repo.SearchApiKeys(s.ctx, 0, "testkey", 10) found, err := s.repo.SearchAPIKeys(s.ctx, 0, "testkey", 10)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(found, 1) s.Require().Len(found, 1)
} }
// --- ClearGroupIDByGroupID --- // --- ClearGroupIDByGroupID ---
func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() { func (s *APIKeyRepoSuite) TestClearGroupIDByGroupID() {
user := s.mustCreateUser("cleargrp@test.com") user := s.mustCreateUser("cleargrp@test.com")
group := s.mustCreateGroup("g-clear-bulk") group := s.mustCreateGroup("g-clear-bulk")
...@@ -284,7 +284,7 @@ func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() { ...@@ -284,7 +284,7 @@ func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
// --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) --- // --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) ---
func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() { func (s *APIKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
user := s.mustCreateUser("k@example.com") user := s.mustCreateUser("k@example.com")
group := s.mustCreateGroup("g-k") group := s.mustCreateGroup("g-k")
key := s.mustCreateApiKey(user.ID, "sk-test-1", "My Key", &group.ID) key := s.mustCreateApiKey(user.ID, "sk-test-1", "My Key", &group.ID)
...@@ -320,8 +320,8 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() { ...@@ -320,8 +320,8 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s.Require().NoError(err, "ExistsByKey") s.Require().NoError(err, "ExistsByKey")
s.Require().True(exists, "expected key to exist") s.Require().True(exists, "expected key to exist")
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "renam", 10) found, err := s.repo.SearchAPIKeys(s.ctx, user.ID, "renam", 10)
s.Require().NoError(err, "SearchApiKeys") s.Require().NoError(err, "SearchAPIKeys")
s.Require().Len(found, 1) s.Require().Len(found, 1)
s.Require().Equal(key.ID, found[0].ID) s.Require().Equal(key.ID, found[0].ID)
...@@ -346,7 +346,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() { ...@@ -346,7 +346,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s.Require().Equal(int64(0), countAfter, "expected 0 keys in group after clear") s.Require().Equal(int64(0), countAfter, "expected 0 keys in group after clear")
} }
func (s *ApiKeyRepoSuite) mustCreateUser(email string) *service.User { func (s *APIKeyRepoSuite) mustCreateUser(email string) *service.User {
s.T().Helper() s.T().Helper()
u, err := s.client.User.Create(). u, err := s.client.User.Create().
...@@ -359,7 +359,7 @@ func (s *ApiKeyRepoSuite) mustCreateUser(email string) *service.User { ...@@ -359,7 +359,7 @@ func (s *ApiKeyRepoSuite) mustCreateUser(email string) *service.User {
return userEntityToService(u) return userEntityToService(u)
} }
func (s *ApiKeyRepoSuite) mustCreateGroup(name string) *service.Group { func (s *APIKeyRepoSuite) mustCreateGroup(name string) *service.Group {
s.T().Helper() s.T().Helper()
g, err := s.client.Group.Create(). g, err := s.client.Group.Create().
...@@ -370,10 +370,10 @@ func (s *ApiKeyRepoSuite) mustCreateGroup(name string) *service.Group { ...@@ -370,10 +370,10 @@ func (s *ApiKeyRepoSuite) mustCreateGroup(name string) *service.Group {
return groupEntityToService(g) return groupEntityToService(g)
} }
func (s *ApiKeyRepoSuite) mustCreateApiKey(userID int64, key, name string, groupID *int64) *service.ApiKey { func (s *APIKeyRepoSuite) mustCreateApiKey(userID int64, key, name string, groupID *int64) *service.APIKey {
s.T().Helper() s.T().Helper()
k := &service.ApiKey{ k := &service.APIKey{
UserID: userID, UserID: userID,
Key: key, Key: key,
Name: name, Name: name,
......
...@@ -5,28 +5,20 @@ import ( ...@@ -5,28 +5,20 @@ import (
"encoding/json" "encoding/json"
"io" "io"
"net/http" "net/http"
"net/http/httptest"
"strings" "strings"
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth" "github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/imroc/req/v3"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
) )
type ClaudeOAuthServiceSuite struct { type ClaudeOAuthServiceSuite struct {
suite.Suite suite.Suite
srv *httptest.Server
client *claudeOAuthService client *claudeOAuthService
} }
func (s *ClaudeOAuthServiceSuite) TearDownTest() {
if s.srv != nil {
s.srv.Close()
s.srv = nil
}
}
// requestCapture holds captured request data for assertions in the main goroutine. // requestCapture holds captured request data for assertions in the main goroutine.
type requestCapture struct { type requestCapture struct {
path string path string
...@@ -37,6 +29,12 @@ type requestCapture struct { ...@@ -37,6 +29,12 @@ type requestCapture struct {
contentType string contentType string
} }
func newTestReqClient(rt http.RoundTripper) *req.Client {
c := req.C()
c.GetClient().Transport = rt
return c
}
func (s *ClaudeOAuthServiceSuite) TestGetOrganizationUUID() { func (s *ClaudeOAuthServiceSuite) TestGetOrganizationUUID() {
tests := []struct { tests := []struct {
name string name string
...@@ -83,17 +81,17 @@ func (s *ClaudeOAuthServiceSuite) TestGetOrganizationUUID() { ...@@ -83,17 +81,17 @@ func (s *ClaudeOAuthServiceSuite) TestGetOrganizationUUID() {
s.Run(tt.name, func() { s.Run(tt.name, func() {
var captured requestCapture var captured requestCapture
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { rt := newInProcessTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.path = r.URL.Path captured.path = r.URL.Path
captured.cookies = r.Cookies() captured.cookies = r.Cookies()
tt.handler(w, r) tt.handler(w, r)
})) }), nil)
defer s.srv.Close()
client, ok := NewClaudeOAuthClient().(*claudeOAuthService) client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
require.True(s.T(), ok, "type assertion failed") require.True(s.T(), ok, "type assertion failed")
s.client = client s.client = client
s.client.baseURL = s.srv.URL s.client.baseURL = "http://in-process"
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
got, err := s.client.GetOrganizationUUID(context.Background(), "sess", "") got, err := s.client.GetOrganizationUUID(context.Background(), "sess", "")
...@@ -158,20 +156,20 @@ func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() { ...@@ -158,20 +156,20 @@ func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() {
s.Run(tt.name, func() { s.Run(tt.name, func() {
var captured requestCapture var captured requestCapture
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { rt := newInProcessTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.path = r.URL.Path captured.path = r.URL.Path
captured.method = r.Method captured.method = r.Method
captured.cookies = r.Cookies() captured.cookies = r.Cookies()
captured.body, _ = io.ReadAll(r.Body) captured.body, _ = io.ReadAll(r.Body)
_ = json.Unmarshal(captured.body, &captured.bodyJSON) _ = json.Unmarshal(captured.body, &captured.bodyJSON)
tt.handler(w, r) tt.handler(w, r)
})) }), nil)
defer s.srv.Close()
client, ok := NewClaudeOAuthClient().(*claudeOAuthService) client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
require.True(s.T(), ok, "type assertion failed") require.True(s.T(), ok, "type assertion failed")
s.client = client s.client = client
s.client.baseURL = s.srv.URL s.client.baseURL = "http://in-process"
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeProfile, "cc", "st", "") code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeProfile, "cc", "st", "")
...@@ -266,19 +264,19 @@ func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() { ...@@ -266,19 +264,19 @@ func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() {
s.Run(tt.name, func() { s.Run(tt.name, func() {
var captured requestCapture var captured requestCapture
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { rt := newInProcessTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.method = r.Method captured.method = r.Method
captured.contentType = r.Header.Get("Content-Type") captured.contentType = r.Header.Get("Content-Type")
captured.body, _ = io.ReadAll(r.Body) captured.body, _ = io.ReadAll(r.Body)
_ = json.Unmarshal(captured.body, &captured.bodyJSON) _ = json.Unmarshal(captured.body, &captured.bodyJSON)
tt.handler(w, r) tt.handler(w, r)
})) }), nil)
defer s.srv.Close()
client, ok := NewClaudeOAuthClient().(*claudeOAuthService) client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
require.True(s.T(), ok, "type assertion failed") require.True(s.T(), ok, "type assertion failed")
s.client = client s.client = client
s.client.tokenURL = s.srv.URL s.client.tokenURL = "http://in-process/token"
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
resp, err := s.client.ExchangeCodeForToken(context.Background(), tt.code, "ver", "", "", tt.isSetupToken) resp, err := s.client.ExchangeCodeForToken(context.Background(), tt.code, "ver", "", "", tt.isSetupToken)
...@@ -362,19 +360,19 @@ func (s *ClaudeOAuthServiceSuite) TestRefreshToken() { ...@@ -362,19 +360,19 @@ func (s *ClaudeOAuthServiceSuite) TestRefreshToken() {
s.Run(tt.name, func() { s.Run(tt.name, func() {
var captured requestCapture var captured requestCapture
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { rt := newInProcessTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.method = r.Method captured.method = r.Method
captured.contentType = r.Header.Get("Content-Type") captured.contentType = r.Header.Get("Content-Type")
captured.body, _ = io.ReadAll(r.Body) captured.body, _ = io.ReadAll(r.Body)
_ = json.Unmarshal(captured.body, &captured.bodyJSON) _ = json.Unmarshal(captured.body, &captured.bodyJSON)
tt.handler(w, r) tt.handler(w, r)
})) }), nil)
defer s.srv.Close()
client, ok := NewClaudeOAuthClient().(*claudeOAuthService) client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
require.True(s.T(), ok, "type assertion failed") require.True(s.T(), ok, "type assertion failed")
s.client = client s.client = client
s.client.tokenURL = s.srv.URL s.client.tokenURL = "http://in-process/token"
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
resp, err := s.client.RefreshToken(context.Background(), "rt", "") resp, err := s.client.RefreshToken(context.Background(), "rt", "")
......
...@@ -33,7 +33,7 @@ type usageRequestCapture struct { ...@@ -33,7 +33,7 @@ type usageRequestCapture struct {
func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() { func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() {
var captured usageRequestCapture var captured usageRequestCapture
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.authorization = r.Header.Get("Authorization") captured.authorization = r.Header.Get("Authorization")
captured.anthropicBeta = r.Header.Get("anthropic-beta") captured.anthropicBeta = r.Header.Get("anthropic-beta")
...@@ -59,7 +59,7 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() { ...@@ -59,7 +59,7 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() {
} }
func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() { func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
_, _ = io.WriteString(w, "nope") _, _ = io.WriteString(w, "nope")
})) }))
...@@ -73,7 +73,7 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() { ...@@ -73,7 +73,7 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() {
} }
func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() { func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, "not-json") _, _ = io.WriteString(w, "not-json")
})) }))
...@@ -86,7 +86,7 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() { ...@@ -86,7 +86,7 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() {
} }
func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() { func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Never respond - simulate slow server // Never respond - simulate slow server
<-r.Context().Done() <-r.Context().Done()
})) }))
......
...@@ -309,7 +309,7 @@ func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) ...@@ -309,7 +309,7 @@ func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64)
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) { func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
key := waitQueueKey(userID) key := waitQueueKey(userID)
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.slotTTLSeconds).Int() result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.waitQueueTTLSeconds).Int()
if err != nil { if err != nil {
return false, err return false, err
} }
......
// Package infrastructure 提供应用程序的基础设施层组件。 // Package repository 提供应用程序的基础设施层组件。
// 包括数据库连接初始化、ORM 客户端管理、Redis 连接、数据库迁移等核心功能。 // 包括数据库连接初始化、ORM 客户端管理、Redis 连接、数据库迁移等核心功能。
package repository package repository
......
...@@ -243,7 +243,7 @@ func mustCreateAccount(t *testing.T, client *dbent.Client, a *service.Account) * ...@@ -243,7 +243,7 @@ func mustCreateAccount(t *testing.T, client *dbent.Client, a *service.Account) *
return a return a
} }
func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.ApiKey) *service.ApiKey { func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.APIKey) *service.APIKey {
t.Helper() t.Helper()
ctx := context.Background() ctx := context.Background()
...@@ -257,7 +257,7 @@ func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.ApiKey) *se ...@@ -257,7 +257,7 @@ func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.ApiKey) *se
k.Name = "default" k.Name = "default"
} }
create := client.ApiKey.Create(). create := client.APIKey.Create().
SetUserID(k.UserID). SetUserID(k.UserID).
SetKey(k.Key). SetKey(k.Key).
SetName(k.Name). SetName(k.Name).
......
...@@ -30,6 +30,7 @@ func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, c ...@@ -30,6 +30,7 @@ func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, c
// Use different OAuth clients based on oauthType: // Use different OAuth clients based on oauthType:
// - code_assist: always use built-in Gemini CLI OAuth client (public) // - code_assist: always use built-in Gemini CLI OAuth client (public)
// - google_one: uses configured OAuth client when provided; otherwise falls back to built-in client
// - ai_studio: requires a user-provided OAuth client // - ai_studio: requires a user-provided OAuth client
oauthCfgInput := geminicli.OAuthConfig{ oauthCfgInput := geminicli.OAuthConfig{
ClientID: c.cfg.Gemini.OAuth.ClientID, ClientID: c.cfg.Gemini.OAuth.ClientID,
......
...@@ -49,7 +49,7 @@ func (s *GitHubReleaseServiceSuite) TearDownTest() { ...@@ -49,7 +49,7 @@ func (s *GitHubReleaseServiceSuite) TearDownTest() {
} }
func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_ContentLength() { func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_ContentLength() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Length", "100") w.Header().Set("Content-Length", "100")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
_, _ = w.Write(bytes.Repeat([]byte("a"), 100)) _, _ = w.Write(bytes.Repeat([]byte("a"), 100))
...@@ -68,7 +68,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_ContentLeng ...@@ -68,7 +68,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_ContentLeng
} }
func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_Chunked() { func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_Chunked() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Force chunked encoding (unknown Content-Length) by flushing headers before writing. // Force chunked encoding (unknown Content-Length) by flushing headers before writing.
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
if fl, ok := w.(http.Flusher); ok { if fl, ok := w.(http.Flusher); ok {
...@@ -95,7 +95,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_Chunked() { ...@@ -95,7 +95,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_Chunked() {
} }
func (s *GitHubReleaseServiceSuite) TestDownloadFile_Success() { func (s *GitHubReleaseServiceSuite) TestDownloadFile_Success() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
if fl, ok := w.(http.Flusher); ok { if fl, ok := w.(http.Flusher); ok {
fl.Flush() fl.Flush()
...@@ -123,7 +123,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_Success() { ...@@ -123,7 +123,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_Success() {
} }
func (s *GitHubReleaseServiceSuite) TestDownloadFile_404() { func (s *GitHubReleaseServiceSuite) TestDownloadFile_404() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound) w.WriteHeader(http.StatusNotFound)
})) }))
...@@ -140,7 +140,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_404() { ...@@ -140,7 +140,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_404() {
} }
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Success() { func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Success() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("sum")) _, _ = w.Write([]byte("sum"))
})) }))
...@@ -155,7 +155,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Success() { ...@@ -155,7 +155,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Success() {
} }
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Non200() { func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Non200() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
})) }))
...@@ -168,7 +168,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Non200() { ...@@ -168,7 +168,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Non200() {
} }
func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() { func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
<-r.Context().Done() <-r.Context().Done()
})) }))
...@@ -195,7 +195,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidURL() { ...@@ -195,7 +195,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidURL() {
} }
func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() { func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("content")) _, _ = w.Write([]byte("content"))
})) }))
...@@ -233,7 +233,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() { ...@@ -233,7 +233,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() {
] ]
}` }`
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(s.T(), "/repos/test/repo/releases/latest", r.URL.Path) require.Equal(s.T(), "/repos/test/repo/releases/latest", r.URL.Path)
require.Equal(s.T(), "application/vnd.github.v3+json", r.Header.Get("Accept")) require.Equal(s.T(), "application/vnd.github.v3+json", r.Header.Get("Accept"))
require.Equal(s.T(), "Sub2API-Updater", r.Header.Get("User-Agent")) require.Equal(s.T(), "Sub2API-Updater", r.Header.Get("User-Agent"))
...@@ -258,7 +258,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() { ...@@ -258,7 +258,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() {
} }
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() { func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound) w.WriteHeader(http.StatusNotFound)
})) }))
...@@ -274,7 +274,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() { ...@@ -274,7 +274,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() {
} }
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() { func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("not valid json")) _, _ = w.Write([]byte("not valid json"))
})) }))
...@@ -290,7 +290,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() { ...@@ -290,7 +290,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() {
} }
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() { func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
<-r.Context().Done() <-r.Context().Done()
})) }))
...@@ -308,7 +308,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() { ...@@ -308,7 +308,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() {
} }
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_ContextCancel() { func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_ContextCancel() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
<-r.Context().Done() <-r.Context().Done()
})) }))
......
...@@ -293,8 +293,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, ...@@ -293,8 +293,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
// 2. Clear group_id for api keys bound to this group. // 2. Clear group_id for api keys bound to this group.
// 仅更新未软删除的记录,避免修改已删除数据,保证审计与历史回溯一致性。 // 仅更新未软删除的记录,避免修改已删除数据,保证审计与历史回溯一致性。
// 与 ApiKeyRepository 的软删除语义保持一致,减少跨模块行为差异。 // 与 APIKeyRepository 的软删除语义保持一致,减少跨模块行为差异。
if _, err := txClient.ApiKey.Update(). if _, err := txClient.APIKey.Update().
Where(apikey.GroupIDEQ(id), apikey.DeletedAtIsNil()). Where(apikey.GroupIDEQ(id), apikey.DeletedAtIsNil()).
ClearGroupID(). ClearGroupID().
Save(ctx); err != nil { Save(ctx); err != nil {
......
...@@ -3,7 +3,6 @@ package repository ...@@ -3,7 +3,6 @@ package repository
import ( import (
"io" "io"
"net/http" "net/http"
"net/http/httptest"
"sync/atomic" "sync/atomic"
"testing" "testing"
"time" "time"
...@@ -93,7 +92,7 @@ func (s *HTTPUpstreamSuite) TestAcquireClient_OverLimitReturnsError() { ...@@ -93,7 +92,7 @@ func (s *HTTPUpstreamSuite) TestAcquireClient_OverLimitReturnsError() {
// 验证空代理 URL 时请求直接发送到目标服务器 // 验证空代理 URL 时请求直接发送到目标服务器
func (s *HTTPUpstreamSuite) TestDo_WithoutProxy_GoesDirect() { func (s *HTTPUpstreamSuite) TestDo_WithoutProxy_GoesDirect() {
// 创建模拟上游服务器 // 创建模拟上游服务器
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "direct") _, _ = io.WriteString(w, "direct")
})) }))
s.T().Cleanup(upstream.Close) s.T().Cleanup(upstream.Close)
...@@ -115,7 +114,7 @@ func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() { ...@@ -115,7 +114,7 @@ func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() {
// 用于接收代理请求的通道 // 用于接收代理请求的通道
seen := make(chan string, 1) seen := make(chan string, 1)
// 创建模拟代理服务器 // 创建模拟代理服务器
proxySrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { proxySrv := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
seen <- r.RequestURI // 记录请求 URI seen <- r.RequestURI // 记录请求 URI
_, _ = io.WriteString(w, "proxied") _, _ = io.WriteString(w, "proxied")
})) }))
...@@ -145,7 +144,7 @@ func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() { ...@@ -145,7 +144,7 @@ func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() {
// TestDo_EmptyProxy_UsesDirect 测试空代理字符串 // TestDo_EmptyProxy_UsesDirect 测试空代理字符串
// 验证空字符串代理等同于直连 // 验证空字符串代理等同于直连
func (s *HTTPUpstreamSuite) TestDo_EmptyProxy_UsesDirect() { func (s *HTTPUpstreamSuite) TestDo_EmptyProxy_UsesDirect() {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "direct-empty") _, _ = io.WriteString(w, "direct-empty")
})) }))
s.T().Cleanup(upstream.Close) s.T().Cleanup(upstream.Close)
......
package repository
import (
"bytes"
"io"
"net"
"net/http"
"net/http/httptest"
"sync"
"testing"
)
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) }
// newInProcessTransport adapts an http.HandlerFunc into an http.RoundTripper without opening sockets.
// It captures the request body (if any) and then rewinds it before invoking the handler.
func newInProcessTransport(handler http.HandlerFunc, capture func(r *http.Request, body []byte)) http.RoundTripper {
return roundTripFunc(func(r *http.Request) (*http.Response, error) {
var body []byte
if r.Body != nil {
body, _ = io.ReadAll(r.Body)
_ = r.Body.Close()
r.Body = io.NopCloser(bytes.NewReader(body))
}
if capture != nil {
capture(r, body)
}
rec := httptest.NewRecorder()
handler(rec, r)
return rec.Result(), nil
})
}
var (
canListenOnce sync.Once
canListen bool
canListenErr error
)
func localListenerAvailable() bool {
canListenOnce.Do(func() {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
canListenErr = err
canListen = false
return
}
_ = ln.Close()
canListen = true
})
return canListen
}
func newLocalTestServer(tb testing.TB, handler http.Handler) *httptest.Server {
tb.Helper()
if !localListenerAvailable() {
tb.Skipf("local listeners are not permitted in this environment: %v", canListenErr)
}
return httptest.NewServer(handler)
}
...@@ -34,7 +34,7 @@ func (s *OpenAIOAuthServiceSuite) TearDownTest() { ...@@ -34,7 +34,7 @@ func (s *OpenAIOAuthServiceSuite) TearDownTest() {
} }
func (s *OpenAIOAuthServiceSuite) setupServer(handler http.HandlerFunc) { func (s *OpenAIOAuthServiceSuite) setupServer(handler http.HandlerFunc) {
s.srv = httptest.NewServer(handler) s.srv = newLocalTestServer(s.T(), handler)
s.svc = &openaiOAuthService{tokenURL: s.srv.URL} s.svc = &openaiOAuthService{tokenURL: s.srv.URL}
} }
......
...@@ -32,7 +32,7 @@ func (s *PricingServiceSuite) TearDownTest() { ...@@ -32,7 +32,7 @@ func (s *PricingServiceSuite) TearDownTest() {
} }
func (s *PricingServiceSuite) setupServer(handler http.HandlerFunc) { func (s *PricingServiceSuite) setupServer(handler http.HandlerFunc) {
s.srv = httptest.NewServer(handler) s.srv = newLocalTestServer(s.T(), handler)
} }
func (s *PricingServiceSuite) TestFetchPricingJSON_Success() { func (s *PricingServiceSuite) TestFetchPricingJSON_Success() {
......
...@@ -31,7 +31,7 @@ func (s *ProxyProbeServiceSuite) TearDownTest() { ...@@ -31,7 +31,7 @@ func (s *ProxyProbeServiceSuite) TearDownTest() {
} }
func (s *ProxyProbeServiceSuite) setupProxyServer(handler http.HandlerFunc) { func (s *ProxyProbeServiceSuite) setupProxyServer(handler http.HandlerFunc) {
s.proxySrv = httptest.NewServer(handler) s.proxySrv = newLocalTestServer(s.T(), handler)
} }
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidProxyURL() { func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidProxyURL() {
......
...@@ -41,8 +41,8 @@ func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) { ...@@ -41,8 +41,8 @@ func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) {
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user")+"@example.com") u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user")+"@example.com")
repo := NewApiKeyRepository(client) repo := NewAPIKeyRepository(client)
key := &service.ApiKey{ key := &service.APIKey{
UserID: u.ID, UserID: u.ID,
Key: uniqueSoftDeleteValue(t, "sk-soft-delete"), Key: uniqueSoftDeleteValue(t, "sk-soft-delete"),
Name: "soft-delete", Name: "soft-delete",
...@@ -53,13 +53,13 @@ func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) { ...@@ -53,13 +53,13 @@ func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) {
require.NoError(t, repo.Delete(ctx, key.ID), "soft delete api key") require.NoError(t, repo.Delete(ctx, key.ID), "soft delete api key")
_, err := repo.GetByID(ctx, key.ID) _, err := repo.GetByID(ctx, key.ID)
require.ErrorIs(t, err, service.ErrApiKeyNotFound, "deleted rows should be hidden by default") require.ErrorIs(t, err, service.ErrAPIKeyNotFound, "deleted rows should be hidden by default")
_, err = client.ApiKey.Query().Where(apikey.IDEQ(key.ID)).Only(ctx) _, err = client.APIKey.Query().Where(apikey.IDEQ(key.ID)).Only(ctx)
require.Error(t, err, "default ent query should not see soft-deleted rows") require.Error(t, err, "default ent query should not see soft-deleted rows")
require.True(t, dbent.IsNotFound(err), "expected ent not-found after default soft delete filter") require.True(t, dbent.IsNotFound(err), "expected ent not-found after default soft delete filter")
got, err := client.ApiKey.Query(). got, err := client.APIKey.Query().
Where(apikey.IDEQ(key.ID)). Where(apikey.IDEQ(key.ID)).
Only(mixins.SkipSoftDelete(ctx)) Only(mixins.SkipSoftDelete(ctx))
require.NoError(t, err, "SkipSoftDelete should include soft-deleted rows") require.NoError(t, err, "SkipSoftDelete should include soft-deleted rows")
...@@ -73,8 +73,8 @@ func TestEntSoftDelete_ApiKey_DeleteIdempotent(t *testing.T) { ...@@ -73,8 +73,8 @@ func TestEntSoftDelete_ApiKey_DeleteIdempotent(t *testing.T) {
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user2")+"@example.com") u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user2")+"@example.com")
repo := NewApiKeyRepository(client) repo := NewAPIKeyRepository(client)
key := &service.ApiKey{ key := &service.APIKey{
UserID: u.ID, UserID: u.ID,
Key: uniqueSoftDeleteValue(t, "sk-soft-delete2"), Key: uniqueSoftDeleteValue(t, "sk-soft-delete2"),
Name: "soft-delete2", Name: "soft-delete2",
...@@ -93,8 +93,8 @@ func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) { ...@@ -93,8 +93,8 @@ func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user3")+"@example.com") u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user3")+"@example.com")
repo := NewApiKeyRepository(client) repo := NewAPIKeyRepository(client)
key := &service.ApiKey{ key := &service.APIKey{
UserID: u.ID, UserID: u.ID,
Key: uniqueSoftDeleteValue(t, "sk-soft-delete3"), Key: uniqueSoftDeleteValue(t, "sk-soft-delete3"),
Name: "soft-delete3", Name: "soft-delete3",
...@@ -105,10 +105,10 @@ func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) { ...@@ -105,10 +105,10 @@ func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
require.NoError(t, repo.Delete(ctx, key.ID), "soft delete api key") require.NoError(t, repo.Delete(ctx, key.ID), "soft delete api key")
// Hard delete using SkipSoftDelete so the hook doesn't convert it to update-deleted_at. // Hard delete using SkipSoftDelete so the hook doesn't convert it to update-deleted_at.
_, err := client.ApiKey.Delete().Where(apikey.IDEQ(key.ID)).Exec(mixins.SkipSoftDelete(ctx)) _, err := client.APIKey.Delete().Where(apikey.IDEQ(key.ID)).Exec(mixins.SkipSoftDelete(ctx))
require.NoError(t, err, "hard delete") require.NoError(t, err, "hard delete")
_, err = client.ApiKey.Query(). _, err = client.APIKey.Query().
Where(apikey.IDEQ(key.ID)). Where(apikey.IDEQ(key.ID)).
Only(mixins.SkipSoftDelete(ctx)) Only(mixins.SkipSoftDelete(ctx))
require.True(t, dbent.IsNotFound(err), "expected row to be hard deleted") require.True(t, dbent.IsNotFound(err), "expected row to be hard deleted")
......
package repository
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const tempUnschedPrefix = "temp_unsched:account:"
var tempUnschedSetScript = redis.NewScript(`
local key = KEYS[1]
local new_until = tonumber(ARGV[1])
local new_value = ARGV[2]
local new_ttl = tonumber(ARGV[3])
local existing = redis.call('GET', key)
if existing then
local ok, existing_data = pcall(cjson.decode, existing)
if ok and existing_data and existing_data.until_unix then
local existing_until = tonumber(existing_data.until_unix)
if existing_until and new_until <= existing_until then
return 0
end
end
end
redis.call('SET', key, new_value, 'EX', new_ttl)
return 1
`)
type tempUnschedCache struct {
rdb *redis.Client
}
func NewTempUnschedCache(rdb *redis.Client) service.TempUnschedCache {
return &tempUnschedCache{rdb: rdb}
}
// SetTempUnsched 设置临时不可调度状态(只延长不缩短)
func (c *tempUnschedCache) SetTempUnsched(ctx context.Context, accountID int64, state *service.TempUnschedState) error {
key := fmt.Sprintf("%s%d", tempUnschedPrefix, accountID)
stateJSON, err := json.Marshal(state)
if err != nil {
return fmt.Errorf("marshal state: %w", err)
}
ttl := time.Until(time.Unix(state.UntilUnix, 0))
if ttl <= 0 {
return nil // 已过期,不设置
}
ttlSeconds := int(ttl.Seconds())
if ttlSeconds < 1 {
ttlSeconds = 1
}
_, err = tempUnschedSetScript.Run(ctx, c.rdb, []string{key}, state.UntilUnix, string(stateJSON), ttlSeconds).Result()
return err
}
// GetTempUnsched 获取临时不可调度状态
func (c *tempUnschedCache) GetTempUnsched(ctx context.Context, accountID int64) (*service.TempUnschedState, error) {
key := fmt.Sprintf("%s%d", tempUnschedPrefix, accountID)
val, err := c.rdb.Get(ctx, key).Result()
if err == redis.Nil {
return nil, nil
}
if err != nil {
return nil, err
}
var state service.TempUnschedState
if err := json.Unmarshal([]byte(val), &state); err != nil {
return nil, fmt.Errorf("unmarshal state: %w", err)
}
return &state, nil
}
// DeleteTempUnsched 删除临时不可调度状态
func (c *tempUnschedCache) DeleteTempUnsched(ctx context.Context, accountID int64) error {
key := fmt.Sprintf("%s%d", tempUnschedPrefix, accountID)
return c.rdb.Del(ctx, key).Err()
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment