Commit e5a77853 authored by Forest's avatar Forest
Browse files

refactor: 调整项目结构为单向依赖

parent b3463769
......@@ -6,7 +6,6 @@ import (
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
......@@ -15,29 +14,29 @@ var (
)
type AccountRepository interface {
Create(ctx context.Context, account *model.Account) error
GetByID(ctx context.Context, id int64) (*model.Account, error)
Create(ctx context.Context, account *Account) error
GetByID(ctx context.Context, id int64) (*Account, error)
// GetByCRSAccountID finds an account previously synced from CRS.
// Returns (nil, nil) if not found.
GetByCRSAccountID(ctx context.Context, crsAccountID string) (*model.Account, error)
Update(ctx context.Context, account *model.Account) error
GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error)
Update(ctx context.Context, account *Account) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error)
ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error)
ListActive(ctx context.Context) ([]model.Account, error)
ListByPlatform(ctx context.Context, platform string) ([]model.Account, error)
List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error)
ListByGroup(ctx context.Context, groupID int64) ([]Account, error)
ListActive(ctx context.Context) ([]Account, error)
ListByPlatform(ctx context.Context, platform string) ([]Account, error)
UpdateLastUsed(ctx context.Context, id int64) error
SetError(ctx context.Context, id int64, errorMsg string) error
SetSchedulable(ctx context.Context, id int64, schedulable bool) error
BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error
ListSchedulable(ctx context.Context) ([]model.Account, error)
ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error)
ListSchedulableByPlatform(ctx context.Context, platform string) ([]model.Account, error)
ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]model.Account, error)
ListSchedulable(ctx context.Context) ([]Account, error)
ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error)
ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error)
ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error)
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
SetOverloaded(ctx context.Context, id int64, until time.Time) error
......@@ -99,7 +98,7 @@ func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository)
}
// Create 创建账号
func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (*model.Account, error) {
func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (*Account, error) {
// 验证分组是否存在(如果指定了分组)
if len(req.GroupIDs) > 0 {
for _, groupID := range req.GroupIDs {
......@@ -111,7 +110,7 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
}
// 创建账号
account := &model.Account{
account := &Account{
Name: req.Name,
Platform: req.Platform,
Type: req.Type,
......@@ -120,7 +119,7 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,
Status: model.StatusActive,
Status: StatusActive,
}
if err := s.accountRepo.Create(ctx, account); err != nil {
......@@ -138,7 +137,7 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
}
// GetByID 根据ID获取账号
func (s *AccountService) GetByID(ctx context.Context, id int64) (*model.Account, error) {
func (s *AccountService) GetByID(ctx context.Context, id int64) (*Account, error) {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get account: %w", err)
......@@ -147,7 +146,7 @@ func (s *AccountService) GetByID(ctx context.Context, id int64) (*model.Account,
}
// List 获取账号列表
func (s *AccountService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) {
func (s *AccountService) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
accounts, pagination, err := s.accountRepo.List(ctx, params)
if err != nil {
return nil, nil, fmt.Errorf("list accounts: %w", err)
......@@ -156,7 +155,7 @@ func (s *AccountService) List(ctx context.Context, params pagination.PaginationP
}
// ListByPlatform 根据平台获取账号列表
func (s *AccountService) ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
func (s *AccountService) ListByPlatform(ctx context.Context, platform string) ([]Account, error) {
accounts, err := s.accountRepo.ListByPlatform(ctx, platform)
if err != nil {
return nil, fmt.Errorf("list accounts by platform: %w", err)
......@@ -165,7 +164,7 @@ func (s *AccountService) ListByPlatform(ctx context.Context, platform string) ([
}
// ListByGroup 根据分组获取账号列表
func (s *AccountService) ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) {
func (s *AccountService) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
accounts, err := s.accountRepo.ListByGroup(ctx, groupID)
if err != nil {
return nil, fmt.Errorf("list accounts by group: %w", err)
......@@ -174,7 +173,7 @@ func (s *AccountService) ListByGroup(ctx context.Context, groupID int64) ([]mode
}
// Update 更新账号
func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccountRequest) (*model.Account, error) {
func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccountRequest) (*Account, error) {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get account: %w", err)
......@@ -290,13 +289,13 @@ func (s *AccountService) TestCredentials(ctx context.Context, id int64) error {
// 根据平台执行不同的测试逻辑
switch account.Platform {
case model.PlatformAnthropic:
case PlatformAnthropic:
// TODO: 测试Anthropic API凭证
return nil
case model.PlatformOpenAI:
case PlatformOpenAI:
// TODO: 测试OpenAI API凭证
return nil
case model.PlatformGemini:
case PlatformGemini:
// TODO: 测试Gemini API凭证
return nil
default:
......
......@@ -14,7 +14,6 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/gin-gonic/gin"
......@@ -127,7 +126,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
}
// testClaudeAccountConnection tests an Anthropic Claude account's connection
func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account *model.Account, modelID string) error {
func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account *Account, modelID string) error {
ctx := c.Request.Context()
// Determine the model to use
......@@ -254,7 +253,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
}
// testOpenAIAccountConnection tests an OpenAI account's connection
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *model.Account, modelID string) error {
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string) error {
ctx := c.Request.Context()
// Default to openai.DefaultTestModel for OpenAI testing
......
......@@ -7,24 +7,23 @@ import (
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
)
type UsageLogRepository interface {
Create(ctx context.Context, log *model.UsageLog) error
GetByID(ctx context.Context, id int64) (*model.UsageLog, error)
Create(ctx context.Context, log *UsageLog) error
GetByID(ctx context.Context, id int64) (*UsageLog, error)
Delete(ctx context.Context, id int64) error
ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error)
ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error)
ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error)
ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error)
GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error)
......@@ -44,7 +43,7 @@ type UsageLogRepository interface {
GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error)
// Admin usage listing/stats
ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]UsageLog, *pagination.PaginationResult, error)
GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*usagestats.UsageStats, error)
// Account stats
......@@ -163,7 +162,7 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
}
// Setup Token账号:根据session_window推算(没有profile scope,无法调用usage API)
if account.Type == model.AccountTypeSetupToken {
if account.Type == AccountTypeSetupToken {
usage := s.estimateSetupTokenUsage(account)
// 添加窗口统计
s.addWindowStats(ctx, account, usage)
......@@ -175,7 +174,7 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
}
// addWindowStats 为usage数据添加窗口期统计
func (s *AccountUsageService) addWindowStats(ctx context.Context, account *model.Account, usage *UsageInfo) {
func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Account, usage *UsageInfo) {
if usage.FiveHour == nil {
return
}
......@@ -225,7 +224,7 @@ func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountI
}
// fetchOAuthUsage 从Anthropic API获取OAuth账号的使用量
func (s *AccountUsageService) fetchOAuthUsage(ctx context.Context, account *model.Account) (*UsageInfo, error) {
func (s *AccountUsageService) fetchOAuthUsage(ctx context.Context, account *Account) (*UsageInfo, error) {
accessToken := account.GetCredential("access_token")
if accessToken == "" {
return nil, fmt.Errorf("no access token available")
......@@ -320,7 +319,7 @@ func (s *AccountUsageService) buildUsageInfo(resp *ClaudeUsageResponse, updatedA
}
// estimateSetupTokenUsage 根据session_window推算Setup Token账号的使用量
func (s *AccountUsageService) estimateSetupTokenUsage(account *model.Account) *UsageInfo {
func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageInfo {
info := &UsageInfo{}
// 如果有session_window信息
......
......@@ -7,62 +7,61 @@ import (
"log"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
// AdminService interface defines admin management operations
type AdminService interface {
// User management
ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]model.User, int64, error)
GetUser(ctx context.Context, id int64) (*model.User, error)
CreateUser(ctx context.Context, input *CreateUserInput) (*model.User, error)
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*model.User, error)
ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]User, int64, error)
GetUser(ctx context.Context, id int64) (*User, error)
CreateUser(ctx context.Context, input *CreateUserInput) (*User, error)
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error)
DeleteUser(ctx context.Context, id int64) error
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*model.User, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error)
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]ApiKey, int64, error)
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
// Group management
ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]model.Group, int64, error)
GetAllGroups(ctx context.Context) ([]model.Group, error)
GetAllGroupsByPlatform(ctx context.Context, platform string) ([]model.Group, error)
GetGroup(ctx context.Context, id int64) (*model.Group, error)
CreateGroup(ctx context.Context, input *CreateGroupInput) (*model.Group, error)
UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*model.Group, error)
ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error)
GetAllGroups(ctx context.Context) ([]Group, error)
GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error)
GetGroup(ctx context.Context, id int64) (*Group, error)
CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error)
UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error)
DeleteGroup(ctx context.Context, id int64) error
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]model.ApiKey, int64, error)
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]ApiKey, int64, error)
// Account management
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]model.Account, int64, error)
GetAccount(ctx context.Context, id int64) (*model.Account, error)
CreateAccount(ctx context.Context, input *CreateAccountInput) (*model.Account, error)
UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*model.Account, error)
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error)
GetAccount(ctx context.Context, id int64) (*Account, error)
CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error)
UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error)
DeleteAccount(ctx context.Context, id int64) error
RefreshAccountCredentials(ctx context.Context, id int64) (*model.Account, error)
ClearAccountError(ctx context.Context, id int64) (*model.Account, error)
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*model.Account, error)
RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error)
ClearAccountError(ctx context.Context, id int64) (*Account, error)
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error)
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
// Proxy management
ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]model.Proxy, int64, error)
GetAllProxies(ctx context.Context) ([]model.Proxy, error)
GetAllProxiesWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error)
GetProxy(ctx context.Context, id int64) (*model.Proxy, error)
CreateProxy(ctx context.Context, input *CreateProxyInput) (*model.Proxy, error)
UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*model.Proxy, error)
ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error)
GetAllProxies(ctx context.Context) ([]Proxy, error)
GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error)
GetProxy(ctx context.Context, id int64) (*Proxy, error)
CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error)
UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error)
DeleteProxy(ctx context.Context, id int64) error
GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]model.Account, int64, error)
GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]Account, int64, error)
CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error)
TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error)
// Redeem code management
ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]model.RedeemCode, int64, error)
GetRedeemCode(ctx context.Context, id int64) (*model.RedeemCode, error)
GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]model.RedeemCode, error)
ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error)
GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error)
GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error)
DeleteRedeemCode(ctx context.Context, id int64) error
BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error)
ExpireRedeemCode(ctx context.Context, id int64) (*model.RedeemCode, error)
ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error)
}
// Input types for admin operations
......@@ -252,7 +251,7 @@ func NewAdminService(
}
// User management implementations
func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]model.User, int64, error) {
func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]User, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
users, result, err := s.userRepo.ListWithFilters(ctx, params, status, role, search)
if err != nil {
......@@ -261,20 +260,21 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, st
return users, result.Total, nil
}
func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*model.User, error) {
func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) {
return s.userRepo.GetByID(ctx, id)
}
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*model.User, error) {
user := &model.User{
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) {
user := &User{
Email: input.Email,
Username: input.Username,
Wechat: input.Wechat,
Notes: input.Notes,
Role: "user", // Always create as regular user, never admin
Role: RoleUser, // Always create as regular user, never admin
Balance: input.Balance,
Concurrency: input.Concurrency,
Status: model.StatusActive,
Status: StatusActive,
AllowedGroups: input.AllowedGroups,
}
if err := user.SetPassword(input.Password); err != nil {
return nil, err
......@@ -285,7 +285,7 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu
return user, nil
}
func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*model.User, error) {
func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) {
user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
return nil, err
......@@ -335,16 +335,16 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
concurrencyDiff := user.Concurrency - oldConcurrency
if concurrencyDiff != 0 {
code, err := model.GenerateRedeemCode()
code, err := GenerateRedeemCode()
if err != nil {
log.Printf("failed to generate adjustment redeem code: %v", err)
return user, nil
}
adjustmentRecord := &model.RedeemCode{
adjustmentRecord := &RedeemCode{
Code: code,
Type: model.AdjustmentTypeAdminConcurrency,
Type: AdjustmentTypeAdminConcurrency,
Value: float64(concurrencyDiff),
Status: model.StatusUsed,
Status: StatusUsed,
UsedBy: &user.ID,
}
now := time.Now()
......@@ -369,7 +369,7 @@ func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error {
return s.userRepo.Delete(ctx, id)
}
func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*model.User, error) {
func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, err
......@@ -406,17 +406,17 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
balanceDiff := user.Balance - oldBalance
if balanceDiff != 0 {
code, err := model.GenerateRedeemCode()
code, err := GenerateRedeemCode()
if err != nil {
log.Printf("failed to generate adjustment redeem code: %v", err)
return user, nil
}
adjustmentRecord := &model.RedeemCode{
adjustmentRecord := &RedeemCode{
Code: code,
Type: model.AdjustmentTypeAdminBalance,
Type: AdjustmentTypeAdminBalance,
Value: balanceDiff,
Status: model.StatusUsed,
Status: StatusUsed,
UsedBy: &user.ID,
Notes: notes,
}
......@@ -431,7 +431,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
return user, nil
}
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error) {
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]ApiKey, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil {
......@@ -452,7 +452,7 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64,
}
// Group management implementations
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]model.Group, int64, error) {
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, isExclusive)
if err != nil {
......@@ -461,36 +461,36 @@ func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, p
return groups, result.Total, nil
}
func (s *adminServiceImpl) GetAllGroups(ctx context.Context) ([]model.Group, error) {
func (s *adminServiceImpl) GetAllGroups(ctx context.Context) ([]Group, error) {
return s.groupRepo.ListActive(ctx)
}
func (s *adminServiceImpl) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]model.Group, error) {
func (s *adminServiceImpl) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error) {
return s.groupRepo.ListActiveByPlatform(ctx, platform)
}
func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*model.Group, error) {
func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*Group, error) {
return s.groupRepo.GetByID(ctx, id)
}
func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*model.Group, error) {
func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) {
platform := input.Platform
if platform == "" {
platform = model.PlatformAnthropic
platform = PlatformAnthropic
}
subscriptionType := input.SubscriptionType
if subscriptionType == "" {
subscriptionType = model.SubscriptionTypeStandard
subscriptionType = SubscriptionTypeStandard
}
group := &model.Group{
group := &Group{
Name: input.Name,
Description: input.Description,
Platform: platform,
RateMultiplier: input.RateMultiplier,
IsExclusive: input.IsExclusive,
Status: model.StatusActive,
Status: StatusActive,
SubscriptionType: subscriptionType,
DailyLimitUSD: input.DailyLimitUSD,
WeeklyLimitUSD: input.WeeklyLimitUSD,
......@@ -502,7 +502,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
return group, nil
}
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*model.Group, error) {
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
group, err := s.groupRepo.GetByID(ctx, id)
if err != nil {
return nil, err
......@@ -571,7 +571,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
return nil
}
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]model.ApiKey, int64, error) {
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]ApiKey, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params)
if err != nil {
......@@ -581,7 +581,7 @@ func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, p
}
// Account management implementations
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]model.Account, int64, error) {
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search)
if err != nil {
......@@ -590,21 +590,21 @@ func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int,
return accounts, result.Total, nil
}
func (s *adminServiceImpl) GetAccount(ctx context.Context, id int64) (*model.Account, error) {
func (s *adminServiceImpl) GetAccount(ctx context.Context, id int64) (*Account, error) {
return s.accountRepo.GetByID(ctx, id)
}
func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*model.Account, error) {
account := &model.Account{
func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) {
account := &Account{
Name: input.Name,
Platform: input.Platform,
Type: input.Type,
Credentials: model.JSONB(input.Credentials),
Extra: model.JSONB(input.Extra),
Credentials: input.Credentials,
Extra: input.Extra,
ProxyID: input.ProxyID,
Concurrency: input.Concurrency,
Priority: input.Priority,
Status: model.StatusActive,
Status: StatusActive,
}
if err := s.accountRepo.Create(ctx, account); err != nil {
return nil, err
......@@ -618,7 +618,7 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
return account, nil
}
func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*model.Account, error) {
func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error) {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return nil, err
......@@ -631,10 +631,10 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
account.Type = input.Type
}
if len(input.Credentials) > 0 {
account.Credentials = model.JSONB(input.Credentials)
account.Credentials = input.Credentials
}
if len(input.Extra) > 0 {
account.Extra = model.JSONB(input.Extra)
account.Extra = input.Extra
}
if input.ProxyID != nil {
account.ProxyID = input.ProxyID
......@@ -730,7 +730,7 @@ func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error {
return s.accountRepo.Delete(ctx, id)
}
func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int64) (*model.Account, error) {
func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error) {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return nil, err
......@@ -739,12 +739,12 @@ func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int
return account, nil
}
func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*model.Account, error) {
func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Account, error) {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
account.Status = model.StatusActive
account.Status = StatusActive
account.ErrorMessage = ""
if err := s.accountRepo.Update(ctx, account); err != nil {
return nil, err
......@@ -752,7 +752,7 @@ func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*mo
return account, nil
}
func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*model.Account, error) {
func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) {
if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil {
return nil, err
}
......@@ -760,7 +760,7 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64,
}
// Proxy management implementations
func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]model.Proxy, int64, error) {
func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search)
if err != nil {
......@@ -769,27 +769,27 @@ func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int,
return proxies, result.Total, nil
}
func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]model.Proxy, error) {
func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]Proxy, error) {
return s.proxyRepo.ListActive(ctx)
}
func (s *adminServiceImpl) GetAllProxiesWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error) {
func (s *adminServiceImpl) GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) {
return s.proxyRepo.ListActiveWithAccountCount(ctx)
}
func (s *adminServiceImpl) GetProxy(ctx context.Context, id int64) (*model.Proxy, error) {
func (s *adminServiceImpl) GetProxy(ctx context.Context, id int64) (*Proxy, error) {
return s.proxyRepo.GetByID(ctx, id)
}
func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyInput) (*model.Proxy, error) {
proxy := &model.Proxy{
func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error) {
proxy := &Proxy{
Name: input.Name,
Protocol: input.Protocol,
Host: input.Host,
Port: input.Port,
Username: input.Username,
Password: input.Password,
Status: model.StatusActive,
Status: StatusActive,
}
if err := s.proxyRepo.Create(ctx, proxy); err != nil {
return nil, err
......@@ -797,7 +797,7 @@ func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyIn
return proxy, nil
}
func (s *adminServiceImpl) UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*model.Proxy, error) {
func (s *adminServiceImpl) UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
return nil, err
......@@ -835,9 +835,9 @@ func (s *adminServiceImpl) DeleteProxy(ctx context.Context, id int64) error {
return s.proxyRepo.Delete(ctx, id)
}
func (s *adminServiceImpl) GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]model.Account, int64, error) {
func (s *adminServiceImpl) GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]Account, int64, error) {
// Return mock data for now - would need a dedicated repository method
return []model.Account{}, 0, nil
return []Account{}, 0, nil
}
func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) {
......@@ -845,7 +845,7 @@ func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, po
}
// Redeem code management implementations
func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]model.RedeemCode, int64, error) {
func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search)
if err != nil {
......@@ -854,13 +854,13 @@ func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize i
return codes, result.Total, nil
}
func (s *adminServiceImpl) GetRedeemCode(ctx context.Context, id int64) (*model.RedeemCode, error) {
func (s *adminServiceImpl) GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) {
return s.redeemCodeRepo.GetByID(ctx, id)
}
func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]model.RedeemCode, error) {
func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error) {
// 如果是订阅类型,验证必须有 GroupID
if input.Type == model.RedeemTypeSubscription {
if input.Type == RedeemTypeSubscription {
if input.GroupID == nil {
return nil, errors.New("group_id is required for subscription type")
}
......@@ -874,20 +874,20 @@ func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *Gener
}
}
codes := make([]model.RedeemCode, 0, input.Count)
codes := make([]RedeemCode, 0, input.Count)
for i := 0; i < input.Count; i++ {
codeValue, err := model.GenerateRedeemCode()
codeValue, err := GenerateRedeemCode()
if err != nil {
return nil, err
}
code := model.RedeemCode{
code := RedeemCode{
Code: codeValue,
Type: input.Type,
Value: input.Value,
Status: model.StatusUnused,
Status: StatusUnused,
}
// 订阅类型专用字段
if input.Type == model.RedeemTypeSubscription {
if input.Type == RedeemTypeSubscription {
code.GroupID = input.GroupID
code.ValidityDays = input.ValidityDays
if code.ValidityDays <= 0 {
......@@ -916,12 +916,12 @@ func (s *adminServiceImpl) BatchDeleteRedeemCodes(ctx context.Context, ids []int
return deleted, nil
}
func (s *adminServiceImpl) ExpireRedeemCode(ctx context.Context, id int64) (*model.RedeemCode, error) {
func (s *adminServiceImpl) ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) {
code, err := s.redeemCodeRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
code.Status = model.StatusExpired
code.Status = StatusExpired
if err := s.redeemCodeRepo.Update(ctx, code); err != nil {
return nil, err
}
......
package service
import "time"
type ApiKey struct {
ID int64
UserID int64
Key string
Name string
GroupID *int64
Status string
CreatedAt time.Time
UpdatedAt time.Time
User *User
Group *Group
}
func (k *ApiKey) IsActive() bool {
return k.Status == StatusActive
}
......@@ -10,7 +10,6 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/redis/go-redis/v9"
......@@ -30,17 +29,17 @@ const (
)
type ApiKeyRepository interface {
Create(ctx context.Context, key *model.ApiKey) error
GetByID(ctx context.Context, id int64) (*model.ApiKey, error)
GetByKey(ctx context.Context, key string) (*model.ApiKey, error)
Update(ctx context.Context, key *model.ApiKey) error
Create(ctx context.Context, key *ApiKey) error
GetByID(ctx context.Context, id int64) (*ApiKey, error)
GetByKey(ctx context.Context, key string) (*ApiKey, error)
Update(ctx context.Context, key *ApiKey) error
Delete(ctx context.Context, id int64) error
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error)
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
CountByUserID(ctx context.Context, userID int64) (int64, error)
ExistsByKey(ctx context.Context, key string) (bool, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error)
SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error)
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
}
......@@ -168,7 +167,7 @@ func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID in
// canUserBindGroup 检查用户是否可以绑定指定分组
// 对于订阅类型分组:检查用户是否有有效订阅
// 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑
func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *model.User, group *model.Group) bool {
func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool {
// 订阅类型分组:需要有效订阅
if group.IsSubscriptionType() {
_, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, group.ID)
......@@ -179,7 +178,7 @@ func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *model.User,
}
// Create 创建API Key
func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiKeyRequest) (*model.ApiKey, error) {
func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiKeyRequest) (*ApiKey, error) {
// 验证用户存在
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
......@@ -235,12 +234,12 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
}
// 创建API Key记录
apiKey := &model.ApiKey{
apiKey := &ApiKey{
UserID: userID,
Key: key,
Name: req.Name,
GroupID: req.GroupID,
Status: model.StatusActive,
Status: StatusActive,
}
if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil {
......@@ -251,7 +250,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
}
// List 获取用户的API Key列表
func (s *ApiKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
func (s *ApiKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil {
return nil, nil, fmt.Errorf("list api keys: %w", err)
......@@ -260,7 +259,7 @@ func (s *ApiKeyService) List(ctx context.Context, userID int64, params paginatio
}
// GetByID 根据ID获取API Key
func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) {
func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
......@@ -269,7 +268,7 @@ func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*model.ApiKey, e
}
// GetByKey 根据Key字符串获取API Key(用于认证)
func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey, error) {
func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*ApiKey, error) {
// 尝试从Redis缓存获取
cacheKey := fmt.Sprintf("apikey:%s", key)
......@@ -289,7 +288,7 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey
}
// Update 更新API Key
func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*model.ApiKey, error) {
func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*ApiKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
......@@ -364,7 +363,7 @@ func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) erro
}
// ValidateKey 验证API Key是否有效(用于认证中间件)
func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*model.ApiKey, *model.User, error) {
func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*ApiKey, *User, error) {
// 获取API Key
apiKey, err := s.GetByKey(ctx, key)
if err != nil {
......@@ -408,7 +407,7 @@ func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
// 返回用户可以选择的分组:
// - 标准类型分组:公开的(非专属)或用户被明确允许的
// - 订阅类型分组:用户有有效订阅的
func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]model.Group, error) {
func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) {
// 获取用户信息
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
......@@ -434,7 +433,7 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([
}
// 过滤出用户有权限的分组
availableGroups := make([]model.Group, 0)
availableGroups := make([]Group, 0)
for _, group := range allGroups {
if s.canUserBindGroupInternal(user, &group, subscribedGroupIDs) {
availableGroups = append(availableGroups, group)
......@@ -445,7 +444,7 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([
}
// canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据)
func (s *ApiKeyService) canUserBindGroupInternal(user *model.User, group *model.Group, subscribedGroupIDs map[int64]bool) bool {
func (s *ApiKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool {
// 订阅类型分组:需要有效订阅
if group.IsSubscriptionType() {
return subscribedGroupIDs[group.ID]
......@@ -454,7 +453,7 @@ func (s *ApiKeyService) canUserBindGroupInternal(user *model.User, group *model.
return user.CanBindGroup(group.ID, group.IsExclusive)
}
func (s *ApiKeyService) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) {
func (s *ApiKeyService) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) {
keys, err := s.apiKeyRepo.SearchApiKeys(ctx, userID, keyword, limit)
if err != nil {
return nil, fmt.Errorf("search api keys: %w", err)
......
......@@ -9,7 +9,6 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/golang-jwt/jwt/v5"
"golang.org/x/crypto/bcrypt"
......@@ -64,12 +63,12 @@ func NewAuthService(
}
// Register 用户注册,返回token和用户
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *model.User, error) {
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
return s.RegisterWithVerification(ctx, email, password, "")
}
// RegisterWithVerification 用户注册(支持邮件验证),返回token和用户
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode string) (string, *model.User, error) {
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode string) (string, *User, error) {
// 检查是否开放注册
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) {
return "", nil, ErrRegDisabled
......@@ -113,13 +112,13 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
}
// 创建用户
user := &model.User{
user := &User{
Email: email,
PasswordHash: hashedPassword,
Role: model.RoleUser,
Role: RoleUser,
Balance: defaultBalance,
Concurrency: defaultConcurrency,
Status: model.StatusActive,
Status: StatusActive,
}
if err := s.userRepo.Create(ctx, user); err != nil {
......@@ -251,7 +250,7 @@ func (s *AuthService) IsEmailVerifyEnabled(ctx context.Context) bool {
}
// Login 用户登录,返回JWT token
func (s *AuthService) Login(ctx context.Context, email, password string) (string, *model.User, error) {
func (s *AuthService) Login(ctx context.Context, email, password string) (string, *User, error) {
// 查找用户
user, err := s.userRepo.GetByEmail(ctx, email)
if err != nil {
......@@ -307,7 +306,7 @@ func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
}
// GenerateToken 生成JWT token
func (s *AuthService) GenerateToken(user *model.User) (string, error) {
func (s *AuthService) GenerateToken(user *User) (string, error) {
now := time.Now()
expiresAt := now.Add(time.Duration(s.cfg.JWT.ExpireHour) * time.Hour)
......
......@@ -7,7 +7,6 @@ import (
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
)
// 错误定义
......@@ -224,7 +223,7 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID
// CheckBillingEligibility 检查用户是否有资格发起请求
// 余额模式:检查缓存余额 > 0
// 订阅模式:检查缓存用量未超过限额(Group限额从参数传入)
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *model.User, apiKey *model.ApiKey, group *model.Group, subscription *model.UserSubscription) error {
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *ApiKey, group *Group, subscription *UserSubscription) error {
// 判断计费模式
isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil
......@@ -252,7 +251,7 @@ func (s *BillingCacheService) checkBalanceEligibility(ctx context.Context, userI
}
// checkSubscriptionEligibility 检查订阅模式资格
func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context, userID int64, group *model.Group, subscription *model.UserSubscription) error {
func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context, userID int64, group *Group, subscription *UserSubscription) error {
// 获取订阅缓存数据
subData, err := s.GetSubscriptionStatus(ctx, userID, group.ID)
if err != nil {
......@@ -262,7 +261,7 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
}
// 检查订阅状态
if subData.Status != model.SubscriptionStatusActive {
if subData.Status != SubscriptionStatusActive {
return ErrSubscriptionInvalid
}
......@@ -288,7 +287,7 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
}
// checkSubscriptionLimitsFallback 降级检查订阅限额
func (s *BillingCacheService) checkSubscriptionLimitsFallback(subscription *model.UserSubscription, group *model.Group) error {
func (s *BillingCacheService) checkSubscriptionLimitsFallback(subscription *UserSubscription, group *Group) error {
if subscription == nil {
return ErrSubscriptionInvalid
}
......
......@@ -11,8 +11,6 @@ import (
"net/url"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
)
type CRSSyncService struct {
......@@ -180,7 +178,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
),
}
var proxies []model.Proxy
var proxies []Proxy
if input.SyncProxies {
proxies, _ = s.proxyRepo.ListActive(ctx)
}
......@@ -197,7 +195,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
if targetType == "" {
targetType = "oauth"
}
if targetType != model.AccountTypeOAuth && targetType != model.AccountTypeSetupToken {
if targetType != AccountTypeOAuth && targetType != AccountTypeSetupToken {
item.Action = "skipped"
item.Error = "unsupported authType: " + targetType
result.Skipped++
......@@ -268,12 +266,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
if existing == nil {
account := &model.Account{
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: model.PlatformAnthropic,
Platform: PlatformAnthropic,
Type: targetType,
Credentials: model.JSONB(credentials),
Extra: model.JSONB(extra),
Credentials: credentials,
Extra: extra,
ProxyID: proxyID,
Concurrency: concurrency,
Priority: priority,
......@@ -288,7 +286,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
continue
}
// 🔄 Refresh OAuth token after creation
if targetType == model.AccountTypeOAuth {
if targetType == AccountTypeOAuth {
if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil {
account.Credentials = refreshedCreds
_ = s.accountRepo.Update(ctx, account)
......@@ -301,11 +299,11 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
// Update existing
existing.Extra = mergeJSONB(existing.Extra, extra)
existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = model.PlatformAnthropic
existing.Platform = PlatformAnthropic
existing.Type = targetType
existing.Credentials = mergeJSONB(existing.Credentials, credentials)
existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID
}
......@@ -323,7 +321,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
// 🔄 Refresh OAuth token after update
if targetType == model.AccountTypeOAuth {
if targetType == AccountTypeOAuth {
if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil {
existing.Credentials = refreshedCreds
_ = s.accountRepo.Update(ctx, existing)
......@@ -385,12 +383,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
if existing == nil {
account := &model.Account{
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: model.PlatformAnthropic,
Type: model.AccountTypeApiKey,
Credentials: model.JSONB(credentials),
Extra: model.JSONB(extra),
Platform: PlatformAnthropic,
Type: AccountTypeApiKey,
Credentials: credentials,
Extra: extra,
ProxyID: proxyID,
Concurrency: concurrency,
Priority: priority,
......@@ -410,11 +408,11 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
continue
}
existing.Extra = mergeJSONB(existing.Extra, extra)
existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = model.PlatformAnthropic
existing.Type = model.AccountTypeApiKey
existing.Credentials = mergeJSONB(existing.Credentials, credentials)
existing.Platform = PlatformAnthropic
existing.Type = AccountTypeApiKey
existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID
}
......@@ -508,12 +506,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
if existing == nil {
account := &model.Account{
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: model.PlatformOpenAI,
Type: model.AccountTypeOAuth,
Credentials: model.JSONB(credentials),
Extra: model.JSONB(extra),
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: credentials,
Extra: extra,
ProxyID: proxyID,
Concurrency: concurrency,
Priority: priority,
......@@ -538,11 +536,11 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
continue
}
existing.Extra = mergeJSONB(existing.Extra, extra)
existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = model.PlatformOpenAI
existing.Type = model.AccountTypeOAuth
existing.Credentials = mergeJSONB(existing.Credentials, credentials)
existing.Platform = PlatformOpenAI
existing.Type = AccountTypeOAuth
existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID
}
......@@ -629,12 +627,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
if existing == nil {
account := &model.Account{
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: model.PlatformOpenAI,
Type: model.AccountTypeApiKey,
Credentials: model.JSONB(credentials),
Extra: model.JSONB(extra),
Platform: PlatformOpenAI,
Type: AccountTypeApiKey,
Credentials: credentials,
Extra: extra,
ProxyID: proxyID,
Concurrency: concurrency,
Priority: priority,
......@@ -654,11 +652,11 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
continue
}
existing.Extra = mergeJSONB(existing.Extra, extra)
existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = model.PlatformOpenAI
existing.Type = model.AccountTypeApiKey
existing.Credentials = mergeJSONB(existing.Credentials, credentials)
existing.Platform = PlatformOpenAI
existing.Type = AccountTypeApiKey
existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID
}
......@@ -683,9 +681,8 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
return result, nil
}
// mergeJSONB merges two JSONB maps without removing keys that are absent in updates.
func mergeJSONB(existing model.JSONB, updates map[string]any) model.JSONB {
out := make(model.JSONB)
func mergeMap(existing map[string]any, updates map[string]any) map[string]any {
out := make(map[string]any, len(existing)+len(updates))
for k, v := range existing {
out[k] = v
}
......@@ -695,7 +692,7 @@ func mergeJSONB(existing model.JSONB, updates map[string]any) model.JSONB {
return out
}
func (s *CRSSyncService) mapOrCreateProxy(ctx context.Context, enabled bool, cached *[]model.Proxy, src *crsProxy, defaultName string) (*int64, error) {
func (s *CRSSyncService) mapOrCreateProxy(ctx context.Context, enabled bool, cached *[]Proxy, src *crsProxy, defaultName string) (*int64, error) {
if !enabled || src == nil {
return nil, nil
}
......@@ -731,14 +728,14 @@ func (s *CRSSyncService) mapOrCreateProxy(ctx context.Context, enabled bool, cac
}
// Create new proxy
proxy := &model.Proxy{
proxy := &Proxy{
Name: defaultProxyName(defaultName, protocol, host, port),
Protocol: protocol,
Host: host,
Port: port,
Username: username,
Password: password,
Status: model.StatusActive,
Status: StatusActive,
}
if err := s.proxyRepo.Create(ctx, proxy); err != nil {
return nil, err
......@@ -897,8 +894,8 @@ func crsExportAccounts(ctx context.Context, client *http.Client, baseURL, adminT
// refreshOAuthToken attempts to refresh OAuth token for a synced account
// Returns updated credentials or nil if refresh failed/not applicable
func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *model.Account) model.JSONB {
if account.Type != model.AccountTypeOAuth {
func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *Account) map[string]any {
if account.Type != AccountTypeOAuth {
return nil
}
......@@ -906,7 +903,7 @@ func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *model.A
var err error
switch account.Platform {
case model.PlatformAnthropic:
case PlatformAnthropic:
if s.oauthService == nil {
return nil
}
......@@ -931,7 +928,7 @@ func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *model.A
newCredentials["scope"] = tokenInfo.Scope
}
}
case model.PlatformOpenAI:
case PlatformOpenAI:
if s.openaiOAuthService == nil {
return nil
}
......@@ -956,5 +953,5 @@ func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *model.A
return nil
}
return model.JSONB(newCredentials)
return newCredentials
}
package model
package service
import (
"time"
// Status constants
const (
StatusActive = "active"
StatusDisabled = "disabled"
StatusError = "error"
StatusUnused = "unused"
StatusUsed = "used"
StatusExpired = "expired"
)
// Role constants
const (
RoleAdmin = "admin"
RoleUser = "user"
)
// Platform constants
const (
PlatformAnthropic = "anthropic"
PlatformOpenAI = "openai"
PlatformGemini = "gemini"
)
// Account type constants
const (
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
AccountTypeApiKey = "apikey" // API Key类型账号
)
// Setting 系统设置模型(Key-Value存储)
type Setting struct {
ID int64 `gorm:"primaryKey" json:"id"`
Key string `gorm:"uniqueIndex;size:100;not null" json:"key"`
Value string `gorm:"type:text;not null" json:"value"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
}
// Redeem type constants
const (
RedeemTypeBalance = "balance"
RedeemTypeConcurrency = "concurrency"
RedeemTypeSubscription = "subscription"
)
func (Setting) TableName() string {
return "settings"
}
// Admin adjustment type constants
const (
AdjustmentTypeAdminBalance = "admin_balance" // 管理员调整余额
AdjustmentTypeAdminConcurrency = "admin_concurrency" // 管理员调整并发数
)
// 设置Key常量
// Group subscription type constants
const (
SubscriptionTypeStandard = "standard" // 标准计费模式(按余额扣费)
SubscriptionTypeSubscription = "subscription" // 订阅模式(按限额控制)
)
// Subscription status constants
const (
SubscriptionStatusActive = "active"
SubscriptionStatusExpired = "expired"
SubscriptionStatusSuspended = "suspended"
)
// Setting keys
const (
// 注册设置
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
......@@ -52,53 +92,5 @@ const (
SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成)
)
// 管理员 API Key 前缀(与用户 sk- 前缀区分)
// Admin API Key prefix (distinct from user "sk-" keys)
const AdminApiKeyPrefix = "admin-"
// SystemSettings 系统设置结构体(用于API响应)
type SystemSettings struct {
// 注册设置
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
// 邮件服务设置
SmtpHost string `json:"smtp_host"`
SmtpPort int `json:"smtp_port"`
SmtpUsername string `json:"smtp_username"`
SmtpPassword string `json:"smtp_password,omitempty"` // 不返回明文密码
SmtpFrom string `json:"smtp_from_email"`
SmtpFromName string `json:"smtp_from_name"`
SmtpUseTLS bool `json:"smtp_use_tls"`
// Cloudflare Turnstile 设置
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
TurnstileSecretKey string `json:"turnstile_secret_key,omitempty"` // 不返回明文密钥
// OEM设置
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocUrl string `json:"doc_url"`
// 默认配置
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
}
// PublicSettings 公开设置(无需登录即可获取)
type PublicSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocUrl string `json:"doc_url"`
Version string `json:"version"`
}
......@@ -11,7 +11,6 @@ import (
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
)
var (
......@@ -69,13 +68,13 @@ func NewEmailService(settingRepo SettingRepository, cache EmailCache) *EmailServ
// GetSmtpConfig 从数据库获取SMTP配置
func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) {
keys := []string{
model.SettingKeySmtpHost,
model.SettingKeySmtpPort,
model.SettingKeySmtpUsername,
model.SettingKeySmtpPassword,
model.SettingKeySmtpFrom,
model.SettingKeySmtpFromName,
model.SettingKeySmtpUseTLS,
SettingKeySmtpHost,
SettingKeySmtpPort,
SettingKeySmtpUsername,
SettingKeySmtpPassword,
SettingKeySmtpFrom,
SettingKeySmtpFromName,
SettingKeySmtpUseTLS,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
......@@ -83,27 +82,27 @@ func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) {
return nil, fmt.Errorf("get smtp settings: %w", err)
}
host := settings[model.SettingKeySmtpHost]
host := settings[SettingKeySmtpHost]
if host == "" {
return nil, ErrEmailNotConfigured
}
port := 587 // 默认端口
if portStr := settings[model.SettingKeySmtpPort]; portStr != "" {
if portStr := settings[SettingKeySmtpPort]; portStr != "" {
if p, err := strconv.Atoi(portStr); err == nil {
port = p
}
}
useTLS := settings[model.SettingKeySmtpUseTLS] == "true"
useTLS := settings[SettingKeySmtpUseTLS] == "true"
return &SmtpConfig{
Host: host,
Port: port,
Username: settings[model.SettingKeySmtpUsername],
Password: settings[model.SettingKeySmtpPassword],
From: settings[model.SettingKeySmtpFrom],
FromName: settings[model.SettingKeySmtpFromName],
Username: settings[SettingKeySmtpUsername],
Password: settings[SettingKeySmtpPassword],
From: settings[SettingKeySmtpFrom],
FromName: settings[SettingKeySmtpFromName],
UseTLS: useTLS,
}, nil
}
......
......@@ -17,7 +17,6 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
......@@ -265,12 +264,12 @@ func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte
}
// SelectAccount 选择账号(粘性会话+优先级)
func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*model.Account, error) {
func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) {
return s.SelectAccountForModel(ctx, groupID, sessionHash, "")
}
// SelectAccountForModel 选择支持指定模型的账号(粘性会话+优先级+模型映射)
func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*model.Account, error) {
func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
// 1. 查询粘性会话
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
......@@ -289,19 +288,19 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
}
// 2. 获取可调度账号列表(排除限流和过载的账号,仅限 Anthropic 平台)
var accounts []model.Account
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, model.PlatformAnthropic)
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformAnthropic)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, model.PlatformAnthropic)
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformAnthropic)
}
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
// 3. 按优先级+最久未用选择(考虑模型支持)
var selected *model.Account
var selected *Account
for i := range accounts {
acc := &accounts[i]
// 检查模型支持
......@@ -341,12 +340,12 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
}
// GetAccessToken 获取账号凭证
func (s *GatewayService) GetAccessToken(ctx context.Context, account *model.Account) (string, string, error) {
func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
switch account.Type {
case model.AccountTypeOAuth, model.AccountTypeSetupToken:
case AccountTypeOAuth, AccountTypeSetupToken:
// Both oauth and setup-token use OAuth token flow
return s.getOAuthToken(ctx, account)
case model.AccountTypeApiKey:
case AccountTypeApiKey:
apiKey := account.GetCredential("api_key")
if apiKey == "" {
return "", "", errors.New("api_key not found in credentials")
......@@ -357,7 +356,7 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *model.Acco
}
}
func (s *GatewayService) getOAuthToken(ctx context.Context, account *model.Account) (string, string, error) {
func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (string, string, error) {
accessToken := account.GetCredential("access_token")
if accessToken == "" {
return "", "", errors.New("access_token not found in credentials")
......@@ -372,10 +371,7 @@ const (
retryDelay = 3 * time.Second // 重试等待时间
)
// shouldRetryUpstreamError 判断是否应该重试上游错误
// OAuth/Setup Token 账号:仅 403 重试
// API Key 账号:未配置的错误码重试
func (s *GatewayService) shouldRetryUpstreamError(account *model.Account, statusCode int) bool {
func (s *GatewayService) shouldRetryUpstreamError(account *Account, statusCode int) bool {
// OAuth/Setup Token 账号:仅 403 重试
if account.IsOAuth() {
return statusCode == 403
......@@ -386,7 +382,7 @@ func (s *GatewayService) shouldRetryUpstreamError(account *model.Account, status
}
// Forward 转发请求到Claude API
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *model.Account, body []byte) (*ForwardResult, error) {
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
startTime := time.Now()
// 解析请求获取model和stream
......@@ -412,7 +408,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
// 应用模型映射(仅对apikey类型账号)
originalModel := req.Model
if account.Type == model.AccountTypeApiKey {
if account.Type == AccountTypeApiKey {
mappedModel := account.GetMappedModel(req.Model)
if mappedModel != req.Model {
// 替换请求体中的模型名
......@@ -504,10 +500,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
}, nil
}
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token, tokenType string) (*http.Request, error) {
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType string) (*http.Request, error) {
// 确定目标URL
targetURL := claudeAPIURL
if account.Type == model.AccountTypeApiKey {
if account.Type == AccountTypeApiKey {
baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages"
}
......@@ -631,7 +627,7 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str
return claude.DefaultBetaHeader
}
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account) (*ForwardResult, error) {
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
body, _ := io.ReadAll(resp.Body)
// 处理上游错误,标记账号状态
......@@ -686,7 +682,7 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
// handleRetryExhaustedError 处理重试耗尽后的错误
// OAuth 403:标记账号异常
// API Key 未配置错误码:仅返回错误,不标记账号
func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account) (*ForwardResult, error) {
func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
body, _ := io.ReadAll(resp.Body)
statusCode := resp.StatusCode
......@@ -717,7 +713,7 @@ type streamingResult struct {
firstTokenMs *int
}
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) {
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) {
// 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
......@@ -856,7 +852,7 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
}
}
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account, originalModel, mappedModel string) (*ClaudeUsage, error) {
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) {
// 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
......@@ -915,10 +911,10 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
// RecordUsageInput 记录使用量的输入参数
type RecordUsageInput struct {
Result *ForwardResult
ApiKey *model.ApiKey
User *model.User
Account *model.Account
Subscription *model.UserSubscription // 可选:订阅信息
ApiKey *ApiKey
User *User
Account *Account
Subscription *UserSubscription // 可选:订阅信息
}
// RecordUsage 记录使用量并扣费(或更新订阅用量)
......@@ -952,14 +948,14 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
// 判断计费方式:订阅模式 vs 余额模式
isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
billingType := model.BillingTypeBalance
billingType := BillingTypeBalance
if isSubscriptionBilling {
billingType = model.BillingTypeSubscription
billingType = BillingTypeSubscription
}
// 创建使用日志
durationMs := int(result.Duration.Milliseconds())
usageLog := &model.UsageLog{
usageLog := &UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
......@@ -1038,9 +1034,9 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
// ForwardCountTokens 转发 count_tokens 请求到上游 API
// 特点:不记录使用量、仅支持非流式响应
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *model.Account, body []byte) error {
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, body []byte) error {
// 应用模型映射(仅对 apikey 类型账号)
if account.Type == model.AccountTypeApiKey {
if account.Type == AccountTypeApiKey {
var req struct {
Model string `json:"model"`
}
......@@ -1113,10 +1109,10 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
// buildCountTokensRequest 构建 count_tokens 上游请求
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token, tokenType string) (*http.Request, error) {
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType string) (*http.Request, error) {
// 确定目标 URL
targetURL := claudeAPICountTokensURL
if account.Type == model.AccountTypeApiKey {
if account.Type == AccountTypeApiKey {
baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages/count_tokens"
}
......
package service
import "time"
type Group struct {
ID int64
Name string
Description string
Platform string
RateMultiplier float64
IsExclusive bool
Status string
SubscriptionType string
DailyLimitUSD *float64
WeeklyLimitUSD *float64
MonthlyLimitUSD *float64
CreatedAt time.Time
UpdatedAt time.Time
AccountGroups []AccountGroup
AccountCount int64
}
func (g *Group) IsActive() bool {
return g.Status == StatusActive
}
func (g *Group) IsSubscriptionType() bool {
return g.SubscriptionType == SubscriptionTypeSubscription
}
func (g *Group) IsFreeSubscription() bool {
return g.IsSubscriptionType() && g.RateMultiplier == 0
}
func (g *Group) HasDailyLimit() bool {
return g.DailyLimitUSD != nil && *g.DailyLimitUSD > 0
}
func (g *Group) HasWeeklyLimit() bool {
return g.WeeklyLimitUSD != nil && *g.WeeklyLimitUSD > 0
}
func (g *Group) HasMonthlyLimit() bool {
return g.MonthlyLimitUSD != nil && *g.MonthlyLimitUSD > 0
}
......@@ -5,7 +5,6 @@ import (
"fmt"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
......@@ -15,16 +14,16 @@ var (
)
type GroupRepository interface {
Create(ctx context.Context, group *model.Group) error
GetByID(ctx context.Context, id int64) (*model.Group, error)
Update(ctx context.Context, group *model.Group) error
Create(ctx context.Context, group *Group) error
GetByID(ctx context.Context, id int64) (*Group, error)
Update(ctx context.Context, group *Group) error
Delete(ctx context.Context, id int64) error
DeleteCascade(ctx context.Context, id int64) ([]int64, error)
List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error)
ListActive(ctx context.Context) ([]model.Group, error)
ListActiveByPlatform(ctx context.Context, platform string) ([]model.Group, error)
List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error)
ListActive(ctx context.Context) ([]Group, error)
ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error)
ExistsByName(ctx context.Context, name string) (bool, error)
GetAccountCount(ctx context.Context, groupID int64) (int64, error)
......@@ -61,7 +60,7 @@ func NewGroupService(groupRepo GroupRepository) *GroupService {
}
// Create 创建分组
func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*model.Group, error) {
func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*Group, error) {
// 检查名称是否已存在
exists, err := s.groupRepo.ExistsByName(ctx, req.Name)
if err != nil {
......@@ -72,12 +71,14 @@ func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*mod
}
// 创建分组
group := &model.Group{
group := &Group{
Name: req.Name,
Description: req.Description,
Platform: PlatformAnthropic,
RateMultiplier: req.RateMultiplier,
IsExclusive: req.IsExclusive,
Status: model.StatusActive,
Status: StatusActive,
SubscriptionType: SubscriptionTypeStandard,
}
if err := s.groupRepo.Create(ctx, group); err != nil {
......@@ -88,7 +89,7 @@ func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*mod
}
// GetByID 根据ID获取分组
func (s *GroupService) GetByID(ctx context.Context, id int64) (*model.Group, error) {
func (s *GroupService) GetByID(ctx context.Context, id int64) (*Group, error) {
group, err := s.groupRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get group: %w", err)
......@@ -97,7 +98,7 @@ func (s *GroupService) GetByID(ctx context.Context, id int64) (*model.Group, err
}
// List 获取分组列表
func (s *GroupService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) {
func (s *GroupService) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
groups, pagination, err := s.groupRepo.List(ctx, params)
if err != nil {
return nil, nil, fmt.Errorf("list groups: %w", err)
......@@ -106,7 +107,7 @@ func (s *GroupService) List(ctx context.Context, params pagination.PaginationPar
}
// ListActive 获取活跃分组列表
func (s *GroupService) ListActive(ctx context.Context) ([]model.Group, error) {
func (s *GroupService) ListActive(ctx context.Context) ([]Group, error) {
groups, err := s.groupRepo.ListActive(ctx)
if err != nil {
return nil, fmt.Errorf("list active groups: %w", err)
......@@ -115,7 +116,7 @@ func (s *GroupService) ListActive(ctx context.Context) ([]model.Group, error) {
}
// Update 更新分组
func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequest) (*model.Group, error) {
func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequest) (*Group, error) {
group, err := s.groupRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get group: %w", err)
......
......@@ -6,7 +6,6 @@ import (
"log"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
)
......@@ -274,7 +273,7 @@ func (s *OAuthService) RefreshToken(ctx context.Context, refreshToken string, pr
}
// RefreshAccountToken refreshes token for an account
func (s *OAuthService) RefreshAccountToken(ctx context.Context, account *model.Account) (*TokenInfo, error) {
func (s *OAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*TokenInfo, error) {
refreshToken := account.GetCredential("refresh_token")
if refreshToken == "" {
return nil, fmt.Errorf("no refresh token available")
......
......@@ -16,7 +16,6 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/gin-gonic/gin"
)
......@@ -119,12 +118,12 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context) string {
}
// SelectAccount selects an OpenAI account with sticky session support
func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*model.Account, error) {
func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) {
return s.SelectAccountForModel(ctx, groupID, sessionHash, "")
}
// SelectAccountForModel selects an account supporting the requested model
func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*model.Account, error) {
func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
// 1. Check sticky session
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash)
......@@ -139,19 +138,19 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
}
// 2. Get schedulable OpenAI accounts
var accounts []model.Account
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, model.PlatformOpenAI)
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, model.PlatformOpenAI)
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
}
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
// 3. Select by priority + LRU
var selected *model.Account
var selected *Account
for i := range accounts {
acc := &accounts[i]
// Check model support
......@@ -189,15 +188,15 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
}
// GetAccessToken gets the access token for an OpenAI account
func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *model.Account) (string, string, error) {
func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
switch account.Type {
case model.AccountTypeOAuth:
case AccountTypeOAuth:
accessToken := account.GetOpenAIAccessToken()
if accessToken == "" {
return "", "", errors.New("access_token not found in credentials")
}
return accessToken, "oauth", nil
case model.AccountTypeApiKey:
case AccountTypeApiKey:
apiKey := account.GetOpenAIApiKey()
if apiKey == "" {
return "", "", errors.New("api_key not found in credentials")
......@@ -209,7 +208,7 @@ func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *mode
}
// Forward forwards request to OpenAI API
func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *model.Account, body []byte) (*OpenAIForwardResult, error) {
func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*OpenAIForwardResult, error) {
startTime := time.Now()
// Parse request body once (avoid multiple parse/serialize cycles)
......@@ -234,7 +233,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}
// For OAuth accounts using ChatGPT internal API, add store: false
if account.Type == model.AccountTypeOAuth {
if account.Type == AccountTypeOAuth {
reqBody["store"] = false
bodyModified = true
}
......@@ -296,7 +295,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
if account.Type == model.AccountTypeOAuth {
if account.Type == AccountTypeOAuth {
if snapshot := extractCodexUsageHeaders(resp.Header); snapshot != nil {
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
}
......@@ -312,14 +311,14 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}, nil
}
func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token string, isStream bool) (*http.Request, error) {
func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool) (*http.Request, error) {
// Determine target URL based on account type
var targetURL string
switch account.Type {
case model.AccountTypeOAuth:
case AccountTypeOAuth:
// OAuth accounts use ChatGPT internal API
targetURL = chatgptCodexURL
case model.AccountTypeApiKey:
case AccountTypeApiKey:
// API Key accounts use Platform API or custom base URL
baseURL := account.GetOpenAIBaseURL()
if baseURL != "" {
......@@ -340,7 +339,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
req.Header.Set("authorization", "Bearer "+token)
// Set headers specific to OAuth accounts (ChatGPT internal API)
if account.Type == model.AccountTypeOAuth {
if account.Type == AccountTypeOAuth {
// Required: set Host for ChatGPT API (must use req.Host, not Header.Set)
req.Host = "chatgpt.com"
// Required: set chatgpt-account-id header
......@@ -380,7 +379,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
return req, nil
}
func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account) (*OpenAIForwardResult, error) {
func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*OpenAIForwardResult, error) {
body, _ := io.ReadAll(resp.Body)
// Check custom error codes
......@@ -436,7 +435,7 @@ type openaiStreamingResult struct {
firstTokenMs *int
}
func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) {
func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) {
// Set SSE response headers
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
......@@ -552,7 +551,7 @@ func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
}
}
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
......@@ -618,10 +617,10 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
// OpenAIRecordUsageInput input for recording usage
type OpenAIRecordUsageInput struct {
Result *OpenAIForwardResult
ApiKey *model.ApiKey
User *model.User
Account *model.Account
Subscription *model.UserSubscription
ApiKey *ApiKey
User *User
Account *Account
Subscription *UserSubscription
}
// RecordUsage records usage and deducts balance
......@@ -660,14 +659,14 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
// Determine billing type
isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
billingType := model.BillingTypeBalance
billingType := BillingTypeBalance
if isSubscriptionBilling {
billingType = model.BillingTypeSubscription
billingType = BillingTypeSubscription
}
// Create usage log
durationMs := int(result.Duration.Milliseconds())
usageLog := &model.UsageLog{
usageLog := &UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
......
......@@ -5,7 +5,6 @@ import (
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
)
......@@ -200,7 +199,7 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri
}
// RefreshAccountToken refreshes token for an OpenAI account
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *model.Account) (*OpenAITokenInfo, error) {
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
if !account.IsOpenAI() {
return nil, fmt.Errorf("account is not an OpenAI account")
}
......
package service
import (
"fmt"
"time"
)
type Proxy struct {
ID int64
Name string
Protocol string
Host string
Port int
Username string
Password string
Status string
CreatedAt time.Time
UpdatedAt time.Time
}
func (p *Proxy) IsActive() bool {
return p.Status == StatusActive
}
func (p *Proxy) URL() string {
if p.Username != "" && p.Password != "" {
return fmt.Sprintf("%s://%s:%s@%s:%d", p.Protocol, p.Username, p.Password, p.Host, p.Port)
}
return fmt.Sprintf("%s://%s:%d", p.Protocol, p.Host, p.Port)
}
type ProxyWithAccountCount struct {
Proxy
AccountCount int64
}
......@@ -5,7 +5,6 @@ import (
"fmt"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
......@@ -14,15 +13,15 @@ var (
)
type ProxyRepository interface {
Create(ctx context.Context, proxy *model.Proxy) error
GetByID(ctx context.Context, id int64) (*model.Proxy, error)
Update(ctx context.Context, proxy *model.Proxy) error
Create(ctx context.Context, proxy *Proxy) error
GetByID(ctx context.Context, id int64) (*Proxy, error)
Update(ctx context.Context, proxy *Proxy) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error)
ListActive(ctx context.Context) ([]model.Proxy, error)
ListActiveWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error)
List(ctx context.Context, params pagination.PaginationParams) ([]Proxy, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error)
ListActive(ctx context.Context) ([]Proxy, error)
ListActiveWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error)
ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error)
CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error)
......@@ -62,16 +61,16 @@ func NewProxyService(proxyRepo ProxyRepository) *ProxyService {
}
// Create 创建代理
func (s *ProxyService) Create(ctx context.Context, req CreateProxyRequest) (*model.Proxy, error) {
func (s *ProxyService) Create(ctx context.Context, req CreateProxyRequest) (*Proxy, error) {
// 创建代理
proxy := &model.Proxy{
proxy := &Proxy{
Name: req.Name,
Protocol: req.Protocol,
Host: req.Host,
Port: req.Port,
Username: req.Username,
Password: req.Password,
Status: model.StatusActive,
Status: StatusActive,
}
if err := s.proxyRepo.Create(ctx, proxy); err != nil {
......@@ -82,7 +81,7 @@ func (s *ProxyService) Create(ctx context.Context, req CreateProxyRequest) (*mod
}
// GetByID 根据ID获取代理
func (s *ProxyService) GetByID(ctx context.Context, id int64) (*model.Proxy, error) {
func (s *ProxyService) GetByID(ctx context.Context, id int64) (*Proxy, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get proxy: %w", err)
......@@ -91,7 +90,7 @@ func (s *ProxyService) GetByID(ctx context.Context, id int64) (*model.Proxy, err
}
// List 获取代理列表
func (s *ProxyService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) {
func (s *ProxyService) List(ctx context.Context, params pagination.PaginationParams) ([]Proxy, *pagination.PaginationResult, error) {
proxies, pagination, err := s.proxyRepo.List(ctx, params)
if err != nil {
return nil, nil, fmt.Errorf("list proxies: %w", err)
......@@ -100,7 +99,7 @@ func (s *ProxyService) List(ctx context.Context, params pagination.PaginationPar
}
// ListActive 获取活跃代理列表
func (s *ProxyService) ListActive(ctx context.Context) ([]model.Proxy, error) {
func (s *ProxyService) ListActive(ctx context.Context) ([]Proxy, error) {
proxies, err := s.proxyRepo.ListActive(ctx)
if err != nil {
return nil, fmt.Errorf("list active proxies: %w", err)
......@@ -109,7 +108,7 @@ func (s *ProxyService) ListActive(ctx context.Context) ([]model.Proxy, error) {
}
// Update 更新代理
func (s *ProxyService) Update(ctx context.Context, id int64, req UpdateProxyRequest) (*model.Proxy, error) {
func (s *ProxyService) Update(ctx context.Context, id int64, req UpdateProxyRequest) (*Proxy, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get proxy: %w", err)
......
......@@ -8,7 +8,6 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/model"
)
// RateLimitService 处理限流和过载状态管理
......@@ -27,7 +26,7 @@ func NewRateLimitService(accountRepo AccountRepository, cfg *config.Config) *Rat
// HandleUpstreamError 处理上游错误响应,标记账号状态
// 返回是否应该停止该账号的调度
func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *model.Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) {
func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) {
// apikey 类型账号:检查自定义错误码配置
// 如果启用且错误码不在列表中,则不处理(不停止调度、不标记限流/过载)
if !account.ShouldHandleErrorCode(statusCode) {
......@@ -60,7 +59,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *mod
}
// handleAuthError 处理认证类错误(401/403),停止账号调度
func (s *RateLimitService) handleAuthError(ctx context.Context, account *model.Account, errorMsg string) {
func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account, errorMsg string) {
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
log.Printf("SetError failed for account %d: %v", account.ID, err)
return
......@@ -70,7 +69,7 @@ func (s *RateLimitService) handleAuthError(ctx context.Context, account *model.A
// handle429 处理429限流错误
// 解析响应头获取重置时间,标记账号为限流状态
func (s *RateLimitService) handle429(ctx context.Context, account *model.Account, headers http.Header) {
func (s *RateLimitService) handle429(ctx context.Context, account *Account, headers http.Header) {
// 解析重置时间戳
resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset")
if resetTimestamp == "" {
......@@ -113,7 +112,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *model.Account
// handle529 处理529过载错误
// 根据配置设置过载冷却时间
func (s *RateLimitService) handle529(ctx context.Context, account *model.Account) {
func (s *RateLimitService) handle529(ctx context.Context, account *Account) {
cooldownMinutes := s.cfg.RateLimit.OverloadCooldownMinutes
if cooldownMinutes <= 0 {
cooldownMinutes = 10 // 默认10分钟
......@@ -129,7 +128,7 @@ func (s *RateLimitService) handle529(ctx context.Context, account *model.Account
}
// UpdateSessionWindow 从成功响应更新5h窗口状态
func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *model.Account, headers http.Header) {
func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Account, headers http.Header) {
status := headers.Get("anthropic-ratelimit-unified-5h-status")
if status == "" {
return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment