Unverified Commit 14b155c6 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #7 from NepetaLemon/refactor/ports-pattern

refactor(backend): 引入端口接口模式
parents 7fd94ab7 e99b344b
...@@ -12,7 +12,7 @@ import ( ...@@ -12,7 +12,7 @@ import (
"time" "time"
"sub2api/internal/model" "sub2api/internal/model"
"sub2api/internal/repository" "sub2api/internal/service/ports"
) )
// usageCache 用于缓存usage数据 // usageCache 用于缓存usage数据
...@@ -35,10 +35,10 @@ type WindowStats struct { ...@@ -35,10 +35,10 @@ type WindowStats struct {
// UsageProgress 使用量进度 // UsageProgress 使用量进度
type UsageProgress struct { type UsageProgress struct {
Utilization float64 `json:"utilization"` // 使用率百分比 (0-100+,100表示100%) Utilization float64 `json:"utilization"` // 使用率百分比 (0-100+,100表示100%)
ResetsAt *time.Time `json:"resets_at"` // 重置时间 ResetsAt *time.Time `json:"resets_at"` // 重置时间
RemainingSeconds int `json:"remaining_seconds"` // 距重置剩余秒数 RemainingSeconds int `json:"remaining_seconds"` // 距重置剩余秒数
WindowStats *WindowStats `json:"window_stats,omitempty"` // 窗口期统计(从窗口开始到当前的使用量) WindowStats *WindowStats `json:"window_stats,omitempty"` // 窗口期统计(从窗口开始到当前的使用量)
} }
// UsageInfo 账号使用量信息 // UsageInfo 账号使用量信息
...@@ -67,15 +67,17 @@ type ClaudeUsageResponse struct { ...@@ -67,15 +67,17 @@ type ClaudeUsageResponse struct {
// AccountUsageService 账号使用量查询服务 // AccountUsageService 账号使用量查询服务
type AccountUsageService struct { type AccountUsageService struct {
repos *repository.Repositories accountRepo ports.AccountRepository
usageLogRepo ports.UsageLogRepository
oauthService *OAuthService oauthService *OAuthService
httpClient *http.Client httpClient *http.Client
} }
// NewAccountUsageService 创建AccountUsageService实例 // NewAccountUsageService 创建AccountUsageService实例
func NewAccountUsageService(repos *repository.Repositories, oauthService *OAuthService) *AccountUsageService { func NewAccountUsageService(accountRepo ports.AccountRepository, usageLogRepo ports.UsageLogRepository, oauthService *OAuthService) *AccountUsageService {
return &AccountUsageService{ return &AccountUsageService{
repos: repos, accountRepo: accountRepo,
usageLogRepo: usageLogRepo,
oauthService: oauthService, oauthService: oauthService,
httpClient: &http.Client{ httpClient: &http.Client{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
...@@ -88,7 +90,7 @@ func NewAccountUsageService(repos *repository.Repositories, oauthService *OAuthS ...@@ -88,7 +90,7 @@ func NewAccountUsageService(repos *repository.Repositories, oauthService *OAuthS
// Setup Token账号: 根据session_window推算5h窗口,7d数据不可用(没有profile scope) // Setup Token账号: 根据session_window推算5h窗口,7d数据不可用(没有profile scope)
// API Key账号: 不支持usage查询 // API Key账号: 不支持usage查询
func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*UsageInfo, error) { func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*UsageInfo, error) {
account, err := s.repos.Account.GetByID(ctx, accountID) account, err := s.accountRepo.GetByID(ctx, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("get account failed: %w", err) return nil, fmt.Errorf("get account failed: %w", err)
} }
...@@ -148,7 +150,7 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *model ...@@ -148,7 +150,7 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *model
startTime = time.Now().Add(-5 * time.Hour) startTime = time.Now().Add(-5 * time.Hour)
} }
stats, err := s.repos.UsageLog.GetAccountWindowStats(ctx, account.ID, startTime) stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime)
if err != nil { if err != nil {
log.Printf("Failed to get window stats for account %d: %v", account.ID, err) log.Printf("Failed to get window stats for account %d: %v", account.ID, err)
return return
...@@ -163,7 +165,7 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *model ...@@ -163,7 +165,7 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *model
// GetTodayStats 获取账号今日统计 // GetTodayStats 获取账号今日统计
func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64) (*WindowStats, error) { func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64) (*WindowStats, error) {
stats, err := s.repos.UsageLog.GetAccountTodayStats(ctx, accountID) stats, err := s.usageLogRepo.GetAccountTodayStats(ctx, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("get today stats failed: %w", err) return nil, fmt.Errorf("get today stats failed: %w", err)
} }
......
...@@ -13,7 +13,8 @@ import ( ...@@ -13,7 +13,8 @@ import (
"time" "time"
"sub2api/internal/model" "sub2api/internal/model"
"sub2api/internal/repository" "sub2api/internal/pkg/pagination"
"sub2api/internal/service/ports"
"golang.org/x/net/proxy" "golang.org/x/net/proxy"
"gorm.io/gorm" "gorm.io/gorm"
...@@ -179,35 +180,45 @@ type ProxyTestResult struct { ...@@ -179,35 +180,45 @@ type ProxyTestResult struct {
// adminServiceImpl implements AdminService // adminServiceImpl implements AdminService
type adminServiceImpl struct { type adminServiceImpl struct {
userRepo *repository.UserRepository userRepo ports.UserRepository
groupRepo *repository.GroupRepository groupRepo ports.GroupRepository
accountRepo *repository.AccountRepository accountRepo ports.AccountRepository
proxyRepo *repository.ProxyRepository proxyRepo ports.ProxyRepository
apiKeyRepo *repository.ApiKeyRepository apiKeyRepo ports.ApiKeyRepository
redeemCodeRepo *repository.RedeemCodeRepository redeemCodeRepo ports.RedeemCodeRepository
usageLogRepo *repository.UsageLogRepository usageLogRepo ports.UsageLogRepository
userSubRepo *repository.UserSubscriptionRepository userSubRepo ports.UserSubscriptionRepository
billingCacheService *BillingCacheService billingCacheService *BillingCacheService
} }
// NewAdminService creates a new AdminService // NewAdminService creates a new AdminService
func NewAdminService(repos *repository.Repositories, billingCacheService *BillingCacheService) AdminService { func NewAdminService(
userRepo ports.UserRepository,
groupRepo ports.GroupRepository,
accountRepo ports.AccountRepository,
proxyRepo ports.ProxyRepository,
apiKeyRepo ports.ApiKeyRepository,
redeemCodeRepo ports.RedeemCodeRepository,
usageLogRepo ports.UsageLogRepository,
userSubRepo ports.UserSubscriptionRepository,
billingCacheService *BillingCacheService,
) AdminService {
return &adminServiceImpl{ return &adminServiceImpl{
userRepo: repos.User, userRepo: userRepo,
groupRepo: repos.Group, groupRepo: groupRepo,
accountRepo: repos.Account, accountRepo: accountRepo,
proxyRepo: repos.Proxy, proxyRepo: proxyRepo,
apiKeyRepo: repos.ApiKey, apiKeyRepo: apiKeyRepo,
redeemCodeRepo: repos.RedeemCode, redeemCodeRepo: redeemCodeRepo,
usageLogRepo: repos.UsageLog, usageLogRepo: usageLogRepo,
userSubRepo: repos.UserSubscription, userSubRepo: userSubRepo,
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
} }
} }
// User management implementations // User management implementations
func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]model.User, int64, error) { func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]model.User, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
users, result, err := s.userRepo.ListWithFilters(ctx, params, status, role, search) users, result, err := s.userRepo.ListWithFilters(ctx, params, status, role, search)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
...@@ -376,7 +387,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, ...@@ -376,7 +387,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
} }
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error) { func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params) keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
...@@ -397,7 +408,7 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, ...@@ -397,7 +408,7 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64,
// Group management implementations // Group management implementations
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]model.Group, int64, error) { func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]model.Group, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, isExclusive) groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, isExclusive)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
...@@ -568,7 +579,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error { ...@@ -568,7 +579,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
} }
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]model.ApiKey, int64, error) { func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]model.ApiKey, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params) keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
...@@ -578,7 +589,7 @@ func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, p ...@@ -578,7 +589,7 @@ func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, p
// Account management implementations // Account management implementations
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]model.Account, int64, error) { func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]model.Account, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search) accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
...@@ -696,7 +707,7 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, ...@@ -696,7 +707,7 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64,
// Proxy management implementations // Proxy management implementations
func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]model.Proxy, int64, error) { func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]model.Proxy, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search) proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
...@@ -781,7 +792,7 @@ func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, po ...@@ -781,7 +792,7 @@ func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, po
// Redeem code management implementations // Redeem code management implementations
func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]model.RedeemCode, int64, error) { func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]model.RedeemCode, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search) codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
......
...@@ -8,8 +8,9 @@ import ( ...@@ -8,8 +8,9 @@ import (
"fmt" "fmt"
"sub2api/internal/config" "sub2api/internal/config"
"sub2api/internal/model" "sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"sub2api/internal/pkg/timezone" "sub2api/internal/pkg/timezone"
"sub2api/internal/repository" "sub2api/internal/service/ports"
"time" "time"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
...@@ -17,12 +18,12 @@ import ( ...@@ -17,12 +18,12 @@ import (
) )
var ( var (
ErrApiKeyNotFound = errors.New("api key not found") ErrApiKeyNotFound = errors.New("api key not found")
ErrGroupNotAllowed = errors.New("user is not allowed to bind this group") ErrGroupNotAllowed = errors.New("user is not allowed to bind this group")
ErrApiKeyExists = errors.New("api key already exists") ErrApiKeyExists = errors.New("api key already exists")
ErrApiKeyTooShort = errors.New("api key must be at least 16 characters") ErrApiKeyTooShort = errors.New("api key must be at least 16 characters")
ErrApiKeyInvalidChars = errors.New("api key can only contain letters, numbers, underscores, and hyphens") ErrApiKeyInvalidChars = errors.New("api key can only contain letters, numbers, underscores, and hyphens")
ErrApiKeyRateLimited = errors.New("too many failed attempts, please try again later") ErrApiKeyRateLimited = errors.New("too many failed attempts, please try again later")
) )
const ( const (
...@@ -47,20 +48,20 @@ type UpdateApiKeyRequest struct { ...@@ -47,20 +48,20 @@ type UpdateApiKeyRequest struct {
// ApiKeyService API Key服务 // ApiKeyService API Key服务
type ApiKeyService struct { type ApiKeyService struct {
apiKeyRepo *repository.ApiKeyRepository apiKeyRepo ports.ApiKeyRepository
userRepo *repository.UserRepository userRepo ports.UserRepository
groupRepo *repository.GroupRepository groupRepo ports.GroupRepository
userSubRepo *repository.UserSubscriptionRepository userSubRepo ports.UserSubscriptionRepository
rdb *redis.Client rdb *redis.Client
cfg *config.Config cfg *config.Config
} }
// NewApiKeyService 创建API Key服务实例 // NewApiKeyService 创建API Key服务实例
func NewApiKeyService( func NewApiKeyService(
apiKeyRepo *repository.ApiKeyRepository, apiKeyRepo ports.ApiKeyRepository,
userRepo *repository.UserRepository, userRepo ports.UserRepository,
groupRepo *repository.GroupRepository, groupRepo ports.GroupRepository,
userSubRepo *repository.UserSubscriptionRepository, userSubRepo ports.UserSubscriptionRepository,
rdb *redis.Client, rdb *redis.Client,
cfg *config.Config, cfg *config.Config,
) *ApiKeyService { ) *ApiKeyService {
...@@ -237,7 +238,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK ...@@ -237,7 +238,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
} }
// List 获取用户的API Key列表 // List 获取用户的API Key列表
func (s *ApiKeyService) List(ctx context.Context, userID int64, params repository.PaginationParams) ([]model.ApiKey, *repository.PaginationResult, error) { func (s *ApiKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params) keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("list api keys: %w", err) return nil, nil, fmt.Errorf("list api keys: %w", err)
......
...@@ -7,7 +7,7 @@ import ( ...@@ -7,7 +7,7 @@ import (
"log" "log"
"sub2api/internal/config" "sub2api/internal/config"
"sub2api/internal/model" "sub2api/internal/model"
"sub2api/internal/repository" "sub2api/internal/service/ports"
"time" "time"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
...@@ -35,7 +35,7 @@ type JWTClaims struct { ...@@ -35,7 +35,7 @@ type JWTClaims struct {
// AuthService 认证服务 // AuthService 认证服务
type AuthService struct { type AuthService struct {
userRepo *repository.UserRepository userRepo ports.UserRepository
cfg *config.Config cfg *config.Config
settingService *SettingService settingService *SettingService
emailService *EmailService emailService *EmailService
...@@ -45,7 +45,7 @@ type AuthService struct { ...@@ -45,7 +45,7 @@ type AuthService struct {
// NewAuthService 创建认证服务实例 // NewAuthService 创建认证服务实例
func NewAuthService( func NewAuthService(
userRepo *repository.UserRepository, userRepo ports.UserRepository,
cfg *config.Config, cfg *config.Config,
settingService *SettingService, settingService *SettingService,
emailService *EmailService, emailService *EmailService,
......
...@@ -9,7 +9,7 @@ import ( ...@@ -9,7 +9,7 @@ import (
"time" "time"
"sub2api/internal/model" "sub2api/internal/model"
"sub2api/internal/repository" "sub2api/internal/service/ports"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
) )
...@@ -81,12 +81,12 @@ type subscriptionCacheData struct { ...@@ -81,12 +81,12 @@ type subscriptionCacheData struct {
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查 // 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
type BillingCacheService struct { type BillingCacheService struct {
rdb *redis.Client rdb *redis.Client
userRepo *repository.UserRepository userRepo ports.UserRepository
subRepo *repository.UserSubscriptionRepository subRepo ports.UserSubscriptionRepository
} }
// NewBillingCacheService 创建计费缓存服务 // NewBillingCacheService 创建计费缓存服务
func NewBillingCacheService(rdb *redis.Client, userRepo *repository.UserRepository, subRepo *repository.UserSubscriptionRepository) *BillingCacheService { func NewBillingCacheService(rdb *redis.Client, userRepo ports.UserRepository, subRepo ports.UserSubscriptionRepository) *BillingCacheService {
return &BillingCacheService{ return &BillingCacheService{
rdb: rdb, rdb: rdb,
userRepo: userRepo, userRepo: userRepo,
......
...@@ -11,7 +11,7 @@ import ( ...@@ -11,7 +11,7 @@ import (
"net/smtp" "net/smtp"
"strconv" "strconv"
"sub2api/internal/model" "sub2api/internal/model"
"sub2api/internal/repository" "sub2api/internal/service/ports"
"time" "time"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
...@@ -25,9 +25,9 @@ var ( ...@@ -25,9 +25,9 @@ var (
) )
const ( const (
verifyCodeKeyPrefix = "email_verify:" verifyCodeKeyPrefix = "email_verify:"
verifyCodeTTL = 15 * time.Minute verifyCodeTTL = 15 * time.Minute
verifyCodeCooldown = 1 * time.Minute verifyCodeCooldown = 1 * time.Minute
maxVerifyCodeAttempts = 5 maxVerifyCodeAttempts = 5
) )
...@@ -51,12 +51,12 @@ type SmtpConfig struct { ...@@ -51,12 +51,12 @@ type SmtpConfig struct {
// EmailService 邮件服务 // EmailService 邮件服务
type EmailService struct { type EmailService struct {
settingRepo *repository.SettingRepository settingRepo ports.SettingRepository
rdb *redis.Client rdb *redis.Client
} }
// NewEmailService 创建邮件服务实例 // NewEmailService 创建邮件服务实例
func NewEmailService(settingRepo *repository.SettingRepository, rdb *redis.Client) *EmailService { func NewEmailService(settingRepo ports.SettingRepository, rdb *redis.Client) *EmailService {
return &EmailService{ return &EmailService{
settingRepo: settingRepo, settingRepo: settingRepo,
rdb: rdb, rdb: rdb,
......
...@@ -21,7 +21,7 @@ import ( ...@@ -21,7 +21,7 @@ import (
"sub2api/internal/config" "sub2api/internal/config"
"sub2api/internal/model" "sub2api/internal/model"
"sub2api/internal/pkg/claude" "sub2api/internal/pkg/claude"
"sub2api/internal/repository" "sub2api/internal/service/ports"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
...@@ -78,7 +78,10 @@ type ForwardResult struct { ...@@ -78,7 +78,10 @@ type ForwardResult struct {
// GatewayService handles API gateway operations // GatewayService handles API gateway operations
type GatewayService struct { type GatewayService struct {
repos *repository.Repositories accountRepo ports.AccountRepository
usageLogRepo ports.UsageLogRepository
userRepo ports.UserRepository
userSubRepo ports.UserSubscriptionRepository
rdb *redis.Client rdb *redis.Client
cfg *config.Config cfg *config.Config
oauthService *OAuthService oauthService *OAuthService
...@@ -90,7 +93,19 @@ type GatewayService struct { ...@@ -90,7 +93,19 @@ type GatewayService struct {
} }
// NewGatewayService creates a new GatewayService // NewGatewayService creates a new GatewayService
func NewGatewayService(repos *repository.Repositories, rdb *redis.Client, cfg *config.Config, oauthService *OAuthService, billingService *BillingService, rateLimitService *RateLimitService, billingCacheService *BillingCacheService, identityService *IdentityService) *GatewayService { func NewGatewayService(
accountRepo ports.AccountRepository,
usageLogRepo ports.UsageLogRepository,
userRepo ports.UserRepository,
userSubRepo ports.UserSubscriptionRepository,
rdb *redis.Client,
cfg *config.Config,
oauthService *OAuthService,
billingService *BillingService,
rateLimitService *RateLimitService,
billingCacheService *BillingCacheService,
identityService *IdentityService,
) *GatewayService {
// 计算响应头超时时间 // 计算响应头超时时间
responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second
if responseHeaderTimeout == 0 { if responseHeaderTimeout == 0 {
...@@ -105,7 +120,10 @@ func NewGatewayService(repos *repository.Repositories, rdb *redis.Client, cfg *c ...@@ -105,7 +120,10 @@ func NewGatewayService(repos *repository.Repositories, rdb *redis.Client, cfg *c
// 注意:不设置整体 Timeout,让流式响应可以无限时间传输 // 注意:不设置整体 Timeout,让流式响应可以无限时间传输
} }
return &GatewayService{ return &GatewayService{
repos: repos, accountRepo: accountRepo,
usageLogRepo: usageLogRepo,
userRepo: userRepo,
userSubRepo: userSubRepo,
rdb: rdb, rdb: rdb,
cfg: cfg, cfg: cfg,
oauthService: oauthService, oauthService: oauthService,
...@@ -274,7 +292,7 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int ...@@ -274,7 +292,7 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
if sessionHash != "" { if sessionHash != "" {
accountID, err := s.rdb.Get(ctx, stickySessionPrefix+sessionHash).Int64() accountID, err := s.rdb.Get(ctx, stickySessionPrefix+sessionHash).Int64()
if err == nil && accountID > 0 { if err == nil && accountID > 0 {
account, err := s.repos.Account.GetByID(ctx, accountID) account, err := s.accountRepo.GetByID(ctx, accountID)
// 使用IsSchedulable代替IsActive,确保限流/过载账号不会被选中 // 使用IsSchedulable代替IsActive,确保限流/过载账号不会被选中
// 同时检查模型支持 // 同时检查模型支持
if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) { if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
...@@ -289,9 +307,9 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int ...@@ -289,9 +307,9 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
var accounts []model.Account var accounts []model.Account
var err error var err error
if groupID != nil { if groupID != nil {
accounts, err = s.repos.Account.ListSchedulableByGroupID(ctx, *groupID) accounts, err = s.accountRepo.ListSchedulableByGroupID(ctx, *groupID)
} else { } else {
accounts, err = s.repos.Account.ListSchedulable(ctx) accounts, err = s.accountRepo.ListSchedulable(ctx)
} }
if err != nil { if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err) return nil, fmt.Errorf("query accounts failed: %w", err)
...@@ -378,7 +396,7 @@ func (s *GatewayService) getOAuthToken(ctx context.Context, account *model.Accou ...@@ -378,7 +396,7 @@ func (s *GatewayService) getOAuthToken(ctx context.Context, account *model.Accou
account.Credentials["refresh_token"] = tokenInfo.RefreshToken account.Credentials["refresh_token"] = tokenInfo.RefreshToken
} }
if err := s.repos.Account.Update(ctx, account); err != nil { if err := s.accountRepo.Update(ctx, account); err != nil {
log.Printf("Failed to update account credentials: %v", err) log.Printf("Failed to update account credentials: %v", err)
} }
...@@ -667,7 +685,7 @@ func (s *GatewayService) forceRefreshToken(ctx context.Context, account *model.A ...@@ -667,7 +685,7 @@ func (s *GatewayService) forceRefreshToken(ctx context.Context, account *model.A
account.Credentials["refresh_token"] = tokenInfo.RefreshToken account.Credentials["refresh_token"] = tokenInfo.RefreshToken
} }
if err := s.repos.Account.Update(ctx, account); err != nil { if err := s.accountRepo.Update(ctx, account); err != nil {
log.Printf("Failed to update account credentials: %v", err) log.Printf("Failed to update account credentials: %v", err)
} }
...@@ -999,7 +1017,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -999,7 +1017,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
usageLog.SubscriptionID = &subscription.ID usageLog.SubscriptionID = &subscription.ID
} }
if err := s.repos.UsageLog.Create(ctx, usageLog); err != nil { if err := s.usageLogRepo.Create(ctx, usageLog); err != nil {
log.Printf("Create usage log failed: %v", err) log.Printf("Create usage log failed: %v", err)
} }
...@@ -1007,7 +1025,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -1007,7 +1025,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
if isSubscriptionBilling { if isSubscriptionBilling {
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率) // 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
if cost.TotalCost > 0 { if cost.TotalCost > 0 {
if err := s.repos.UserSubscription.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil { if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
log.Printf("Increment subscription usage failed: %v", err) log.Printf("Increment subscription usage failed: %v", err)
} }
// 异步更新订阅缓存 // 异步更新订阅缓存
...@@ -1022,7 +1040,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -1022,7 +1040,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
} else { } else {
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用) // 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
if cost.ActualCost > 0 { if cost.ActualCost > 0 {
if err := s.repos.User.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil { if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
log.Printf("Deduct balance failed: %v", err) log.Printf("Deduct balance failed: %v", err)
} }
// 异步更新余额缓存 // 异步更新余额缓存
...@@ -1037,7 +1055,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -1037,7 +1055,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
} }
// 更新账号最后使用时间 // 更新账号最后使用时间
if err := s.repos.Account.UpdateLastUsed(ctx, account.ID); err != nil { if err := s.accountRepo.UpdateLastUsed(ctx, account.ID); err != nil {
log.Printf("Update last used failed: %v", err) log.Printf("Update last used failed: %v", err)
} }
......
...@@ -5,7 +5,8 @@ import ( ...@@ -5,7 +5,8 @@ import (
"errors" "errors"
"fmt" "fmt"
"sub2api/internal/model" "sub2api/internal/model"
"sub2api/internal/repository" "sub2api/internal/pkg/pagination"
"sub2api/internal/service/ports"
"gorm.io/gorm" "gorm.io/gorm"
) )
...@@ -34,11 +35,11 @@ type UpdateGroupRequest struct { ...@@ -34,11 +35,11 @@ type UpdateGroupRequest struct {
// GroupService 分组管理服务 // GroupService 分组管理服务
type GroupService struct { type GroupService struct {
groupRepo *repository.GroupRepository groupRepo ports.GroupRepository
} }
// NewGroupService 创建分组服务实例 // NewGroupService 创建分组服务实例
func NewGroupService(groupRepo *repository.GroupRepository) *GroupService { func NewGroupService(groupRepo ports.GroupRepository) *GroupService {
return &GroupService{ return &GroupService{
groupRepo: groupRepo, groupRepo: groupRepo,
} }
...@@ -84,7 +85,7 @@ func (s *GroupService) GetByID(ctx context.Context, id int64) (*model.Group, err ...@@ -84,7 +85,7 @@ func (s *GroupService) GetByID(ctx context.Context, id int64) (*model.Group, err
} }
// List 获取分组列表 // List 获取分组列表
func (s *GroupService) List(ctx context.Context, params repository.PaginationParams) ([]model.Group, *repository.PaginationResult, error) { func (s *GroupService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) {
groups, pagination, err := s.groupRepo.List(ctx, params) groups, pagination, err := s.groupRepo.List(ctx, params)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("list groups: %w", err) return nil, nil, fmt.Errorf("list groups: %w", err)
......
...@@ -12,7 +12,7 @@ import ( ...@@ -12,7 +12,7 @@ import (
"sub2api/internal/model" "sub2api/internal/model"
"sub2api/internal/pkg/oauth" "sub2api/internal/pkg/oauth"
"sub2api/internal/repository" "sub2api/internal/service/ports"
"github.com/imroc/req/v3" "github.com/imroc/req/v3"
) )
...@@ -20,11 +20,11 @@ import ( ...@@ -20,11 +20,11 @@ import (
// OAuthService handles OAuth authentication flows // OAuthService handles OAuth authentication flows
type OAuthService struct { type OAuthService struct {
sessionStore *oauth.SessionStore sessionStore *oauth.SessionStore
proxyRepo *repository.ProxyRepository proxyRepo ports.ProxyRepository
} }
// NewOAuthService creates a new OAuth service // NewOAuthService creates a new OAuth service
func NewOAuthService(proxyRepo *repository.ProxyRepository) *OAuthService { func NewOAuthService(proxyRepo ports.ProxyRepository) *OAuthService {
return &OAuthService{ return &OAuthService{
sessionStore: oauth.NewSessionStore(), sessionStore: oauth.NewSessionStore(),
proxyRepo: proxyRepo, proxyRepo: proxyRepo,
...@@ -459,7 +459,7 @@ func (s *OAuthService) RefreshAccountToken(ctx context.Context, account *model.A ...@@ -459,7 +459,7 @@ func (s *OAuthService) RefreshAccountToken(ctx context.Context, account *model.A
// createReqClient creates a req client with Chrome impersonation and optional proxy // createReqClient creates a req client with Chrome impersonation and optional proxy
func (s *OAuthService) createReqClient(proxyURL string) *req.Client { func (s *OAuthService) createReqClient(proxyURL string) *req.Client {
client := req.C(). client := req.C().
ImpersonateChrome(). // Impersonate Chrome browser to bypass Cloudflare ImpersonateChrome(). // Impersonate Chrome browser to bypass Cloudflare
SetTimeout(60 * time.Second) SetTimeout(60 * time.Second)
// Set proxy if specified // Set proxy if specified
......
package ports
import (
"context"
"time"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
)
type AccountRepository interface {
Create(ctx context.Context, account *model.Account) error
GetByID(ctx context.Context, id int64) (*model.Account, error)
Update(ctx context.Context, account *model.Account) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error)
ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error)
ListActive(ctx context.Context) ([]model.Account, error)
ListByPlatform(ctx context.Context, platform string) ([]model.Account, error)
UpdateLastUsed(ctx context.Context, id int64) error
SetError(ctx context.Context, id int64, errorMsg string) error
SetSchedulable(ctx context.Context, id int64, schedulable bool) error
BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error
ListSchedulable(ctx context.Context) ([]model.Account, error)
ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error)
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
SetOverloaded(ctx context.Context, id int64, until time.Time) error
ClearRateLimit(ctx context.Context, id int64) error
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
}
package ports
import (
"context"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
)
type ApiKeyRepository interface {
Create(ctx context.Context, key *model.ApiKey) error
GetByID(ctx context.Context, id int64) (*model.ApiKey, error)
GetByKey(ctx context.Context, key string) (*model.ApiKey, error)
Update(ctx context.Context, key *model.ApiKey) error
Delete(ctx context.Context, id int64) error
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error)
CountByUserID(ctx context.Context, userID int64) (int64, error)
ExistsByKey(ctx context.Context, key string) (bool, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error)
SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error)
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
}
package ports
import (
"context"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"gorm.io/gorm"
)
type GroupRepository interface {
Create(ctx context.Context, group *model.Group) error
GetByID(ctx context.Context, id int64) (*model.Group, error)
Update(ctx context.Context, group *model.Group) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error)
ListActive(ctx context.Context) ([]model.Group, error)
ListActiveByPlatform(ctx context.Context, platform string) ([]model.Group, error)
ExistsByName(ctx context.Context, name string) (bool, error)
GetAccountCount(ctx context.Context, groupID int64) (int64, error)
DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error)
DB() *gorm.DB
}
package ports
import (
"context"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
)
type ProxyRepository interface {
Create(ctx context.Context, proxy *model.Proxy) error
GetByID(ctx context.Context, id int64) (*model.Proxy, error)
Update(ctx context.Context, proxy *model.Proxy) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error)
ListActive(ctx context.Context) ([]model.Proxy, error)
ListActiveWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error)
ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error)
CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error)
}
package ports
import (
"context"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
)
type RedeemCodeRepository interface {
Create(ctx context.Context, code *model.RedeemCode) error
CreateBatch(ctx context.Context, codes []model.RedeemCode) error
GetByID(ctx context.Context, id int64) (*model.RedeemCode, error)
GetByCode(ctx context.Context, code string) (*model.RedeemCode, error)
Update(ctx context.Context, code *model.RedeemCode) error
Delete(ctx context.Context, id int64) error
Use(ctx context.Context, id, userID int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]model.RedeemCode, *pagination.PaginationResult, error)
ListByUser(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error)
}
package ports
import (
"context"
"sub2api/internal/model"
)
type SettingRepository interface {
Get(ctx context.Context, key string) (*model.Setting, error)
GetValue(ctx context.Context, key string) (string, error)
Set(ctx context.Context, key, value string) error
GetMultiple(ctx context.Context, keys []string) (map[string]string, error)
SetMultiple(ctx context.Context, settings map[string]string) error
GetAll(ctx context.Context) (map[string]string, error)
Delete(ctx context.Context, key string) error
}
package ports
import (
"context"
"time"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"sub2api/internal/pkg/usagestats"
)
type UsageLogRepository interface {
Create(ctx context.Context, log *model.UsageLog) error
GetByID(ctx context.Context, id int64) (*model.UsageLog, error)
Delete(ctx context.Context, id int64) error
ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error)
ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error)
ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error)
ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error)
GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error)
}
package ports
import (
"context"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
)
type UserRepository interface {
Create(ctx context.Context, user *model.User) error
GetByID(ctx context.Context, id int64) (*model.User, error)
GetByEmail(ctx context.Context, email string) (*model.User, error)
Update(ctx context.Context, user *model.User) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error)
UpdateBalance(ctx context.Context, id int64, amount float64) error
DeductBalance(ctx context.Context, id int64, amount float64) error
UpdateConcurrency(ctx context.Context, id int64, amount int) error
ExistsByEmail(ctx context.Context, email string) (bool, error)
RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error)
}
package ports
import (
"context"
"time"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
)
type UserSubscriptionRepository interface {
Create(ctx context.Context, sub *model.UserSubscription) error
GetByID(ctx context.Context, id int64) (*model.UserSubscription, error)
GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error)
GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error)
Update(ctx context.Context, sub *model.UserSubscription) error
Delete(ctx context.Context, id int64) error
ListByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error)
ListActiveByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.UserSubscription, *pagination.PaginationResult, error)
List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error)
ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error)
ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error
UpdateStatus(ctx context.Context, subscriptionID int64, status string) error
UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error
ActivateWindows(ctx context.Context, id int64, start time.Time) error
ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error
ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error
ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error
IncrementUsage(ctx context.Context, id int64, costUSD float64) error
BatchUpdateExpiredStatus(ctx context.Context) (int64, error)
}
...@@ -5,7 +5,8 @@ import ( ...@@ -5,7 +5,8 @@ import (
"errors" "errors"
"fmt" "fmt"
"sub2api/internal/model" "sub2api/internal/model"
"sub2api/internal/repository" "sub2api/internal/pkg/pagination"
"sub2api/internal/service/ports"
"gorm.io/gorm" "gorm.io/gorm"
) )
...@@ -37,11 +38,11 @@ type UpdateProxyRequest struct { ...@@ -37,11 +38,11 @@ type UpdateProxyRequest struct {
// ProxyService 代理管理服务 // ProxyService 代理管理服务
type ProxyService struct { type ProxyService struct {
proxyRepo *repository.ProxyRepository proxyRepo ports.ProxyRepository
} }
// NewProxyService 创建代理服务实例 // NewProxyService 创建代理服务实例
func NewProxyService(proxyRepo *repository.ProxyRepository) *ProxyService { func NewProxyService(proxyRepo ports.ProxyRepository) *ProxyService {
return &ProxyService{ return &ProxyService{
proxyRepo: proxyRepo, proxyRepo: proxyRepo,
} }
...@@ -80,7 +81,7 @@ func (s *ProxyService) GetByID(ctx context.Context, id int64) (*model.Proxy, err ...@@ -80,7 +81,7 @@ func (s *ProxyService) GetByID(ctx context.Context, id int64) (*model.Proxy, err
} }
// List 获取代理列表 // List 获取代理列表
func (s *ProxyService) List(ctx context.Context, params repository.PaginationParams) ([]model.Proxy, *repository.PaginationResult, error) { func (s *ProxyService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) {
proxies, pagination, err := s.proxyRepo.List(ctx, params) proxies, pagination, err := s.proxyRepo.List(ctx, params)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("list proxies: %w", err) return nil, nil, fmt.Errorf("list proxies: %w", err)
......
...@@ -9,20 +9,20 @@ import ( ...@@ -9,20 +9,20 @@ import (
"sub2api/internal/config" "sub2api/internal/config"
"sub2api/internal/model" "sub2api/internal/model"
"sub2api/internal/repository" "sub2api/internal/service/ports"
) )
// RateLimitService 处理限流和过载状态管理 // RateLimitService 处理限流和过载状态管理
type RateLimitService struct { type RateLimitService struct {
repos *repository.Repositories accountRepo ports.AccountRepository
cfg *config.Config cfg *config.Config
} }
// NewRateLimitService 创建RateLimitService实例 // NewRateLimitService 创建RateLimitService实例
func NewRateLimitService(repos *repository.Repositories, cfg *config.Config) *RateLimitService { func NewRateLimitService(accountRepo ports.AccountRepository, cfg *config.Config) *RateLimitService {
return &RateLimitService{ return &RateLimitService{
repos: repos, accountRepo: accountRepo,
cfg: cfg, cfg: cfg,
} }
} }
...@@ -62,7 +62,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *mod ...@@ -62,7 +62,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *mod
// handleAuthError 处理认证类错误(401/403),停止账号调度 // handleAuthError 处理认证类错误(401/403),停止账号调度
func (s *RateLimitService) handleAuthError(ctx context.Context, account *model.Account, errorMsg string) { func (s *RateLimitService) handleAuthError(ctx context.Context, account *model.Account, errorMsg string) {
if err := s.repos.Account.SetError(ctx, account.ID, errorMsg); err != nil { if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
log.Printf("SetError failed for account %d: %v", account.ID, err) log.Printf("SetError failed for account %d: %v", account.ID, err)
return return
} }
...@@ -77,7 +77,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *model.Account ...@@ -77,7 +77,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *model.Account
if resetTimestamp == "" { if resetTimestamp == "" {
// 没有重置时间,使用默认5分钟 // 没有重置时间,使用默认5分钟
resetAt := time.Now().Add(5 * time.Minute) resetAt := time.Now().Add(5 * time.Minute)
if err := s.repos.Account.SetRateLimited(ctx, account.ID, resetAt); err != nil { if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err) log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
} }
return return
...@@ -88,7 +88,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *model.Account ...@@ -88,7 +88,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *model.Account
if err != nil { if err != nil {
log.Printf("Parse reset timestamp failed: %v", err) log.Printf("Parse reset timestamp failed: %v", err)
resetAt := time.Now().Add(5 * time.Minute) resetAt := time.Now().Add(5 * time.Minute)
if err := s.repos.Account.SetRateLimited(ctx, account.ID, resetAt); err != nil { if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err) log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
} }
return return
...@@ -97,7 +97,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *model.Account ...@@ -97,7 +97,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *model.Account
resetAt := time.Unix(ts, 0) resetAt := time.Unix(ts, 0)
// 标记限流状态 // 标记限流状态
if err := s.repos.Account.SetRateLimited(ctx, account.ID, resetAt); err != nil { if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err) log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
return return
} }
...@@ -105,7 +105,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *model.Account ...@@ -105,7 +105,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *model.Account
// 根据重置时间反推5h窗口 // 根据重置时间反推5h窗口
windowEnd := resetAt windowEnd := resetAt
windowStart := resetAt.Add(-5 * time.Hour) windowStart := resetAt.Add(-5 * time.Hour)
if err := s.repos.Account.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil { if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil {
log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err) log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err)
} }
...@@ -121,7 +121,7 @@ func (s *RateLimitService) handle529(ctx context.Context, account *model.Account ...@@ -121,7 +121,7 @@ func (s *RateLimitService) handle529(ctx context.Context, account *model.Account
} }
until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute) until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute)
if err := s.repos.Account.SetOverloaded(ctx, account.ID, until); err != nil { if err := s.accountRepo.SetOverloaded(ctx, account.ID, until); err != nil {
log.Printf("SetOverloaded failed for account %d: %v", account.ID, err) log.Printf("SetOverloaded failed for account %d: %v", account.ID, err)
return return
} }
...@@ -152,13 +152,13 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *mod ...@@ -152,13 +152,13 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *mod
log.Printf("Account %d: initializing 5h window from %v to %v (status: %s)", account.ID, start, end, status) log.Printf("Account %d: initializing 5h window from %v to %v (status: %s)", account.ID, start, end, status)
} }
if err := s.repos.Account.UpdateSessionWindow(ctx, account.ID, windowStart, windowEnd, status); err != nil { if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, windowStart, windowEnd, status); err != nil {
log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err) log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err)
} }
// 如果状态为allowed且之前有限流,说明窗口已重置,清除限流状态 // 如果状态为allowed且之前有限流,说明窗口已重置,清除限流状态
if status == "allowed" && account.IsRateLimited() { if status == "allowed" && account.IsRateLimited() {
if err := s.repos.Account.ClearRateLimit(ctx, account.ID); err != nil { if err := s.accountRepo.ClearRateLimit(ctx, account.ID); err != nil {
log.Printf("ClearRateLimit failed for account %d: %v", account.ID, err) log.Printf("ClearRateLimit failed for account %d: %v", account.ID, err)
} }
} }
...@@ -166,5 +166,5 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *mod ...@@ -166,5 +166,5 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *mod
// ClearRateLimit 清除账号的限流状态 // ClearRateLimit 清除账号的限流状态
func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64) error { func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64) error {
return s.repos.Account.ClearRateLimit(ctx, accountID) return s.accountRepo.ClearRateLimit(ctx, accountID)
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment