Commit a4953785 authored by IanShaw027's avatar IanShaw027
Browse files

fix(lint): 修复所有 Go 命名规范问题

- 全局替换 ApiKey → APIKey(类型、字段、方法、变量)
- 修复所有 initialism 命名(API, SMTP, HTML, URL 等)
- 添加所有缺失的包注释
- 修复导出符号的注释格式

主要修改:
- ApiKey → APIKey(所有出现的地方)
- ApiKeyID → APIKeyID
- ApiKeyIDs → APIKeyIDs
- TestSmtpConnection → TestSMTPConnection
- HtmlURL → HTMLURL
- 添加 20+ 个包注释
- 修复 10+ 个导出符号注释格式

验证结果:
- ✓ golangci-lint: 0 issues
- ✓ 单元测试: 通过
- ✓ 集成测试: 通过
parent d92e71a1
...@@ -91,7 +91,7 @@ func TestAPIContracts(t *testing.T) { ...@@ -91,7 +91,7 @@ func TestAPIContracts(t *testing.T) {
name: "GET /api/v1/keys (paginated)", name: "GET /api/v1/keys (paginated)",
setup: func(t *testing.T, deps *contractDeps) { setup: func(t *testing.T, deps *contractDeps) {
t.Helper() t.Helper()
deps.apiKeyRepo.MustSeed(&service.ApiKey{ deps.apiKeyRepo.MustSeed(&service.APIKey{
ID: 100, ID: 100,
UserID: 1, UserID: 1,
Key: "sk_custom_1234567890", Key: "sk_custom_1234567890",
...@@ -135,7 +135,7 @@ func TestAPIContracts(t *testing.T) { ...@@ -135,7 +135,7 @@ func TestAPIContracts(t *testing.T) {
{ {
ID: 1, ID: 1,
UserID: 1, UserID: 1,
ApiKeyID: 100, APIKeyID: 100,
AccountID: 200, AccountID: 200,
Model: "claude-3", Model: "claude-3",
InputTokens: 10, InputTokens: 10,
...@@ -150,7 +150,7 @@ func TestAPIContracts(t *testing.T) { ...@@ -150,7 +150,7 @@ func TestAPIContracts(t *testing.T) {
{ {
ID: 2, ID: 2,
UserID: 1, UserID: 1,
ApiKeyID: 100, APIKeyID: 100,
AccountID: 200, AccountID: 200,
Model: "claude-3", Model: "claude-3",
InputTokens: 5, InputTokens: 5,
...@@ -188,7 +188,7 @@ func TestAPIContracts(t *testing.T) { ...@@ -188,7 +188,7 @@ func TestAPIContracts(t *testing.T) {
{ {
ID: 1, ID: 1,
UserID: 1, UserID: 1,
ApiKeyID: 100, APIKeyID: 100,
AccountID: 200, AccountID: 200,
RequestID: "req_123", RequestID: "req_123",
Model: "claude-3", Model: "claude-3",
...@@ -259,13 +259,13 @@ func TestAPIContracts(t *testing.T) { ...@@ -259,13 +259,13 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeyRegistrationEnabled: "true", service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyEmailVerifyEnabled: "false", service.SettingKeyEmailVerifyEnabled: "false",
service.SettingKeySmtpHost: "smtp.example.com", service.SettingKeySMTPHost: "smtp.example.com",
service.SettingKeySmtpPort: "587", service.SettingKeySMTPPort: "587",
service.SettingKeySmtpUsername: "user", service.SettingKeySMTPUsername: "user",
service.SettingKeySmtpPassword: "secret", service.SettingKeySMTPPassword: "secret",
service.SettingKeySmtpFrom: "no-reply@example.com", service.SettingKeySMTPFrom: "no-reply@example.com",
service.SettingKeySmtpFromName: "Sub2API", service.SettingKeySMTPFromName: "Sub2API",
service.SettingKeySmtpUseTLS: "true", service.SettingKeySMTPUseTLS: "true",
service.SettingKeyTurnstileEnabled: "true", service.SettingKeyTurnstileEnabled: "true",
service.SettingKeyTurnstileSiteKey: "site-key", service.SettingKeyTurnstileSiteKey: "site-key",
...@@ -274,9 +274,9 @@ func TestAPIContracts(t *testing.T) { ...@@ -274,9 +274,9 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeySiteName: "Sub2API", service.SettingKeySiteName: "Sub2API",
service.SettingKeySiteLogo: "", service.SettingKeySiteLogo: "",
service.SettingKeySiteSubtitle: "Subtitle", service.SettingKeySiteSubtitle: "Subtitle",
service.SettingKeyApiBaseUrl: "https://api.example.com", service.SettingKeyAPIBaseURL: "https://api.example.com",
service.SettingKeyContactInfo: "support", service.SettingKeyContactInfo: "support",
service.SettingKeyDocUrl: "https://docs.example.com", service.SettingKeyDocURL: "https://docs.example.com",
service.SettingKeyDefaultConcurrency: "5", service.SettingKeyDefaultConcurrency: "5",
service.SettingKeyDefaultBalance: "1.25", service.SettingKeyDefaultBalance: "1.25",
...@@ -371,13 +371,13 @@ func newContractDeps(t *testing.T) *contractDeps { ...@@ -371,13 +371,13 @@ func newContractDeps(t *testing.T) *contractDeps {
cfg := &config.Config{ cfg := &config.Config{
Default: config.DefaultConfig{ Default: config.DefaultConfig{
ApiKeyPrefix: "sk-", APIKeyPrefix: "sk-",
}, },
RunMode: config.RunModeStandard, RunMode: config.RunModeStandard,
} }
userService := service.NewUserService(userRepo) userService := service.NewUserService(userRepo)
apiKeyService := service.NewApiKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg) apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg)
usageRepo := newStubUsageLogRepo() usageRepo := newStubUsageLogRepo()
usageService := service.NewUsageService(usageRepo, userRepo, nil) usageService := service.NewUsageService(usageRepo, userRepo, nil)
...@@ -669,20 +669,20 @@ type stubApiKeyRepo struct { ...@@ -669,20 +669,20 @@ type stubApiKeyRepo struct {
now time.Time now time.Time
nextID int64 nextID int64
byID map[int64]*service.ApiKey byID map[int64]*service.APIKey
byKey map[string]*service.ApiKey byKey map[string]*service.APIKey
} }
func newStubApiKeyRepo(now time.Time) *stubApiKeyRepo { func newStubApiKeyRepo(now time.Time) *stubApiKeyRepo {
return &stubApiKeyRepo{ return &stubApiKeyRepo{
now: now, now: now,
nextID: 100, nextID: 100,
byID: make(map[int64]*service.ApiKey), byID: make(map[int64]*service.APIKey),
byKey: make(map[string]*service.ApiKey), byKey: make(map[string]*service.APIKey),
} }
} }
func (r *stubApiKeyRepo) MustSeed(key *service.ApiKey) { func (r *stubApiKeyRepo) MustSeed(key *service.APIKey) {
if key == nil { if key == nil {
return return
} }
...@@ -691,7 +691,7 @@ func (r *stubApiKeyRepo) MustSeed(key *service.ApiKey) { ...@@ -691,7 +691,7 @@ func (r *stubApiKeyRepo) MustSeed(key *service.ApiKey) {
r.byKey[clone.Key] = &clone r.byKey[clone.Key] = &clone
} }
func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error { func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
if key == nil { if key == nil {
return errors.New("nil key") return errors.New("nil key")
} }
...@@ -711,10 +711,10 @@ func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error ...@@ -711,10 +711,10 @@ func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error
return nil return nil
} }
func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) { func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
key, ok := r.byID[id] key, ok := r.byID[id]
if !ok { if !ok {
return nil, service.ErrApiKeyNotFound return nil, service.ErrAPIKeyNotFound
} }
clone := *key clone := *key
return &clone, nil return &clone, nil
...@@ -723,26 +723,26 @@ func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey ...@@ -723,26 +723,26 @@ func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey
func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) { func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
key, ok := r.byID[id] key, ok := r.byID[id]
if !ok { if !ok {
return 0, service.ErrApiKeyNotFound return 0, service.ErrAPIKeyNotFound
} }
return key.UserID, nil return key.UserID, nil
} }
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) { func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
found, ok := r.byKey[key] found, ok := r.byKey[key]
if !ok { if !ok {
return nil, service.ErrApiKeyNotFound return nil, service.ErrAPIKeyNotFound
} }
clone := *found clone := *found
return &clone, nil return &clone, nil
} }
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error { func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
if key == nil { if key == nil {
return errors.New("nil key") return errors.New("nil key")
} }
if _, ok := r.byID[key.ID]; !ok { if _, ok := r.byID[key.ID]; !ok {
return service.ErrApiKeyNotFound return service.ErrAPIKeyNotFound
} }
if key.UpdatedAt.IsZero() { if key.UpdatedAt.IsZero() {
key.UpdatedAt = r.now key.UpdatedAt = r.now
...@@ -756,14 +756,14 @@ func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error ...@@ -756,14 +756,14 @@ func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error
func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error { func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
key, ok := r.byID[id] key, ok := r.byID[id]
if !ok { if !ok {
return service.ErrApiKeyNotFound return service.ErrAPIKeyNotFound
} }
delete(r.byID, id) delete(r.byID, id)
delete(r.byKey, key.Key) delete(r.byKey, key.Key)
return nil return nil
} }
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) { func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
ids := make([]int64, 0, len(r.byID)) ids := make([]int64, 0, len(r.byID))
for id := range r.byID { for id := range r.byID {
if r.byID[id].UserID == userID { if r.byID[id].UserID == userID {
...@@ -781,7 +781,7 @@ func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params ...@@ -781,7 +781,7 @@ func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params
end = len(ids) end = len(ids)
} }
out := make([]service.ApiKey, 0, end-start) out := make([]service.APIKey, 0, end-start)
for _, id := range ids[start:end] { for _, id := range ids[start:end] {
clone := *r.byID[id] clone := *r.byID[id]
out = append(out, clone) out = append(out, clone)
...@@ -835,11 +835,11 @@ func (r *stubApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, err ...@@ -835,11 +835,11 @@ func (r *stubApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, err
return ok, nil return ok, nil
} }
func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) { func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (r *stubApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) { func (r *stubApiKeyRepo) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
...@@ -882,7 +882,7 @@ func (r *stubUsageLogRepo) ListByUser(ctx context.Context, userID int64, params ...@@ -882,7 +882,7 @@ func (r *stubUsageLogRepo) ListByUser(ctx context.Context, userID int64, params
return out, paginationResult(total, params), nil return out, paginationResult(total, params), nil
} }
func (r *stubUsageLogRepo) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { func (r *stubUsageLogRepo) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
...@@ -895,7 +895,7 @@ func (r *stubUsageLogRepo) ListByUserAndTimeRange(ctx context.Context, userID in ...@@ -895,7 +895,7 @@ func (r *stubUsageLogRepo) ListByUserAndTimeRange(ctx context.Context, userID in
return logs, paginationResult(int64(len(logs)), pagination.PaginationParams{Page: 1, PageSize: 100}), nil return logs, paginationResult(int64(len(logs)), pagination.PaginationParams{Page: 1, PageSize: 100}), nil
} }
func (r *stubUsageLogRepo) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { func (r *stubUsageLogRepo) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
...@@ -927,7 +927,7 @@ func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTi ...@@ -927,7 +927,7 @@ func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTi
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubUsageLogRepo) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error) { func (r *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
...@@ -980,7 +980,7 @@ func (r *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID in ...@@ -980,7 +980,7 @@ func (r *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID in
}, nil }, nil
} }
func (r *stubUsageLogRepo) GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) { func (r *stubUsageLogRepo) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
...@@ -1000,7 +1000,7 @@ func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs [ ...@@ -1000,7 +1000,7 @@ func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs [
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubUsageLogRepo) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) { func (r *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
...@@ -1022,8 +1022,8 @@ func (r *stubUsageLogRepo) ListWithFilters(ctx context.Context, params paginatio ...@@ -1022,8 +1022,8 @@ func (r *stubUsageLogRepo) ListWithFilters(ctx context.Context, params paginatio
// Apply filters // Apply filters
var filtered []service.UsageLog var filtered []service.UsageLog
for _, log := range logs { for _, log := range logs {
// Apply ApiKeyID filter // Apply APIKeyID filter
if filters.ApiKeyID > 0 && log.ApiKeyID != filters.ApiKeyID { if filters.APIKeyID > 0 && log.APIKeyID != filters.APIKeyID {
continue continue
} }
// Apply Model filter // Apply Model filter
...@@ -1156,8 +1156,8 @@ func paginationResult(total int64, params pagination.PaginationParams) *paginati ...@@ -1156,8 +1156,8 @@ func paginationResult(total int64, params pagination.PaginationParams) *paginati
// Ensure compile-time interface compliance. // Ensure compile-time interface compliance.
var ( var (
_ service.UserRepository = (*stubUserRepo)(nil) _ service.UserRepository = (*stubUserRepo)(nil)
_ service.ApiKeyRepository = (*stubApiKeyRepo)(nil) _ service.APIKeyRepository = (*stubApiKeyRepo)(nil)
_ service.ApiKeyCache = (*stubApiKeyCache)(nil) _ service.APIKeyCache = (*stubApiKeyCache)(nil)
_ service.GroupRepository = (*stubGroupRepo)(nil) _ service.GroupRepository = (*stubGroupRepo)(nil)
_ service.UserSubscriptionRepository = (*stubUserSubscriptionRepo)(nil) _ service.UserSubscriptionRepository = (*stubUserSubscriptionRepo)(nil)
_ service.UsageLogRepository = (*stubUsageLogRepo)(nil) _ service.UsageLogRepository = (*stubUsageLogRepo)(nil)
......
// Package server provides HTTP server initialization and configuration.
package server package server
import ( import (
...@@ -25,8 +26,8 @@ func ProvideRouter( ...@@ -25,8 +26,8 @@ func ProvideRouter(
handlers *handler.Handlers, handlers *handler.Handlers,
jwtAuth middleware2.JWTAuthMiddleware, jwtAuth middleware2.JWTAuthMiddleware,
adminAuth middleware2.AdminAuthMiddleware, adminAuth middleware2.AdminAuthMiddleware,
apiKeyAuth middleware2.ApiKeyAuthMiddleware, apiKeyAuth middleware2.APIKeyAuthMiddleware,
apiKeyService *service.ApiKeyService, apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService, subscriptionService *service.SubscriptionService,
) *gin.Engine { ) *gin.Engine {
if cfg.Server.Mode == "release" { if cfg.Server.Mode == "release" {
......
// Package middleware provides HTTP middleware for authentication, authorization, and request processing.
package middleware package middleware
import ( import (
...@@ -32,7 +33,7 @@ func adminAuth( ...@@ -32,7 +33,7 @@ func adminAuth(
// 检查 x-api-key header(Admin API Key 认证) // 检查 x-api-key header(Admin API Key 认证)
apiKey := c.GetHeader("x-api-key") apiKey := c.GetHeader("x-api-key")
if apiKey != "" { if apiKey != "" {
if !validateAdminApiKey(c, apiKey, settingService, userService) { if !validateAdminAPIKey(c, apiKey, settingService, userService) {
return return
} }
c.Next() c.Next()
...@@ -57,14 +58,14 @@ func adminAuth( ...@@ -57,14 +58,14 @@ func adminAuth(
} }
} }
// validateAdminApiKey 验证管理员 API Key // validateAdminAPIKey 验证管理员 API Key
func validateAdminApiKey( func validateAdminAPIKey(
c *gin.Context, c *gin.Context,
key string, key string,
settingService *service.SettingService, settingService *service.SettingService,
userService *service.UserService, userService *service.UserService,
) bool { ) bool {
storedKey, err := settingService.GetAdminApiKey(c.Request.Context()) storedKey, err := settingService.GetAdminAPIKey(c.Request.Context())
if err != nil { if err != nil {
AbortWithError(c, 500, "INTERNAL_ERROR", "Internal server error") AbortWithError(c, 500, "INTERNAL_ERROR", "Internal server error")
return false return false
......
...@@ -11,13 +11,13 @@ import ( ...@@ -11,13 +11,13 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// NewApiKeyAuthMiddleware 创建 API Key 认证中间件 // NewAPIKeyAuthMiddleware 创建 API Key 认证中间件
func NewApiKeyAuthMiddleware(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) ApiKeyAuthMiddleware { func NewAPIKeyAuthMiddleware(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) APIKeyAuthMiddleware {
return ApiKeyAuthMiddleware(apiKeyAuthWithSubscription(apiKeyService, subscriptionService, cfg)) return APIKeyAuthMiddleware(apiKeyAuthWithSubscription(apiKeyService, subscriptionService, cfg))
} }
// apiKeyAuthWithSubscription API Key认证中间件(支持订阅验证) // apiKeyAuthWithSubscription API Key认证中间件(支持订阅验证)
func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc { func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// 尝试从Authorization header中提取API key (Bearer scheme) // 尝试从Authorization header中提取API key (Bearer scheme)
authHeader := c.GetHeader("Authorization") authHeader := c.GetHeader("Authorization")
...@@ -60,7 +60,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti ...@@ -60,7 +60,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
// 从数据库验证API key // 从数据库验证API key
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString) apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
if err != nil { if err != nil {
if errors.Is(err, service.ErrApiKeyNotFound) { if errors.Is(err, service.ErrAPIKeyNotFound) {
AbortWithError(c, 401, "INVALID_API_KEY", "Invalid API key") AbortWithError(c, 401, "INVALID_API_KEY", "Invalid API key")
return return
} }
...@@ -88,7 +88,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti ...@@ -88,7 +88,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
if cfg.RunMode == config.RunModeSimple { if cfg.RunMode == config.RunModeSimple {
// 简易模式:跳过余额和订阅检查,但仍需设置必要的上下文 // 简易模式:跳过余额和订阅检查,但仍需设置必要的上下文
c.Set(string(ContextKeyApiKey), apiKey) c.Set(string(ContextKeyAPIKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{ c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID, UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency, Concurrency: apiKey.User.Concurrency,
...@@ -146,7 +146,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti ...@@ -146,7 +146,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
} }
// 将API key和用户信息存入上下文 // 将API key和用户信息存入上下文
c.Set(string(ContextKeyApiKey), apiKey) c.Set(string(ContextKeyAPIKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{ c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID, UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency, Concurrency: apiKey.User.Concurrency,
...@@ -157,13 +157,13 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti ...@@ -157,13 +157,13 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
} }
} }
// GetApiKeyFromContext 从上下文中获取API key // GetAPIKeyFromContext 从上下文中获取API key
func GetApiKeyFromContext(c *gin.Context) (*service.ApiKey, bool) { func GetAPIKeyFromContext(c *gin.Context) (*service.APIKey, bool) {
value, exists := c.Get(string(ContextKeyApiKey)) value, exists := c.Get(string(ContextKeyAPIKey))
if !exists { if !exists {
return nil, false return nil, false
} }
apiKey, ok := value.(*service.ApiKey) apiKey, ok := value.(*service.APIKey)
return apiKey, ok return apiKey, ok
} }
......
...@@ -11,16 +11,16 @@ import ( ...@@ -11,16 +11,16 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// ApiKeyAuthGoogle is a Google-style error wrapper for API key auth. // APIKeyAuthGoogle is a Google-style error wrapper for API key auth.
func ApiKeyAuthGoogle(apiKeyService *service.ApiKeyService, cfg *config.Config) gin.HandlerFunc { func APIKeyAuthGoogle(apiKeyService *service.APIKeyService, cfg *config.Config) gin.HandlerFunc {
return ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg) return APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg)
} }
// ApiKeyAuthWithSubscriptionGoogle behaves like ApiKeyAuthWithSubscription but returns Google-style errors: // APIKeyAuthWithSubscriptionGoogle behaves like ApiKeyAuthWithSubscription but returns Google-style errors:
// {"error":{"code":401,"message":"...","status":"UNAUTHENTICATED"}} // {"error":{"code":401,"message":"...","status":"UNAUTHENTICATED"}}
// //
// It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations. // It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations.
func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc { func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
apiKeyString := extractAPIKeyFromRequest(c) apiKeyString := extractAPIKeyFromRequest(c)
if apiKeyString == "" { if apiKeyString == "" {
...@@ -30,7 +30,7 @@ func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subs ...@@ -30,7 +30,7 @@ func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subs
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString) apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
if err != nil { if err != nil {
if errors.Is(err, service.ErrApiKeyNotFound) { if errors.Is(err, service.ErrAPIKeyNotFound) {
abortWithGoogleError(c, 401, "Invalid API key") abortWithGoogleError(c, 401, "Invalid API key")
return return
} }
...@@ -53,7 +53,7 @@ func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subs ...@@ -53,7 +53,7 @@ func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subs
// 简易模式:跳过余额和订阅检查 // 简易模式:跳过余额和订阅检查
if cfg.RunMode == config.RunModeSimple { if cfg.RunMode == config.RunModeSimple {
c.Set(string(ContextKeyApiKey), apiKey) c.Set(string(ContextKeyAPIKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{ c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID, UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency, Concurrency: apiKey.User.Concurrency,
...@@ -92,7 +92,7 @@ func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subs ...@@ -92,7 +92,7 @@ func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subs
} }
} }
c.Set(string(ContextKeyApiKey), apiKey) c.Set(string(ContextKeyAPIKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{ c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID, UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency, Concurrency: apiKey.User.Concurrency,
......
...@@ -16,53 +16,53 @@ import ( ...@@ -16,53 +16,53 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
type fakeApiKeyRepo struct { type fakeAPIKeyRepo struct {
getByKey func(ctx context.Context, key string) (*service.ApiKey, error) getByKey func(ctx context.Context, key string) (*service.APIKey, error)
} }
func (f fakeApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error { func (f fakeAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
func (f fakeApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) { func (f fakeAPIKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (f fakeApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) { func (f fakeAPIKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
return 0, errors.New("not implemented") return 0, errors.New("not implemented")
} }
func (f fakeApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) { func (f fakeAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
if f.getByKey == nil { if f.getByKey == nil {
return nil, errors.New("unexpected call") return nil, errors.New("unexpected call")
} }
return f.getByKey(ctx, key) return f.getByKey(ctx, key)
} }
func (f fakeApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error { func (f fakeAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
func (f fakeApiKeyRepo) Delete(ctx context.Context, id int64) error { func (f fakeAPIKeyRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
func (f fakeApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) { func (f fakeAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (f fakeApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) { func (f fakeAPIKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (f fakeApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) { func (f fakeAPIKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
return 0, errors.New("not implemented") return 0, errors.New("not implemented")
} }
func (f fakeApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) { func (f fakeAPIKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
return false, errors.New("not implemented") return false, errors.New("not implemented")
} }
func (f fakeApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) { func (f fakeAPIKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (f fakeApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) { func (f fakeAPIKeyRepo) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (f fakeApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) { func (f fakeAPIKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented") return 0, errors.New("not implemented")
} }
func (f fakeApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { func (f fakeAPIKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented") return 0, errors.New("not implemented")
} }
...@@ -74,8 +74,8 @@ type googleErrorResponse struct { ...@@ -74,8 +74,8 @@ type googleErrorResponse struct {
} `json:"error"` } `json:"error"`
} }
func newTestApiKeyService(repo service.ApiKeyRepository) *service.ApiKeyService { func newTestAPIKeyService(repo service.APIKeyRepository) *service.APIKeyService {
return service.NewApiKeyService( return service.NewAPIKeyService(
repo, repo,
nil, // userRepo (unused in GetByKey) nil, // userRepo (unused in GetByKey)
nil, // groupRepo nil, // groupRepo
...@@ -89,12 +89,12 @@ func TestApiKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) { ...@@ -89,12 +89,12 @@ func TestApiKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
r := gin.New() r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{ apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) { getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return nil, errors.New("should not be called") return nil, errors.New("should not be called")
}, },
}) })
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{})) r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
...@@ -113,12 +113,12 @@ func TestApiKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) { ...@@ -113,12 +113,12 @@ func TestApiKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
r := gin.New() r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{ apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) { getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return nil, service.ErrApiKeyNotFound return nil, service.ErrAPIKeyNotFound
}, },
}) })
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{})) r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
...@@ -138,12 +138,12 @@ func TestApiKeyAuthWithSubscriptionGoogle_RepoError(t *testing.T) { ...@@ -138,12 +138,12 @@ func TestApiKeyAuthWithSubscriptionGoogle_RepoError(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
r := gin.New() r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{ apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) { getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return nil, errors.New("db down") return nil, errors.New("db down")
}, },
}) })
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{})) r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
...@@ -163,9 +163,9 @@ func TestApiKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) { ...@@ -163,9 +163,9 @@ func TestApiKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
r := gin.New() r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{ apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) { getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return &service.ApiKey{ return &service.APIKey{
ID: 1, ID: 1,
Key: key, Key: key,
Status: service.StatusDisabled, Status: service.StatusDisabled,
...@@ -176,7 +176,7 @@ func TestApiKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) { ...@@ -176,7 +176,7 @@ func TestApiKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) {
}, nil }, nil
}, },
}) })
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{})) r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
...@@ -196,9 +196,9 @@ func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) { ...@@ -196,9 +196,9 @@ func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
r := gin.New() r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{ apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) { getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return &service.ApiKey{ return &service.APIKey{
ID: 1, ID: 1,
Key: key, Key: key,
Status: service.StatusActive, Status: service.StatusActive,
...@@ -210,7 +210,7 @@ func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) { ...@@ -210,7 +210,7 @@ func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) {
}, nil }, nil
}, },
}) })
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{})) r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
......
...@@ -35,7 +35,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { ...@@ -35,7 +35,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
Balance: 10, Balance: 10,
Concurrency: 3, Concurrency: 3,
} }
apiKey := &service.ApiKey{ apiKey := &service.APIKey{
ID: 100, ID: 100,
UserID: user.ID, UserID: user.ID,
Key: "test-key", Key: "test-key",
...@@ -46,9 +46,9 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { ...@@ -46,9 +46,9 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
apiKey.GroupID = &group.ID apiKey.GroupID = &group.ID
apiKeyRepo := &stubApiKeyRepo{ apiKeyRepo := &stubApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) { getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
if key != apiKey.Key { if key != apiKey.Key {
return nil, service.ErrApiKeyNotFound return nil, service.ErrAPIKeyNotFound
} }
clone := *apiKey clone := *apiKey
return &clone, nil return &clone, nil
...@@ -57,7 +57,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { ...@@ -57,7 +57,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) { t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeSimple} cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewApiKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil) subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil)
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg) router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
...@@ -71,7 +71,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { ...@@ -71,7 +71,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
t.Run("standard_mode_enforces_quota_check", func(t *testing.T) { t.Run("standard_mode_enforces_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeStandard} cfg := &config.Config{RunMode: config.RunModeStandard}
apiKeyService := service.NewApiKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
now := time.Now() now := time.Now()
sub := &service.UserSubscription{ sub := &service.UserSubscription{
...@@ -110,9 +110,9 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { ...@@ -110,9 +110,9 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
}) })
} }
func newAuthTestRouter(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine { func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
router := gin.New() router := gin.New()
router.Use(gin.HandlerFunc(NewApiKeyAuthMiddleware(apiKeyService, subscriptionService, cfg))) router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg)))
router.GET("/t", func(c *gin.Context) { router.GET("/t", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true}) c.JSON(http.StatusOK, gin.H{"ok": true})
}) })
...@@ -120,14 +120,14 @@ func newAuthTestRouter(apiKeyService *service.ApiKeyService, subscriptionService ...@@ -120,14 +120,14 @@ func newAuthTestRouter(apiKeyService *service.ApiKeyService, subscriptionService
} }
type stubApiKeyRepo struct { type stubApiKeyRepo struct {
getByKey func(ctx context.Context, key string) (*service.ApiKey, error) getByKey func(ctx context.Context, key string) (*service.APIKey, error)
} }
func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error { func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) { func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
...@@ -135,14 +135,14 @@ func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error ...@@ -135,14 +135,14 @@ func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error
return 0, errors.New("not implemented") return 0, errors.New("not implemented")
} }
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) { func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
if r.getByKey != nil { if r.getByKey != nil {
return r.getByKey(ctx, key) return r.getByKey(ctx, key)
} }
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error { func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
...@@ -150,7 +150,7 @@ func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error { ...@@ -150,7 +150,7 @@ func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) { func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
...@@ -166,11 +166,11 @@ func (r *stubApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, err ...@@ -166,11 +166,11 @@ func (r *stubApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, err
return false, errors.New("not implemented") return false, errors.New("not implemented")
} }
func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) { func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (r *stubApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) { func (r *stubApiKeyRepo) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
......
...@@ -15,8 +15,8 @@ const ( ...@@ -15,8 +15,8 @@ const (
ContextKeyUser ContextKey = "user" ContextKeyUser ContextKey = "user"
// ContextKeyUserRole 当前用户角色(string) // ContextKeyUserRole 当前用户角色(string)
ContextKeyUserRole ContextKey = "user_role" ContextKeyUserRole ContextKey = "user_role"
// ContextKeyApiKey API密钥上下文键 // ContextKeyAPIKey API密钥上下文键
ContextKeyApiKey ContextKey = "api_key" ContextKeyAPIKey ContextKey = "api_key"
// ContextKeySubscription 订阅上下文键 // ContextKeySubscription 订阅上下文键
ContextKeySubscription ContextKey = "subscription" ContextKeySubscription ContextKey = "subscription"
// ContextKeyForcePlatform 强制平台(用于 /antigravity 路由) // ContextKeyForcePlatform 强制平台(用于 /antigravity 路由)
......
...@@ -11,12 +11,12 @@ type JWTAuthMiddleware gin.HandlerFunc ...@@ -11,12 +11,12 @@ type JWTAuthMiddleware gin.HandlerFunc
// AdminAuthMiddleware 管理员认证中间件类型 // AdminAuthMiddleware 管理员认证中间件类型
type AdminAuthMiddleware gin.HandlerFunc type AdminAuthMiddleware gin.HandlerFunc
// ApiKeyAuthMiddleware API Key 认证中间件类型 // APIKeyAuthMiddleware API Key 认证中间件类型
type ApiKeyAuthMiddleware gin.HandlerFunc type APIKeyAuthMiddleware gin.HandlerFunc
// ProviderSet 中间件层的依赖注入 // ProviderSet 中间件层的依赖注入
var ProviderSet = wire.NewSet( var ProviderSet = wire.NewSet(
NewJWTAuthMiddleware, NewJWTAuthMiddleware,
NewAdminAuthMiddleware, NewAdminAuthMiddleware,
NewApiKeyAuthMiddleware, NewAPIKeyAuthMiddleware,
) )
...@@ -17,8 +17,8 @@ func SetupRouter( ...@@ -17,8 +17,8 @@ func SetupRouter(
handlers *handler.Handlers, handlers *handler.Handlers,
jwtAuth middleware2.JWTAuthMiddleware, jwtAuth middleware2.JWTAuthMiddleware,
adminAuth middleware2.AdminAuthMiddleware, adminAuth middleware2.AdminAuthMiddleware,
apiKeyAuth middleware2.ApiKeyAuthMiddleware, apiKeyAuth middleware2.APIKeyAuthMiddleware,
apiKeyService *service.ApiKeyService, apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService, subscriptionService *service.SubscriptionService,
cfg *config.Config, cfg *config.Config,
) *gin.Engine { ) *gin.Engine {
...@@ -43,8 +43,8 @@ func registerRoutes( ...@@ -43,8 +43,8 @@ func registerRoutes(
h *handler.Handlers, h *handler.Handlers,
jwtAuth middleware2.JWTAuthMiddleware, jwtAuth middleware2.JWTAuthMiddleware,
adminAuth middleware2.AdminAuthMiddleware, adminAuth middleware2.AdminAuthMiddleware,
apiKeyAuth middleware2.ApiKeyAuthMiddleware, apiKeyAuth middleware2.APIKeyAuthMiddleware,
apiKeyService *service.ApiKeyService, apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService, subscriptionService *service.SubscriptionService,
cfg *config.Config, cfg *config.Config,
) { ) {
......
// Package routes provides HTTP route registration and handlers.
package routes package routes
import ( import (
...@@ -67,10 +68,10 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -67,10 +68,10 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics) dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics)
dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend) dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend)
dashboard.GET("/models", h.Admin.Dashboard.GetModelStats) dashboard.GET("/models", h.Admin.Dashboard.GetModelStats)
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetApiKeyUsageTrend) dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetAPIKeyUsageTrend)
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend) dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage) dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchApiKeysUsage) dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchAPIKeysUsage)
} }
} }
...@@ -205,12 +206,12 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -205,12 +206,12 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
{ {
adminSettings.GET("", h.Admin.Setting.GetSettings) adminSettings.GET("", h.Admin.Setting.GetSettings)
adminSettings.PUT("", h.Admin.Setting.UpdateSettings) adminSettings.PUT("", h.Admin.Setting.UpdateSettings)
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSmtpConnection) adminSettings.POST("/test-smtp", h.Admin.Setting.TestSMTPConnection)
adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail) adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail)
// Admin API Key 管理 // Admin API Key 管理
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminApiKey) adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminAPIKey)
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminApiKey) adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminAPIKey)
adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminApiKey) adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminAPIKey)
} }
} }
...@@ -250,7 +251,7 @@ func registerUsageRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -250,7 +251,7 @@ func registerUsageRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
usage.GET("", h.Admin.Usage.List) usage.GET("", h.Admin.Usage.List)
usage.GET("/stats", h.Admin.Usage.Stats) usage.GET("/stats", h.Admin.Usage.Stats)
usage.GET("/search-users", h.Admin.Usage.SearchUsers) usage.GET("/search-users", h.Admin.Usage.SearchUsers)
usage.GET("/search-api-keys", h.Admin.Usage.SearchApiKeys) usage.GET("/search-api-keys", h.Admin.Usage.SearchAPIKeys)
} }
} }
......
...@@ -13,8 +13,8 @@ import ( ...@@ -13,8 +13,8 @@ import (
func RegisterGatewayRoutes( func RegisterGatewayRoutes(
r *gin.Engine, r *gin.Engine,
h *handler.Handlers, h *handler.Handlers,
apiKeyAuth middleware.ApiKeyAuthMiddleware, apiKeyAuth middleware.APIKeyAuthMiddleware,
apiKeyService *service.ApiKeyService, apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService, subscriptionService *service.SubscriptionService,
cfg *config.Config, cfg *config.Config,
) { ) {
...@@ -36,7 +36,7 @@ func RegisterGatewayRoutes( ...@@ -36,7 +36,7 @@ func RegisterGatewayRoutes(
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连) // Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
gemini := r.Group("/v1beta") gemini := r.Group("/v1beta")
gemini.Use(bodyLimit) gemini.Use(bodyLimit)
gemini.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg)) gemini.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
{ {
gemini.GET("/models", h.Gateway.GeminiV1BetaListModels) gemini.GET("/models", h.Gateway.GeminiV1BetaListModels)
gemini.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel) gemini.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
...@@ -65,7 +65,7 @@ func RegisterGatewayRoutes( ...@@ -65,7 +65,7 @@ func RegisterGatewayRoutes(
antigravityV1Beta := r.Group("/antigravity/v1beta") antigravityV1Beta := r.Group("/antigravity/v1beta")
antigravityV1Beta.Use(bodyLimit) antigravityV1Beta.Use(bodyLimit)
antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity)) antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity))
antigravityV1Beta.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg)) antigravityV1Beta.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
{ {
antigravityV1Beta.GET("/models", h.Gateway.GeminiV1BetaListModels) antigravityV1Beta.GET("/models", h.Gateway.GeminiV1BetaListModels)
antigravityV1Beta.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel) antigravityV1Beta.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
......
...@@ -50,7 +50,7 @@ func RegisterUserRoutes( ...@@ -50,7 +50,7 @@ func RegisterUserRoutes(
usage.GET("/dashboard/stats", h.Usage.DashboardStats) usage.GET("/dashboard/stats", h.Usage.DashboardStats)
usage.GET("/dashboard/trend", h.Usage.DashboardTrend) usage.GET("/dashboard/trend", h.Usage.DashboardTrend)
usage.GET("/dashboard/models", h.Usage.DashboardModels) usage.GET("/dashboard/models", h.Usage.DashboardModels)
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardApiKeysUsage) usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardAPIKeysUsage)
} }
// 卡密兑换 // 卡密兑换
......
// Package service provides business logic and domain services for the application.
package service package service
import ( import (
...@@ -324,7 +325,7 @@ func (a *Account) GetMappedModel(requestedModel string) string { ...@@ -324,7 +325,7 @@ func (a *Account) GetMappedModel(requestedModel string) string {
} }
func (a *Account) GetBaseURL() string { func (a *Account) GetBaseURL() string {
if a.Type != AccountTypeApiKey { if a.Type != AccountTypeAPIKey {
return "" return ""
} }
baseURL := a.GetCredential("base_url") baseURL := a.GetCredential("base_url")
...@@ -347,7 +348,7 @@ func (a *Account) GetExtraString(key string) string { ...@@ -347,7 +348,7 @@ func (a *Account) GetExtraString(key string) string {
} }
func (a *Account) IsCustomErrorCodesEnabled() bool { func (a *Account) IsCustomErrorCodesEnabled() bool {
if a.Type != AccountTypeApiKey || a.Credentials == nil { if a.Type != AccountTypeAPIKey || a.Credentials == nil {
return false return false
} }
if v, ok := a.Credentials["custom_error_codes_enabled"]; ok { if v, ok := a.Credentials["custom_error_codes_enabled"]; ok {
...@@ -419,14 +420,14 @@ func (a *Account) IsOpenAIOAuth() bool { ...@@ -419,14 +420,14 @@ func (a *Account) IsOpenAIOAuth() bool {
} }
func (a *Account) IsOpenAIApiKey() bool { func (a *Account) IsOpenAIApiKey() bool {
return a.IsOpenAI() && a.Type == AccountTypeApiKey return a.IsOpenAI() && a.Type == AccountTypeAPIKey
} }
func (a *Account) GetOpenAIBaseURL() string { func (a *Account) GetOpenAIBaseURL() string {
if !a.IsOpenAI() { if !a.IsOpenAI() {
return "" return ""
} }
if a.Type == AccountTypeApiKey { if a.Type == AccountTypeAPIKey {
baseURL := a.GetCredential("base_url") baseURL := a.GetCredential("base_url")
if baseURL != "" { if baseURL != "" {
return baseURL return baseURL
......
...@@ -369,7 +369,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account ...@@ -369,7 +369,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
} }
// For API Key accounts with model mapping, map the model // For API Key accounts with model mapping, map the model
if account.Type == AccountTypeApiKey { if account.Type == AccountTypeAPIKey {
mapping := account.GetModelMapping() mapping := account.GetModelMapping()
if len(mapping) > 0 { if len(mapping) > 0 {
if mappedModel, exists := mapping[testModelID]; exists { if mappedModel, exists := mapping[testModelID]; exists {
...@@ -393,7 +393,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account ...@@ -393,7 +393,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
var err error var err error
switch account.Type { switch account.Type {
case AccountTypeApiKey: case AccountTypeAPIKey:
req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload) req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload)
case AccountTypeOAuth: case AccountTypeOAuth:
req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload) req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload)
......
...@@ -19,11 +19,11 @@ type UsageLogRepository interface { ...@@ -19,11 +19,11 @@ type UsageLogRepository interface {
Delete(ctx context.Context, id int64) error Delete(ctx context.Context, id int64) error
ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
...@@ -34,10 +34,10 @@ type UsageLogRepository interface { ...@@ -34,10 +34,10 @@ type UsageLogRepository interface {
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error)
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error)
GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error)
// User dashboard stats // User dashboard stats
GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error) GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error)
...@@ -53,7 +53,7 @@ type UsageLogRepository interface { ...@@ -53,7 +53,7 @@ type UsageLogRepository interface {
// Aggregated stats (optimized) // Aggregated stats (optimized)
GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error)
......
...@@ -20,7 +20,7 @@ type AdminService interface { ...@@ -20,7 +20,7 @@ type AdminService interface {
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error)
DeleteUser(ctx context.Context, id int64) error DeleteUser(ctx context.Context, id int64) error
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]ApiKey, int64, error) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error)
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
// Group management // Group management
...@@ -31,7 +31,7 @@ type AdminService interface { ...@@ -31,7 +31,7 @@ type AdminService interface {
CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error)
UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error)
DeleteGroup(ctx context.Context, id int64) error DeleteGroup(ctx context.Context, id int64) error
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]ApiKey, int64, error) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error)
// Account management // Account management
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error)
...@@ -66,7 +66,7 @@ type AdminService interface { ...@@ -66,7 +66,7 @@ type AdminService interface {
ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error)
} }
// Input types for admin operations // CreateUserInput represents input for creating a new user via admin operations.
type CreateUserInput struct { type CreateUserInput struct {
Email string Email string
Password string Password string
...@@ -228,7 +228,7 @@ type adminServiceImpl struct { ...@@ -228,7 +228,7 @@ type adminServiceImpl struct {
groupRepo GroupRepository groupRepo GroupRepository
accountRepo AccountRepository accountRepo AccountRepository
proxyRepo ProxyRepository proxyRepo ProxyRepository
apiKeyRepo ApiKeyRepository apiKeyRepo APIKeyRepository
redeemCodeRepo RedeemCodeRepository redeemCodeRepo RedeemCodeRepository
billingCacheService *BillingCacheService billingCacheService *BillingCacheService
proxyProber ProxyExitInfoProber proxyProber ProxyExitInfoProber
...@@ -240,7 +240,7 @@ func NewAdminService( ...@@ -240,7 +240,7 @@ func NewAdminService(
groupRepo GroupRepository, groupRepo GroupRepository,
accountRepo AccountRepository, accountRepo AccountRepository,
proxyRepo ProxyRepository, proxyRepo ProxyRepository,
apiKeyRepo ApiKeyRepository, apiKeyRepo APIKeyRepository,
redeemCodeRepo RedeemCodeRepository, redeemCodeRepo RedeemCodeRepository,
billingCacheService *BillingCacheService, billingCacheService *BillingCacheService,
proxyProber ProxyExitInfoProber, proxyProber ProxyExitInfoProber,
...@@ -438,7 +438,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, ...@@ -438,7 +438,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
return user, nil return user, nil
} }
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]ApiKey, int64, error) { func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params) keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil { if err != nil {
...@@ -591,7 +591,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error { ...@@ -591,7 +591,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
return nil return nil
} }
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]ApiKey, int64, error) { func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params) keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params)
if err != nil { if err != nil {
......
...@@ -2,7 +2,7 @@ package service ...@@ -2,7 +2,7 @@ package service
import "time" import "time"
type ApiKey struct { type APIKey struct {
ID int64 ID int64
UserID int64 UserID int64
Key string Key string
...@@ -15,6 +15,6 @@ type ApiKey struct { ...@@ -15,6 +15,6 @@ type ApiKey struct {
Group *Group Group *Group
} }
func (k *ApiKey) IsActive() bool { func (k *APIKey) IsActive() bool {
return k.Status == StatusActive return k.Status == StatusActive
} }
...@@ -14,39 +14,39 @@ import ( ...@@ -14,39 +14,39 @@ import (
) )
var ( var (
ErrApiKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found") ErrAPIKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found")
ErrGroupNotAllowed = infraerrors.Forbidden("GROUP_NOT_ALLOWED", "user is not allowed to bind this group") ErrGroupNotAllowed = infraerrors.Forbidden("GROUP_NOT_ALLOWED", "user is not allowed to bind this group")
ErrApiKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists") ErrAPIKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists")
ErrApiKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters") ErrAPIKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
ErrApiKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens") ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
ErrApiKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later") ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
) )
const ( const (
apiKeyMaxErrorsPerHour = 20 apiKeyMaxErrorsPerHour = 20
) )
type ApiKeyRepository interface { type APIKeyRepository interface {
Create(ctx context.Context, key *ApiKey) error Create(ctx context.Context, key *APIKey) error
GetByID(ctx context.Context, id int64) (*ApiKey, error) GetByID(ctx context.Context, id int64) (*APIKey, error)
// GetOwnerID 仅获取 API Key 的所有者 ID,用于删除前的轻量级权限验证 // GetOwnerID 仅获取 API Key 的所有者 ID,用于删除前的轻量级权限验证
GetOwnerID(ctx context.Context, id int64) (int64, error) GetOwnerID(ctx context.Context, id int64) (int64, error)
GetByKey(ctx context.Context, key string) (*ApiKey, error) GetByKey(ctx context.Context, key string) (*APIKey, error)
Update(ctx context.Context, key *ApiKey) error Update(ctx context.Context, key *APIKey) error
Delete(ctx context.Context, id int64) error Delete(ctx context.Context, id int64) error
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error)
VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error)
CountByUserID(ctx context.Context, userID int64) (int64, error) CountByUserID(ctx context.Context, userID int64) (int64, error)
ExistsByKey(ctx context.Context, key string) (bool, error) ExistsByKey(ctx context.Context, key string) (bool, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error)
SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error)
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
CountByGroupID(ctx context.Context, groupID int64) (int64, error) CountByGroupID(ctx context.Context, groupID int64) (int64, error)
} }
// ApiKeyCache defines cache operations for API key service // APIKeyCache defines cache operations for API key service
type ApiKeyCache interface { type APIKeyCache interface {
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
IncrementCreateAttemptCount(ctx context.Context, userID int64) error IncrementCreateAttemptCount(ctx context.Context, userID int64) error
DeleteCreateAttemptCount(ctx context.Context, userID int64) error DeleteCreateAttemptCount(ctx context.Context, userID int64) error
...@@ -55,40 +55,40 @@ type ApiKeyCache interface { ...@@ -55,40 +55,40 @@ type ApiKeyCache interface {
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
} }
// CreateApiKeyRequest 创建API Key请求 // CreateAPIKeyRequest 创建API Key请求
type CreateApiKeyRequest struct { type CreateAPIKeyRequest struct {
Name string `json:"name"` Name string `json:"name"`
GroupID *int64 `json:"group_id"` GroupID *int64 `json:"group_id"`
CustomKey *string `json:"custom_key"` // 可选的自定义key CustomKey *string `json:"custom_key"` // 可选的自定义key
} }
// UpdateApiKeyRequest 更新API Key请求 // UpdateAPIKeyRequest 更新API Key请求
type UpdateApiKeyRequest struct { type UpdateAPIKeyRequest struct {
Name *string `json:"name"` Name *string `json:"name"`
GroupID *int64 `json:"group_id"` GroupID *int64 `json:"group_id"`
Status *string `json:"status"` Status *string `json:"status"`
} }
// ApiKeyService API Key服务 // APIKeyService API Key服务
type ApiKeyService struct { type APIKeyService struct {
apiKeyRepo ApiKeyRepository apiKeyRepo APIKeyRepository
userRepo UserRepository userRepo UserRepository
groupRepo GroupRepository groupRepo GroupRepository
userSubRepo UserSubscriptionRepository userSubRepo UserSubscriptionRepository
cache ApiKeyCache cache APIKeyCache
cfg *config.Config cfg *config.Config
} }
// NewApiKeyService 创建API Key服务实例 // NewAPIKeyService 创建API Key服务实例
func NewApiKeyService( func NewAPIKeyService(
apiKeyRepo ApiKeyRepository, apiKeyRepo APIKeyRepository,
userRepo UserRepository, userRepo UserRepository,
groupRepo GroupRepository, groupRepo GroupRepository,
userSubRepo UserSubscriptionRepository, userSubRepo UserSubscriptionRepository,
cache ApiKeyCache, cache APIKeyCache,
cfg *config.Config, cfg *config.Config,
) *ApiKeyService { ) *APIKeyService {
return &ApiKeyService{ return &APIKeyService{
apiKeyRepo: apiKeyRepo, apiKeyRepo: apiKeyRepo,
userRepo: userRepo, userRepo: userRepo,
groupRepo: groupRepo, groupRepo: groupRepo,
...@@ -99,7 +99,7 @@ func NewApiKeyService( ...@@ -99,7 +99,7 @@ func NewApiKeyService(
} }
// GenerateKey 生成随机API Key // GenerateKey 生成随机API Key
func (s *ApiKeyService) GenerateKey() (string, error) { func (s *APIKeyService) GenerateKey() (string, error) {
// 生成32字节随机数据 // 生成32字节随机数据
bytes := make([]byte, 32) bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil { if _, err := rand.Read(bytes); err != nil {
...@@ -107,7 +107,7 @@ func (s *ApiKeyService) GenerateKey() (string, error) { ...@@ -107,7 +107,7 @@ func (s *ApiKeyService) GenerateKey() (string, error) {
} }
// 转换为十六进制字符串并添加前缀 // 转换为十六进制字符串并添加前缀
prefix := s.cfg.Default.ApiKeyPrefix prefix := s.cfg.Default.APIKeyPrefix
if prefix == "" { if prefix == "" {
prefix = "sk-" prefix = "sk-"
} }
...@@ -117,10 +117,10 @@ func (s *ApiKeyService) GenerateKey() (string, error) { ...@@ -117,10 +117,10 @@ func (s *ApiKeyService) GenerateKey() (string, error) {
} }
// ValidateCustomKey 验证自定义API Key格式 // ValidateCustomKey 验证自定义API Key格式
func (s *ApiKeyService) ValidateCustomKey(key string) error { func (s *APIKeyService) ValidateCustomKey(key string) error {
// 检查长度 // 检查长度
if len(key) < 16 { if len(key) < 16 {
return ErrApiKeyTooShort return ErrAPIKeyTooShort
} }
// 检查字符:只允许字母、数字、下划线、连字符 // 检查字符:只允许字母、数字、下划线、连字符
...@@ -131,14 +131,14 @@ func (s *ApiKeyService) ValidateCustomKey(key string) error { ...@@ -131,14 +131,14 @@ func (s *ApiKeyService) ValidateCustomKey(key string) error {
c == '_' || c == '-' { c == '_' || c == '-' {
continue continue
} }
return ErrApiKeyInvalidChars return ErrAPIKeyInvalidChars
} }
return nil return nil
} }
// checkApiKeyRateLimit 检查用户创建自定义Key的错误次数是否超限 // checkAPIKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) error { func (s *APIKeyService) checkAPIKeyRateLimit(ctx context.Context, userID int64) error {
if s.cache == nil { if s.cache == nil {
return nil return nil
} }
...@@ -150,14 +150,14 @@ func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) ...@@ -150,14 +150,14 @@ func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64)
} }
if count >= apiKeyMaxErrorsPerHour { if count >= apiKeyMaxErrorsPerHour {
return ErrApiKeyRateLimited return ErrAPIKeyRateLimited
} }
return nil return nil
} }
// incrementApiKeyErrorCount 增加用户创建自定义Key的错误计数 // incrementAPIKeyErrorCount 增加用户创建自定义Key的错误计数
func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID int64) { func (s *APIKeyService) incrementAPIKeyErrorCount(ctx context.Context, userID int64) {
if s.cache == nil { if s.cache == nil {
return return
} }
...@@ -168,7 +168,7 @@ func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID in ...@@ -168,7 +168,7 @@ func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID in
// canUserBindGroup 检查用户是否可以绑定指定分组 // canUserBindGroup 检查用户是否可以绑定指定分组
// 对于订阅类型分组:检查用户是否有有效订阅 // 对于订阅类型分组:检查用户是否有有效订阅
// 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑 // 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑
func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool { func (s *APIKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool {
// 订阅类型分组:需要有效订阅 // 订阅类型分组:需要有效订阅
if group.IsSubscriptionType() { if group.IsSubscriptionType() {
_, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, group.ID) _, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, group.ID)
...@@ -179,7 +179,7 @@ func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *User, group ...@@ -179,7 +179,7 @@ func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *User, group
} }
// Create 创建API Key // Create 创建API Key
func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiKeyRequest) (*ApiKey, error) { func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIKeyRequest) (*APIKey, error) {
// 验证用户存在 // 验证用户存在
user, err := s.userRepo.GetByID(ctx, userID) user, err := s.userRepo.GetByID(ctx, userID)
if err != nil { if err != nil {
...@@ -204,7 +204,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK ...@@ -204,7 +204,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
// 判断是否使用自定义Key // 判断是否使用自定义Key
if req.CustomKey != nil && *req.CustomKey != "" { if req.CustomKey != nil && *req.CustomKey != "" {
// 检查限流(仅对自定义key进行限流) // 检查限流(仅对自定义key进行限流)
if err := s.checkApiKeyRateLimit(ctx, userID); err != nil { if err := s.checkAPIKeyRateLimit(ctx, userID); err != nil {
return nil, err return nil, err
} }
...@@ -220,8 +220,8 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK ...@@ -220,8 +220,8 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
} }
if exists { if exists {
// Key已存在,增加错误计数 // Key已存在,增加错误计数
s.incrementApiKeyErrorCount(ctx, userID) s.incrementAPIKeyErrorCount(ctx, userID)
return nil, ErrApiKeyExists return nil, ErrAPIKeyExists
} }
key = *req.CustomKey key = *req.CustomKey
...@@ -235,7 +235,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK ...@@ -235,7 +235,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
} }
// 创建API Key记录 // 创建API Key记录
apiKey := &ApiKey{ apiKey := &APIKey{
UserID: userID, UserID: userID,
Key: key, Key: key,
Name: req.Name, Name: req.Name,
...@@ -251,7 +251,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK ...@@ -251,7 +251,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
} }
// List 获取用户的API Key列表 // List 获取用户的API Key列表
func (s *ApiKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) { func (s *APIKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params) keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("list api keys: %w", err) return nil, nil, fmt.Errorf("list api keys: %w", err)
...@@ -259,7 +259,7 @@ func (s *ApiKeyService) List(ctx context.Context, userID int64, params paginatio ...@@ -259,7 +259,7 @@ func (s *ApiKeyService) List(ctx context.Context, userID int64, params paginatio
return keys, pagination, nil return keys, pagination, nil
} }
func (s *ApiKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) { func (s *APIKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
if len(apiKeyIDs) == 0 { if len(apiKeyIDs) == 0 {
return []int64{}, nil return []int64{}, nil
} }
...@@ -272,7 +272,7 @@ func (s *ApiKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKe ...@@ -272,7 +272,7 @@ func (s *ApiKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKe
} }
// GetByID 根据ID获取API Key // GetByID 根据ID获取API Key
func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error) { func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id) apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, fmt.Errorf("get api key: %w", err) return nil, fmt.Errorf("get api key: %w", err)
...@@ -281,7 +281,7 @@ func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error) ...@@ -281,7 +281,7 @@ func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error)
} }
// GetByKey 根据Key字符串获取API Key(用于认证) // GetByKey 根据Key字符串获取API Key(用于认证)
func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*ApiKey, error) { func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, error) {
// 尝试从Redis缓存获取 // 尝试从Redis缓存获取
cacheKey := fmt.Sprintf("apikey:%s", key) cacheKey := fmt.Sprintf("apikey:%s", key)
...@@ -301,7 +301,7 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*ApiKey, erro ...@@ -301,7 +301,7 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*ApiKey, erro
} }
// Update 更新API Key // Update 更新API Key
func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*ApiKey, error) { func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateAPIKeyRequest) (*APIKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id) apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, fmt.Errorf("get api key: %w", err) return nil, fmt.Errorf("get api key: %w", err)
...@@ -353,8 +353,8 @@ func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req ...@@ -353,8 +353,8 @@ func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req
// Delete 删除API Key // Delete 删除API Key
// 优化:使用 GetOwnerID 替代 GetByID 进行权限验证, // 优化:使用 GetOwnerID 替代 GetByID 进行权限验证,
// 避免加载完整 ApiKey 对象及其关联数据(User、Group),提升删除操作的性能 // 避免加载完整 APIKey 对象及其关联数据(User、Group),提升删除操作的性能
func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) error { func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) error {
// 仅获取所有者 ID 用于权限验证,而非加载完整对象 // 仅获取所有者 ID 用于权限验证,而非加载完整对象
ownerID, err := s.apiKeyRepo.GetOwnerID(ctx, id) ownerID, err := s.apiKeyRepo.GetOwnerID(ctx, id)
if err != nil { if err != nil {
...@@ -379,7 +379,7 @@ func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) erro ...@@ -379,7 +379,7 @@ func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) erro
} }
// ValidateKey 验证API Key是否有效(用于认证中间件) // ValidateKey 验证API Key是否有效(用于认证中间件)
func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*ApiKey, *User, error) { func (s *APIKeyService) ValidateKey(ctx context.Context, key string) (*APIKey, *User, error) {
// 获取API Key // 获取API Key
apiKey, err := s.GetByKey(ctx, key) apiKey, err := s.GetByKey(ctx, key)
if err != nil { if err != nil {
...@@ -406,7 +406,7 @@ func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*ApiKey, * ...@@ -406,7 +406,7 @@ func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*ApiKey, *
} }
// IncrementUsage 增加API Key使用次数(可选:用于统计) // IncrementUsage 增加API Key使用次数(可选:用于统计)
func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error { func (s *APIKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
// 使用Redis计数器 // 使用Redis计数器
if s.cache != nil { if s.cache != nil {
cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02")) cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02"))
...@@ -423,7 +423,7 @@ func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error { ...@@ -423,7 +423,7 @@ func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
// 返回用户可以选择的分组: // 返回用户可以选择的分组:
// - 标准类型分组:公开的(非专属)或用户被明确允许的 // - 标准类型分组:公开的(非专属)或用户被明确允许的
// - 订阅类型分组:用户有有效订阅的 // - 订阅类型分组:用户有有效订阅的
func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) { func (s *APIKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) {
// 获取用户信息 // 获取用户信息
user, err := s.userRepo.GetByID(ctx, userID) user, err := s.userRepo.GetByID(ctx, userID)
if err != nil { if err != nil {
...@@ -460,7 +460,7 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([ ...@@ -460,7 +460,7 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([
} }
// canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据) // canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据)
func (s *ApiKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool { func (s *APIKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool {
// 订阅类型分组:需要有效订阅 // 订阅类型分组:需要有效订阅
if group.IsSubscriptionType() { if group.IsSubscriptionType() {
return subscribedGroupIDs[group.ID] return subscribedGroupIDs[group.ID]
...@@ -469,8 +469,8 @@ func (s *ApiKeyService) canUserBindGroupInternal(user *User, group *Group, subsc ...@@ -469,8 +469,8 @@ func (s *ApiKeyService) canUserBindGroupInternal(user *User, group *Group, subsc
return user.CanBindGroup(group.ID, group.IsExclusive) return user.CanBindGroup(group.ID, group.IsExclusive)
} }
func (s *ApiKeyService) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) { func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
keys, err := s.apiKeyRepo.SearchApiKeys(ctx, userID, keyword, limit) keys, err := s.apiKeyRepo.SearchAPIKeys(ctx, userID, keyword, limit)
if err != nil { if err != nil {
return nil, fmt.Errorf("search api keys: %w", err) return nil, fmt.Errorf("search api keys: %w", err)
} }
......
//go:build unit //go:build unit
// API Key 服务删除方法的单元测试 // API Key 服务删除方法的单元测试
// 测试 ApiKeyService.Delete 方法在各种场景下的行为, // 测试 APIKeyService.Delete 方法在各种场景下的行为,
// 包括权限验证、缓存清理和错误处理 // 包括权限验证、缓存清理和错误处理
package service package service
...@@ -16,12 +16,12 @@ import ( ...@@ -16,12 +16,12 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
// apiKeyRepoStub 是 ApiKeyRepository 接口的测试桩实现。 // apiKeyRepoStub 是 APIKeyRepository 接口的测试桩实现。
// 用于隔离测试 ApiKeyService.Delete 方法,避免依赖真实数据库。 // 用于隔离测试 APIKeyService.Delete 方法,避免依赖真实数据库。
// //
// 设计说明: // 设计说明:
// - ownerID: 模拟 GetOwnerID 返回的所有者 ID // - ownerID: 模拟 GetOwnerID 返回的所有者 ID
// - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrApiKeyNotFound) // - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrAPIKeyNotFound)
// - deleteErr: 模拟 Delete 返回的错误 // - deleteErr: 模拟 Delete 返回的错误
// - deletedIDs: 记录被调用删除的 API Key ID,用于断言验证 // - deletedIDs: 记录被调用删除的 API Key ID,用于断言验证
type apiKeyRepoStub struct { type apiKeyRepoStub struct {
...@@ -33,11 +33,11 @@ type apiKeyRepoStub struct { ...@@ -33,11 +33,11 @@ type apiKeyRepoStub struct {
// 以下方法在本测试中不应被调用,使用 panic 确保测试失败时能快速定位问题 // 以下方法在本测试中不应被调用,使用 panic 确保测试失败时能快速定位问题
func (s *apiKeyRepoStub) Create(ctx context.Context, key *ApiKey) error { func (s *apiKeyRepoStub) Create(ctx context.Context, key *APIKey) error {
panic("unexpected Create call") panic("unexpected Create call")
} }
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*ApiKey, error) { func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
panic("unexpected GetByID call") panic("unexpected GetByID call")
} }
...@@ -47,11 +47,11 @@ func (s *apiKeyRepoStub) GetOwnerID(ctx context.Context, id int64) (int64, error ...@@ -47,11 +47,11 @@ func (s *apiKeyRepoStub) GetOwnerID(ctx context.Context, id int64) (int64, error
return s.ownerID, s.ownerErr return s.ownerID, s.ownerErr
} }
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*ApiKey, error) { func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
panic("unexpected GetByKey call") panic("unexpected GetByKey call")
} }
func (s *apiKeyRepoStub) Update(ctx context.Context, key *ApiKey) error { func (s *apiKeyRepoStub) Update(ctx context.Context, key *APIKey) error {
panic("unexpected Update call") panic("unexpected Update call")
} }
...@@ -64,7 +64,7 @@ func (s *apiKeyRepoStub) Delete(ctx context.Context, id int64) error { ...@@ -64,7 +64,7 @@ func (s *apiKeyRepoStub) Delete(ctx context.Context, id int64) error {
// 以下是接口要求实现但本测试不关心的方法 // 以下是接口要求实现但本测试不关心的方法
func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) { func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByUserID call") panic("unexpected ListByUserID call")
} }
...@@ -80,12 +80,12 @@ func (s *apiKeyRepoStub) ExistsByKey(ctx context.Context, key string) (bool, err ...@@ -80,12 +80,12 @@ func (s *apiKeyRepoStub) ExistsByKey(ctx context.Context, key string) (bool, err
panic("unexpected ExistsByKey call") panic("unexpected ExistsByKey call")
} }
func (s *apiKeyRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) { func (s *apiKeyRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByGroupID call") panic("unexpected ListByGroupID call")
} }
func (s *apiKeyRepoStub) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) { func (s *apiKeyRepoStub) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
panic("unexpected SearchApiKeys call") panic("unexpected SearchAPIKeys call")
} }
func (s *apiKeyRepoStub) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) { func (s *apiKeyRepoStub) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
...@@ -96,7 +96,7 @@ func (s *apiKeyRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int ...@@ -96,7 +96,7 @@ func (s *apiKeyRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int
panic("unexpected CountByGroupID call") panic("unexpected CountByGroupID call")
} }
// apiKeyCacheStub 是 ApiKeyCache 接口的测试桩实现。 // apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
// 用于验证删除操作时缓存清理逻辑是否被正确调用。 // 用于验证删除操作时缓存清理逻辑是否被正确调用。
// //
// 设计说明: // 设计说明:
...@@ -142,7 +142,7 @@ func (s *apiKeyCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string ...@@ -142,7 +142,7 @@ func (s *apiKeyCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string
func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) { func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
repo := &apiKeyRepoStub{ownerID: 1} repo := &apiKeyRepoStub{ownerID: 1}
cache := &apiKeyCacheStub{} cache := &apiKeyCacheStub{}
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache} svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 10, 2) // API Key ID=10, 调用者 userID=2 err := svc.Delete(context.Background(), 10, 2) // API Key ID=10, 调用者 userID=2
require.ErrorIs(t, err, ErrInsufficientPerms) require.ErrorIs(t, err, ErrInsufficientPerms)
...@@ -160,7 +160,7 @@ func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) { ...@@ -160,7 +160,7 @@ func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
func TestApiKeyService_Delete_Success(t *testing.T) { func TestApiKeyService_Delete_Success(t *testing.T) {
repo := &apiKeyRepoStub{ownerID: 7} repo := &apiKeyRepoStub{ownerID: 7}
cache := &apiKeyCacheStub{} cache := &apiKeyCacheStub{}
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache} svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 42, 7) // API Key ID=42, 调用者 userID=7 err := svc.Delete(context.Background(), 42, 7) // API Key ID=42, 调用者 userID=7
require.NoError(t, err) require.NoError(t, err)
...@@ -170,17 +170,17 @@ func TestApiKeyService_Delete_Success(t *testing.T) { ...@@ -170,17 +170,17 @@ func TestApiKeyService_Delete_Success(t *testing.T) {
// TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。 // TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。
// 预期行为: // 预期行为:
// - GetOwnerID 返回 ErrApiKeyNotFound 错误 // - GetOwnerID 返回 ErrAPIKeyNotFound 错误
// - 返回 ErrApiKeyNotFound 错误(被 fmt.Errorf 包装) // - 返回 ErrAPIKeyNotFound 错误(被 fmt.Errorf 包装)
// - Delete 方法不被调用 // - Delete 方法不被调用
// - 缓存不被清除 // - 缓存不被清除
func TestApiKeyService_Delete_NotFound(t *testing.T) { func TestApiKeyService_Delete_NotFound(t *testing.T) {
repo := &apiKeyRepoStub{ownerErr: ErrApiKeyNotFound} repo := &apiKeyRepoStub{ownerErr: ErrAPIKeyNotFound}
cache := &apiKeyCacheStub{} cache := &apiKeyCacheStub{}
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache} svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 99, 1) err := svc.Delete(context.Background(), 99, 1)
require.ErrorIs(t, err, ErrApiKeyNotFound) require.ErrorIs(t, err, ErrAPIKeyNotFound)
require.Empty(t, repo.deletedIDs) require.Empty(t, repo.deletedIDs)
require.Empty(t, cache.invalidated) require.Empty(t, cache.invalidated)
} }
...@@ -198,7 +198,7 @@ func TestApiKeyService_Delete_DeleteFails(t *testing.T) { ...@@ -198,7 +198,7 @@ func TestApiKeyService_Delete_DeleteFails(t *testing.T) {
deleteErr: errors.New("delete failed"), deleteErr: errors.New("delete failed"),
} }
cache := &apiKeyCacheStub{} cache := &apiKeyCacheStub{}
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache} svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 3, 3) // API Key ID=3, 调用者 userID=3 err := svc.Delete(context.Background(), 3, 3) // API Key ID=3, 调用者 userID=3
require.Error(t, err) require.Error(t, err)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment