Commit 112a2d08 authored by ianshaw's avatar ianshaw
Browse files

chore: 更新依赖、配置和代码生成

主要更新:
- 更新 go.mod/go.sum 依赖
- 重新生成 Ent ORM 代码
- 更新 Wire 依赖注入配置
- 添加 docker-compose.override.yml 到 .gitignore
- 更新 README 文档(Simple Mode 说明和已知问题)
- 清理调试日志
- 其他代码优化和格式修复
parent b1702de5
...@@ -2,7 +2,7 @@ package service ...@@ -2,7 +2,7 @@ package service
import "time" import "time"
type APIKey struct { type ApiKey struct {
ID int64 ID int64
UserID int64 UserID int64
Key string Key string
...@@ -15,6 +15,6 @@ type APIKey struct { ...@@ -15,6 +15,6 @@ type APIKey struct {
Group *Group Group *Group
} }
func (k *APIKey) IsActive() bool { func (k *ApiKey) IsActive() bool {
return k.Status == StatusActive return k.Status == StatusActive
} }
...@@ -14,39 +14,39 @@ import ( ...@@ -14,39 +14,39 @@ import (
) )
var ( var (
ErrAPIKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found") ErrApiKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found")
ErrGroupNotAllowed = infraerrors.Forbidden("GROUP_NOT_ALLOWED", "user is not allowed to bind this group") ErrGroupNotAllowed = infraerrors.Forbidden("GROUP_NOT_ALLOWED", "user is not allowed to bind this group")
ErrAPIKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists") ErrApiKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists")
ErrAPIKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters") ErrApiKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens") ErrApiKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later") ErrApiKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
) )
const ( const (
apiKeyMaxErrorsPerHour = 20 apiKeyMaxErrorsPerHour = 20
) )
type APIKeyRepository interface { type ApiKeyRepository interface {
Create(ctx context.Context, key *APIKey) error Create(ctx context.Context, key *ApiKey) error
GetByID(ctx context.Context, id int64) (*APIKey, error) GetByID(ctx context.Context, id int64) (*ApiKey, error)
// GetOwnerID 仅获取 API Key 的所有者 ID,用于删除前的轻量级权限验证 // GetOwnerID 仅获取 API Key 的所有者 ID,用于删除前的轻量级权限验证
GetOwnerID(ctx context.Context, id int64) (int64, error) GetOwnerID(ctx context.Context, id int64) (int64, error)
GetByKey(ctx context.Context, key string) (*APIKey, error) GetByKey(ctx context.Context, key string) (*ApiKey, error)
Update(ctx context.Context, key *APIKey) error Update(ctx context.Context, key *ApiKey) error
Delete(ctx context.Context, id int64) error Delete(ctx context.Context, id int64) error
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error)
CountByUserID(ctx context.Context, userID int64) (int64, error) CountByUserID(ctx context.Context, userID int64) (int64, error)
ExistsByKey(ctx context.Context, key string) (bool, error) ExistsByKey(ctx context.Context, key string) (bool, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error)
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
CountByGroupID(ctx context.Context, groupID int64) (int64, error) CountByGroupID(ctx context.Context, groupID int64) (int64, error)
} }
// APIKeyCache defines cache operations for API key service // ApiKeyCache defines cache operations for API key service
type APIKeyCache interface { type ApiKeyCache interface {
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
IncrementCreateAttemptCount(ctx context.Context, userID int64) error IncrementCreateAttemptCount(ctx context.Context, userID int64) error
DeleteCreateAttemptCount(ctx context.Context, userID int64) error DeleteCreateAttemptCount(ctx context.Context, userID int64) error
...@@ -55,40 +55,40 @@ type APIKeyCache interface { ...@@ -55,40 +55,40 @@ type APIKeyCache interface {
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
} }
// CreateAPIKeyRequest 创建API Key请求 // CreateApiKeyRequest 创建API Key请求
type CreateAPIKeyRequest struct { type CreateApiKeyRequest struct {
Name string `json:"name"` Name string `json:"name"`
GroupID *int64 `json:"group_id"` GroupID *int64 `json:"group_id"`
CustomKey *string `json:"custom_key"` // 可选的自定义key CustomKey *string `json:"custom_key"` // 可选的自定义key
} }
// UpdateAPIKeyRequest 更新API Key请求 // UpdateApiKeyRequest 更新API Key请求
type UpdateAPIKeyRequest struct { type UpdateApiKeyRequest struct {
Name *string `json:"name"` Name *string `json:"name"`
GroupID *int64 `json:"group_id"` GroupID *int64 `json:"group_id"`
Status *string `json:"status"` Status *string `json:"status"`
} }
// APIKeyService API Key服务 // ApiKeyService API Key服务
type APIKeyService struct { type ApiKeyService struct {
apiKeyRepo APIKeyRepository apiKeyRepo ApiKeyRepository
userRepo UserRepository userRepo UserRepository
groupRepo GroupRepository groupRepo GroupRepository
userSubRepo UserSubscriptionRepository userSubRepo UserSubscriptionRepository
cache APIKeyCache cache ApiKeyCache
cfg *config.Config cfg *config.Config
} }
// NewAPIKeyService 创建API Key服务实例 // NewApiKeyService 创建API Key服务实例
func NewAPIKeyService( func NewApiKeyService(
apiKeyRepo APIKeyRepository, apiKeyRepo ApiKeyRepository,
userRepo UserRepository, userRepo UserRepository,
groupRepo GroupRepository, groupRepo GroupRepository,
userSubRepo UserSubscriptionRepository, userSubRepo UserSubscriptionRepository,
cache APIKeyCache, cache ApiKeyCache,
cfg *config.Config, cfg *config.Config,
) *APIKeyService { ) *ApiKeyService {
return &APIKeyService{ return &ApiKeyService{
apiKeyRepo: apiKeyRepo, apiKeyRepo: apiKeyRepo,
userRepo: userRepo, userRepo: userRepo,
groupRepo: groupRepo, groupRepo: groupRepo,
...@@ -99,7 +99,7 @@ func NewAPIKeyService( ...@@ -99,7 +99,7 @@ func NewAPIKeyService(
} }
// GenerateKey 生成随机API Key // GenerateKey 生成随机API Key
func (s *APIKeyService) GenerateKey() (string, error) { func (s *ApiKeyService) GenerateKey() (string, error) {
// 生成32字节随机数据 // 生成32字节随机数据
bytes := make([]byte, 32) bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil { if _, err := rand.Read(bytes); err != nil {
...@@ -107,7 +107,7 @@ func (s *APIKeyService) GenerateKey() (string, error) { ...@@ -107,7 +107,7 @@ func (s *APIKeyService) GenerateKey() (string, error) {
} }
// 转换为十六进制字符串并添加前缀 // 转换为十六进制字符串并添加前缀
prefix := s.cfg.Default.APIKeyPrefix prefix := s.cfg.Default.ApiKeyPrefix
if prefix == "" { if prefix == "" {
prefix = "sk-" prefix = "sk-"
} }
...@@ -117,10 +117,10 @@ func (s *APIKeyService) GenerateKey() (string, error) { ...@@ -117,10 +117,10 @@ func (s *APIKeyService) GenerateKey() (string, error) {
} }
// ValidateCustomKey 验证自定义API Key格式 // ValidateCustomKey 验证自定义API Key格式
func (s *APIKeyService) ValidateCustomKey(key string) error { func (s *ApiKeyService) ValidateCustomKey(key string) error {
// 检查长度 // 检查长度
if len(key) < 16 { if len(key) < 16 {
return ErrAPIKeyTooShort return ErrApiKeyTooShort
} }
// 检查字符:只允许字母、数字、下划线、连字符 // 检查字符:只允许字母、数字、下划线、连字符
...@@ -131,14 +131,14 @@ func (s *APIKeyService) ValidateCustomKey(key string) error { ...@@ -131,14 +131,14 @@ func (s *APIKeyService) ValidateCustomKey(key string) error {
c == '_' || c == '-' { c == '_' || c == '-' {
continue continue
} }
return ErrAPIKeyInvalidChars return ErrApiKeyInvalidChars
} }
return nil return nil
} }
// checkAPIKeyRateLimit 检查用户创建自定义Key的错误次数是否超限 // checkApiKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
func (s *APIKeyService) checkAPIKeyRateLimit(ctx context.Context, userID int64) error { func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) error {
if s.cache == nil { if s.cache == nil {
return nil return nil
} }
...@@ -150,14 +150,14 @@ func (s *APIKeyService) checkAPIKeyRateLimit(ctx context.Context, userID int64) ...@@ -150,14 +150,14 @@ func (s *APIKeyService) checkAPIKeyRateLimit(ctx context.Context, userID int64)
} }
if count >= apiKeyMaxErrorsPerHour { if count >= apiKeyMaxErrorsPerHour {
return ErrAPIKeyRateLimited return ErrApiKeyRateLimited
} }
return nil return nil
} }
// incrementAPIKeyErrorCount 增加用户创建自定义Key的错误计数 // incrementApiKeyErrorCount 增加用户创建自定义Key的错误计数
func (s *APIKeyService) incrementAPIKeyErrorCount(ctx context.Context, userID int64) { func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID int64) {
if s.cache == nil { if s.cache == nil {
return return
} }
...@@ -168,7 +168,7 @@ func (s *APIKeyService) incrementAPIKeyErrorCount(ctx context.Context, userID in ...@@ -168,7 +168,7 @@ func (s *APIKeyService) incrementAPIKeyErrorCount(ctx context.Context, userID in
// canUserBindGroup 检查用户是否可以绑定指定分组 // canUserBindGroup 检查用户是否可以绑定指定分组
// 对于订阅类型分组:检查用户是否有有效订阅 // 对于订阅类型分组:检查用户是否有有效订阅
// 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑 // 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑
func (s *APIKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool { func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool {
// 订阅类型分组:需要有效订阅 // 订阅类型分组:需要有效订阅
if group.IsSubscriptionType() { if group.IsSubscriptionType() {
_, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, group.ID) _, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, group.ID)
...@@ -179,7 +179,7 @@ func (s *APIKeyService) canUserBindGroup(ctx context.Context, user *User, group ...@@ -179,7 +179,7 @@ func (s *APIKeyService) canUserBindGroup(ctx context.Context, user *User, group
} }
// Create 创建API Key // Create 创建API Key
func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIKeyRequest) (*APIKey, error) { func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiKeyRequest) (*ApiKey, error) {
// 验证用户存在 // 验证用户存在
user, err := s.userRepo.GetByID(ctx, userID) user, err := s.userRepo.GetByID(ctx, userID)
if err != nil { if err != nil {
...@@ -204,7 +204,7 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK ...@@ -204,7 +204,7 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
// 判断是否使用自定义Key // 判断是否使用自定义Key
if req.CustomKey != nil && *req.CustomKey != "" { if req.CustomKey != nil && *req.CustomKey != "" {
// 检查限流(仅对自定义key进行限流) // 检查限流(仅对自定义key进行限流)
if err := s.checkAPIKeyRateLimit(ctx, userID); err != nil { if err := s.checkApiKeyRateLimit(ctx, userID); err != nil {
return nil, err return nil, err
} }
...@@ -219,9 +219,9 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK ...@@ -219,9 +219,9 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
return nil, fmt.Errorf("check key exists: %w", err) return nil, fmt.Errorf("check key exists: %w", err)
} }
if exists { if exists {
// Key已存在,增加错误计数 // Key已存在增加错误计数
s.incrementAPIKeyErrorCount(ctx, userID) s.incrementApiKeyErrorCount(ctx, userID)
return nil, ErrAPIKeyExists return nil, ErrApiKeyExists
} }
key = *req.CustomKey key = *req.CustomKey
...@@ -235,7 +235,7 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK ...@@ -235,7 +235,7 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
} }
// 创建API Key记录 // 创建API Key记录
apiKey := &APIKey{ apiKey := &ApiKey{
UserID: userID, UserID: userID,
Key: key, Key: key,
Name: req.Name, Name: req.Name,
...@@ -251,7 +251,7 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK ...@@ -251,7 +251,7 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
} }
// List 获取用户的API Key列表 // List 获取用户的API Key列表
func (s *APIKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) { func (s *ApiKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params) keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("list api keys: %w", err) return nil, nil, fmt.Errorf("list api keys: %w", err)
...@@ -259,7 +259,7 @@ func (s *APIKeyService) List(ctx context.Context, userID int64, params paginatio ...@@ -259,7 +259,7 @@ func (s *APIKeyService) List(ctx context.Context, userID int64, params paginatio
return keys, pagination, nil return keys, pagination, nil
} }
func (s *APIKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) { func (s *ApiKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
if len(apiKeyIDs) == 0 { if len(apiKeyIDs) == 0 {
return []int64{}, nil return []int64{}, nil
} }
...@@ -272,7 +272,7 @@ func (s *APIKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKe ...@@ -272,7 +272,7 @@ func (s *APIKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKe
} }
// GetByID 根据ID获取API Key // GetByID 根据ID获取API Key
func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error) { func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id) apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, fmt.Errorf("get api key: %w", err) return nil, fmt.Errorf("get api key: %w", err)
...@@ -281,7 +281,7 @@ func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error) ...@@ -281,7 +281,7 @@ func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error)
} }
// GetByKey 根据Key字符串获取API Key(用于认证) // GetByKey 根据Key字符串获取API Key(用于认证)
func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, error) { func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*ApiKey, error) {
// 尝试从Redis缓存获取 // 尝试从Redis缓存获取
cacheKey := fmt.Sprintf("apikey:%s", key) cacheKey := fmt.Sprintf("apikey:%s", key)
...@@ -301,7 +301,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro ...@@ -301,7 +301,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro
} }
// Update 更新API Key // Update 更新API Key
func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateAPIKeyRequest) (*APIKey, error) { func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*ApiKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id) apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, fmt.Errorf("get api key: %w", err) return nil, fmt.Errorf("get api key: %w", err)
...@@ -353,8 +353,8 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req ...@@ -353,8 +353,8 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
// Delete 删除API Key // Delete 删除API Key
// 优化:使用 GetOwnerID 替代 GetByID 进行权限验证, // 优化:使用 GetOwnerID 替代 GetByID 进行权限验证,
// 避免加载完整 APIKey 对象及其关联数据(User、Group),提升删除操作的性能 // 避免加载完整 ApiKey 对象及其关联数据(User、Group),提升删除操作的性能
func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) error { func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) error {
// 仅获取所有者 ID 用于权限验证,而非加载完整对象 // 仅获取所有者 ID 用于权限验证,而非加载完整对象
ownerID, err := s.apiKeyRepo.GetOwnerID(ctx, id) ownerID, err := s.apiKeyRepo.GetOwnerID(ctx, id)
if err != nil { if err != nil {
...@@ -379,7 +379,7 @@ func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) erro ...@@ -379,7 +379,7 @@ func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) erro
} }
// ValidateKey 验证API Key是否有效(用于认证中间件) // ValidateKey 验证API Key是否有效(用于认证中间件)
func (s *APIKeyService) ValidateKey(ctx context.Context, key string) (*APIKey, *User, error) { func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*ApiKey, *User, error) {
// 获取API Key // 获取API Key
apiKey, err := s.GetByKey(ctx, key) apiKey, err := s.GetByKey(ctx, key)
if err != nil { if err != nil {
...@@ -406,7 +406,7 @@ func (s *APIKeyService) ValidateKey(ctx context.Context, key string) (*APIKey, * ...@@ -406,7 +406,7 @@ func (s *APIKeyService) ValidateKey(ctx context.Context, key string) (*APIKey, *
} }
// IncrementUsage 增加API Key使用次数(可选:用于统计) // IncrementUsage 增加API Key使用次数(可选:用于统计)
func (s *APIKeyService) IncrementUsage(ctx context.Context, keyID int64) error { func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
// 使用Redis计数器 // 使用Redis计数器
if s.cache != nil { if s.cache != nil {
cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02")) cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02"))
...@@ -423,7 +423,7 @@ func (s *APIKeyService) IncrementUsage(ctx context.Context, keyID int64) error { ...@@ -423,7 +423,7 @@ func (s *APIKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
// 返回用户可以选择的分组: // 返回用户可以选择的分组:
// - 标准类型分组:公开的(非专属)或用户被明确允许的 // - 标准类型分组:公开的(非专属)或用户被明确允许的
// - 订阅类型分组:用户有有效订阅的 // - 订阅类型分组:用户有有效订阅的
func (s *APIKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) { func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) {
// 获取用户信息 // 获取用户信息
user, err := s.userRepo.GetByID(ctx, userID) user, err := s.userRepo.GetByID(ctx, userID)
if err != nil { if err != nil {
...@@ -460,7 +460,7 @@ func (s *APIKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([ ...@@ -460,7 +460,7 @@ func (s *APIKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([
} }
// canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据) // canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据)
func (s *APIKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool { func (s *ApiKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool {
// 订阅类型分组:需要有效订阅 // 订阅类型分组:需要有效订阅
if group.IsSubscriptionType() { if group.IsSubscriptionType() {
return subscribedGroupIDs[group.ID] return subscribedGroupIDs[group.ID]
...@@ -469,8 +469,8 @@ func (s *APIKeyService) canUserBindGroupInternal(user *User, group *Group, subsc ...@@ -469,8 +469,8 @@ func (s *APIKeyService) canUserBindGroupInternal(user *User, group *Group, subsc
return user.CanBindGroup(group.ID, group.IsExclusive) return user.CanBindGroup(group.ID, group.IsExclusive)
} }
func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) { func (s *ApiKeyService) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) {
keys, err := s.apiKeyRepo.SearchAPIKeys(ctx, userID, keyword, limit) keys, err := s.apiKeyRepo.SearchApiKeys(ctx, userID, keyword, limit)
if err != nil { if err != nil {
return nil, fmt.Errorf("search api keys: %w", err) return nil, fmt.Errorf("search api keys: %w", err)
} }
......
//go:build unit //go:build unit
// API Key 服务删除方法的单元测试 // API Key 服务删除方法的单元测试
// 测试 APIKeyService.Delete 方法在各种场景下的行为, // 测试 ApiKeyService.Delete 方法在各种场景下的行为,
// 包括权限验证、缓存清理和错误处理 // 包括权限验证、缓存清理和错误处理
package service package service
...@@ -16,12 +16,12 @@ import ( ...@@ -16,12 +16,12 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
// apiKeyRepoStub 是 APIKeyRepository 接口的测试桩实现。 // apiKeyRepoStub 是 ApiKeyRepository 接口的测试桩实现。
// 用于隔离测试 APIKeyService.Delete 方法,避免依赖真实数据库。 // 用于隔离测试 ApiKeyService.Delete 方法,避免依赖真实数据库。
// //
// 设计说明: // 设计说明:
// - ownerID: 模拟 GetOwnerID 返回的所有者 ID // - ownerID: 模拟 GetOwnerID 返回的所有者 ID
// - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrAPIKeyNotFound) // - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrApiKeyNotFound)
// - deleteErr: 模拟 Delete 返回的错误 // - deleteErr: 模拟 Delete 返回的错误
// - deletedIDs: 记录被调用删除的 API Key ID,用于断言验证 // - deletedIDs: 记录被调用删除的 API Key ID,用于断言验证
type apiKeyRepoStub struct { type apiKeyRepoStub struct {
...@@ -33,11 +33,11 @@ type apiKeyRepoStub struct { ...@@ -33,11 +33,11 @@ type apiKeyRepoStub struct {
// 以下方法在本测试中不应被调用,使用 panic 确保测试失败时能快速定位问题 // 以下方法在本测试中不应被调用,使用 panic 确保测试失败时能快速定位问题
func (s *apiKeyRepoStub) Create(ctx context.Context, key *APIKey) error { func (s *apiKeyRepoStub) Create(ctx context.Context, key *ApiKey) error {
panic("unexpected Create call") panic("unexpected Create call")
} }
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) { func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*ApiKey, error) {
panic("unexpected GetByID call") panic("unexpected GetByID call")
} }
...@@ -47,11 +47,11 @@ func (s *apiKeyRepoStub) GetOwnerID(ctx context.Context, id int64) (int64, error ...@@ -47,11 +47,11 @@ func (s *apiKeyRepoStub) GetOwnerID(ctx context.Context, id int64) (int64, error
return s.ownerID, s.ownerErr return s.ownerID, s.ownerErr
} }
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) { func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*ApiKey, error) {
panic("unexpected GetByKey call") panic("unexpected GetByKey call")
} }
func (s *apiKeyRepoStub) Update(ctx context.Context, key *APIKey) error { func (s *apiKeyRepoStub) Update(ctx context.Context, key *ApiKey) error {
panic("unexpected Update call") panic("unexpected Update call")
} }
...@@ -64,7 +64,7 @@ func (s *apiKeyRepoStub) Delete(ctx context.Context, id int64) error { ...@@ -64,7 +64,7 @@ func (s *apiKeyRepoStub) Delete(ctx context.Context, id int64) error {
// 以下是接口要求实现但本测试不关心的方法 // 以下是接口要求实现但本测试不关心的方法
func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) { func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
panic("unexpected ListByUserID call") panic("unexpected ListByUserID call")
} }
...@@ -80,12 +80,12 @@ func (s *apiKeyRepoStub) ExistsByKey(ctx context.Context, key string) (bool, err ...@@ -80,12 +80,12 @@ func (s *apiKeyRepoStub) ExistsByKey(ctx context.Context, key string) (bool, err
panic("unexpected ExistsByKey call") panic("unexpected ExistsByKey call")
} }
func (s *apiKeyRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) { func (s *apiKeyRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
panic("unexpected ListByGroupID call") panic("unexpected ListByGroupID call")
} }
func (s *apiKeyRepoStub) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) { func (s *apiKeyRepoStub) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) {
panic("unexpected SearchAPIKeys call") panic("unexpected SearchApiKeys call")
} }
func (s *apiKeyRepoStub) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) { func (s *apiKeyRepoStub) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
...@@ -96,7 +96,7 @@ func (s *apiKeyRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int ...@@ -96,7 +96,7 @@ func (s *apiKeyRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int
panic("unexpected CountByGroupID call") panic("unexpected CountByGroupID call")
} }
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。 // apiKeyCacheStub 是 ApiKeyCache 接口的测试桩实现。
// 用于验证删除操作时缓存清理逻辑是否被正确调用。 // 用于验证删除操作时缓存清理逻辑是否被正确调用。
// //
// 设计说明: // 设计说明:
...@@ -132,17 +132,17 @@ func (s *apiKeyCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string ...@@ -132,17 +132,17 @@ func (s *apiKeyCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string
return nil return nil
} }
// TestAPIKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。 // TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
// 预期行为: // 预期行为:
// - GetOwnerID 返回所有者 ID 为 1 // - GetOwnerID 返回所有者 ID 为 1
// - 调用者 userID 为 2(不匹配) // - 调用者 userID 为 2(不匹配)
// - 返回 ErrInsufficientPerms 错误 // - 返回 ErrInsufficientPerms 错误
// - Delete 方法不被调用 // - Delete 方法不被调用
// - 缓存不被清除 // - 缓存不被清除
func TestAPIKeyService_Delete_OwnerMismatch(t *testing.T) { func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
repo := &apiKeyRepoStub{ownerID: 1} repo := &apiKeyRepoStub{ownerID: 1}
cache := &apiKeyCacheStub{} cache := &apiKeyCacheStub{}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache} svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 10, 2) // API Key ID=10, 调用者 userID=2 err := svc.Delete(context.Background(), 10, 2) // API Key ID=10, 调用者 userID=2
require.ErrorIs(t, err, ErrInsufficientPerms) require.ErrorIs(t, err, ErrInsufficientPerms)
...@@ -150,17 +150,17 @@ func TestAPIKeyService_Delete_OwnerMismatch(t *testing.T) { ...@@ -150,17 +150,17 @@ func TestAPIKeyService_Delete_OwnerMismatch(t *testing.T) {
require.Empty(t, cache.invalidated) // 验证缓存未被清除 require.Empty(t, cache.invalidated) // 验证缓存未被清除
} }
// TestAPIKeyService_Delete_Success 测试所有者成功删除 API Key 的场景。 // TestApiKeyService_Delete_Success 测试所有者成功删除 API Key 的场景。
// 预期行为: // 预期行为:
// - GetOwnerID 返回所有者 ID 为 7 // - GetOwnerID 返回所有者 ID 为 7
// - 调用者 userID 为 7(匹配) // - 调用者 userID 为 7(匹配)
// - Delete 成功执行 // - Delete 成功执行
// - 缓存被正确清除(使用 ownerID) // - 缓存被正确清除(使用 ownerID)
// - 返回 nil 错误 // - 返回 nil 错误
func TestAPIKeyService_Delete_Success(t *testing.T) { func TestApiKeyService_Delete_Success(t *testing.T) {
repo := &apiKeyRepoStub{ownerID: 7} repo := &apiKeyRepoStub{ownerID: 7}
cache := &apiKeyCacheStub{} cache := &apiKeyCacheStub{}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache} svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 42, 7) // API Key ID=42, 调用者 userID=7 err := svc.Delete(context.Background(), 42, 7) // API Key ID=42, 调用者 userID=7
require.NoError(t, err) require.NoError(t, err)
...@@ -168,37 +168,37 @@ func TestAPIKeyService_Delete_Success(t *testing.T) { ...@@ -168,37 +168,37 @@ func TestAPIKeyService_Delete_Success(t *testing.T) {
require.Equal(t, []int64{7}, cache.invalidated) // 验证所有者的缓存被清除 require.Equal(t, []int64{7}, cache.invalidated) // 验证所有者的缓存被清除
} }
// TestAPIKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。 // TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。
// 预期行为: // 预期行为:
// - GetOwnerID 返回 ErrAPIKeyNotFound 错误 // - GetOwnerID 返回 ErrApiKeyNotFound 错误
// - 返回 ErrAPIKeyNotFound 错误(被 fmt.Errorf 包装) // - 返回 ErrApiKeyNotFound 错误(被 fmt.Errorf 包装)
// - Delete 方法不被调用 // - Delete 方法不被调用
// - 缓存不被清除 // - 缓存不被清除
func TestAPIKeyService_Delete_NotFound(t *testing.T) { func TestApiKeyService_Delete_NotFound(t *testing.T) {
repo := &apiKeyRepoStub{ownerErr: ErrAPIKeyNotFound} repo := &apiKeyRepoStub{ownerErr: ErrApiKeyNotFound}
cache := &apiKeyCacheStub{} cache := &apiKeyCacheStub{}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache} svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 99, 1) err := svc.Delete(context.Background(), 99, 1)
require.ErrorIs(t, err, ErrAPIKeyNotFound) require.ErrorIs(t, err, ErrApiKeyNotFound)
require.Empty(t, repo.deletedIDs) require.Empty(t, repo.deletedIDs)
require.Empty(t, cache.invalidated) require.Empty(t, cache.invalidated)
} }
// TestAPIKeyService_Delete_DeleteFails 测试删除操作失败时的错误处理。 // TestApiKeyService_Delete_DeleteFails 测试删除操作失败时的错误处理。
// 预期行为: // 预期行为:
// - GetOwnerID 返回正确的所有者 ID // - GetOwnerID 返回正确的所有者 ID
// - 所有权验证通过 // - 所有权验证通过
// - 缓存被清除(在删除之前) // - 缓存被清除(在删除之前)
// - Delete 被调用但返回错误 // - Delete 被调用但返回错误
// - 返回包含 "delete api key" 的错误信息 // - 返回包含 "delete api key" 的错误信息
func TestAPIKeyService_Delete_DeleteFails(t *testing.T) { func TestApiKeyService_Delete_DeleteFails(t *testing.T) {
repo := &apiKeyRepoStub{ repo := &apiKeyRepoStub{
ownerID: 3, ownerID: 3,
deleteErr: errors.New("delete failed"), deleteErr: errors.New("delete failed"),
} }
cache := &apiKeyCacheStub{} cache := &apiKeyCacheStub{}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache} svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 3, 3) // API Key ID=3, 调用者 userID=3 err := svc.Delete(context.Background(), 3, 3) // API Key ID=3, 调用者 userID=3
require.Error(t, err) require.Error(t, err)
......
...@@ -445,7 +445,7 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID ...@@ -445,7 +445,7 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID
// CheckBillingEligibility 检查用户是否有资格发起请求 // CheckBillingEligibility 检查用户是否有资格发起请求
// 余额模式:检查缓存余额 > 0 // 余额模式:检查缓存余额 > 0
// 订阅模式:检查缓存用量未超过限额(Group限额从参数传入) // 订阅模式:检查缓存用量未超过限额(Group限额从参数传入)
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *APIKey, group *Group, subscription *UserSubscription) error { func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *ApiKey, group *Group, subscription *UserSubscription) error {
// 简易模式:跳过所有计费检查 // 简易模式:跳过所有计费检查
if s.cfg.RunMode == config.RunModeSimple { if s.cfg.RunMode == config.RunModeSimple {
return nil return nil
......
...@@ -82,7 +82,7 @@ type crsExportResponse struct { ...@@ -82,7 +82,7 @@ type crsExportResponse struct {
OpenAIOAuthAccounts []crsOpenAIOAuthAccount `json:"openaiOAuthAccounts"` OpenAIOAuthAccounts []crsOpenAIOAuthAccount `json:"openaiOAuthAccounts"`
OpenAIResponsesAccounts []crsOpenAIResponsesAccount `json:"openaiResponsesAccounts"` OpenAIResponsesAccounts []crsOpenAIResponsesAccount `json:"openaiResponsesAccounts"`
GeminiOAuthAccounts []crsGeminiOAuthAccount `json:"geminiOAuthAccounts"` GeminiOAuthAccounts []crsGeminiOAuthAccount `json:"geminiOAuthAccounts"`
GeminiAPIKeyAccounts []crsGeminiAPIKeyAccount `json:"geminiAPIKeyAccounts"` GeminiAPIKeyAccounts []crsGeminiAPIKeyAccount `json:"geminiApiKeyAccounts"`
} `json:"data"` } `json:"data"`
} }
...@@ -430,7 +430,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -430,7 +430,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
account := &Account{ account := &Account{
Name: defaultName(src.Name, src.ID), Name: defaultName(src.Name, src.ID),
Platform: PlatformAnthropic, Platform: PlatformAnthropic,
Type: AccountTypeAPIKey, Type: AccountTypeApiKey,
Credentials: credentials, Credentials: credentials,
Extra: extra, Extra: extra,
ProxyID: proxyID, ProxyID: proxyID,
...@@ -455,7 +455,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -455,7 +455,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
existing.Extra = mergeMap(existing.Extra, extra) existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID) existing.Name = defaultName(src.Name, src.ID)
existing.Platform = PlatformAnthropic existing.Platform = PlatformAnthropic
existing.Type = AccountTypeAPIKey existing.Type = AccountTypeApiKey
existing.Credentials = mergeMap(existing.Credentials, credentials) existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil { if proxyID != nil {
existing.ProxyID = proxyID existing.ProxyID = proxyID
...@@ -674,7 +674,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -674,7 +674,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
account := &Account{ account := &Account{
Name: defaultName(src.Name, src.ID), Name: defaultName(src.Name, src.ID),
Platform: PlatformOpenAI, Platform: PlatformOpenAI,
Type: AccountTypeAPIKey, Type: AccountTypeApiKey,
Credentials: credentials, Credentials: credentials,
Extra: extra, Extra: extra,
ProxyID: proxyID, ProxyID: proxyID,
...@@ -699,7 +699,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -699,7 +699,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
existing.Extra = mergeMap(existing.Extra, extra) existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID) existing.Name = defaultName(src.Name, src.ID)
existing.Platform = PlatformOpenAI existing.Platform = PlatformOpenAI
existing.Type = AccountTypeAPIKey existing.Type = AccountTypeApiKey
existing.Credentials = mergeMap(existing.Credentials, credentials) existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil { if proxyID != nil {
existing.ProxyID = proxyID existing.ProxyID = proxyID
...@@ -893,7 +893,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -893,7 +893,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
account := &Account{ account := &Account{
Name: defaultName(src.Name, src.ID), Name: defaultName(src.Name, src.ID),
Platform: PlatformGemini, Platform: PlatformGemini,
Type: AccountTypeAPIKey, Type: AccountTypeApiKey,
Credentials: credentials, Credentials: credentials,
Extra: extra, Extra: extra,
ProxyID: proxyID, ProxyID: proxyID,
...@@ -918,7 +918,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -918,7 +918,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
existing.Extra = mergeMap(existing.Extra, extra) existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID) existing.Name = defaultName(src.Name, src.ID)
existing.Platform = PlatformGemini existing.Platform = PlatformGemini
existing.Type = AccountTypeAPIKey existing.Type = AccountTypeApiKey
existing.Credentials = mergeMap(existing.Credentials, credentials) existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil { if proxyID != nil {
existing.ProxyID = proxyID existing.ProxyID = proxyID
......
...@@ -43,8 +43,8 @@ func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTi ...@@ -43,8 +43,8 @@ func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTi
return stats, nil return stats, nil
} }
func (s *DashboardService) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) { func (s *DashboardService) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error) {
trend, err := s.usageRepo.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit) trend, err := s.usageRepo.GetApiKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
if err != nil { if err != nil {
return nil, fmt.Errorf("get api key usage trend: %w", err) return nil, fmt.Errorf("get api key usage trend: %w", err)
} }
...@@ -67,8 +67,8 @@ func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs [ ...@@ -67,8 +67,8 @@ func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs [
return stats, nil return stats, nil
} }
func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { func (s *DashboardService) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs) stats, err := s.usageRepo.GetBatchApiKeyUsageStats(ctx, apiKeyIDs)
if err != nil { if err != nil {
return nil, fmt.Errorf("get batch api key usage stats: %w", err) return nil, fmt.Errorf("get batch api key usage stats: %w", err)
} }
......
...@@ -28,7 +28,7 @@ const ( ...@@ -28,7 +28,7 @@ const (
const ( const (
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference) AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope) AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
AccountTypeAPIKey = "apikey" // API Key类型账号 AccountTypeApiKey = "apikey" // API Key类型账号
) )
// Redeem type constants // Redeem type constants
...@@ -64,13 +64,13 @@ const ( ...@@ -64,13 +64,13 @@ const (
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证 SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
// 邮件服务设置 // 邮件服务设置
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址 SettingKeySmtpHost = "smtp_host" // SMTP服务器地址
SettingKeySMTPPort = "smtp_port" // SMTP端口 SettingKeySmtpPort = "smtp_port" // SMTP端口
SettingKeySMTPUsername = "smtp_username" // SMTP用户名 SettingKeySmtpUsername = "smtp_username" // SMTP用户名
SettingKeySMTPPassword = "smtp_password" // SMTP密码(加密存储) SettingKeySmtpPassword = "smtp_password" // SMTP密码(加密存储)
SettingKeySMTPFrom = "smtp_from" // 发件人地址 SettingKeySmtpFrom = "smtp_from" // 发件人地址
SettingKeySMTPFromName = "smtp_from_name" // 发件人名称 SettingKeySmtpFromName = "smtp_from_name" // 发件人名称
SettingKeySMTPUseTLS = "smtp_use_tls" // 是否使用TLS SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS
// Cloudflare Turnstile 设置 // Cloudflare Turnstile 设置
SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证 SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证
...@@ -81,20 +81,27 @@ const ( ...@@ -81,20 +81,27 @@ const (
SettingKeySiteName = "site_name" // 网站名称 SettingKeySiteName = "site_name" // 网站名称
SettingKeySiteLogo = "site_logo" // 网站Logo (base64) SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题 SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
SettingKeyAPIBaseURL = "api_base_url" // API端点地址(用于客户端配置和导入) SettingKeyApiBaseUrl = "api_base_url" // API端点地址(用于客户端配置和导入)
SettingKeyContactInfo = "contact_info" // 客服联系方式 SettingKeyContactInfo = "contact_info" // 客服联系方式
SettingKeyDocURL = "doc_url" // 文档链接 SettingKeyDocUrl = "doc_url" // 文档链接
// 默认配置 // 默认配置
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量 SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额 SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
// 管理员 API Key // 管理员 API Key
SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成) SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成)
// Gemini 配额策略(JSON) // Gemini 配额策略(JSON)
SettingKeyGeminiQuotaPolicy = "gemini_quota_policy" SettingKeyGeminiQuotaPolicy = "gemini_quota_policy"
// Model fallback settings
SettingKeyEnableModelFallback = "enable_model_fallback"
SettingKeyFallbackModelAnthropic = "fallback_model_anthropic"
SettingKeyFallbackModelOpenAI = "fallback_model_openai"
SettingKeyFallbackModelGemini = "fallback_model_gemini"
SettingKeyFallbackModelAntigravity = "fallback_model_antigravity"
) )
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys) // Admin API Key prefix (distinct from user "sk-" keys)
const AdminAPIKeyPrefix = "admin-" const AdminApiKeyPrefix = "admin-"
...@@ -40,8 +40,8 @@ const ( ...@@ -40,8 +40,8 @@ const (
maxVerifyCodeAttempts = 5 maxVerifyCodeAttempts = 5
) )
// SMTPConfig SMTP配置 // SmtpConfig SMTP配置
type SMTPConfig struct { type SmtpConfig struct {
Host string Host string
Port int Port int
Username string Username string
...@@ -65,16 +65,16 @@ func NewEmailService(settingRepo SettingRepository, cache EmailCache) *EmailServ ...@@ -65,16 +65,16 @@ func NewEmailService(settingRepo SettingRepository, cache EmailCache) *EmailServ
} }
} }
// GetSMTPConfig 从数据库获取SMTP配置 // GetSmtpConfig 从数据库获取SMTP配置
func (s *EmailService) GetSMTPConfig(ctx context.Context) (*SMTPConfig, error) { func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) {
keys := []string{ keys := []string{
SettingKeySMTPHost, SettingKeySmtpHost,
SettingKeySMTPPort, SettingKeySmtpPort,
SettingKeySMTPUsername, SettingKeySmtpUsername,
SettingKeySMTPPassword, SettingKeySmtpPassword,
SettingKeySMTPFrom, SettingKeySmtpFrom,
SettingKeySMTPFromName, SettingKeySmtpFromName,
SettingKeySMTPUseTLS, SettingKeySmtpUseTLS,
} }
settings, err := s.settingRepo.GetMultiple(ctx, keys) settings, err := s.settingRepo.GetMultiple(ctx, keys)
...@@ -82,34 +82,34 @@ func (s *EmailService) GetSMTPConfig(ctx context.Context) (*SMTPConfig, error) { ...@@ -82,34 +82,34 @@ func (s *EmailService) GetSMTPConfig(ctx context.Context) (*SMTPConfig, error) {
return nil, fmt.Errorf("get smtp settings: %w", err) return nil, fmt.Errorf("get smtp settings: %w", err)
} }
host := settings[SettingKeySMTPHost] host := settings[SettingKeySmtpHost]
if host == "" { if host == "" {
return nil, ErrEmailNotConfigured return nil, ErrEmailNotConfigured
} }
port := 587 // 默认端口 port := 587 // 默认端口
if portStr := settings[SettingKeySMTPPort]; portStr != "" { if portStr := settings[SettingKeySmtpPort]; portStr != "" {
if p, err := strconv.Atoi(portStr); err == nil { if p, err := strconv.Atoi(portStr); err == nil {
port = p port = p
} }
} }
useTLS := settings[SettingKeySMTPUseTLS] == "true" useTLS := settings[SettingKeySmtpUseTLS] == "true"
return &SMTPConfig{ return &SmtpConfig{
Host: host, Host: host,
Port: port, Port: port,
Username: settings[SettingKeySMTPUsername], Username: settings[SettingKeySmtpUsername],
Password: settings[SettingKeySMTPPassword], Password: settings[SettingKeySmtpPassword],
From: settings[SettingKeySMTPFrom], From: settings[SettingKeySmtpFrom],
FromName: settings[SettingKeySMTPFromName], FromName: settings[SettingKeySmtpFromName],
UseTLS: useTLS, UseTLS: useTLS,
}, nil }, nil
} }
// SendEmail 发送邮件(使用数据库中保存的配置) // SendEmail 发送邮件(使用数据库中保存的配置)
func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string) error { func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string) error {
config, err := s.GetSMTPConfig(ctx) config, err := s.GetSmtpConfig(ctx)
if err != nil { if err != nil {
return err return err
} }
...@@ -117,7 +117,7 @@ func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string) ...@@ -117,7 +117,7 @@ func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string)
} }
// SendEmailWithConfig 使用指定配置发送邮件 // SendEmailWithConfig 使用指定配置发送邮件
func (s *EmailService) SendEmailWithConfig(config *SMTPConfig, to, subject, body string) error { func (s *EmailService) SendEmailWithConfig(config *SmtpConfig, to, subject, body string) error {
from := config.From from := config.From
if config.FromName != "" { if config.FromName != "" {
from = fmt.Sprintf("%s <%s>", config.FromName, config.From) from = fmt.Sprintf("%s <%s>", config.FromName, config.From)
...@@ -306,8 +306,8 @@ func (s *EmailService) buildVerifyCodeEmailBody(code, siteName string) string { ...@@ -306,8 +306,8 @@ func (s *EmailService) buildVerifyCodeEmailBody(code, siteName string) string {
`, siteName, code) `, siteName, code)
} }
// TestSMTPConnectionWithConfig 使用指定配置测试SMTP连接 // TestSmtpConnectionWithConfig 使用指定配置测试SMTP连接
func (s *EmailService) TestSMTPConnectionWithConfig(config *SMTPConfig) error { func (s *EmailService) TestSmtpConnectionWithConfig(config *SmtpConfig) error {
addr := fmt.Sprintf("%s:%d", config.Host, config.Port) addr := fmt.Sprintf("%s:%d", config.Host, config.Port)
if config.UseTLS { if config.UseTLS {
......
...@@ -487,8 +487,8 @@ func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Acco ...@@ -487,8 +487,8 @@ func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Acco
return "", "", errors.New("access_token not found in credentials") return "", "", errors.New("access_token not found in credentials")
} }
return accessToken, "oauth", nil return accessToken, "oauth", nil
case AccountTypeAPIKey: case AccountTypeApiKey:
apiKey := account.GetOpenAIAPIKey() apiKey := account.GetOpenAIApiKey()
if apiKey == "" { if apiKey == "" {
return "", "", errors.New("api_key not found in credentials") return "", "", errors.New("api_key not found in credentials")
} }
...@@ -627,7 +627,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. ...@@ -627,7 +627,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
case AccountTypeOAuth: case AccountTypeOAuth:
// OAuth accounts use ChatGPT internal API // OAuth accounts use ChatGPT internal API
targetURL = chatgptCodexURL targetURL = chatgptCodexURL
case AccountTypeAPIKey: case AccountTypeApiKey:
// API Key accounts use Platform API or custom base URL // API Key accounts use Platform API or custom base URL
baseURL := account.GetOpenAIBaseURL() baseURL := account.GetOpenAIBaseURL()
if baseURL != "" { if baseURL != "" {
...@@ -703,7 +703,13 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht ...@@ -703,7 +703,13 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
} }
// Handle upstream error (mark account status) // Handle upstream error (mark account status)
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) shouldDisable := false
if s.rateLimitService != nil {
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
if shouldDisable {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
// Return appropriate error response // Return appropriate error response
var errType, errMsg string var errType, errMsg string
...@@ -940,7 +946,7 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel ...@@ -940,7 +946,7 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
// OpenAIRecordUsageInput input for recording usage // OpenAIRecordUsageInput input for recording usage
type OpenAIRecordUsageInput struct { type OpenAIRecordUsageInput struct {
Result *OpenAIForwardResult Result *OpenAIForwardResult
APIKey *APIKey ApiKey *ApiKey
User *User User *User
Account *Account Account *Account
Subscription *UserSubscription Subscription *UserSubscription
...@@ -949,7 +955,7 @@ type OpenAIRecordUsageInput struct { ...@@ -949,7 +955,7 @@ type OpenAIRecordUsageInput struct {
// RecordUsage records usage and deducts balance // RecordUsage records usage and deducts balance
func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error { func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error {
result := input.Result result := input.Result
apiKey := input.APIKey apiKey := input.ApiKey
user := input.User user := input.User
account := input.Account account := input.Account
subscription := input.Subscription subscription := input.Subscription
...@@ -991,7 +997,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec ...@@ -991,7 +997,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
durationMs := int(result.Duration.Milliseconds()) durationMs := int(result.Duration.Milliseconds())
usageLog := &UsageLog{ usageLog := &UsageLog{
UserID: user.ID, UserID: user.ID,
APIKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
RequestID: result.RequestID, RequestID: result.RequestID,
Model: result.Model, Model: result.Model,
......
...@@ -61,9 +61,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings ...@@ -61,9 +61,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeySiteName, SettingKeySiteName,
SettingKeySiteLogo, SettingKeySiteLogo,
SettingKeySiteSubtitle, SettingKeySiteSubtitle,
SettingKeyAPIBaseURL, SettingKeyApiBaseUrl,
SettingKeyContactInfo, SettingKeyContactInfo,
SettingKeyDocURL, SettingKeyDocUrl,
} }
settings, err := s.settingRepo.GetMultiple(ctx, keys) settings, err := s.settingRepo.GetMultiple(ctx, keys)
...@@ -79,9 +79,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings ...@@ -79,9 +79,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"), SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
SiteLogo: settings[SettingKeySiteLogo], SiteLogo: settings[SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"), SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
APIBaseURL: settings[SettingKeyAPIBaseURL], ApiBaseUrl: settings[SettingKeyApiBaseUrl],
ContactInfo: settings[SettingKeyContactInfo], ContactInfo: settings[SettingKeyContactInfo],
DocURL: settings[SettingKeyDocURL], DocUrl: settings[SettingKeyDocUrl],
}, nil }, nil
} }
...@@ -94,15 +94,15 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet ...@@ -94,15 +94,15 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled) updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
// 邮件服务设置(只有非空才更新密码) // 邮件服务设置(只有非空才更新密码)
updates[SettingKeySMTPHost] = settings.SMTPHost updates[SettingKeySmtpHost] = settings.SmtpHost
updates[SettingKeySMTPPort] = strconv.Itoa(settings.SMTPPort) updates[SettingKeySmtpPort] = strconv.Itoa(settings.SmtpPort)
updates[SettingKeySMTPUsername] = settings.SMTPUsername updates[SettingKeySmtpUsername] = settings.SmtpUsername
if settings.SMTPPassword != "" { if settings.SmtpPassword != "" {
updates[SettingKeySMTPPassword] = settings.SMTPPassword updates[SettingKeySmtpPassword] = settings.SmtpPassword
} }
updates[SettingKeySMTPFrom] = settings.SMTPFrom updates[SettingKeySmtpFrom] = settings.SmtpFrom
updates[SettingKeySMTPFromName] = settings.SMTPFromName updates[SettingKeySmtpFromName] = settings.SmtpFromName
updates[SettingKeySMTPUseTLS] = strconv.FormatBool(settings.SMTPUseTLS) updates[SettingKeySmtpUseTLS] = strconv.FormatBool(settings.SmtpUseTLS)
// Cloudflare Turnstile 设置(只有非空才更新密钥) // Cloudflare Turnstile 设置(只有非空才更新密钥)
updates[SettingKeyTurnstileEnabled] = strconv.FormatBool(settings.TurnstileEnabled) updates[SettingKeyTurnstileEnabled] = strconv.FormatBool(settings.TurnstileEnabled)
...@@ -115,14 +115,21 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet ...@@ -115,14 +115,21 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeySiteName] = settings.SiteName updates[SettingKeySiteName] = settings.SiteName
updates[SettingKeySiteLogo] = settings.SiteLogo updates[SettingKeySiteLogo] = settings.SiteLogo
updates[SettingKeySiteSubtitle] = settings.SiteSubtitle updates[SettingKeySiteSubtitle] = settings.SiteSubtitle
updates[SettingKeyAPIBaseURL] = settings.APIBaseURL updates[SettingKeyApiBaseUrl] = settings.ApiBaseUrl
updates[SettingKeyContactInfo] = settings.ContactInfo updates[SettingKeyContactInfo] = settings.ContactInfo
updates[SettingKeyDocURL] = settings.DocURL updates[SettingKeyDocUrl] = settings.DocUrl
// 默认配置 // 默认配置
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency) updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64) updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
// Model fallback configuration
updates[SettingKeyEnableModelFallback] = strconv.FormatBool(settings.EnableModelFallback)
updates[SettingKeyFallbackModelAnthropic] = settings.FallbackModelAnthropic
updates[SettingKeyFallbackModelOpenAI] = settings.FallbackModelOpenAI
updates[SettingKeyFallbackModelGemini] = settings.FallbackModelGemini
updates[SettingKeyFallbackModelAntigravity] = settings.FallbackModelAntigravity
return s.settingRepo.SetMultiple(ctx, updates) return s.settingRepo.SetMultiple(ctx, updates)
} }
...@@ -198,8 +205,14 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { ...@@ -198,8 +205,14 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeySiteLogo: "", SettingKeySiteLogo: "",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
SettingKeySMTPPort: "587", SettingKeySmtpPort: "587",
SettingKeySMTPUseTLS: "false", SettingKeySmtpUseTLS: "false",
// Model fallback defaults
SettingKeyEnableModelFallback: "false",
SettingKeyFallbackModelAnthropic: "claude-3-5-sonnet-20241022",
SettingKeyFallbackModelOpenAI: "gpt-4o",
SettingKeyFallbackModelGemini: "gemini-2.5-pro",
SettingKeyFallbackModelAntigravity: "gemini-2.5-pro",
} }
return s.settingRepo.SetMultiple(ctx, defaults) return s.settingRepo.SetMultiple(ctx, defaults)
...@@ -210,26 +223,26 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin ...@@ -210,26 +223,26 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result := &SystemSettings{ result := &SystemSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true", EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
SMTPHost: settings[SettingKeySMTPHost], SmtpHost: settings[SettingKeySmtpHost],
SMTPUsername: settings[SettingKeySMTPUsername], SmtpUsername: settings[SettingKeySmtpUsername],
SMTPFrom: settings[SettingKeySMTPFrom], SmtpFrom: settings[SettingKeySmtpFrom],
SMTPFromName: settings[SettingKeySMTPFromName], SmtpFromName: settings[SettingKeySmtpFromName],
SMTPUseTLS: settings[SettingKeySMTPUseTLS] == "true", SmtpUseTLS: settings[SettingKeySmtpUseTLS] == "true",
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true", TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey], TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"), SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
SiteLogo: settings[SettingKeySiteLogo], SiteLogo: settings[SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"), SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
APIBaseURL: settings[SettingKeyAPIBaseURL], ApiBaseUrl: settings[SettingKeyApiBaseUrl],
ContactInfo: settings[SettingKeyContactInfo], ContactInfo: settings[SettingKeyContactInfo],
DocURL: settings[SettingKeyDocURL], DocUrl: settings[SettingKeyDocUrl],
} }
// 解析整数类型 // 解析整数类型
if port, err := strconv.Atoi(settings[SettingKeySMTPPort]); err == nil { if port, err := strconv.Atoi(settings[SettingKeySmtpPort]); err == nil {
result.SMTPPort = port result.SmtpPort = port
} else { } else {
result.SMTPPort = 587 result.SmtpPort = 587
} }
if concurrency, err := strconv.Atoi(settings[SettingKeyDefaultConcurrency]); err == nil { if concurrency, err := strconv.Atoi(settings[SettingKeyDefaultConcurrency]); err == nil {
...@@ -245,10 +258,17 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin ...@@ -245,10 +258,17 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result.DefaultBalance = s.cfg.Default.UserBalance result.DefaultBalance = s.cfg.Default.UserBalance
} }
// 敏感信息直接返回,方便测试连接时使用 // 敏感信息直接返回方便测试连接时使用
result.SMTPPassword = settings[SettingKeySMTPPassword] result.SmtpPassword = settings[SettingKeySmtpPassword]
result.TurnstileSecretKey = settings[SettingKeyTurnstileSecretKey] result.TurnstileSecretKey = settings[SettingKeyTurnstileSecretKey]
// Model fallback settings
result.EnableModelFallback = settings[SettingKeyEnableModelFallback] == "true"
result.FallbackModelAnthropic = s.getStringOrDefault(settings, SettingKeyFallbackModelAnthropic, "claude-3-5-sonnet-20241022")
result.FallbackModelOpenAI = s.getStringOrDefault(settings, SettingKeyFallbackModelOpenAI, "gpt-4o")
result.FallbackModelGemini = s.getStringOrDefault(settings, SettingKeyFallbackModelGemini, "gemini-2.5-pro")
result.FallbackModelAntigravity = s.getStringOrDefault(settings, SettingKeyFallbackModelAntigravity, "gemini-2.5-pro")
return result return result
} }
...@@ -278,28 +298,28 @@ func (s *SettingService) GetTurnstileSecretKey(ctx context.Context) string { ...@@ -278,28 +298,28 @@ func (s *SettingService) GetTurnstileSecretKey(ctx context.Context) string {
return value return value
} }
// GenerateAdminAPIKey 生成新的管理员 API Key // GenerateAdminApiKey 生成新的管理员 API Key
func (s *SettingService) GenerateAdminAPIKey(ctx context.Context) (string, error) { func (s *SettingService) GenerateAdminApiKey(ctx context.Context) (string, error) {
// 生成 32 字节随机数 = 64 位十六进制字符 // 生成 32 字节随机数 = 64 位十六进制字符
bytes := make([]byte, 32) bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil { if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("generate random bytes: %w", err) return "", fmt.Errorf("generate random bytes: %w", err)
} }
key := AdminAPIKeyPrefix + hex.EncodeToString(bytes) key := AdminApiKeyPrefix + hex.EncodeToString(bytes)
// 存储到 settings 表 // 存储到 settings 表
if err := s.settingRepo.Set(ctx, SettingKeyAdminAPIKey, key); err != nil { if err := s.settingRepo.Set(ctx, SettingKeyAdminApiKey, key); err != nil {
return "", fmt.Errorf("save admin api key: %w", err) return "", fmt.Errorf("save admin api key: %w", err)
} }
return key, nil return key, nil
} }
// GetAdminAPIKeyStatus 获取管理员 API Key 状态 // GetAdminApiKeyStatus 获取管理员 API Key 状态
// 返回脱敏的 key、是否存在、错误 // 返回脱敏的 key、是否存在、错误
func (s *SettingService) GetAdminAPIKeyStatus(ctx context.Context) (maskedKey string, exists bool, err error) { func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey string, exists bool, err error) {
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminAPIKey) key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminApiKey)
if err != nil { if err != nil {
if errors.Is(err, ErrSettingNotFound) { if errors.Is(err, ErrSettingNotFound) {
return "", false, nil return "", false, nil
...@@ -320,10 +340,10 @@ func (s *SettingService) GetAdminAPIKeyStatus(ctx context.Context) (maskedKey st ...@@ -320,10 +340,10 @@ func (s *SettingService) GetAdminAPIKeyStatus(ctx context.Context) (maskedKey st
return maskedKey, true, nil return maskedKey, true, nil
} }
// GetAdminAPIKey 获取完整的管理员 API Key(仅供内部验证使用) // GetAdminApiKey 获取完整的管理员 API Key(仅供内部验证使用)
// 如果未配置返回空字符串和 nil 错误,只有数据库错误时才返回 error // 如果未配置返回空字符串和 nil 错误,只有数据库错误时才返回 error
func (s *SettingService) GetAdminAPIKey(ctx context.Context) (string, error) { func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) {
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminAPIKey) key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminApiKey)
if err != nil { if err != nil {
if errors.Is(err, ErrSettingNotFound) { if errors.Is(err, ErrSettingNotFound) {
return "", nil // 未配置,返回空字符串 return "", nil // 未配置,返回空字符串
...@@ -333,7 +353,45 @@ func (s *SettingService) GetAdminAPIKey(ctx context.Context) (string, error) { ...@@ -333,7 +353,45 @@ func (s *SettingService) GetAdminAPIKey(ctx context.Context) (string, error) {
return key, nil return key, nil
} }
// DeleteAdminAPIKey 删除管理员 API Key // DeleteAdminApiKey 删除管理员 API Key
func (s *SettingService) DeleteAdminAPIKey(ctx context.Context) error { func (s *SettingService) DeleteAdminApiKey(ctx context.Context) error {
return s.settingRepo.Delete(ctx, SettingKeyAdminAPIKey) return s.settingRepo.Delete(ctx, SettingKeyAdminApiKey)
}
// IsModelFallbackEnabled 检查是否启用模型兜底机制
func (s *SettingService) IsModelFallbackEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, SettingKeyEnableModelFallback)
if err != nil {
return false // Default: disabled
}
return value == "true"
}
// GetFallbackModel 获取指定平台的兜底模型
func (s *SettingService) GetFallbackModel(ctx context.Context, platform string) string {
var key string
var defaultModel string
switch platform {
case PlatformAnthropic:
key = SettingKeyFallbackModelAnthropic
defaultModel = "claude-3-5-sonnet-20241022"
case PlatformOpenAI:
key = SettingKeyFallbackModelOpenAI
defaultModel = "gpt-4o"
case PlatformGemini:
key = SettingKeyFallbackModelGemini
defaultModel = "gemini-2.5-pro"
case PlatformAntigravity:
key = SettingKeyFallbackModelAntigravity
defaultModel = "gemini-2.5-pro"
default:
return ""
}
value, err := s.settingRepo.GetValue(ctx, key)
if err != nil || value == "" {
return defaultModel
}
return value
} }
...@@ -4,13 +4,13 @@ type SystemSettings struct { ...@@ -4,13 +4,13 @@ type SystemSettings struct {
RegistrationEnabled bool RegistrationEnabled bool
EmailVerifyEnabled bool EmailVerifyEnabled bool
SMTPHost string SmtpHost string
SMTPPort int SmtpPort int
SMTPUsername string SmtpUsername string
SMTPPassword string SmtpPassword string
SMTPFrom string SmtpFrom string
SMTPFromName string SmtpFromName string
SMTPUseTLS bool SmtpUseTLS bool
TurnstileEnabled bool TurnstileEnabled bool
TurnstileSiteKey string TurnstileSiteKey string
...@@ -19,12 +19,19 @@ type SystemSettings struct { ...@@ -19,12 +19,19 @@ type SystemSettings struct {
SiteName string SiteName string
SiteLogo string SiteLogo string
SiteSubtitle string SiteSubtitle string
APIBaseURL string ApiBaseUrl string
ContactInfo string ContactInfo string
DocURL string DocUrl string
DefaultConcurrency int DefaultConcurrency int
DefaultBalance float64 DefaultBalance float64
// Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"`
FallbackModelAnthropic string `json:"fallback_model_anthropic"`
FallbackModelOpenAI string `json:"fallback_model_openai"`
FallbackModelGemini string `json:"fallback_model_gemini"`
FallbackModelAntigravity string `json:"fallback_model_antigravity"`
} }
type PublicSettings struct { type PublicSettings struct {
...@@ -35,8 +42,8 @@ type PublicSettings struct { ...@@ -35,8 +42,8 @@ type PublicSettings struct {
SiteName string SiteName string
SiteLogo string SiteLogo string
SiteSubtitle string SiteSubtitle string
APIBaseURL string ApiBaseUrl string
ContactInfo string ContactInfo string
DocURL string DocUrl string
Version string Version string
} }
...@@ -79,7 +79,7 @@ type ReleaseInfo struct { ...@@ -79,7 +79,7 @@ type ReleaseInfo struct {
Name string `json:"name"` Name string `json:"name"`
Body string `json:"body"` Body string `json:"body"`
PublishedAt string `json:"published_at"` PublishedAt string `json:"published_at"`
HTMLURL string `json:"html_url"` HtmlURL string `json:"html_url"`
Assets []Asset `json:"assets,omitempty"` Assets []Asset `json:"assets,omitempty"`
} }
...@@ -96,13 +96,13 @@ type GitHubRelease struct { ...@@ -96,13 +96,13 @@ type GitHubRelease struct {
Name string `json:"name"` Name string `json:"name"`
Body string `json:"body"` Body string `json:"body"`
PublishedAt string `json:"published_at"` PublishedAt string `json:"published_at"`
HTMLURL string `json:"html_url"` HtmlUrl string `json:"html_url"`
Assets []GitHubAsset `json:"assets"` Assets []GitHubAsset `json:"assets"`
} }
type GitHubAsset struct { type GitHubAsset struct {
Name string `json:"name"` Name string `json:"name"`
BrowserDownloadURL string `json:"browser_download_url"` BrowserDownloadUrl string `json:"browser_download_url"`
Size int64 `json:"size"` Size int64 `json:"size"`
} }
...@@ -285,7 +285,7 @@ func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, er ...@@ -285,7 +285,7 @@ func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, er
for i, a := range release.Assets { for i, a := range release.Assets {
assets[i] = Asset{ assets[i] = Asset{
Name: a.Name, Name: a.Name,
DownloadURL: a.BrowserDownloadURL, DownloadURL: a.BrowserDownloadUrl,
Size: a.Size, Size: a.Size,
} }
} }
...@@ -298,7 +298,7 @@ func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, er ...@@ -298,7 +298,7 @@ func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, er
Name: release.Name, Name: release.Name,
Body: release.Body, Body: release.Body,
PublishedAt: release.PublishedAt, PublishedAt: release.PublishedAt,
HTMLURL: release.HTMLURL, HtmlURL: release.HtmlUrl,
Assets: assets, Assets: assets,
}, },
Cached: false, Cached: false,
......
package service
import "time"
// clampInt 将整数限制在指定范围内
func clampInt(value, min, max int) int {
if value < min {
return min
}
if value > max {
return max
}
return value
}
// clampFloat64 将浮点数限制在指定范围内
func clampFloat64(value, min, max float64) float64 {
if value < min {
return min
}
if value > max {
return max
}
return value
}
// remainingSecondsUntil 计算到指定时间的剩余秒数,保证非负
func remainingSecondsUntil(t time.Time) int {
seconds := int(time.Until(t).Seconds())
if seconds < 0 {
return 0
}
return seconds
}
...@@ -10,7 +10,7 @@ const ( ...@@ -10,7 +10,7 @@ const (
type UsageLog struct { type UsageLog struct {
ID int64 ID int64
UserID int64 UserID int64
APIKeyID int64 ApiKeyID int64
AccountID int64 AccountID int64
RequestID string RequestID string
Model string Model string
...@@ -42,7 +42,7 @@ type UsageLog struct { ...@@ -42,7 +42,7 @@ type UsageLog struct {
CreatedAt time.Time CreatedAt time.Time
User *User User *User
APIKey *APIKey ApiKey *ApiKey
Account *Account Account *Account
Group *Group Group *Group
Subscription *UserSubscription Subscription *UserSubscription
......
...@@ -17,7 +17,7 @@ var ( ...@@ -17,7 +17,7 @@ var (
// CreateUsageLogRequest 创建使用日志请求 // CreateUsageLogRequest 创建使用日志请求
type CreateUsageLogRequest struct { type CreateUsageLogRequest struct {
UserID int64 `json:"user_id"` UserID int64 `json:"user_id"`
APIKeyID int64 `json:"api_key_id"` ApiKeyID int64 `json:"api_key_id"`
AccountID int64 `json:"account_id"` AccountID int64 `json:"account_id"`
RequestID string `json:"request_id"` RequestID string `json:"request_id"`
Model string `json:"model"` Model string `json:"model"`
...@@ -75,7 +75,7 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (* ...@@ -75,7 +75,7 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
// 创建使用日志 // 创建使用日志
usageLog := &UsageLog{ usageLog := &UsageLog{
UserID: req.UserID, UserID: req.UserID,
APIKeyID: req.APIKeyID, ApiKeyID: req.ApiKeyID,
AccountID: req.AccountID, AccountID: req.AccountID,
RequestID: req.RequestID, RequestID: req.RequestID,
Model: req.Model, Model: req.Model,
...@@ -128,9 +128,9 @@ func (s *UsageService) ListByUser(ctx context.Context, userID int64, params pagi ...@@ -128,9 +128,9 @@ func (s *UsageService) ListByUser(ctx context.Context, userID int64, params pagi
return logs, pagination, nil return logs, pagination, nil
} }
// ListByAPIKey 获取API Key的使用日志列表 // ListByApiKey 获取API Key的使用日志列表
func (s *UsageService) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) { func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) {
logs, pagination, err := s.usageRepo.ListByAPIKey(ctx, apiKeyID, params) logs, pagination, err := s.usageRepo.ListByApiKey(ctx, apiKeyID, params)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("list usage logs: %w", err) return nil, nil, fmt.Errorf("list usage logs: %w", err)
} }
...@@ -165,9 +165,9 @@ func (s *UsageService) GetStatsByUser(ctx context.Context, userID int64, startTi ...@@ -165,9 +165,9 @@ func (s *UsageService) GetStatsByUser(ctx context.Context, userID int64, startTi
}, nil }, nil
} }
// GetStatsByAPIKey 获取API Key的使用统计 // GetStatsByApiKey 获取API Key的使用统计
func (s *UsageService) GetStatsByAPIKey(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*UsageStats, error) { func (s *UsageService) GetStatsByApiKey(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*UsageStats, error) {
stats, err := s.usageRepo.GetAPIKeyStatsAggregated(ctx, apiKeyID, startTime, endTime) stats, err := s.usageRepo.GetApiKeyStatsAggregated(ctx, apiKeyID, startTime, endTime)
if err != nil { if err != nil {
return nil, fmt.Errorf("get api key stats: %w", err) return nil, fmt.Errorf("get api key stats: %w", err)
} }
...@@ -270,9 +270,9 @@ func (s *UsageService) GetUserModelStats(ctx context.Context, userID int64, star ...@@ -270,9 +270,9 @@ func (s *UsageService) GetUserModelStats(ctx context.Context, userID int64, star
return stats, nil return stats, nil
} }
// GetBatchAPIKeyUsageStats returns today/total actual_cost for given api keys. // GetBatchApiKeyUsageStats returns today/total actual_cost for given api keys.
func (s *UsageService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { func (s *UsageService) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs) stats, err := s.usageRepo.GetBatchApiKeyUsageStats(ctx, apiKeyIDs)
if err != nil { if err != nil {
return nil, fmt.Errorf("get batch api key usage stats: %w", err) return nil, fmt.Errorf("get batch api key usage stats: %w", err)
} }
......
...@@ -21,7 +21,7 @@ type User struct { ...@@ -21,7 +21,7 @@ type User struct {
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
APIKeys []APIKey ApiKeys []ApiKey
Subscriptions []UserSubscription Subscriptions []UserSubscription
} }
......
...@@ -56,6 +56,10 @@ func (s *UserAttributeService) CreateDefinition(ctx context.Context, input Creat ...@@ -56,6 +56,10 @@ func (s *UserAttributeService) CreateDefinition(ctx context.Context, input Creat
Enabled: input.Enabled, Enabled: input.Enabled,
} }
if err := validateDefinitionPattern(def); err != nil {
return nil, err
}
if err := s.defRepo.Create(ctx, def); err != nil { if err := s.defRepo.Create(ctx, def); err != nil {
return nil, fmt.Errorf("create definition: %w", err) return nil, fmt.Errorf("create definition: %w", err)
} }
...@@ -108,6 +112,10 @@ func (s *UserAttributeService) UpdateDefinition(ctx context.Context, id int64, i ...@@ -108,6 +112,10 @@ func (s *UserAttributeService) UpdateDefinition(ctx context.Context, id int64, i
def.Enabled = *input.Enabled def.Enabled = *input.Enabled
} }
if err := validateDefinitionPattern(def); err != nil {
return nil, err
}
if err := s.defRepo.Update(ctx, def); err != nil { if err := s.defRepo.Update(ctx, def); err != nil {
return nil, fmt.Errorf("update definition: %w", err) return nil, fmt.Errorf("update definition: %w", err)
} }
...@@ -231,7 +239,10 @@ func (s *UserAttributeService) validateValue(def *UserAttributeDefinition, value ...@@ -231,7 +239,10 @@ func (s *UserAttributeService) validateValue(def *UserAttributeDefinition, value
// Pattern validation // Pattern validation
if v.Pattern != nil && *v.Pattern != "" && value != "" { if v.Pattern != nil && *v.Pattern != "" && value != "" {
re, err := regexp.Compile(*v.Pattern) re, err := regexp.Compile(*v.Pattern)
if err == nil && !re.MatchString(value) { if err != nil {
return validationError(def.Name + " has an invalid pattern")
}
if !re.MatchString(value) {
msg := def.Name + " format is invalid" msg := def.Name + " format is invalid"
if v.Message != nil && *v.Message != "" { if v.Message != nil && *v.Message != "" {
msg = *v.Message msg = *v.Message
...@@ -293,3 +304,20 @@ func isValidAttributeType(t UserAttributeType) bool { ...@@ -293,3 +304,20 @@ func isValidAttributeType(t UserAttributeType) bool {
} }
return false return false
} }
func validateDefinitionPattern(def *UserAttributeDefinition) error {
if def == nil {
return nil
}
if def.Validation.Pattern == nil {
return nil
}
pattern := strings.TrimSpace(*def.Validation.Pattern)
if pattern == "" {
return nil
}
if _, err := regexp.Compile(pattern); err != nil {
return infraerrors.BadRequest("INVALID_ATTRIBUTE_PATTERN", fmt.Sprintf("invalid pattern for %s: %v", def.Name, err))
}
return nil
}
...@@ -54,18 +54,6 @@ func ProvideTimingWheelService() *TimingWheelService { ...@@ -54,18 +54,6 @@ func ProvideTimingWheelService() *TimingWheelService {
return svc return svc
} }
// ProvideAntigravityQuotaRefresher creates and starts AntigravityQuotaRefresher
func ProvideAntigravityQuotaRefresher(
accountRepo AccountRepository,
proxyRepo ProxyRepository,
oauthSvc *AntigravityOAuthService,
cfg *config.Config,
) *AntigravityQuotaRefresher {
svc := NewAntigravityQuotaRefresher(accountRepo, proxyRepo, oauthSvc, cfg)
svc.Start()
return svc
}
// ProvideDeferredService creates and starts DeferredService // ProvideDeferredService creates and starts DeferredService
func ProvideDeferredService(accountRepo AccountRepository, timingWheel *TimingWheelService) *DeferredService { func ProvideDeferredService(accountRepo AccountRepository, timingWheel *TimingWheelService) *DeferredService {
svc := NewDeferredService(accountRepo, timingWheel, 10*time.Second) svc := NewDeferredService(accountRepo, timingWheel, 10*time.Second)
...@@ -73,20 +61,6 @@ func ProvideDeferredService(accountRepo AccountRepository, timingWheel *TimingWh ...@@ -73,20 +61,6 @@ func ProvideDeferredService(accountRepo AccountRepository, timingWheel *TimingWh
return svc return svc
} }
// ProvideOpsMetricsCollector creates and starts OpsMetricsCollector.
func ProvideOpsMetricsCollector(opsService *OpsService, concurrencyService *ConcurrencyService) *OpsMetricsCollector {
svc := NewOpsMetricsCollector(opsService, concurrencyService)
svc.Start()
return svc
}
// ProvideOpsAlertService creates and starts OpsAlertService.
func ProvideOpsAlertService(opsService *OpsService, userService *UserService, emailService *EmailService) *OpsAlertService {
svc := NewOpsAlertService(opsService, userService, emailService)
svc.Start()
return svc
}
// ProvideConcurrencyService creates ConcurrencyService and starts slot cleanup worker. // ProvideConcurrencyService creates ConcurrencyService and starts slot cleanup worker.
func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountRepository, cfg *config.Config) *ConcurrencyService { func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountRepository, cfg *config.Config) *ConcurrencyService {
svc := NewConcurrencyService(cache) svc := NewConcurrencyService(cache)
...@@ -101,14 +75,13 @@ var ProviderSet = wire.NewSet( ...@@ -101,14 +75,13 @@ var ProviderSet = wire.NewSet(
// Core services // Core services
NewAuthService, NewAuthService,
NewUserService, NewUserService,
NewAPIKeyService, NewApiKeyService,
NewGroupService, NewGroupService,
NewAccountService, NewAccountService,
NewProxyService, NewProxyService,
NewRedeemService, NewRedeemService,
NewUsageService, NewUsageService,
NewDashboardService, NewDashboardService,
NewOpsService,
ProvidePricingService, ProvidePricingService,
NewBillingService, NewBillingService,
NewBillingCacheService, NewBillingCacheService,
...@@ -139,8 +112,7 @@ var ProviderSet = wire.NewSet( ...@@ -139,8 +112,7 @@ var ProviderSet = wire.NewSet(
ProvideTokenRefreshService, ProvideTokenRefreshService,
ProvideTimingWheelService, ProvideTimingWheelService,
ProvideDeferredService, ProvideDeferredService,
ProvideAntigravityQuotaRefresher, NewAntigravityQuotaFetcher,
ProvideOpsMetricsCollector,
ProvideOpsAlertService,
NewUserAttributeService, NewUserAttributeService,
NewUsageCache,
) )
// Package setup provides CLI-based installation wizard for initial system configuration.
package setup package setup
import ( import (
......
...@@ -345,7 +345,7 @@ func writeConfigFile(cfg *SetupConfig) error { ...@@ -345,7 +345,7 @@ func writeConfigFile(cfg *SetupConfig) error {
Default struct { Default struct {
UserConcurrency int `yaml:"user_concurrency"` UserConcurrency int `yaml:"user_concurrency"`
UserBalance float64 `yaml:"user_balance"` UserBalance float64 `yaml:"user_balance"`
APIKeyPrefix string `yaml:"api_key_prefix"` ApiKeyPrefix string `yaml:"api_key_prefix"`
RateMultiplier float64 `yaml:"rate_multiplier"` RateMultiplier float64 `yaml:"rate_multiplier"`
} `yaml:"default"` } `yaml:"default"`
RateLimit struct { RateLimit struct {
...@@ -367,12 +367,12 @@ func writeConfigFile(cfg *SetupConfig) error { ...@@ -367,12 +367,12 @@ func writeConfigFile(cfg *SetupConfig) error {
Default: struct { Default: struct {
UserConcurrency int `yaml:"user_concurrency"` UserConcurrency int `yaml:"user_concurrency"`
UserBalance float64 `yaml:"user_balance"` UserBalance float64 `yaml:"user_balance"`
APIKeyPrefix string `yaml:"api_key_prefix"` ApiKeyPrefix string `yaml:"api_key_prefix"`
RateMultiplier float64 `yaml:"rate_multiplier"` RateMultiplier float64 `yaml:"rate_multiplier"`
}{ }{
UserConcurrency: 5, UserConcurrency: 5,
UserBalance: 0, UserBalance: 0,
APIKeyPrefix: "sk-", ApiKeyPrefix: "sk-",
RateMultiplier: 1.0, RateMultiplier: 1.0,
}, },
RateLimit: struct { RateLimit: struct {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment