Unverified Commit 4587c3e5 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #670 from DaydreamCoding/feat/admin-apikey-group-update

feat(admin): 添加管理员直接修改用户 API Key 分组的功能 
parents be18bc6f 6f9e6903
...@@ -103,7 +103,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -103,7 +103,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
proxyRepository := repository.NewProxyRepository(client, db) proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient) proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator) adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService) adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
...@@ -192,7 +192,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -192,7 +192,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
errorPassthroughCache := repository.NewErrorPassthroughCache(redisClient) errorPassthroughCache := repository.NewErrorPassthroughCache(redisClient)
errorPassthroughService := service.NewErrorPassthroughService(errorPassthroughRepository, errorPassthroughCache) errorPassthroughService := service.NewErrorPassthroughService(errorPassthroughRepository, errorPassthroughCache)
errorPassthroughHandler := admin.NewErrorPassthroughHandler(errorPassthroughService) errorPassthroughHandler := admin.NewErrorPassthroughHandler(errorPassthroughService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler) adminAPIKeyHandler := admin.NewAdminAPIKeyHandler(adminService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler)
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig) usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig) gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig) openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig)
......
...@@ -109,6 +109,7 @@ require ( ...@@ -109,6 +109,7 @@ require (
github.com/goccy/go-json v0.10.2 // indirect github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-cmp v0.7.0 // indirect github.com/google/go-cmp v0.7.0 // indirect
github.com/google/go-querystring v1.1.0 // indirect github.com/google/go-querystring v1.1.0 // indirect
github.com/google/subcommands v1.2.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect
github.com/hashicorp/hcl/v2 v2.18.1 // indirect github.com/hashicorp/hcl/v2 v2.18.1 // indirect
...@@ -177,6 +178,7 @@ require ( ...@@ -177,6 +178,7 @@ require (
golang.org/x/mod v0.32.0 // indirect golang.org/x/mod v0.32.0 // indirect
golang.org/x/sys v0.41.0 // indirect golang.org/x/sys v0.41.0 // indirect
golang.org/x/text v0.34.0 // indirect golang.org/x/text v0.34.0 // indirect
golang.org/x/tools v0.41.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect
modernc.org/libc v1.67.6 // indirect modernc.org/libc v1.67.6 // indirect
......
...@@ -182,6 +182,8 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17 ...@@ -182,6 +182,8 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4= github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
......
...@@ -403,5 +403,23 @@ func (s *stubAdminService) UpdateGroupSortOrders(ctx context.Context, updates [] ...@@ -403,5 +403,23 @@ func (s *stubAdminService) UpdateGroupSortOrders(ctx context.Context, updates []
return nil return nil
} }
func (s *stubAdminService) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*service.AdminUpdateAPIKeyGroupIDResult, error) {
for i := range s.apiKeys {
if s.apiKeys[i].ID == keyID {
k := s.apiKeys[i]
if groupID != nil {
if *groupID == 0 {
k.GroupID = nil
} else {
gid := *groupID
k.GroupID = &gid
}
}
return &service.AdminUpdateAPIKeyGroupIDResult{APIKey: &k}, nil
}
}
return nil, service.ErrAPIKeyNotFound
}
// Ensure stub implements interface. // Ensure stub implements interface.
var _ service.AdminService = (*stubAdminService)(nil) var _ service.AdminService = (*stubAdminService)(nil)
package admin
import (
"strconv"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// AdminAPIKeyHandler handles admin API key management
type AdminAPIKeyHandler struct {
adminService service.AdminService
}
// NewAdminAPIKeyHandler creates a new admin API key handler
func NewAdminAPIKeyHandler(adminService service.AdminService) *AdminAPIKeyHandler {
return &AdminAPIKeyHandler{
adminService: adminService,
}
}
// AdminUpdateAPIKeyGroupRequest represents the request to update an API key's group
type AdminUpdateAPIKeyGroupRequest struct {
GroupID *int64 `json:"group_id"` // nil=不修改, 0=解绑, >0=绑定到目标分组
}
// UpdateGroup handles updating an API key's group binding
// PUT /api/v1/admin/api-keys/:id
func (h *AdminAPIKeyHandler) UpdateGroup(c *gin.Context) {
keyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid API key ID")
return
}
var req AdminUpdateAPIKeyGroupRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
result, err := h.adminService.AdminUpdateAPIKeyGroupID(c.Request.Context(), keyID, req.GroupID)
if err != nil {
response.ErrorFrom(c, err)
return
}
resp := struct {
APIKey *dto.APIKey `json:"api_key"`
AutoGrantedGroupAccess bool `json:"auto_granted_group_access"`
GrantedGroupID *int64 `json:"granted_group_id,omitempty"`
GrantedGroupName string `json:"granted_group_name,omitempty"`
}{
APIKey: dto.APIKeyFromService(result.APIKey),
AutoGrantedGroupAccess: result.AutoGrantedGroupAccess,
GrantedGroupID: result.GrantedGroupID,
GrantedGroupName: result.GrantedGroupName,
}
response.Success(c, resp)
}
package admin
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func setupAPIKeyHandler(adminSvc service.AdminService) *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
h := NewAdminAPIKeyHandler(adminSvc)
router.PUT("/api/v1/admin/api-keys/:id", h.UpdateGroup)
return router
}
func TestAdminAPIKeyHandler_UpdateGroup_InvalidID(t *testing.T) {
router := setupAPIKeyHandler(newStubAdminService())
body := `{"group_id": 2}`
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/abc", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
require.Contains(t, rec.Body.String(), "Invalid API key ID")
}
func TestAdminAPIKeyHandler_UpdateGroup_InvalidJSON(t *testing.T) {
router := setupAPIKeyHandler(newStubAdminService())
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{bad json`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
require.Contains(t, rec.Body.String(), "Invalid request")
}
func TestAdminAPIKeyHandler_UpdateGroup_KeyNotFound(t *testing.T) {
router := setupAPIKeyHandler(newStubAdminService())
body := `{"group_id": 2}`
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/999", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
// ErrAPIKeyNotFound maps to 404
require.Equal(t, http.StatusNotFound, rec.Code)
}
func TestAdminAPIKeyHandler_UpdateGroup_BindGroup(t *testing.T) {
router := setupAPIKeyHandler(newStubAdminService())
body := `{"group_id": 2}`
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp struct {
Code int `json:"code"`
Data json.RawMessage `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
var data struct {
APIKey struct {
ID int64 `json:"id"`
GroupID *int64 `json:"group_id"`
} `json:"api_key"`
AutoGrantedGroupAccess bool `json:"auto_granted_group_access"`
}
require.NoError(t, json.Unmarshal(resp.Data, &data))
require.Equal(t, int64(10), data.APIKey.ID)
require.NotNil(t, data.APIKey.GroupID)
require.Equal(t, int64(2), *data.APIKey.GroupID)
}
func TestAdminAPIKeyHandler_UpdateGroup_Unbind(t *testing.T) {
svc := newStubAdminService()
gid := int64(2)
svc.apiKeys[0].GroupID = &gid
router := setupAPIKeyHandler(svc)
body := `{"group_id": 0}`
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp struct {
Data struct {
APIKey struct {
GroupID *int64 `json:"group_id"`
} `json:"api_key"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Nil(t, resp.Data.APIKey.GroupID)
}
func TestAdminAPIKeyHandler_UpdateGroup_ServiceError(t *testing.T) {
svc := &failingUpdateGroupService{
stubAdminService: newStubAdminService(),
err: errors.New("internal failure"),
}
router := setupAPIKeyHandler(svc)
body := `{"group_id": 2}`
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusInternalServerError, rec.Code)
}
// H2: empty body → group_id is nil → no-op, returns original key
func TestAdminAPIKeyHandler_UpdateGroup_EmptyBody_NoChange(t *testing.T) {
router := setupAPIKeyHandler(newStubAdminService())
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{}`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp struct {
Code int `json:"code"`
Data struct {
APIKey struct {
ID int64 `json:"id"`
} `json:"api_key"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Equal(t, int64(10), resp.Data.APIKey.ID)
}
// M2: service returns GROUP_NOT_ACTIVE → handler maps to 400
func TestAdminAPIKeyHandler_UpdateGroup_GroupNotActive(t *testing.T) {
svc := &failingUpdateGroupService{
stubAdminService: newStubAdminService(),
err: infraerrors.BadRequest("GROUP_NOT_ACTIVE", "target group is not active"),
}
router := setupAPIKeyHandler(svc)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{"group_id": 5}`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
require.Contains(t, rec.Body.String(), "GROUP_NOT_ACTIVE")
}
// M2: service returns INVALID_GROUP_ID → handler maps to 400
func TestAdminAPIKeyHandler_UpdateGroup_NegativeGroupID(t *testing.T) {
svc := &failingUpdateGroupService{
stubAdminService: newStubAdminService(),
err: infraerrors.BadRequest("INVALID_GROUP_ID", "group_id must be non-negative"),
}
router := setupAPIKeyHandler(svc)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{"group_id": -5}`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
require.Contains(t, rec.Body.String(), "INVALID_GROUP_ID")
}
// failingUpdateGroupService overrides AdminUpdateAPIKeyGroupID to return an error.
type failingUpdateGroupService struct {
*stubAdminService
err error
}
func (f *failingUpdateGroupService) AdminUpdateAPIKeyGroupID(_ context.Context, _ int64, _ *int64) (*service.AdminUpdateAPIKeyGroupIDResult, error) {
return nil, f.err
}
...@@ -26,6 +26,7 @@ type AdminHandlers struct { ...@@ -26,6 +26,7 @@ type AdminHandlers struct {
Usage *admin.UsageHandler Usage *admin.UsageHandler
UserAttribute *admin.UserAttributeHandler UserAttribute *admin.UserAttributeHandler
ErrorPassthrough *admin.ErrorPassthroughHandler ErrorPassthrough *admin.ErrorPassthroughHandler
APIKey *admin.AdminAPIKeyHandler
} }
// Handlers contains all HTTP handlers // Handlers contains all HTTP handlers
......
...@@ -945,6 +945,9 @@ func (r *stubUserRepoForHandler) RemoveGroupFromAllowedGroups(context.Context, i ...@@ -945,6 +945,9 @@ func (r *stubUserRepoForHandler) RemoveGroupFromAllowedGroups(context.Context, i
func (r *stubUserRepoForHandler) UpdateTotpSecret(context.Context, int64, *string) error { return nil } func (r *stubUserRepoForHandler) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (r *stubUserRepoForHandler) EnableTotp(context.Context, int64) error { return nil } func (r *stubUserRepoForHandler) EnableTotp(context.Context, int64) error { return nil }
func (r *stubUserRepoForHandler) DisableTotp(context.Context, int64) error { return nil } func (r *stubUserRepoForHandler) DisableTotp(context.Context, int64) error { return nil }
func (r *stubUserRepoForHandler) AddGroupToAllowedGroups(context.Context, int64, int64) error {
return nil
}
// ==================== NewSoraClientHandler ==================== // ==================== NewSoraClientHandler ====================
......
...@@ -29,6 +29,7 @@ func ProvideAdminHandlers( ...@@ -29,6 +29,7 @@ func ProvideAdminHandlers(
usageHandler *admin.UsageHandler, usageHandler *admin.UsageHandler,
userAttributeHandler *admin.UserAttributeHandler, userAttributeHandler *admin.UserAttributeHandler,
errorPassthroughHandler *admin.ErrorPassthroughHandler, errorPassthroughHandler *admin.ErrorPassthroughHandler,
apiKeyHandler *admin.AdminAPIKeyHandler,
) *AdminHandlers { ) *AdminHandlers {
return &AdminHandlers{ return &AdminHandlers{
Dashboard: dashboardHandler, Dashboard: dashboardHandler,
...@@ -51,6 +52,7 @@ func ProvideAdminHandlers( ...@@ -51,6 +52,7 @@ func ProvideAdminHandlers(
Usage: usageHandler, Usage: usageHandler,
UserAttribute: userAttributeHandler, UserAttribute: userAttributeHandler,
ErrorPassthrough: errorPassthroughHandler, ErrorPassthrough: errorPassthroughHandler,
APIKey: apiKeyHandler,
} }
} }
...@@ -138,6 +140,7 @@ var ProviderSet = wire.NewSet( ...@@ -138,6 +140,7 @@ var ProviderSet = wire.NewSet(
admin.NewUsageHandler, admin.NewUsageHandler,
admin.NewUserAttributeHandler, admin.NewUserAttributeHandler,
admin.NewErrorPassthroughHandler, admin.NewErrorPassthroughHandler,
admin.NewAdminAPIKeyHandler,
// AdminHandlers and Handlers constructors // AdminHandlers and Handlers constructors
ProvideAdminHandlers, ProvideAdminHandlers,
......
...@@ -171,8 +171,9 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro ...@@ -171,8 +171,9 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
// 则会更新已删除的记录。 // 则会更新已删除的记录。
// 这里选择 Update().Where(),确保只有未软删除记录能被更新。 // 这里选择 Update().Where(),确保只有未软删除记录能被更新。
// 同时显式设置 updated_at,避免二次查询带来的并发可见性问题。 // 同时显式设置 updated_at,避免二次查询带来的并发可见性问题。
client := clientFromContext(ctx, r.client)
now := time.Now() now := time.Now()
builder := r.client.APIKey.Update(). builder := client.APIKey.Update().
Where(apikey.IDEQ(key.ID), apikey.DeletedAtIsNil()). Where(apikey.IDEQ(key.ID), apikey.DeletedAtIsNil()).
SetName(key.Name). SetName(key.Name).
SetStatus(key.Status). SetStatus(key.Status).
......
...@@ -429,6 +429,16 @@ func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, ...@@ -429,6 +429,16 @@ func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool,
return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx) return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx)
} }
func (r *userRepository) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
client := clientFromContext(ctx, r.client)
return client.UserAllowedGroup.Create().
SetUserID(userID).
SetGroupID(groupID).
OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID).
DoNothing().
Exec(ctx)
}
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) { func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
// 仅操作 user_allowed_groups 联接表,legacy users.allowed_groups 列已弃用。 // 仅操作 user_allowed_groups 联接表,legacy users.allowed_groups 列已弃用。
affected, err := r.client.UserAllowedGroup.Delete(). affected, err := r.client.UserAllowedGroup.Delete().
......
...@@ -619,7 +619,7 @@ func newContractDeps(t *testing.T) *contractDeps { ...@@ -619,7 +619,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo() settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg) settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil) adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
...@@ -779,6 +779,10 @@ func (r *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID ...@@ -779,6 +779,10 @@ func (r *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
return 0, errors.New("not implemented") return 0, errors.New("not implemented")
} }
func (r *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
return errors.New("not implemented")
}
func (r *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error { func (r *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
......
...@@ -181,6 +181,10 @@ func (s *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID ...@@ -181,6 +181,10 @@ func (s *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
panic("unexpected RemoveGroupFromAllowedGroups call") panic("unexpected RemoveGroupFromAllowedGroups call")
} }
func (s *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
panic("unexpected AddGroupToAllowedGroups call")
}
func (s *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error { func (s *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
panic("unexpected UpdateTotpSecret call") panic("unexpected UpdateTotpSecret call")
} }
......
...@@ -75,6 +75,16 @@ func RegisterAdminRoutes( ...@@ -75,6 +75,16 @@ func RegisterAdminRoutes(
// 错误透传规则管理 // 错误透传规则管理
registerErrorPassthroughRoutes(admin, h) registerErrorPassthroughRoutes(admin, h)
// API Key 管理
registerAdminAPIKeyRoutes(admin, h)
}
}
func registerAdminAPIKeyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
apiKeys := admin.Group("/api-keys")
{
apiKeys.PUT("/:id", h.Admin.APIKey.UpdateGroup)
} }
} }
......
...@@ -9,6 +9,8 @@ import ( ...@@ -9,6 +9,8 @@ import (
"strings" "strings"
"time" "time"
dbent "github.com/Wei-Shaw/sub2api/ent"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
...@@ -42,6 +44,9 @@ type AdminService interface { ...@@ -42,6 +44,9 @@ type AdminService interface {
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error)
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
// API Key management (admin)
AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error)
// Account management // Account management
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error)
GetAccount(ctx context.Context, id int64) (*Account, error) GetAccount(ctx context.Context, id int64) (*Account, error)
...@@ -242,6 +247,14 @@ type BulkUpdateAccountResult struct { ...@@ -242,6 +247,14 @@ type BulkUpdateAccountResult struct {
Error string `json:"error,omitempty"` Error string `json:"error,omitempty"`
} }
// AdminUpdateAPIKeyGroupIDResult is the result of AdminUpdateAPIKeyGroupID.
type AdminUpdateAPIKeyGroupIDResult struct {
APIKey *APIKey
AutoGrantedGroupAccess bool // true if a new exclusive group permission was auto-added
GrantedGroupID *int64 // the group ID that was auto-granted
GrantedGroupName string // the group name that was auto-granted
}
// BulkUpdateAccountsResult is the aggregated response for bulk updates. // BulkUpdateAccountsResult is the aggregated response for bulk updates.
type BulkUpdateAccountsResult struct { type BulkUpdateAccountsResult struct {
Success int `json:"success"` Success int `json:"success"`
...@@ -406,6 +419,7 @@ type adminServiceImpl struct { ...@@ -406,6 +419,7 @@ type adminServiceImpl struct {
proxyProber ProxyExitInfoProber proxyProber ProxyExitInfoProber
proxyLatencyCache ProxyLatencyCache proxyLatencyCache ProxyLatencyCache
authCacheInvalidator APIKeyAuthCacheInvalidator authCacheInvalidator APIKeyAuthCacheInvalidator
entClient *dbent.Client // 用于开启数据库事务
} }
type userGroupRateBatchReader interface { type userGroupRateBatchReader interface {
...@@ -430,6 +444,7 @@ func NewAdminService( ...@@ -430,6 +444,7 @@ func NewAdminService(
proxyProber ProxyExitInfoProber, proxyProber ProxyExitInfoProber,
proxyLatencyCache ProxyLatencyCache, proxyLatencyCache ProxyLatencyCache,
authCacheInvalidator APIKeyAuthCacheInvalidator, authCacheInvalidator APIKeyAuthCacheInvalidator,
entClient *dbent.Client,
) AdminService { ) AdminService {
return &adminServiceImpl{ return &adminServiceImpl{
userRepo: userRepo, userRepo: userRepo,
...@@ -444,6 +459,7 @@ func NewAdminService( ...@@ -444,6 +459,7 @@ func NewAdminService(
proxyProber: proxyProber, proxyProber: proxyProber,
proxyLatencyCache: proxyLatencyCache, proxyLatencyCache: proxyLatencyCache,
authCacheInvalidator: authCacheInvalidator, authCacheInvalidator: authCacheInvalidator,
entClient: entClient,
} }
} }
...@@ -1185,6 +1201,103 @@ func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates [] ...@@ -1185,6 +1201,103 @@ func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []
return s.groupRepo.UpdateSortOrders(ctx, updates) return s.groupRepo.UpdateSortOrders(ctx, updates)
} }
// AdminUpdateAPIKeyGroupID 管理员修改 API Key 分组绑定
// groupID: nil=不修改, 指向0=解绑, 指向正整数=绑定到目标分组
func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, keyID)
if err != nil {
return nil, err
}
if groupID == nil {
// nil 表示不修改,直接返回
return &AdminUpdateAPIKeyGroupIDResult{APIKey: apiKey}, nil
}
if *groupID < 0 {
return nil, infraerrors.BadRequest("INVALID_GROUP_ID", "group_id must be non-negative")
}
result := &AdminUpdateAPIKeyGroupIDResult{}
if *groupID == 0 {
// 0 表示解绑分组(不修改 user_allowed_groups,避免影响用户其他 Key)
apiKey.GroupID = nil
apiKey.Group = nil
} else {
// 验证目标分组存在且状态为 active
group, err := s.groupRepo.GetByID(ctx, *groupID)
if err != nil {
return nil, err
}
if group.Status != StatusActive {
return nil, infraerrors.BadRequest("GROUP_NOT_ACTIVE", "target group is not active")
}
// 订阅类型分组:不允许通过此 API 直接绑定,需通过订阅管理流程
if group.IsSubscriptionType() {
return nil, infraerrors.BadRequest("SUBSCRIPTION_GROUP_NOT_ALLOWED", "subscription groups must be managed through the subscription workflow")
}
gid := *groupID
apiKey.GroupID = &gid
apiKey.Group = group
// 专属标准分组:使用事务保证「添加分组权限」与「更新 API Key」的原子性
if group.IsExclusive {
opCtx := ctx
var tx *dbent.Tx
if s.entClient == nil {
logger.LegacyPrintf("service.admin", "Warning: entClient is nil, skipping transaction protection for exclusive group binding")
} else {
var txErr error
tx, txErr = s.entClient.Tx(ctx)
if txErr != nil {
return nil, fmt.Errorf("begin transaction: %w", txErr)
}
defer func() { _ = tx.Rollback() }()
opCtx = dbent.NewTxContext(ctx, tx)
}
if addErr := s.userRepo.AddGroupToAllowedGroups(opCtx, apiKey.UserID, gid); addErr != nil {
return nil, fmt.Errorf("add group to user allowed groups: %w", addErr)
}
if err := s.apiKeyRepo.Update(opCtx, apiKey); err != nil {
return nil, fmt.Errorf("update api key: %w", err)
}
if tx != nil {
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("commit transaction: %w", err)
}
}
result.AutoGrantedGroupAccess = true
result.GrantedGroupID = &gid
result.GrantedGroupName = group.Name
// 失效认证缓存(在事务提交后执行)
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, apiKey.Key)
}
result.APIKey = apiKey
return result, nil
}
}
// 非专属分组 / 解绑:无需事务,单步更新即可
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
return nil, fmt.Errorf("update api key: %w", err)
}
// 失效认证缓存
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, apiKey.Key)
}
result.APIKey = apiKey
return result, nil
}
// Account management implementations // Account management implementations
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) { func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
......
//go:build unit
package service
import (
"context"
"errors"
"testing"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// Stubs
// ---------------------------------------------------------------------------
// userRepoStubForGroupUpdate implements UserRepository for AdminUpdateAPIKeyGroupID tests.
type userRepoStubForGroupUpdate struct {
addGroupErr error
addGroupCalled bool
addedUserID int64
addedGroupID int64
}
func (s *userRepoStubForGroupUpdate) AddGroupToAllowedGroups(_ context.Context, userID int64, groupID int64) error {
s.addGroupCalled = true
s.addedUserID = userID
s.addedGroupID = groupID
return s.addGroupErr
}
func (s *userRepoStubForGroupUpdate) Create(context.Context, *User) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) GetByID(context.Context, int64) (*User, error) { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) GetByEmail(context.Context, string) (*User, error) { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, error) { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) UpdateBalance(context.Context, int64, float64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) DeductBalance(context.Context, int64, float64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) UpdateConcurrency(context.Context, int64, int) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (bool, error) { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *string) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") }
// apiKeyRepoStubForGroupUpdate implements APIKeyRepository for AdminUpdateAPIKeyGroupID tests.
type apiKeyRepoStubForGroupUpdate struct {
key *APIKey
getErr error
updateErr error
updated *APIKey // captures what was passed to Update
}
func (s *apiKeyRepoStubForGroupUpdate) GetByID(_ context.Context, _ int64) (*APIKey, error) {
if s.getErr != nil {
return nil, s.getErr
}
clone := *s.key
return &clone, nil
}
func (s *apiKeyRepoStubForGroupUpdate) Update(_ context.Context, key *APIKey) error {
if s.updateErr != nil {
return s.updateErr
}
clone := *key
s.updated = &clone
return nil
}
// Unused methods – panic on unexpected call.
func (s *apiKeyRepoStubForGroupUpdate) Create(context.Context, *APIKey) error { panic("unexpected") }
func (s *apiKeyRepoStubForGroupUpdate) GetKeyAndOwnerID(context.Context, int64) (string, int64, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) GetByKey(context.Context, string) (*APIKey, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) GetByKeyForAuth(context.Context, string) (*APIKey, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
func (s *apiKeyRepoStubForGroupUpdate) ListByUserID(context.Context, int64, pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) CountByUserID(context.Context, int64) (int64, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) ExistsByKey(context.Context, string) (bool, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) SearchAPIKeys(context.Context, int64, string, int) ([]APIKey, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) ClearGroupIDByGroupID(context.Context, int64) (int64, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) CountByGroupID(context.Context, int64) (int64, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) ListKeysByUserID(context.Context, int64) ([]string, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) ListKeysByGroupID(context.Context, int64) ([]string, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) IncrementQuotaUsed(context.Context, int64, float64) (float64, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) UpdateLastUsed(context.Context, int64, time.Time) error {
panic("unexpected")
}
// groupRepoStubForGroupUpdate implements GroupRepository for AdminUpdateAPIKeyGroupID tests.
type groupRepoStubForGroupUpdate struct {
group *Group
getErr error
lastGetByIDArg int64
}
func (s *groupRepoStubForGroupUpdate) GetByID(_ context.Context, id int64) (*Group, error) {
s.lastGetByIDArg = id
if s.getErr != nil {
return nil, s.getErr
}
clone := *s.group
return &clone, nil
}
// Unused methods – panic on unexpected call.
func (s *groupRepoStubForGroupUpdate) Create(context.Context, *Group) error { panic("unexpected") }
func (s *groupRepoStubForGroupUpdate) GetByIDLite(context.Context, int64) (*Group, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) Update(context.Context, *Group) error { panic("unexpected") }
func (s *groupRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
func (s *groupRepoStubForGroupUpdate) DeleteCascade(context.Context, int64) ([]int64, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, *bool) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) ListActive(context.Context) ([]Group, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) ListActiveByPlatform(context.Context, string) ([]Group, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) ExistsByName(context.Context, string) (bool, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) GetAccountIDsByGroupIDs(context.Context, []int64) ([]int64, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) BindAccountsToGroup(context.Context, int64, []int64) error {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) UpdateSortOrders(context.Context, []GroupSortOrderUpdate) error {
panic("unexpected")
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
func TestAdminService_AdminUpdateAPIKeyGroupID_KeyNotFound(t *testing.T) {
repo := &apiKeyRepoStubForGroupUpdate{getErr: ErrAPIKeyNotFound}
svc := &adminServiceImpl{apiKeyRepo: repo}
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 999, int64Ptr(1))
require.ErrorIs(t, err, ErrAPIKeyNotFound)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_NilGroupID_NoOp(t *testing.T) {
existing := &APIKey{ID: 1, Key: "sk-test", GroupID: int64Ptr(5)}
repo := &apiKeyRepoStubForGroupUpdate{key: existing}
svc := &adminServiceImpl{apiKeyRepo: repo}
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, nil)
require.NoError(t, err)
require.Equal(t, int64(1), got.APIKey.ID)
// Update should NOT have been called (updated stays nil)
require.Nil(t, repo.updated)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_Unbind(t *testing.T) {
existing := &APIKey{ID: 1, Key: "sk-test", GroupID: int64Ptr(5), Group: &Group{ID: 5, Name: "Old"}}
repo := &apiKeyRepoStubForGroupUpdate{key: existing}
cache := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{apiKeyRepo: repo, authCacheInvalidator: cache}
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(0))
require.NoError(t, err)
require.Nil(t, got.APIKey.GroupID, "group_id should be nil after unbind")
require.Nil(t, got.APIKey.Group, "group object should be nil after unbind")
require.NotNil(t, repo.updated, "Update should have been called")
require.Nil(t, repo.updated.GroupID)
require.Equal(t, []string{"sk-test"}, cache.keys, "cache should be invalidated")
}
func TestAdminService_AdminUpdateAPIKeyGroupID_BindActiveGroup(t *testing.T) {
existing := &APIKey{ID: 1, Key: "sk-test", GroupID: nil}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Pro", Status: StatusActive}}
cache := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, authCacheInvalidator: cache}
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
require.NoError(t, err)
require.NotNil(t, got.APIKey.GroupID)
require.Equal(t, int64(10), *got.APIKey.GroupID)
require.Equal(t, int64(10), *apiKeyRepo.updated.GroupID)
require.Equal(t, []string{"sk-test"}, cache.keys)
// M3: verify correct group ID was passed to repo
require.Equal(t, int64(10), groupRepo.lastGetByIDArg)
// C1 fix: verify Group object is populated
require.NotNil(t, got.APIKey.Group)
require.Equal(t, "Pro", got.APIKey.Group.Name)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_SameGroup_Idempotent(t *testing.T) {
existing := &APIKey{ID: 1, Key: "sk-test", GroupID: int64Ptr(10), Group: &Group{ID: 10, Name: "Pro"}}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Pro", Status: StatusActive}}
cache := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, authCacheInvalidator: cache}
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
require.NoError(t, err)
require.NotNil(t, got.APIKey.GroupID)
require.Equal(t, int64(10), *got.APIKey.GroupID)
// Update is still called (current impl doesn't short-circuit on same group)
require.NotNil(t, apiKeyRepo.updated)
require.Equal(t, []string{"sk-test"}, cache.keys)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_GroupNotFound(t *testing.T) {
existing := &APIKey{ID: 1, Key: "sk-test"}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{getErr: ErrGroupNotFound}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo}
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(99))
require.ErrorIs(t, err, ErrGroupNotFound)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_GroupNotActive(t *testing.T) {
existing := &APIKey{ID: 1, Key: "sk-test"}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 5, Status: StatusDisabled}}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo}
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(5))
require.Error(t, err)
require.Equal(t, "GROUP_NOT_ACTIVE", infraerrors.Reason(err))
}
func TestAdminService_AdminUpdateAPIKeyGroupID_UpdateFails(t *testing.T) {
existing := &APIKey{ID: 1, Key: "sk-test", GroupID: int64Ptr(3)}
repo := &apiKeyRepoStubForGroupUpdate{key: existing, updateErr: errors.New("db write error")}
svc := &adminServiceImpl{apiKeyRepo: repo}
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(0))
require.Error(t, err)
require.Contains(t, err.Error(), "update api key")
}
func TestAdminService_AdminUpdateAPIKeyGroupID_NegativeGroupID(t *testing.T) {
existing := &APIKey{ID: 1, Key: "sk-test"}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo}
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(-5))
require.Error(t, err)
require.Equal(t, "INVALID_GROUP_ID", infraerrors.Reason(err))
}
func TestAdminService_AdminUpdateAPIKeyGroupID_PointerIsolation(t *testing.T) {
existing := &APIKey{ID: 1, Key: "sk-test", GroupID: nil}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Pro", Status: StatusActive}}
cache := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, authCacheInvalidator: cache}
inputGID := int64(10)
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, &inputGID)
require.NoError(t, err)
require.NotNil(t, got.APIKey.GroupID)
// Mutating the input pointer must NOT affect the stored value
inputGID = 999
require.Equal(t, int64(10), *got.APIKey.GroupID)
require.Equal(t, int64(10), *apiKeyRepo.updated.GroupID)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_NilCacheInvalidator(t *testing.T) {
existing := &APIKey{ID: 1, Key: "sk-test"}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 7, Status: StatusActive}}
// authCacheInvalidator is nil – should not panic
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo}
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(7))
require.NoError(t, err)
require.NotNil(t, got.APIKey.GroupID)
require.Equal(t, int64(7), *got.APIKey.GroupID)
}
// ---------------------------------------------------------------------------
// Tests: AllowedGroup auto-sync
// ---------------------------------------------------------------------------
func TestAdminService_AdminUpdateAPIKeyGroupID_ExclusiveGroup_AddsAllowedGroup(t *testing.T) {
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Exclusive", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeStandard}}
userRepo := &userRepoStubForGroupUpdate{}
cache := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, authCacheInvalidator: cache}
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
require.NoError(t, err)
require.NotNil(t, got.APIKey.GroupID)
require.Equal(t, int64(10), *got.APIKey.GroupID)
// 验证 AddGroupToAllowedGroups 被调用,且参数正确
require.True(t, userRepo.addGroupCalled)
require.Equal(t, int64(42), userRepo.addedUserID)
require.Equal(t, int64(10), userRepo.addedGroupID)
// 验证 result 标记了自动授权
require.True(t, got.AutoGrantedGroupAccess)
require.NotNil(t, got.GrantedGroupID)
require.Equal(t, int64(10), *got.GrantedGroupID)
require.Equal(t, "Exclusive", got.GrantedGroupName)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_NonExclusiveGroup_NoAllowedGroupUpdate(t *testing.T) {
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Public", Status: StatusActive, IsExclusive: false, SubscriptionType: SubscriptionTypeStandard}}
userRepo := &userRepoStubForGroupUpdate{}
cache := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, authCacheInvalidator: cache}
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
require.NoError(t, err)
require.NotNil(t, got.APIKey.GroupID)
// 非专属分组不触发 AddGroupToAllowedGroups
require.False(t, userRepo.addGroupCalled)
require.False(t, got.AutoGrantedGroupAccess)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_Blocked(t *testing.T) {
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeSubscription}}
userRepo := &userRepoStubForGroupUpdate{}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo}
// 订阅类型分组应被阻止绑定
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
require.Error(t, err)
require.Equal(t, "SUBSCRIPTION_GROUP_NOT_ALLOWED", infraerrors.Reason(err))
require.False(t, userRepo.addGroupCalled)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_ExclusiveGroup_AllowedGroupAddFails_ReturnsError(t *testing.T) {
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Exclusive", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeStandard}}
userRepo := &userRepoStubForGroupUpdate{addGroupErr: errors.New("db error")}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo}
// 严格模式:AddGroupToAllowedGroups 失败时,整体操作报错
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
require.Error(t, err)
require.Contains(t, err.Error(), "add group to user allowed groups")
require.True(t, userRepo.addGroupCalled)
// apiKey 不应被更新
require.Nil(t, apiKeyRepo.updated)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_Unbind_NoAllowedGroupUpdate(t *testing.T) {
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: int64Ptr(10), Group: &Group{ID: 10, Name: "Exclusive"}}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
userRepo := &userRepoStubForGroupUpdate{}
cache := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, userRepo: userRepo, authCacheInvalidator: cache}
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(0))
require.NoError(t, err)
require.Nil(t, got.APIKey.GroupID)
// 解绑时不修改 allowed_groups
require.False(t, userRepo.addGroupCalled)
require.False(t, got.AutoGrantedGroupAccess)
}
...@@ -93,6 +93,10 @@ func (s *userRepoStub) RemoveGroupFromAllowedGroups(ctx context.Context, groupID ...@@ -93,6 +93,10 @@ func (s *userRepoStub) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
panic("unexpected RemoveGroupFromAllowedGroups call") panic("unexpected RemoveGroupFromAllowedGroups call")
} }
func (s *userRepoStub) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
panic("unexpected AddGroupToAllowedGroups call")
}
func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error { func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
panic("unexpected UpdateTotpSecret call") panic("unexpected UpdateTotpSecret call")
} }
......
...@@ -165,6 +165,9 @@ func (r *stubUserRepoForQuota) RemoveGroupFromAllowedGroups(context.Context, int ...@@ -165,6 +165,9 @@ func (r *stubUserRepoForQuota) RemoveGroupFromAllowedGroups(context.Context, int
func (r *stubUserRepoForQuota) UpdateTotpSecret(context.Context, int64, *string) error { return nil } func (r *stubUserRepoForQuota) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (r *stubUserRepoForQuota) EnableTotp(context.Context, int64) error { return nil } func (r *stubUserRepoForQuota) EnableTotp(context.Context, int64) error { return nil }
func (r *stubUserRepoForQuota) DisableTotp(context.Context, int64) error { return nil } func (r *stubUserRepoForQuota) DisableTotp(context.Context, int64) error { return nil }
func (r *stubUserRepoForQuota) AddGroupToAllowedGroups(context.Context, int64, int64) error {
return nil
}
// ==================== 辅助函数:构造带 CDN 缓存的 SoraS3Storage ==================== // ==================== 辅助函数:构造带 CDN 缓存的 SoraS3Storage ====================
......
...@@ -40,6 +40,8 @@ type UserRepository interface { ...@@ -40,6 +40,8 @@ type UserRepository interface {
UpdateConcurrency(ctx context.Context, id int64, amount int) error UpdateConcurrency(ctx context.Context, id int64, amount int) error
ExistsByEmail(ctx context.Context, email string) (bool, error) ExistsByEmail(ctx context.Context, email string) (bool, error)
RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error)
// AddGroupToAllowedGroups 将指定分组增量添加到用户的 allowed_groups(幂等,冲突忽略)
AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error
// TOTP 双因素认证 // TOTP 双因素认证
UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error
......
...@@ -45,7 +45,8 @@ func (m *mockUserRepo) ExistsByEmail(context.Context, string) (bool, error) { re ...@@ -45,7 +45,8 @@ func (m *mockUserRepo) ExistsByEmail(context.Context, string) (bool, error) { re
func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
return 0, nil return 0, nil
} }
func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil } func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil }
func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil } func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil }
func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil } func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment