"frontend/src/i18n/git@web.lueluesay.top:chenxi/sub2api.git" did not exist on "9d698d9306d7c656c0ac48ba8bf3091dfc77f843"
Commit 22f07a7b authored by shaw's avatar shaw
Browse files

Merge PR #36: refactor: 调整项目结构为单向依赖

parents ecb2c535 e5a77853
...@@ -6,7 +6,6 @@ import ( ...@@ -6,7 +6,6 @@ import (
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
) )
...@@ -15,29 +14,29 @@ var ( ...@@ -15,29 +14,29 @@ var (
) )
type AccountRepository interface { type AccountRepository interface {
Create(ctx context.Context, account *model.Account) error Create(ctx context.Context, account *Account) error
GetByID(ctx context.Context, id int64) (*model.Account, error) GetByID(ctx context.Context, id int64) (*Account, error)
// GetByCRSAccountID finds an account previously synced from CRS. // GetByCRSAccountID finds an account previously synced from CRS.
// Returns (nil, nil) if not found. // Returns (nil, nil) if not found.
GetByCRSAccountID(ctx context.Context, crsAccountID string) (*model.Account, error) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error)
Update(ctx context.Context, account *model.Account) error Update(ctx context.Context, account *Account) error
Delete(ctx context.Context, id int64) error Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error)
ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) ListByGroup(ctx context.Context, groupID int64) ([]Account, error)
ListActive(ctx context.Context) ([]model.Account, error) ListActive(ctx context.Context) ([]Account, error)
ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) ListByPlatform(ctx context.Context, platform string) ([]Account, error)
UpdateLastUsed(ctx context.Context, id int64) error UpdateLastUsed(ctx context.Context, id int64) error
SetError(ctx context.Context, id int64, errorMsg string) error SetError(ctx context.Context, id int64, errorMsg string) error
SetSchedulable(ctx context.Context, id int64, schedulable bool) error SetSchedulable(ctx context.Context, id int64, schedulable bool) error
BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error
ListSchedulable(ctx context.Context) ([]model.Account, error) ListSchedulable(ctx context.Context) ([]Account, error)
ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error)
ListSchedulableByPlatform(ctx context.Context, platform string) ([]model.Account, error) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error)
ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]model.Account, error) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error)
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
SetOverloaded(ctx context.Context, id int64, until time.Time) error SetOverloaded(ctx context.Context, id int64, until time.Time) error
...@@ -99,7 +98,7 @@ func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository) ...@@ -99,7 +98,7 @@ func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository)
} }
// Create 创建账号 // Create 创建账号
func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (*model.Account, error) { func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (*Account, error) {
// 验证分组是否存在(如果指定了分组) // 验证分组是否存在(如果指定了分组)
if len(req.GroupIDs) > 0 { if len(req.GroupIDs) > 0 {
for _, groupID := range req.GroupIDs { for _, groupID := range req.GroupIDs {
...@@ -111,7 +110,7 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) ( ...@@ -111,7 +110,7 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
} }
// 创建账号 // 创建账号
account := &model.Account{ account := &Account{
Name: req.Name, Name: req.Name,
Platform: req.Platform, Platform: req.Platform,
Type: req.Type, Type: req.Type,
...@@ -120,7 +119,7 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) ( ...@@ -120,7 +119,7 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
ProxyID: req.ProxyID, ProxyID: req.ProxyID,
Concurrency: req.Concurrency, Concurrency: req.Concurrency,
Priority: req.Priority, Priority: req.Priority,
Status: model.StatusActive, Status: StatusActive,
} }
if err := s.accountRepo.Create(ctx, account); err != nil { if err := s.accountRepo.Create(ctx, account); err != nil {
...@@ -138,7 +137,7 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) ( ...@@ -138,7 +137,7 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
} }
// GetByID 根据ID获取账号 // GetByID 根据ID获取账号
func (s *AccountService) GetByID(ctx context.Context, id int64) (*model.Account, error) { func (s *AccountService) GetByID(ctx context.Context, id int64) (*Account, error) {
account, err := s.accountRepo.GetByID(ctx, id) account, err := s.accountRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, fmt.Errorf("get account: %w", err) return nil, fmt.Errorf("get account: %w", err)
...@@ -147,7 +146,7 @@ func (s *AccountService) GetByID(ctx context.Context, id int64) (*model.Account, ...@@ -147,7 +146,7 @@ func (s *AccountService) GetByID(ctx context.Context, id int64) (*model.Account,
} }
// List 获取账号列表 // List 获取账号列表
func (s *AccountService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) { func (s *AccountService) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
accounts, pagination, err := s.accountRepo.List(ctx, params) accounts, pagination, err := s.accountRepo.List(ctx, params)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("list accounts: %w", err) return nil, nil, fmt.Errorf("list accounts: %w", err)
...@@ -156,7 +155,7 @@ func (s *AccountService) List(ctx context.Context, params pagination.PaginationP ...@@ -156,7 +155,7 @@ func (s *AccountService) List(ctx context.Context, params pagination.PaginationP
} }
// ListByPlatform 根据平台获取账号列表 // ListByPlatform 根据平台获取账号列表
func (s *AccountService) ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) { func (s *AccountService) ListByPlatform(ctx context.Context, platform string) ([]Account, error) {
accounts, err := s.accountRepo.ListByPlatform(ctx, platform) accounts, err := s.accountRepo.ListByPlatform(ctx, platform)
if err != nil { if err != nil {
return nil, fmt.Errorf("list accounts by platform: %w", err) return nil, fmt.Errorf("list accounts by platform: %w", err)
...@@ -165,7 +164,7 @@ func (s *AccountService) ListByPlatform(ctx context.Context, platform string) ([ ...@@ -165,7 +164,7 @@ func (s *AccountService) ListByPlatform(ctx context.Context, platform string) ([
} }
// ListByGroup 根据分组获取账号列表 // ListByGroup 根据分组获取账号列表
func (s *AccountService) ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) { func (s *AccountService) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
accounts, err := s.accountRepo.ListByGroup(ctx, groupID) accounts, err := s.accountRepo.ListByGroup(ctx, groupID)
if err != nil { if err != nil {
return nil, fmt.Errorf("list accounts by group: %w", err) return nil, fmt.Errorf("list accounts by group: %w", err)
...@@ -174,7 +173,7 @@ func (s *AccountService) ListByGroup(ctx context.Context, groupID int64) ([]mode ...@@ -174,7 +173,7 @@ func (s *AccountService) ListByGroup(ctx context.Context, groupID int64) ([]mode
} }
// Update 更新账号 // Update 更新账号
func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccountRequest) (*model.Account, error) { func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccountRequest) (*Account, error) {
account, err := s.accountRepo.GetByID(ctx, id) account, err := s.accountRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, fmt.Errorf("get account: %w", err) return nil, fmt.Errorf("get account: %w", err)
...@@ -290,13 +289,13 @@ func (s *AccountService) TestCredentials(ctx context.Context, id int64) error { ...@@ -290,13 +289,13 @@ func (s *AccountService) TestCredentials(ctx context.Context, id int64) error {
// 根据平台执行不同的测试逻辑 // 根据平台执行不同的测试逻辑
switch account.Platform { switch account.Platform {
case model.PlatformAnthropic: case PlatformAnthropic:
// TODO: 测试Anthropic API凭证 // TODO: 测试Anthropic API凭证
return nil return nil
case model.PlatformOpenAI: case PlatformOpenAI:
// TODO: 测试OpenAI API凭证 // TODO: 测试OpenAI API凭证
return nil return nil
case model.PlatformGemini: case PlatformGemini:
// TODO: 测试Gemini API凭证 // TODO: 测试Gemini API凭证
return nil return nil
default: default:
......
...@@ -14,7 +14,6 @@ import ( ...@@ -14,7 +14,6 @@ import (
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
...@@ -127,7 +126,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int ...@@ -127,7 +126,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
} }
// testClaudeAccountConnection tests an Anthropic Claude account's connection // testClaudeAccountConnection tests an Anthropic Claude account's connection
func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account *model.Account, modelID string) error { func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account *Account, modelID string) error {
ctx := c.Request.Context() ctx := c.Request.Context()
// Determine the model to use // Determine the model to use
...@@ -254,7 +253,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account ...@@ -254,7 +253,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
} }
// testOpenAIAccountConnection tests an OpenAI account's connection // testOpenAIAccountConnection tests an OpenAI account's connection
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *model.Account, modelID string) error { func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string) error {
ctx := c.Request.Context() ctx := c.Request.Context()
// Default to openai.DefaultTestModel for OpenAI testing // Default to openai.DefaultTestModel for OpenAI testing
......
...@@ -7,24 +7,23 @@ import ( ...@@ -7,24 +7,23 @@ import (
"sync" "sync"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
) )
type UsageLogRepository interface { type UsageLogRepository interface {
Create(ctx context.Context, log *model.UsageLog) error Create(ctx context.Context, log *UsageLog) error
GetByID(ctx context.Context, id int64) (*model.UsageLog, error) GetByID(ctx context.Context, id int64) (*UsageLog, error)
Delete(ctx context.Context, id int64) error Delete(ctx context.Context, id int64) error
ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error)
GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error)
...@@ -44,7 +43,7 @@ type UsageLogRepository interface { ...@@ -44,7 +43,7 @@ type UsageLogRepository interface {
GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error)
// Admin usage listing/stats // Admin usage listing/stats
ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]UsageLog, *pagination.PaginationResult, error)
GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*usagestats.UsageStats, error) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*usagestats.UsageStats, error)
// Account stats // Account stats
...@@ -163,7 +162,7 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U ...@@ -163,7 +162,7 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
} }
// Setup Token账号:根据session_window推算(没有profile scope,无法调用usage API) // Setup Token账号:根据session_window推算(没有profile scope,无法调用usage API)
if account.Type == model.AccountTypeSetupToken { if account.Type == AccountTypeSetupToken {
usage := s.estimateSetupTokenUsage(account) usage := s.estimateSetupTokenUsage(account)
// 添加窗口统计 // 添加窗口统计
s.addWindowStats(ctx, account, usage) s.addWindowStats(ctx, account, usage)
...@@ -175,7 +174,7 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U ...@@ -175,7 +174,7 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
} }
// addWindowStats 为usage数据添加窗口期统计 // addWindowStats 为usage数据添加窗口期统计
func (s *AccountUsageService) addWindowStats(ctx context.Context, account *model.Account, usage *UsageInfo) { func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Account, usage *UsageInfo) {
if usage.FiveHour == nil { if usage.FiveHour == nil {
return return
} }
...@@ -225,7 +224,7 @@ func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountI ...@@ -225,7 +224,7 @@ func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountI
} }
// fetchOAuthUsage 从Anthropic API获取OAuth账号的使用量 // fetchOAuthUsage 从Anthropic API获取OAuth账号的使用量
func (s *AccountUsageService) fetchOAuthUsage(ctx context.Context, account *model.Account) (*UsageInfo, error) { func (s *AccountUsageService) fetchOAuthUsage(ctx context.Context, account *Account) (*UsageInfo, error) {
accessToken := account.GetCredential("access_token") accessToken := account.GetCredential("access_token")
if accessToken == "" { if accessToken == "" {
return nil, fmt.Errorf("no access token available") return nil, fmt.Errorf("no access token available")
...@@ -320,7 +319,7 @@ func (s *AccountUsageService) buildUsageInfo(resp *ClaudeUsageResponse, updatedA ...@@ -320,7 +319,7 @@ func (s *AccountUsageService) buildUsageInfo(resp *ClaudeUsageResponse, updatedA
} }
// estimateSetupTokenUsage 根据session_window推算Setup Token账号的使用量 // estimateSetupTokenUsage 根据session_window推算Setup Token账号的使用量
func (s *AccountUsageService) estimateSetupTokenUsage(account *model.Account) *UsageInfo { func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageInfo {
info := &UsageInfo{} info := &UsageInfo{}
// 如果有session_window信息 // 如果有session_window信息
......
...@@ -7,62 +7,61 @@ import ( ...@@ -7,62 +7,61 @@ import (
"log" "log"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
) )
// AdminService interface defines admin management operations // AdminService interface defines admin management operations
type AdminService interface { type AdminService interface {
// User management // User management
ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]model.User, int64, error) ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]User, int64, error)
GetUser(ctx context.Context, id int64) (*model.User, error) GetUser(ctx context.Context, id int64) (*User, error)
CreateUser(ctx context.Context, input *CreateUserInput) (*model.User, error) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error)
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*model.User, error) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error)
DeleteUser(ctx context.Context, id int64) error DeleteUser(ctx context.Context, id int64) error
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*model.User, error) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]ApiKey, int64, error)
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
// Group management // Group management
ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]model.Group, int64, error) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error)
GetAllGroups(ctx context.Context) ([]model.Group, error) GetAllGroups(ctx context.Context) ([]Group, error)
GetAllGroupsByPlatform(ctx context.Context, platform string) ([]model.Group, error) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error)
GetGroup(ctx context.Context, id int64) (*model.Group, error) GetGroup(ctx context.Context, id int64) (*Group, error)
CreateGroup(ctx context.Context, input *CreateGroupInput) (*model.Group, error) CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error)
UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*model.Group, error) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error)
DeleteGroup(ctx context.Context, id int64) error DeleteGroup(ctx context.Context, id int64) error
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]model.ApiKey, int64, error) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]ApiKey, int64, error)
// Account management // Account management
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]model.Account, int64, error) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error)
GetAccount(ctx context.Context, id int64) (*model.Account, error) GetAccount(ctx context.Context, id int64) (*Account, error)
CreateAccount(ctx context.Context, input *CreateAccountInput) (*model.Account, error) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error)
UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*model.Account, error) UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error)
DeleteAccount(ctx context.Context, id int64) error DeleteAccount(ctx context.Context, id int64) error
RefreshAccountCredentials(ctx context.Context, id int64) (*model.Account, error) RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error)
ClearAccountError(ctx context.Context, id int64) (*model.Account, error) ClearAccountError(ctx context.Context, id int64) (*Account, error)
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*model.Account, error) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error)
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
// Proxy management // Proxy management
ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]model.Proxy, int64, error) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error)
GetAllProxies(ctx context.Context) ([]model.Proxy, error) GetAllProxies(ctx context.Context) ([]Proxy, error)
GetAllProxiesWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error) GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error)
GetProxy(ctx context.Context, id int64) (*model.Proxy, error) GetProxy(ctx context.Context, id int64) (*Proxy, error)
CreateProxy(ctx context.Context, input *CreateProxyInput) (*model.Proxy, error) CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error)
UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*model.Proxy, error) UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error)
DeleteProxy(ctx context.Context, id int64) error DeleteProxy(ctx context.Context, id int64) error
GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]model.Account, int64, error) GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]Account, int64, error)
CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error)
TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error) TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error)
// Redeem code management // Redeem code management
ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]model.RedeemCode, int64, error) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error)
GetRedeemCode(ctx context.Context, id int64) (*model.RedeemCode, error) GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error)
GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]model.RedeemCode, error) GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error)
DeleteRedeemCode(ctx context.Context, id int64) error DeleteRedeemCode(ctx context.Context, id int64) error
BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error) BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error)
ExpireRedeemCode(ctx context.Context, id int64) (*model.RedeemCode, error) ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error)
} }
// Input types for admin operations // Input types for admin operations
...@@ -252,7 +251,7 @@ func NewAdminService( ...@@ -252,7 +251,7 @@ func NewAdminService(
} }
// User management implementations // User management implementations
func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]model.User, int64, error) { func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]User, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
users, result, err := s.userRepo.ListWithFilters(ctx, params, status, role, search) users, result, err := s.userRepo.ListWithFilters(ctx, params, status, role, search)
if err != nil { if err != nil {
...@@ -261,20 +260,21 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, st ...@@ -261,20 +260,21 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, st
return users, result.Total, nil return users, result.Total, nil
} }
func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*model.User, error) { func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) {
return s.userRepo.GetByID(ctx, id) return s.userRepo.GetByID(ctx, id)
} }
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*model.User, error) { func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) {
user := &model.User{ user := &User{
Email: input.Email, Email: input.Email,
Username: input.Username, Username: input.Username,
Wechat: input.Wechat, Wechat: input.Wechat,
Notes: input.Notes, Notes: input.Notes,
Role: "user", // Always create as regular user, never admin Role: RoleUser, // Always create as regular user, never admin
Balance: input.Balance, Balance: input.Balance,
Concurrency: input.Concurrency, Concurrency: input.Concurrency,
Status: model.StatusActive, Status: StatusActive,
AllowedGroups: input.AllowedGroups,
} }
if err := user.SetPassword(input.Password); err != nil { if err := user.SetPassword(input.Password); err != nil {
return nil, err return nil, err
...@@ -285,7 +285,7 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu ...@@ -285,7 +285,7 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu
return user, nil return user, nil
} }
func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*model.User, error) { func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) {
user, err := s.userRepo.GetByID(ctx, id) user, err := s.userRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -335,16 +335,16 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda ...@@ -335,16 +335,16 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
concurrencyDiff := user.Concurrency - oldConcurrency concurrencyDiff := user.Concurrency - oldConcurrency
if concurrencyDiff != 0 { if concurrencyDiff != 0 {
code, err := model.GenerateRedeemCode() code, err := GenerateRedeemCode()
if err != nil { if err != nil {
log.Printf("failed to generate adjustment redeem code: %v", err) log.Printf("failed to generate adjustment redeem code: %v", err)
return user, nil return user, nil
} }
adjustmentRecord := &model.RedeemCode{ adjustmentRecord := &RedeemCode{
Code: code, Code: code,
Type: model.AdjustmentTypeAdminConcurrency, Type: AdjustmentTypeAdminConcurrency,
Value: float64(concurrencyDiff), Value: float64(concurrencyDiff),
Status: model.StatusUsed, Status: StatusUsed,
UsedBy: &user.ID, UsedBy: &user.ID,
} }
now := time.Now() now := time.Now()
...@@ -369,7 +369,7 @@ func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error { ...@@ -369,7 +369,7 @@ func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error {
return s.userRepo.Delete(ctx, id) return s.userRepo.Delete(ctx, id)
} }
func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*model.User, error) { func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) {
user, err := s.userRepo.GetByID(ctx, userID) user, err := s.userRepo.GetByID(ctx, userID)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -406,17 +406,17 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, ...@@ -406,17 +406,17 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
balanceDiff := user.Balance - oldBalance balanceDiff := user.Balance - oldBalance
if balanceDiff != 0 { if balanceDiff != 0 {
code, err := model.GenerateRedeemCode() code, err := GenerateRedeemCode()
if err != nil { if err != nil {
log.Printf("failed to generate adjustment redeem code: %v", err) log.Printf("failed to generate adjustment redeem code: %v", err)
return user, nil return user, nil
} }
adjustmentRecord := &model.RedeemCode{ adjustmentRecord := &RedeemCode{
Code: code, Code: code,
Type: model.AdjustmentTypeAdminBalance, Type: AdjustmentTypeAdminBalance,
Value: balanceDiff, Value: balanceDiff,
Status: model.StatusUsed, Status: StatusUsed,
UsedBy: &user.ID, UsedBy: &user.ID,
Notes: notes, Notes: notes,
} }
...@@ -431,7 +431,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, ...@@ -431,7 +431,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
return user, nil return user, nil
} }
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error) { func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]ApiKey, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params) keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil { if err != nil {
...@@ -452,7 +452,7 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, ...@@ -452,7 +452,7 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64,
} }
// Group management implementations // Group management implementations
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]model.Group, int64, error) { func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, isExclusive) groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, isExclusive)
if err != nil { if err != nil {
...@@ -461,36 +461,36 @@ func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, p ...@@ -461,36 +461,36 @@ func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, p
return groups, result.Total, nil return groups, result.Total, nil
} }
func (s *adminServiceImpl) GetAllGroups(ctx context.Context) ([]model.Group, error) { func (s *adminServiceImpl) GetAllGroups(ctx context.Context) ([]Group, error) {
return s.groupRepo.ListActive(ctx) return s.groupRepo.ListActive(ctx)
} }
func (s *adminServiceImpl) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]model.Group, error) { func (s *adminServiceImpl) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error) {
return s.groupRepo.ListActiveByPlatform(ctx, platform) return s.groupRepo.ListActiveByPlatform(ctx, platform)
} }
func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*model.Group, error) { func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*Group, error) {
return s.groupRepo.GetByID(ctx, id) return s.groupRepo.GetByID(ctx, id)
} }
func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*model.Group, error) { func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) {
platform := input.Platform platform := input.Platform
if platform == "" { if platform == "" {
platform = model.PlatformAnthropic platform = PlatformAnthropic
} }
subscriptionType := input.SubscriptionType subscriptionType := input.SubscriptionType
if subscriptionType == "" { if subscriptionType == "" {
subscriptionType = model.SubscriptionTypeStandard subscriptionType = SubscriptionTypeStandard
} }
group := &model.Group{ group := &Group{
Name: input.Name, Name: input.Name,
Description: input.Description, Description: input.Description,
Platform: platform, Platform: platform,
RateMultiplier: input.RateMultiplier, RateMultiplier: input.RateMultiplier,
IsExclusive: input.IsExclusive, IsExclusive: input.IsExclusive,
Status: model.StatusActive, Status: StatusActive,
SubscriptionType: subscriptionType, SubscriptionType: subscriptionType,
DailyLimitUSD: input.DailyLimitUSD, DailyLimitUSD: input.DailyLimitUSD,
WeeklyLimitUSD: input.WeeklyLimitUSD, WeeklyLimitUSD: input.WeeklyLimitUSD,
...@@ -502,7 +502,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn ...@@ -502,7 +502,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
return group, nil return group, nil
} }
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*model.Group, error) { func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
group, err := s.groupRepo.GetByID(ctx, id) group, err := s.groupRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -571,7 +571,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error { ...@@ -571,7 +571,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
return nil return nil
} }
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]model.ApiKey, int64, error) { func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]ApiKey, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params) keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params)
if err != nil { if err != nil {
...@@ -581,7 +581,7 @@ func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, p ...@@ -581,7 +581,7 @@ func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, p
} }
// Account management implementations // Account management implementations
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]model.Account, int64, error) { func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search) accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search)
if err != nil { if err != nil {
...@@ -590,21 +590,21 @@ func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, ...@@ -590,21 +590,21 @@ func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int,
return accounts, result.Total, nil return accounts, result.Total, nil
} }
func (s *adminServiceImpl) GetAccount(ctx context.Context, id int64) (*model.Account, error) { func (s *adminServiceImpl) GetAccount(ctx context.Context, id int64) (*Account, error) {
return s.accountRepo.GetByID(ctx, id) return s.accountRepo.GetByID(ctx, id)
} }
func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*model.Account, error) { func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) {
account := &model.Account{ account := &Account{
Name: input.Name, Name: input.Name,
Platform: input.Platform, Platform: input.Platform,
Type: input.Type, Type: input.Type,
Credentials: model.JSONB(input.Credentials), Credentials: input.Credentials,
Extra: model.JSONB(input.Extra), Extra: input.Extra,
ProxyID: input.ProxyID, ProxyID: input.ProxyID,
Concurrency: input.Concurrency, Concurrency: input.Concurrency,
Priority: input.Priority, Priority: input.Priority,
Status: model.StatusActive, Status: StatusActive,
} }
if err := s.accountRepo.Create(ctx, account); err != nil { if err := s.accountRepo.Create(ctx, account); err != nil {
return nil, err return nil, err
...@@ -618,7 +618,7 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou ...@@ -618,7 +618,7 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
return account, nil return account, nil
} }
func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*model.Account, error) { func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error) {
account, err := s.accountRepo.GetByID(ctx, id) account, err := s.accountRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -631,10 +631,10 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U ...@@ -631,10 +631,10 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
account.Type = input.Type account.Type = input.Type
} }
if len(input.Credentials) > 0 { if len(input.Credentials) > 0 {
account.Credentials = model.JSONB(input.Credentials) account.Credentials = input.Credentials
} }
if len(input.Extra) > 0 { if len(input.Extra) > 0 {
account.Extra = model.JSONB(input.Extra) account.Extra = input.Extra
} }
if input.ProxyID != nil { if input.ProxyID != nil {
account.ProxyID = input.ProxyID account.ProxyID = input.ProxyID
...@@ -730,7 +730,7 @@ func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error { ...@@ -730,7 +730,7 @@ func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error {
return s.accountRepo.Delete(ctx, id) return s.accountRepo.Delete(ctx, id)
} }
func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int64) (*model.Account, error) { func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error) {
account, err := s.accountRepo.GetByID(ctx, id) account, err := s.accountRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -739,12 +739,12 @@ func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int ...@@ -739,12 +739,12 @@ func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int
return account, nil return account, nil
} }
func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*model.Account, error) { func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Account, error) {
account, err := s.accountRepo.GetByID(ctx, id) account, err := s.accountRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
account.Status = model.StatusActive account.Status = StatusActive
account.ErrorMessage = "" account.ErrorMessage = ""
if err := s.accountRepo.Update(ctx, account); err != nil { if err := s.accountRepo.Update(ctx, account); err != nil {
return nil, err return nil, err
...@@ -752,7 +752,7 @@ func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*mo ...@@ -752,7 +752,7 @@ func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*mo
return account, nil return account, nil
} }
func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*model.Account, error) { func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) {
if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil { if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil {
return nil, err return nil, err
} }
...@@ -760,7 +760,7 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, ...@@ -760,7 +760,7 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64,
} }
// Proxy management implementations // Proxy management implementations
func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]model.Proxy, int64, error) { func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search) proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search)
if err != nil { if err != nil {
...@@ -769,27 +769,27 @@ func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, ...@@ -769,27 +769,27 @@ func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int,
return proxies, result.Total, nil return proxies, result.Total, nil
} }
func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]model.Proxy, error) { func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]Proxy, error) {
return s.proxyRepo.ListActive(ctx) return s.proxyRepo.ListActive(ctx)
} }
func (s *adminServiceImpl) GetAllProxiesWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error) { func (s *adminServiceImpl) GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) {
return s.proxyRepo.ListActiveWithAccountCount(ctx) return s.proxyRepo.ListActiveWithAccountCount(ctx)
} }
func (s *adminServiceImpl) GetProxy(ctx context.Context, id int64) (*model.Proxy, error) { func (s *adminServiceImpl) GetProxy(ctx context.Context, id int64) (*Proxy, error) {
return s.proxyRepo.GetByID(ctx, id) return s.proxyRepo.GetByID(ctx, id)
} }
func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyInput) (*model.Proxy, error) { func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error) {
proxy := &model.Proxy{ proxy := &Proxy{
Name: input.Name, Name: input.Name,
Protocol: input.Protocol, Protocol: input.Protocol,
Host: input.Host, Host: input.Host,
Port: input.Port, Port: input.Port,
Username: input.Username, Username: input.Username,
Password: input.Password, Password: input.Password,
Status: model.StatusActive, Status: StatusActive,
} }
if err := s.proxyRepo.Create(ctx, proxy); err != nil { if err := s.proxyRepo.Create(ctx, proxy); err != nil {
return nil, err return nil, err
...@@ -797,7 +797,7 @@ func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyIn ...@@ -797,7 +797,7 @@ func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyIn
return proxy, nil return proxy, nil
} }
func (s *adminServiceImpl) UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*model.Proxy, error) { func (s *adminServiceImpl) UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id) proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -835,9 +835,9 @@ func (s *adminServiceImpl) DeleteProxy(ctx context.Context, id int64) error { ...@@ -835,9 +835,9 @@ func (s *adminServiceImpl) DeleteProxy(ctx context.Context, id int64) error {
return s.proxyRepo.Delete(ctx, id) return s.proxyRepo.Delete(ctx, id)
} }
func (s *adminServiceImpl) GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]model.Account, int64, error) { func (s *adminServiceImpl) GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]Account, int64, error) {
// Return mock data for now - would need a dedicated repository method // Return mock data for now - would need a dedicated repository method
return []model.Account{}, 0, nil return []Account{}, 0, nil
} }
func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) { func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) {
...@@ -845,7 +845,7 @@ func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, po ...@@ -845,7 +845,7 @@ func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, po
} }
// Redeem code management implementations // Redeem code management implementations
func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]model.RedeemCode, int64, error) { func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search) codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search)
if err != nil { if err != nil {
...@@ -854,13 +854,13 @@ func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize i ...@@ -854,13 +854,13 @@ func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize i
return codes, result.Total, nil return codes, result.Total, nil
} }
func (s *adminServiceImpl) GetRedeemCode(ctx context.Context, id int64) (*model.RedeemCode, error) { func (s *adminServiceImpl) GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) {
return s.redeemCodeRepo.GetByID(ctx, id) return s.redeemCodeRepo.GetByID(ctx, id)
} }
func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]model.RedeemCode, error) { func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error) {
// 如果是订阅类型,验证必须有 GroupID // 如果是订阅类型,验证必须有 GroupID
if input.Type == model.RedeemTypeSubscription { if input.Type == RedeemTypeSubscription {
if input.GroupID == nil { if input.GroupID == nil {
return nil, errors.New("group_id is required for subscription type") return nil, errors.New("group_id is required for subscription type")
} }
...@@ -874,20 +874,20 @@ func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *Gener ...@@ -874,20 +874,20 @@ func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *Gener
} }
} }
codes := make([]model.RedeemCode, 0, input.Count) codes := make([]RedeemCode, 0, input.Count)
for i := 0; i < input.Count; i++ { for i := 0; i < input.Count; i++ {
codeValue, err := model.GenerateRedeemCode() codeValue, err := GenerateRedeemCode()
if err != nil { if err != nil {
return nil, err return nil, err
} }
code := model.RedeemCode{ code := RedeemCode{
Code: codeValue, Code: codeValue,
Type: input.Type, Type: input.Type,
Value: input.Value, Value: input.Value,
Status: model.StatusUnused, Status: StatusUnused,
} }
// 订阅类型专用字段 // 订阅类型专用字段
if input.Type == model.RedeemTypeSubscription { if input.Type == RedeemTypeSubscription {
code.GroupID = input.GroupID code.GroupID = input.GroupID
code.ValidityDays = input.ValidityDays code.ValidityDays = input.ValidityDays
if code.ValidityDays <= 0 { if code.ValidityDays <= 0 {
...@@ -916,12 +916,12 @@ func (s *adminServiceImpl) BatchDeleteRedeemCodes(ctx context.Context, ids []int ...@@ -916,12 +916,12 @@ func (s *adminServiceImpl) BatchDeleteRedeemCodes(ctx context.Context, ids []int
return deleted, nil return deleted, nil
} }
func (s *adminServiceImpl) ExpireRedeemCode(ctx context.Context, id int64) (*model.RedeemCode, error) { func (s *adminServiceImpl) ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) {
code, err := s.redeemCodeRepo.GetByID(ctx, id) code, err := s.redeemCodeRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
code.Status = model.StatusExpired code.Status = StatusExpired
if err := s.redeemCodeRepo.Update(ctx, code); err != nil { if err := s.redeemCodeRepo.Update(ctx, code); err != nil {
return nil, err return nil, err
} }
......
package service
import "time"
type ApiKey struct {
ID int64
UserID int64
Key string
Name string
GroupID *int64
Status string
CreatedAt time.Time
UpdatedAt time.Time
User *User
Group *Group
}
func (k *ApiKey) IsActive() bool {
return k.Status == StatusActive
}
...@@ -10,7 +10,6 @@ import ( ...@@ -10,7 +10,6 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
...@@ -30,17 +29,17 @@ const ( ...@@ -30,17 +29,17 @@ const (
) )
type ApiKeyRepository interface { type ApiKeyRepository interface {
Create(ctx context.Context, key *model.ApiKey) error Create(ctx context.Context, key *ApiKey) error
GetByID(ctx context.Context, id int64) (*model.ApiKey, error) GetByID(ctx context.Context, id int64) (*ApiKey, error)
GetByKey(ctx context.Context, key string) (*model.ApiKey, error) GetByKey(ctx context.Context, key string) (*ApiKey, error)
Update(ctx context.Context, key *model.ApiKey) error Update(ctx context.Context, key *ApiKey) error
Delete(ctx context.Context, id int64) error Delete(ctx context.Context, id int64) error
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
CountByUserID(ctx context.Context, userID int64) (int64, error) CountByUserID(ctx context.Context, userID int64) (int64, error)
ExistsByKey(ctx context.Context, key string) (bool, error) ExistsByKey(ctx context.Context, key string) (bool, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error)
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
CountByGroupID(ctx context.Context, groupID int64) (int64, error) CountByGroupID(ctx context.Context, groupID int64) (int64, error)
} }
...@@ -168,7 +167,7 @@ func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID in ...@@ -168,7 +167,7 @@ func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID in
// canUserBindGroup 检查用户是否可以绑定指定分组 // canUserBindGroup 检查用户是否可以绑定指定分组
// 对于订阅类型分组:检查用户是否有有效订阅 // 对于订阅类型分组:检查用户是否有有效订阅
// 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑 // 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑
func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *model.User, group *model.Group) bool { func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool {
// 订阅类型分组:需要有效订阅 // 订阅类型分组:需要有效订阅
if group.IsSubscriptionType() { if group.IsSubscriptionType() {
_, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, group.ID) _, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, group.ID)
...@@ -179,7 +178,7 @@ func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *model.User, ...@@ -179,7 +178,7 @@ func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *model.User,
} }
// Create 创建API Key // Create 创建API Key
func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiKeyRequest) (*model.ApiKey, error) { func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiKeyRequest) (*ApiKey, error) {
// 验证用户存在 // 验证用户存在
user, err := s.userRepo.GetByID(ctx, userID) user, err := s.userRepo.GetByID(ctx, userID)
if err != nil { if err != nil {
...@@ -235,12 +234,12 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK ...@@ -235,12 +234,12 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
} }
// 创建API Key记录 // 创建API Key记录
apiKey := &model.ApiKey{ apiKey := &ApiKey{
UserID: userID, UserID: userID,
Key: key, Key: key,
Name: req.Name, Name: req.Name,
GroupID: req.GroupID, GroupID: req.GroupID,
Status: model.StatusActive, Status: StatusActive,
} }
if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil { if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil {
...@@ -251,7 +250,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK ...@@ -251,7 +250,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
} }
// List 获取用户的API Key列表 // List 获取用户的API Key列表
func (s *ApiKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) { func (s *ApiKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params) keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("list api keys: %w", err) return nil, nil, fmt.Errorf("list api keys: %w", err)
...@@ -260,7 +259,7 @@ func (s *ApiKeyService) List(ctx context.Context, userID int64, params paginatio ...@@ -260,7 +259,7 @@ func (s *ApiKeyService) List(ctx context.Context, userID int64, params paginatio
} }
// GetByID 根据ID获取API Key // GetByID 根据ID获取API Key
func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) { func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id) apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, fmt.Errorf("get api key: %w", err) return nil, fmt.Errorf("get api key: %w", err)
...@@ -269,7 +268,7 @@ func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*model.ApiKey, e ...@@ -269,7 +268,7 @@ func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*model.ApiKey, e
} }
// GetByKey 根据Key字符串获取API Key(用于认证) // GetByKey 根据Key字符串获取API Key(用于认证)
func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey, error) { func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*ApiKey, error) {
// 尝试从Redis缓存获取 // 尝试从Redis缓存获取
cacheKey := fmt.Sprintf("apikey:%s", key) cacheKey := fmt.Sprintf("apikey:%s", key)
...@@ -289,7 +288,7 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey ...@@ -289,7 +288,7 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey
} }
// Update 更新API Key // Update 更新API Key
func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*model.ApiKey, error) { func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*ApiKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id) apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, fmt.Errorf("get api key: %w", err) return nil, fmt.Errorf("get api key: %w", err)
...@@ -364,7 +363,7 @@ func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) erro ...@@ -364,7 +363,7 @@ func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) erro
} }
// ValidateKey 验证API Key是否有效(用于认证中间件) // ValidateKey 验证API Key是否有效(用于认证中间件)
func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*model.ApiKey, *model.User, error) { func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*ApiKey, *User, error) {
// 获取API Key // 获取API Key
apiKey, err := s.GetByKey(ctx, key) apiKey, err := s.GetByKey(ctx, key)
if err != nil { if err != nil {
...@@ -408,7 +407,7 @@ func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error { ...@@ -408,7 +407,7 @@ func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
// 返回用户可以选择的分组: // 返回用户可以选择的分组:
// - 标准类型分组:公开的(非专属)或用户被明确允许的 // - 标准类型分组:公开的(非专属)或用户被明确允许的
// - 订阅类型分组:用户有有效订阅的 // - 订阅类型分组:用户有有效订阅的
func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]model.Group, error) { func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) {
// 获取用户信息 // 获取用户信息
user, err := s.userRepo.GetByID(ctx, userID) user, err := s.userRepo.GetByID(ctx, userID)
if err != nil { if err != nil {
...@@ -434,7 +433,7 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([ ...@@ -434,7 +433,7 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([
} }
// 过滤出用户有权限的分组 // 过滤出用户有权限的分组
availableGroups := make([]model.Group, 0) availableGroups := make([]Group, 0)
for _, group := range allGroups { for _, group := range allGroups {
if s.canUserBindGroupInternal(user, &group, subscribedGroupIDs) { if s.canUserBindGroupInternal(user, &group, subscribedGroupIDs) {
availableGroups = append(availableGroups, group) availableGroups = append(availableGroups, group)
...@@ -445,7 +444,7 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([ ...@@ -445,7 +444,7 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([
} }
// canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据) // canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据)
func (s *ApiKeyService) canUserBindGroupInternal(user *model.User, group *model.Group, subscribedGroupIDs map[int64]bool) bool { func (s *ApiKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool {
// 订阅类型分组:需要有效订阅 // 订阅类型分组:需要有效订阅
if group.IsSubscriptionType() { if group.IsSubscriptionType() {
return subscribedGroupIDs[group.ID] return subscribedGroupIDs[group.ID]
...@@ -454,7 +453,7 @@ func (s *ApiKeyService) canUserBindGroupInternal(user *model.User, group *model. ...@@ -454,7 +453,7 @@ func (s *ApiKeyService) canUserBindGroupInternal(user *model.User, group *model.
return user.CanBindGroup(group.ID, group.IsExclusive) return user.CanBindGroup(group.ID, group.IsExclusive)
} }
func (s *ApiKeyService) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) { func (s *ApiKeyService) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) {
keys, err := s.apiKeyRepo.SearchApiKeys(ctx, userID, keyword, limit) keys, err := s.apiKeyRepo.SearchApiKeys(ctx, userID, keyword, limit)
if err != nil { if err != nil {
return nil, fmt.Errorf("search api keys: %w", err) return nil, fmt.Errorf("search api keys: %w", err)
......
...@@ -9,7 +9,6 @@ import ( ...@@ -9,7 +9,6 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
...@@ -64,12 +63,12 @@ func NewAuthService( ...@@ -64,12 +63,12 @@ func NewAuthService(
} }
// Register 用户注册,返回token和用户 // Register 用户注册,返回token和用户
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *model.User, error) { func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
return s.RegisterWithVerification(ctx, email, password, "") return s.RegisterWithVerification(ctx, email, password, "")
} }
// RegisterWithVerification 用户注册(支持邮件验证),返回token和用户 // RegisterWithVerification 用户注册(支持邮件验证),返回token和用户
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode string) (string, *model.User, error) { func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode string) (string, *User, error) {
// 检查是否开放注册 // 检查是否开放注册
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) { if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) {
return "", nil, ErrRegDisabled return "", nil, ErrRegDisabled
...@@ -113,13 +112,13 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw ...@@ -113,13 +112,13 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
} }
// 创建用户 // 创建用户
user := &model.User{ user := &User{
Email: email, Email: email,
PasswordHash: hashedPassword, PasswordHash: hashedPassword,
Role: model.RoleUser, Role: RoleUser,
Balance: defaultBalance, Balance: defaultBalance,
Concurrency: defaultConcurrency, Concurrency: defaultConcurrency,
Status: model.StatusActive, Status: StatusActive,
} }
if err := s.userRepo.Create(ctx, user); err != nil { if err := s.userRepo.Create(ctx, user); err != nil {
...@@ -251,7 +250,7 @@ func (s *AuthService) IsEmailVerifyEnabled(ctx context.Context) bool { ...@@ -251,7 +250,7 @@ func (s *AuthService) IsEmailVerifyEnabled(ctx context.Context) bool {
} }
// Login 用户登录,返回JWT token // Login 用户登录,返回JWT token
func (s *AuthService) Login(ctx context.Context, email, password string) (string, *model.User, error) { func (s *AuthService) Login(ctx context.Context, email, password string) (string, *User, error) {
// 查找用户 // 查找用户
user, err := s.userRepo.GetByEmail(ctx, email) user, err := s.userRepo.GetByEmail(ctx, email)
if err != nil { if err != nil {
...@@ -307,7 +306,7 @@ func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) { ...@@ -307,7 +306,7 @@ func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
} }
// GenerateToken 生成JWT token // GenerateToken 生成JWT token
func (s *AuthService) GenerateToken(user *model.User) (string, error) { func (s *AuthService) GenerateToken(user *User) (string, error) {
now := time.Now() now := time.Now()
expiresAt := now.Add(time.Duration(s.cfg.JWT.ExpireHour) * time.Hour) expiresAt := now.Add(time.Duration(s.cfg.JWT.ExpireHour) * time.Hour)
......
...@@ -7,7 +7,6 @@ import ( ...@@ -7,7 +7,6 @@ import (
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
) )
// 错误定义 // 错误定义
...@@ -224,7 +223,7 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID ...@@ -224,7 +223,7 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID
// CheckBillingEligibility 检查用户是否有资格发起请求 // CheckBillingEligibility 检查用户是否有资格发起请求
// 余额模式:检查缓存余额 > 0 // 余额模式:检查缓存余额 > 0
// 订阅模式:检查缓存用量未超过限额(Group限额从参数传入) // 订阅模式:检查缓存用量未超过限额(Group限额从参数传入)
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *model.User, apiKey *model.ApiKey, group *model.Group, subscription *model.UserSubscription) error { func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *ApiKey, group *Group, subscription *UserSubscription) error {
// 判断计费模式 // 判断计费模式
isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil
...@@ -252,7 +251,7 @@ func (s *BillingCacheService) checkBalanceEligibility(ctx context.Context, userI ...@@ -252,7 +251,7 @@ func (s *BillingCacheService) checkBalanceEligibility(ctx context.Context, userI
} }
// checkSubscriptionEligibility 检查订阅模式资格 // checkSubscriptionEligibility 检查订阅模式资格
func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context, userID int64, group *model.Group, subscription *model.UserSubscription) error { func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context, userID int64, group *Group, subscription *UserSubscription) error {
// 获取订阅缓存数据 // 获取订阅缓存数据
subData, err := s.GetSubscriptionStatus(ctx, userID, group.ID) subData, err := s.GetSubscriptionStatus(ctx, userID, group.ID)
if err != nil { if err != nil {
...@@ -262,7 +261,7 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context, ...@@ -262,7 +261,7 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
} }
// 检查订阅状态 // 检查订阅状态
if subData.Status != model.SubscriptionStatusActive { if subData.Status != SubscriptionStatusActive {
return ErrSubscriptionInvalid return ErrSubscriptionInvalid
} }
...@@ -288,7 +287,7 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context, ...@@ -288,7 +287,7 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
} }
// checkSubscriptionLimitsFallback 降级检查订阅限额 // checkSubscriptionLimitsFallback 降级检查订阅限额
func (s *BillingCacheService) checkSubscriptionLimitsFallback(subscription *model.UserSubscription, group *model.Group) error { func (s *BillingCacheService) checkSubscriptionLimitsFallback(subscription *UserSubscription, group *Group) error {
if subscription == nil { if subscription == nil {
return ErrSubscriptionInvalid return ErrSubscriptionInvalid
} }
......
...@@ -11,8 +11,6 @@ import ( ...@@ -11,8 +11,6 @@ import (
"net/url" "net/url"
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
) )
type CRSSyncService struct { type CRSSyncService struct {
...@@ -180,7 +178,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -180,7 +178,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
), ),
} }
var proxies []model.Proxy var proxies []Proxy
if input.SyncProxies { if input.SyncProxies {
proxies, _ = s.proxyRepo.ListActive(ctx) proxies, _ = s.proxyRepo.ListActive(ctx)
} }
...@@ -197,7 +195,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -197,7 +195,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
if targetType == "" { if targetType == "" {
targetType = "oauth" targetType = "oauth"
} }
if targetType != model.AccountTypeOAuth && targetType != model.AccountTypeSetupToken { if targetType != AccountTypeOAuth && targetType != AccountTypeSetupToken {
item.Action = "skipped" item.Action = "skipped"
item.Error = "unsupported authType: " + targetType item.Error = "unsupported authType: " + targetType
result.Skipped++ result.Skipped++
...@@ -268,12 +266,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -268,12 +266,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
} }
if existing == nil { if existing == nil {
account := &model.Account{ account := &Account{
Name: defaultName(src.Name, src.ID), Name: defaultName(src.Name, src.ID),
Platform: model.PlatformAnthropic, Platform: PlatformAnthropic,
Type: targetType, Type: targetType,
Credentials: model.JSONB(credentials), Credentials: credentials,
Extra: model.JSONB(extra), Extra: extra,
ProxyID: proxyID, ProxyID: proxyID,
Concurrency: concurrency, Concurrency: concurrency,
Priority: priority, Priority: priority,
...@@ -288,7 +286,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -288,7 +286,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
continue continue
} }
// 🔄 Refresh OAuth token after creation // 🔄 Refresh OAuth token after creation
if targetType == model.AccountTypeOAuth { if targetType == AccountTypeOAuth {
if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil { if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil {
account.Credentials = refreshedCreds account.Credentials = refreshedCreds
_ = s.accountRepo.Update(ctx, account) _ = s.accountRepo.Update(ctx, account)
...@@ -301,11 +299,11 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -301,11 +299,11 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
} }
// Update existing // Update existing
existing.Extra = mergeJSONB(existing.Extra, extra) existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID) existing.Name = defaultName(src.Name, src.ID)
existing.Platform = model.PlatformAnthropic existing.Platform = PlatformAnthropic
existing.Type = targetType existing.Type = targetType
existing.Credentials = mergeJSONB(existing.Credentials, credentials) existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil { if proxyID != nil {
existing.ProxyID = proxyID existing.ProxyID = proxyID
} }
...@@ -323,7 +321,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -323,7 +321,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
} }
// 🔄 Refresh OAuth token after update // 🔄 Refresh OAuth token after update
if targetType == model.AccountTypeOAuth { if targetType == AccountTypeOAuth {
if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil { if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil {
existing.Credentials = refreshedCreds existing.Credentials = refreshedCreds
_ = s.accountRepo.Update(ctx, existing) _ = s.accountRepo.Update(ctx, existing)
...@@ -385,12 +383,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -385,12 +383,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
} }
if existing == nil { if existing == nil {
account := &model.Account{ account := &Account{
Name: defaultName(src.Name, src.ID), Name: defaultName(src.Name, src.ID),
Platform: model.PlatformAnthropic, Platform: PlatformAnthropic,
Type: model.AccountTypeApiKey, Type: AccountTypeApiKey,
Credentials: model.JSONB(credentials), Credentials: credentials,
Extra: model.JSONB(extra), Extra: extra,
ProxyID: proxyID, ProxyID: proxyID,
Concurrency: concurrency, Concurrency: concurrency,
Priority: priority, Priority: priority,
...@@ -410,11 +408,11 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -410,11 +408,11 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
continue continue
} }
existing.Extra = mergeJSONB(existing.Extra, extra) existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID) existing.Name = defaultName(src.Name, src.ID)
existing.Platform = model.PlatformAnthropic existing.Platform = PlatformAnthropic
existing.Type = model.AccountTypeApiKey existing.Type = AccountTypeApiKey
existing.Credentials = mergeJSONB(existing.Credentials, credentials) existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil { if proxyID != nil {
existing.ProxyID = proxyID existing.ProxyID = proxyID
} }
...@@ -508,12 +506,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -508,12 +506,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
} }
if existing == nil { if existing == nil {
account := &model.Account{ account := &Account{
Name: defaultName(src.Name, src.ID), Name: defaultName(src.Name, src.ID),
Platform: model.PlatformOpenAI, Platform: PlatformOpenAI,
Type: model.AccountTypeOAuth, Type: AccountTypeOAuth,
Credentials: model.JSONB(credentials), Credentials: credentials,
Extra: model.JSONB(extra), Extra: extra,
ProxyID: proxyID, ProxyID: proxyID,
Concurrency: concurrency, Concurrency: concurrency,
Priority: priority, Priority: priority,
...@@ -538,11 +536,11 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -538,11 +536,11 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
continue continue
} }
existing.Extra = mergeJSONB(existing.Extra, extra) existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID) existing.Name = defaultName(src.Name, src.ID)
existing.Platform = model.PlatformOpenAI existing.Platform = PlatformOpenAI
existing.Type = model.AccountTypeOAuth existing.Type = AccountTypeOAuth
existing.Credentials = mergeJSONB(existing.Credentials, credentials) existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil { if proxyID != nil {
existing.ProxyID = proxyID existing.ProxyID = proxyID
} }
...@@ -629,12 +627,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -629,12 +627,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
} }
if existing == nil { if existing == nil {
account := &model.Account{ account := &Account{
Name: defaultName(src.Name, src.ID), Name: defaultName(src.Name, src.ID),
Platform: model.PlatformOpenAI, Platform: PlatformOpenAI,
Type: model.AccountTypeApiKey, Type: AccountTypeApiKey,
Credentials: model.JSONB(credentials), Credentials: credentials,
Extra: model.JSONB(extra), Extra: extra,
ProxyID: proxyID, ProxyID: proxyID,
Concurrency: concurrency, Concurrency: concurrency,
Priority: priority, Priority: priority,
...@@ -654,11 +652,11 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -654,11 +652,11 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
continue continue
} }
existing.Extra = mergeJSONB(existing.Extra, extra) existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID) existing.Name = defaultName(src.Name, src.ID)
existing.Platform = model.PlatformOpenAI existing.Platform = PlatformOpenAI
existing.Type = model.AccountTypeApiKey existing.Type = AccountTypeApiKey
existing.Credentials = mergeJSONB(existing.Credentials, credentials) existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil { if proxyID != nil {
existing.ProxyID = proxyID existing.ProxyID = proxyID
} }
...@@ -683,9 +681,8 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -683,9 +681,8 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
return result, nil return result, nil
} }
// mergeJSONB merges two JSONB maps without removing keys that are absent in updates. func mergeMap(existing map[string]any, updates map[string]any) map[string]any {
func mergeJSONB(existing model.JSONB, updates map[string]any) model.JSONB { out := make(map[string]any, len(existing)+len(updates))
out := make(model.JSONB)
for k, v := range existing { for k, v := range existing {
out[k] = v out[k] = v
} }
...@@ -695,7 +692,7 @@ func mergeJSONB(existing model.JSONB, updates map[string]any) model.JSONB { ...@@ -695,7 +692,7 @@ func mergeJSONB(existing model.JSONB, updates map[string]any) model.JSONB {
return out return out
} }
func (s *CRSSyncService) mapOrCreateProxy(ctx context.Context, enabled bool, cached *[]model.Proxy, src *crsProxy, defaultName string) (*int64, error) { func (s *CRSSyncService) mapOrCreateProxy(ctx context.Context, enabled bool, cached *[]Proxy, src *crsProxy, defaultName string) (*int64, error) {
if !enabled || src == nil { if !enabled || src == nil {
return nil, nil return nil, nil
} }
...@@ -731,14 +728,14 @@ func (s *CRSSyncService) mapOrCreateProxy(ctx context.Context, enabled bool, cac ...@@ -731,14 +728,14 @@ func (s *CRSSyncService) mapOrCreateProxy(ctx context.Context, enabled bool, cac
} }
// Create new proxy // Create new proxy
proxy := &model.Proxy{ proxy := &Proxy{
Name: defaultProxyName(defaultName, protocol, host, port), Name: defaultProxyName(defaultName, protocol, host, port),
Protocol: protocol, Protocol: protocol,
Host: host, Host: host,
Port: port, Port: port,
Username: username, Username: username,
Password: password, Password: password,
Status: model.StatusActive, Status: StatusActive,
} }
if err := s.proxyRepo.Create(ctx, proxy); err != nil { if err := s.proxyRepo.Create(ctx, proxy); err != nil {
return nil, err return nil, err
...@@ -897,8 +894,8 @@ func crsExportAccounts(ctx context.Context, client *http.Client, baseURL, adminT ...@@ -897,8 +894,8 @@ func crsExportAccounts(ctx context.Context, client *http.Client, baseURL, adminT
// refreshOAuthToken attempts to refresh OAuth token for a synced account // refreshOAuthToken attempts to refresh OAuth token for a synced account
// Returns updated credentials or nil if refresh failed/not applicable // Returns updated credentials or nil if refresh failed/not applicable
func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *model.Account) model.JSONB { func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *Account) map[string]any {
if account.Type != model.AccountTypeOAuth { if account.Type != AccountTypeOAuth {
return nil return nil
} }
...@@ -906,7 +903,7 @@ func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *model.A ...@@ -906,7 +903,7 @@ func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *model.A
var err error var err error
switch account.Platform { switch account.Platform {
case model.PlatformAnthropic: case PlatformAnthropic:
if s.oauthService == nil { if s.oauthService == nil {
return nil return nil
} }
...@@ -931,7 +928,7 @@ func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *model.A ...@@ -931,7 +928,7 @@ func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *model.A
newCredentials["scope"] = tokenInfo.Scope newCredentials["scope"] = tokenInfo.Scope
} }
} }
case model.PlatformOpenAI: case PlatformOpenAI:
if s.openaiOAuthService == nil { if s.openaiOAuthService == nil {
return nil return nil
} }
...@@ -956,5 +953,5 @@ func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *model.A ...@@ -956,5 +953,5 @@ func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *model.A
return nil return nil
} }
return model.JSONB(newCredentials) return newCredentials
} }
package model package service
import ( // Status constants
"time" const (
StatusActive = "active"
StatusDisabled = "disabled"
StatusError = "error"
StatusUnused = "unused"
StatusUsed = "used"
StatusExpired = "expired"
)
// Role constants
const (
RoleAdmin = "admin"
RoleUser = "user"
)
// Platform constants
const (
PlatformAnthropic = "anthropic"
PlatformOpenAI = "openai"
PlatformGemini = "gemini"
)
// Account type constants
const (
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
AccountTypeApiKey = "apikey" // API Key类型账号
) )
// Setting 系统设置模型(Key-Value存储) // Redeem type constants
type Setting struct { const (
ID int64 `gorm:"primaryKey" json:"id"` RedeemTypeBalance = "balance"
Key string `gorm:"uniqueIndex;size:100;not null" json:"key"` RedeemTypeConcurrency = "concurrency"
Value string `gorm:"type:text;not null" json:"value"` RedeemTypeSubscription = "subscription"
UpdatedAt time.Time `gorm:"not null" json:"updated_at"` )
}
func (Setting) TableName() string { // Admin adjustment type constants
return "settings" const (
} AdjustmentTypeAdminBalance = "admin_balance" // 管理员调整余额
AdjustmentTypeAdminConcurrency = "admin_concurrency" // 管理员调整并发数
)
// 设置Key常量 // Group subscription type constants
const (
SubscriptionTypeStandard = "standard" // 标准计费模式(按余额扣费)
SubscriptionTypeSubscription = "subscription" // 订阅模式(按限额控制)
)
// Subscription status constants
const (
SubscriptionStatusActive = "active"
SubscriptionStatusExpired = "expired"
SubscriptionStatusSuspended = "suspended"
)
// Setting keys
const ( const (
// 注册设置 // 注册设置
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册 SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
...@@ -52,53 +92,5 @@ const ( ...@@ -52,53 +92,5 @@ const (
SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成) SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成)
) )
// 管理员 API Key 前缀(与用户 sk- 前缀区分) // Admin API Key prefix (distinct from user "sk-" keys)
const AdminApiKeyPrefix = "admin-" const AdminApiKeyPrefix = "admin-"
// SystemSettings 系统设置结构体(用于API响应)
type SystemSettings struct {
// 注册设置
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
// 邮件服务设置
SmtpHost string `json:"smtp_host"`
SmtpPort int `json:"smtp_port"`
SmtpUsername string `json:"smtp_username"`
SmtpPassword string `json:"smtp_password,omitempty"` // 不返回明文密码
SmtpFrom string `json:"smtp_from_email"`
SmtpFromName string `json:"smtp_from_name"`
SmtpUseTLS bool `json:"smtp_use_tls"`
// Cloudflare Turnstile 设置
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
TurnstileSecretKey string `json:"turnstile_secret_key,omitempty"` // 不返回明文密钥
// OEM设置
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocUrl string `json:"doc_url"`
// 默认配置
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
}
// PublicSettings 公开设置(无需登录即可获取)
type PublicSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocUrl string `json:"doc_url"`
Version string `json:"version"`
}
...@@ -11,7 +11,6 @@ import ( ...@@ -11,7 +11,6 @@ import (
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
) )
var ( var (
...@@ -69,13 +68,13 @@ func NewEmailService(settingRepo SettingRepository, cache EmailCache) *EmailServ ...@@ -69,13 +68,13 @@ func NewEmailService(settingRepo SettingRepository, cache EmailCache) *EmailServ
// GetSmtpConfig 从数据库获取SMTP配置 // GetSmtpConfig 从数据库获取SMTP配置
func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) { func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) {
keys := []string{ keys := []string{
model.SettingKeySmtpHost, SettingKeySmtpHost,
model.SettingKeySmtpPort, SettingKeySmtpPort,
model.SettingKeySmtpUsername, SettingKeySmtpUsername,
model.SettingKeySmtpPassword, SettingKeySmtpPassword,
model.SettingKeySmtpFrom, SettingKeySmtpFrom,
model.SettingKeySmtpFromName, SettingKeySmtpFromName,
model.SettingKeySmtpUseTLS, SettingKeySmtpUseTLS,
} }
settings, err := s.settingRepo.GetMultiple(ctx, keys) settings, err := s.settingRepo.GetMultiple(ctx, keys)
...@@ -83,27 +82,27 @@ func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) { ...@@ -83,27 +82,27 @@ func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) {
return nil, fmt.Errorf("get smtp settings: %w", err) return nil, fmt.Errorf("get smtp settings: %w", err)
} }
host := settings[model.SettingKeySmtpHost] host := settings[SettingKeySmtpHost]
if host == "" { if host == "" {
return nil, ErrEmailNotConfigured return nil, ErrEmailNotConfigured
} }
port := 587 // 默认端口 port := 587 // 默认端口
if portStr := settings[model.SettingKeySmtpPort]; portStr != "" { if portStr := settings[SettingKeySmtpPort]; portStr != "" {
if p, err := strconv.Atoi(portStr); err == nil { if p, err := strconv.Atoi(portStr); err == nil {
port = p port = p
} }
} }
useTLS := settings[model.SettingKeySmtpUseTLS] == "true" useTLS := settings[SettingKeySmtpUseTLS] == "true"
return &SmtpConfig{ return &SmtpConfig{
Host: host, Host: host,
Port: port, Port: port,
Username: settings[model.SettingKeySmtpUsername], Username: settings[SettingKeySmtpUsername],
Password: settings[model.SettingKeySmtpPassword], Password: settings[SettingKeySmtpPassword],
From: settings[model.SettingKeySmtpFrom], From: settings[SettingKeySmtpFrom],
FromName: settings[model.SettingKeySmtpFromName], FromName: settings[SettingKeySmtpFromName],
UseTLS: useTLS, UseTLS: useTLS,
}, nil }, nil
} }
......
...@@ -17,7 +17,6 @@ import ( ...@@ -17,7 +17,6 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
...@@ -265,12 +264,12 @@ func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte ...@@ -265,12 +264,12 @@ func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte
} }
// SelectAccount 选择账号(粘性会话+优先级) // SelectAccount 选择账号(粘性会话+优先级)
func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*model.Account, error) { func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) {
return s.SelectAccountForModel(ctx, groupID, sessionHash, "") return s.SelectAccountForModel(ctx, groupID, sessionHash, "")
} }
// SelectAccountForModel 选择支持指定模型的账号(粘性会话+优先级+模型映射) // SelectAccountForModel 选择支持指定模型的账号(粘性会话+优先级+模型映射)
func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*model.Account, error) { func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
// 1. 查询粘性会话 // 1. 查询粘性会话
if sessionHash != "" { if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
...@@ -289,19 +288,19 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int ...@@ -289,19 +288,19 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
} }
// 2. 获取可调度账号列表(排除限流和过载的账号,仅限 Anthropic 平台) // 2. 获取可调度账号列表(排除限流和过载的账号,仅限 Anthropic 平台)
var accounts []model.Account var accounts []Account
var err error var err error
if groupID != nil { if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, model.PlatformAnthropic) accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformAnthropic)
} else { } else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, model.PlatformAnthropic) accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformAnthropic)
} }
if err != nil { if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err) return nil, fmt.Errorf("query accounts failed: %w", err)
} }
// 3. 按优先级+最久未用选择(考虑模型支持) // 3. 按优先级+最久未用选择(考虑模型支持)
var selected *model.Account var selected *Account
for i := range accounts { for i := range accounts {
acc := &accounts[i] acc := &accounts[i]
// 检查模型支持 // 检查模型支持
...@@ -341,12 +340,12 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int ...@@ -341,12 +340,12 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
} }
// GetAccessToken 获取账号凭证 // GetAccessToken 获取账号凭证
func (s *GatewayService) GetAccessToken(ctx context.Context, account *model.Account) (string, string, error) { func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
switch account.Type { switch account.Type {
case model.AccountTypeOAuth, model.AccountTypeSetupToken: case AccountTypeOAuth, AccountTypeSetupToken:
// Both oauth and setup-token use OAuth token flow // Both oauth and setup-token use OAuth token flow
return s.getOAuthToken(ctx, account) return s.getOAuthToken(ctx, account)
case model.AccountTypeApiKey: case AccountTypeApiKey:
apiKey := account.GetCredential("api_key") apiKey := account.GetCredential("api_key")
if apiKey == "" { if apiKey == "" {
return "", "", errors.New("api_key not found in credentials") return "", "", errors.New("api_key not found in credentials")
...@@ -357,7 +356,7 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *model.Acco ...@@ -357,7 +356,7 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *model.Acco
} }
} }
func (s *GatewayService) getOAuthToken(ctx context.Context, account *model.Account) (string, string, error) { func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (string, string, error) {
accessToken := account.GetCredential("access_token") accessToken := account.GetCredential("access_token")
if accessToken == "" { if accessToken == "" {
return "", "", errors.New("access_token not found in credentials") return "", "", errors.New("access_token not found in credentials")
...@@ -372,10 +371,7 @@ const ( ...@@ -372,10 +371,7 @@ const (
retryDelay = 3 * time.Second // 重试等待时间 retryDelay = 3 * time.Second // 重试等待时间
) )
// shouldRetryUpstreamError 判断是否应该重试上游错误 func (s *GatewayService) shouldRetryUpstreamError(account *Account, statusCode int) bool {
// OAuth/Setup Token 账号:仅 403 重试
// API Key 账号:未配置的错误码重试
func (s *GatewayService) shouldRetryUpstreamError(account *model.Account, statusCode int) bool {
// OAuth/Setup Token 账号:仅 403 重试 // OAuth/Setup Token 账号:仅 403 重试
if account.IsOAuth() { if account.IsOAuth() {
return statusCode == 403 return statusCode == 403
...@@ -386,7 +382,7 @@ func (s *GatewayService) shouldRetryUpstreamError(account *model.Account, status ...@@ -386,7 +382,7 @@ func (s *GatewayService) shouldRetryUpstreamError(account *model.Account, status
} }
// Forward 转发请求到Claude API // Forward 转发请求到Claude API
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *model.Account, body []byte) (*ForwardResult, error) { func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
startTime := time.Now() startTime := time.Now()
// 解析请求获取model和stream // 解析请求获取model和stream
...@@ -412,7 +408,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m ...@@ -412,7 +408,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
// 应用模型映射(仅对apikey类型账号) // 应用模型映射(仅对apikey类型账号)
originalModel := req.Model originalModel := req.Model
if account.Type == model.AccountTypeApiKey { if account.Type == AccountTypeApiKey {
mappedModel := account.GetMappedModel(req.Model) mappedModel := account.GetMappedModel(req.Model)
if mappedModel != req.Model { if mappedModel != req.Model {
// 替换请求体中的模型名 // 替换请求体中的模型名
...@@ -504,10 +500,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m ...@@ -504,10 +500,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
}, nil }, nil
} }
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token, tokenType string) (*http.Request, error) { func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType string) (*http.Request, error) {
// 确定目标URL // 确定目标URL
targetURL := claudeAPIURL targetURL := claudeAPIURL
if account.Type == model.AccountTypeApiKey { if account.Type == AccountTypeApiKey {
baseURL := account.GetBaseURL() baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages" targetURL = baseURL + "/v1/messages"
} }
...@@ -631,7 +627,7 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str ...@@ -631,7 +627,7 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str
return claude.DefaultBetaHeader return claude.DefaultBetaHeader
} }
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account) (*ForwardResult, error) { func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
// 处理上游错误,标记账号状态 // 处理上游错误,标记账号状态
...@@ -686,7 +682,7 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res ...@@ -686,7 +682,7 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
// handleRetryExhaustedError 处理重试耗尽后的错误 // handleRetryExhaustedError 处理重试耗尽后的错误
// OAuth 403:标记账号异常 // OAuth 403:标记账号异常
// API Key 未配置错误码:仅返回错误,不标记账号 // API Key 未配置错误码:仅返回错误,不标记账号
func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account) (*ForwardResult, error) { func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
statusCode := resp.StatusCode statusCode := resp.StatusCode
...@@ -717,7 +713,7 @@ type streamingResult struct { ...@@ -717,7 +713,7 @@ type streamingResult struct {
firstTokenMs *int firstTokenMs *int
} }
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) { func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) {
// 更新5h窗口状态 // 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
...@@ -856,7 +852,7 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) { ...@@ -856,7 +852,7 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
} }
} }
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account, originalModel, mappedModel string) (*ClaudeUsage, error) { func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) {
// 更新5h窗口状态 // 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
...@@ -915,10 +911,10 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo ...@@ -915,10 +911,10 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
// RecordUsageInput 记录使用量的输入参数 // RecordUsageInput 记录使用量的输入参数
type RecordUsageInput struct { type RecordUsageInput struct {
Result *ForwardResult Result *ForwardResult
ApiKey *model.ApiKey ApiKey *ApiKey
User *model.User User *User
Account *model.Account Account *Account
Subscription *model.UserSubscription // 可选:订阅信息 Subscription *UserSubscription // 可选:订阅信息
} }
// RecordUsage 记录使用量并扣费(或更新订阅用量) // RecordUsage 记录使用量并扣费(或更新订阅用量)
...@@ -952,14 +948,14 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -952,14 +948,14 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
// 判断计费方式:订阅模式 vs 余额模式 // 判断计费方式:订阅模式 vs 余额模式
isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType() isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
billingType := model.BillingTypeBalance billingType := BillingTypeBalance
if isSubscriptionBilling { if isSubscriptionBilling {
billingType = model.BillingTypeSubscription billingType = BillingTypeSubscription
} }
// 创建使用日志 // 创建使用日志
durationMs := int(result.Duration.Milliseconds()) durationMs := int(result.Duration.Milliseconds())
usageLog := &model.UsageLog{ usageLog := &UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
...@@ -1038,9 +1034,9 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -1038,9 +1034,9 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
// ForwardCountTokens 转发 count_tokens 请求到上游 API // ForwardCountTokens 转发 count_tokens 请求到上游 API
// 特点:不记录使用量、仅支持非流式响应 // 特点:不记录使用量、仅支持非流式响应
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *model.Account, body []byte) error { func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, body []byte) error {
// 应用模型映射(仅对 apikey 类型账号) // 应用模型映射(仅对 apikey 类型账号)
if account.Type == model.AccountTypeApiKey { if account.Type == AccountTypeApiKey {
var req struct { var req struct {
Model string `json:"model"` Model string `json:"model"`
} }
...@@ -1113,10 +1109,10 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, ...@@ -1113,10 +1109,10 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
} }
// buildCountTokensRequest 构建 count_tokens 上游请求 // buildCountTokensRequest 构建 count_tokens 上游请求
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token, tokenType string) (*http.Request, error) { func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType string) (*http.Request, error) {
// 确定目标 URL // 确定目标 URL
targetURL := claudeAPICountTokensURL targetURL := claudeAPICountTokensURL
if account.Type == model.AccountTypeApiKey { if account.Type == AccountTypeApiKey {
baseURL := account.GetBaseURL() baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages/count_tokens" targetURL = baseURL + "/v1/messages/count_tokens"
} }
......
package service
import "time"
type Group struct {
ID int64
Name string
Description string
Platform string
RateMultiplier float64
IsExclusive bool
Status string
SubscriptionType string
DailyLimitUSD *float64
WeeklyLimitUSD *float64
MonthlyLimitUSD *float64
CreatedAt time.Time
UpdatedAt time.Time
AccountGroups []AccountGroup
AccountCount int64
}
func (g *Group) IsActive() bool {
return g.Status == StatusActive
}
func (g *Group) IsSubscriptionType() bool {
return g.SubscriptionType == SubscriptionTypeSubscription
}
func (g *Group) IsFreeSubscription() bool {
return g.IsSubscriptionType() && g.RateMultiplier == 0
}
func (g *Group) HasDailyLimit() bool {
return g.DailyLimitUSD != nil && *g.DailyLimitUSD > 0
}
func (g *Group) HasWeeklyLimit() bool {
return g.WeeklyLimitUSD != nil && *g.WeeklyLimitUSD > 0
}
func (g *Group) HasMonthlyLimit() bool {
return g.MonthlyLimitUSD != nil && *g.MonthlyLimitUSD > 0
}
...@@ -5,7 +5,6 @@ import ( ...@@ -5,7 +5,6 @@ import (
"fmt" "fmt"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
) )
...@@ -15,16 +14,16 @@ var ( ...@@ -15,16 +14,16 @@ var (
) )
type GroupRepository interface { type GroupRepository interface {
Create(ctx context.Context, group *model.Group) error Create(ctx context.Context, group *Group) error
GetByID(ctx context.Context, id int64) (*model.Group, error) GetByID(ctx context.Context, id int64) (*Group, error)
Update(ctx context.Context, group *model.Group) error Update(ctx context.Context, group *Group) error
Delete(ctx context.Context, id int64) error Delete(ctx context.Context, id int64) error
DeleteCascade(ctx context.Context, id int64) ([]int64, error) DeleteCascade(ctx context.Context, id int64) ([]int64, error)
List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error)
ListActive(ctx context.Context) ([]model.Group, error) ListActive(ctx context.Context) ([]Group, error)
ListActiveByPlatform(ctx context.Context, platform string) ([]model.Group, error) ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error)
ExistsByName(ctx context.Context, name string) (bool, error) ExistsByName(ctx context.Context, name string) (bool, error)
GetAccountCount(ctx context.Context, groupID int64) (int64, error) GetAccountCount(ctx context.Context, groupID int64) (int64, error)
...@@ -61,7 +60,7 @@ func NewGroupService(groupRepo GroupRepository) *GroupService { ...@@ -61,7 +60,7 @@ func NewGroupService(groupRepo GroupRepository) *GroupService {
} }
// Create 创建分组 // Create 创建分组
func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*model.Group, error) { func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*Group, error) {
// 检查名称是否已存在 // 检查名称是否已存在
exists, err := s.groupRepo.ExistsByName(ctx, req.Name) exists, err := s.groupRepo.ExistsByName(ctx, req.Name)
if err != nil { if err != nil {
...@@ -72,12 +71,14 @@ func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*mod ...@@ -72,12 +71,14 @@ func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*mod
} }
// 创建分组 // 创建分组
group := &model.Group{ group := &Group{
Name: req.Name, Name: req.Name,
Description: req.Description, Description: req.Description,
RateMultiplier: req.RateMultiplier, Platform: PlatformAnthropic,
IsExclusive: req.IsExclusive, RateMultiplier: req.RateMultiplier,
Status: model.StatusActive, IsExclusive: req.IsExclusive,
Status: StatusActive,
SubscriptionType: SubscriptionTypeStandard,
} }
if err := s.groupRepo.Create(ctx, group); err != nil { if err := s.groupRepo.Create(ctx, group); err != nil {
...@@ -88,7 +89,7 @@ func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*mod ...@@ -88,7 +89,7 @@ func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*mod
} }
// GetByID 根据ID获取分组 // GetByID 根据ID获取分组
func (s *GroupService) GetByID(ctx context.Context, id int64) (*model.Group, error) { func (s *GroupService) GetByID(ctx context.Context, id int64) (*Group, error) {
group, err := s.groupRepo.GetByID(ctx, id) group, err := s.groupRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, fmt.Errorf("get group: %w", err) return nil, fmt.Errorf("get group: %w", err)
...@@ -97,7 +98,7 @@ func (s *GroupService) GetByID(ctx context.Context, id int64) (*model.Group, err ...@@ -97,7 +98,7 @@ func (s *GroupService) GetByID(ctx context.Context, id int64) (*model.Group, err
} }
// List 获取分组列表 // List 获取分组列表
func (s *GroupService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) { func (s *GroupService) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
groups, pagination, err := s.groupRepo.List(ctx, params) groups, pagination, err := s.groupRepo.List(ctx, params)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("list groups: %w", err) return nil, nil, fmt.Errorf("list groups: %w", err)
...@@ -106,7 +107,7 @@ func (s *GroupService) List(ctx context.Context, params pagination.PaginationPar ...@@ -106,7 +107,7 @@ func (s *GroupService) List(ctx context.Context, params pagination.PaginationPar
} }
// ListActive 获取活跃分组列表 // ListActive 获取活跃分组列表
func (s *GroupService) ListActive(ctx context.Context) ([]model.Group, error) { func (s *GroupService) ListActive(ctx context.Context) ([]Group, error) {
groups, err := s.groupRepo.ListActive(ctx) groups, err := s.groupRepo.ListActive(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("list active groups: %w", err) return nil, fmt.Errorf("list active groups: %w", err)
...@@ -115,7 +116,7 @@ func (s *GroupService) ListActive(ctx context.Context) ([]model.Group, error) { ...@@ -115,7 +116,7 @@ func (s *GroupService) ListActive(ctx context.Context) ([]model.Group, error) {
} }
// Update 更新分组 // Update 更新分组
func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequest) (*model.Group, error) { func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequest) (*Group, error) {
group, err := s.groupRepo.GetByID(ctx, id) group, err := s.groupRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, fmt.Errorf("get group: %w", err) return nil, fmt.Errorf("get group: %w", err)
......
...@@ -6,7 +6,6 @@ import ( ...@@ -6,7 +6,6 @@ import (
"log" "log"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth" "github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
) )
...@@ -274,7 +273,7 @@ func (s *OAuthService) RefreshToken(ctx context.Context, refreshToken string, pr ...@@ -274,7 +273,7 @@ func (s *OAuthService) RefreshToken(ctx context.Context, refreshToken string, pr
} }
// RefreshAccountToken refreshes token for an account // RefreshAccountToken refreshes token for an account
func (s *OAuthService) RefreshAccountToken(ctx context.Context, account *model.Account) (*TokenInfo, error) { func (s *OAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*TokenInfo, error) {
refreshToken := account.GetCredential("refresh_token") refreshToken := account.GetCredential("refresh_token")
if refreshToken == "" { if refreshToken == "" {
return nil, fmt.Errorf("no refresh token available") return nil, fmt.Errorf("no refresh token available")
......
...@@ -16,7 +16,6 @@ import ( ...@@ -16,7 +16,6 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
...@@ -119,12 +118,12 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context) string { ...@@ -119,12 +118,12 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context) string {
} }
// SelectAccount selects an OpenAI account with sticky session support // SelectAccount selects an OpenAI account with sticky session support
func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*model.Account, error) { func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) {
return s.SelectAccountForModel(ctx, groupID, sessionHash, "") return s.SelectAccountForModel(ctx, groupID, sessionHash, "")
} }
// SelectAccountForModel selects an account supporting the requested model // SelectAccountForModel selects an account supporting the requested model
func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*model.Account, error) { func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
// 1. Check sticky session // 1. Check sticky session
if sessionHash != "" { if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash) accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash)
...@@ -139,19 +138,19 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI ...@@ -139,19 +138,19 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
} }
// 2. Get schedulable OpenAI accounts // 2. Get schedulable OpenAI accounts
var accounts []model.Account var accounts []Account
var err error var err error
if groupID != nil { if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, model.PlatformOpenAI) accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI)
} else { } else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, model.PlatformOpenAI) accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
} }
if err != nil { if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err) return nil, fmt.Errorf("query accounts failed: %w", err)
} }
// 3. Select by priority + LRU // 3. Select by priority + LRU
var selected *model.Account var selected *Account
for i := range accounts { for i := range accounts {
acc := &accounts[i] acc := &accounts[i]
// Check model support // Check model support
...@@ -189,15 +188,15 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI ...@@ -189,15 +188,15 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
} }
// GetAccessToken gets the access token for an OpenAI account // GetAccessToken gets the access token for an OpenAI account
func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *model.Account) (string, string, error) { func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
switch account.Type { switch account.Type {
case model.AccountTypeOAuth: case AccountTypeOAuth:
accessToken := account.GetOpenAIAccessToken() accessToken := account.GetOpenAIAccessToken()
if accessToken == "" { if accessToken == "" {
return "", "", errors.New("access_token not found in credentials") return "", "", errors.New("access_token not found in credentials")
} }
return accessToken, "oauth", nil return accessToken, "oauth", nil
case model.AccountTypeApiKey: case AccountTypeApiKey:
apiKey := account.GetOpenAIApiKey() apiKey := account.GetOpenAIApiKey()
if apiKey == "" { if apiKey == "" {
return "", "", errors.New("api_key not found in credentials") return "", "", errors.New("api_key not found in credentials")
...@@ -209,7 +208,7 @@ func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *mode ...@@ -209,7 +208,7 @@ func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *mode
} }
// Forward forwards request to OpenAI API // Forward forwards request to OpenAI API
func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *model.Account, body []byte) (*OpenAIForwardResult, error) { func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*OpenAIForwardResult, error) {
startTime := time.Now() startTime := time.Now()
// Parse request body once (avoid multiple parse/serialize cycles) // Parse request body once (avoid multiple parse/serialize cycles)
...@@ -234,7 +233,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco ...@@ -234,7 +233,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
} }
// For OAuth accounts using ChatGPT internal API, add store: false // For OAuth accounts using ChatGPT internal API, add store: false
if account.Type == model.AccountTypeOAuth { if account.Type == AccountTypeOAuth {
reqBody["store"] = false reqBody["store"] = false
bodyModified = true bodyModified = true
} }
...@@ -296,7 +295,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco ...@@ -296,7 +295,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
} }
// Extract and save Codex usage snapshot from response headers (for OAuth accounts) // Extract and save Codex usage snapshot from response headers (for OAuth accounts)
if account.Type == model.AccountTypeOAuth { if account.Type == AccountTypeOAuth {
if snapshot := extractCodexUsageHeaders(resp.Header); snapshot != nil { if snapshot := extractCodexUsageHeaders(resp.Header); snapshot != nil {
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot) s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
} }
...@@ -312,14 +311,14 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco ...@@ -312,14 +311,14 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}, nil }, nil
} }
func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token string, isStream bool) (*http.Request, error) { func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool) (*http.Request, error) {
// Determine target URL based on account type // Determine target URL based on account type
var targetURL string var targetURL string
switch account.Type { switch account.Type {
case model.AccountTypeOAuth: case AccountTypeOAuth:
// OAuth accounts use ChatGPT internal API // OAuth accounts use ChatGPT internal API
targetURL = chatgptCodexURL targetURL = chatgptCodexURL
case model.AccountTypeApiKey: case AccountTypeApiKey:
// API Key accounts use Platform API or custom base URL // API Key accounts use Platform API or custom base URL
baseURL := account.GetOpenAIBaseURL() baseURL := account.GetOpenAIBaseURL()
if baseURL != "" { if baseURL != "" {
...@@ -340,7 +339,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. ...@@ -340,7 +339,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
req.Header.Set("authorization", "Bearer "+token) req.Header.Set("authorization", "Bearer "+token)
// Set headers specific to OAuth accounts (ChatGPT internal API) // Set headers specific to OAuth accounts (ChatGPT internal API)
if account.Type == model.AccountTypeOAuth { if account.Type == AccountTypeOAuth {
// Required: set Host for ChatGPT API (must use req.Host, not Header.Set) // Required: set Host for ChatGPT API (must use req.Host, not Header.Set)
req.Host = "chatgpt.com" req.Host = "chatgpt.com"
// Required: set chatgpt-account-id header // Required: set chatgpt-account-id header
...@@ -380,7 +379,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. ...@@ -380,7 +379,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
return req, nil return req, nil
} }
func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account) (*OpenAIForwardResult, error) { func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*OpenAIForwardResult, error) {
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
// Check custom error codes // Check custom error codes
...@@ -436,7 +435,7 @@ type openaiStreamingResult struct { ...@@ -436,7 +435,7 @@ type openaiStreamingResult struct {
firstTokenMs *int firstTokenMs *int
} }
func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) { func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) {
// Set SSE response headers // Set SSE response headers
c.Header("Content-Type", "text/event-stream") c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache") c.Header("Cache-Control", "no-cache")
...@@ -552,7 +551,7 @@ func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) { ...@@ -552,7 +551,7 @@ func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
} }
} }
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account, originalModel, mappedModel string) (*OpenAIUsage, error) { func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -618,10 +617,10 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel ...@@ -618,10 +617,10 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
// OpenAIRecordUsageInput input for recording usage // OpenAIRecordUsageInput input for recording usage
type OpenAIRecordUsageInput struct { type OpenAIRecordUsageInput struct {
Result *OpenAIForwardResult Result *OpenAIForwardResult
ApiKey *model.ApiKey ApiKey *ApiKey
User *model.User User *User
Account *model.Account Account *Account
Subscription *model.UserSubscription Subscription *UserSubscription
} }
// RecordUsage records usage and deducts balance // RecordUsage records usage and deducts balance
...@@ -660,14 +659,14 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec ...@@ -660,14 +659,14 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
// Determine billing type // Determine billing type
isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType() isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
billingType := model.BillingTypeBalance billingType := BillingTypeBalance
if isSubscriptionBilling { if isSubscriptionBilling {
billingType = model.BillingTypeSubscription billingType = BillingTypeSubscription
} }
// Create usage log // Create usage log
durationMs := int(result.Duration.Milliseconds()) durationMs := int(result.Duration.Milliseconds())
usageLog := &model.UsageLog{ usageLog := &UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
......
...@@ -5,7 +5,6 @@ import ( ...@@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
) )
...@@ -200,7 +199,7 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri ...@@ -200,7 +199,7 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri
} }
// RefreshAccountToken refreshes token for an OpenAI account // RefreshAccountToken refreshes token for an OpenAI account
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *model.Account) (*OpenAITokenInfo, error) { func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
if !account.IsOpenAI() { if !account.IsOpenAI() {
return nil, fmt.Errorf("account is not an OpenAI account") return nil, fmt.Errorf("account is not an OpenAI account")
} }
......
package service
import (
"fmt"
"time"
)
type Proxy struct {
ID int64
Name string
Protocol string
Host string
Port int
Username string
Password string
Status string
CreatedAt time.Time
UpdatedAt time.Time
}
func (p *Proxy) IsActive() bool {
return p.Status == StatusActive
}
func (p *Proxy) URL() string {
if p.Username != "" && p.Password != "" {
return fmt.Sprintf("%s://%s:%s@%s:%d", p.Protocol, p.Username, p.Password, p.Host, p.Port)
}
return fmt.Sprintf("%s://%s:%d", p.Protocol, p.Host, p.Port)
}
type ProxyWithAccountCount struct {
Proxy
AccountCount int64
}
...@@ -5,7 +5,6 @@ import ( ...@@ -5,7 +5,6 @@ import (
"fmt" "fmt"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
) )
...@@ -14,15 +13,15 @@ var ( ...@@ -14,15 +13,15 @@ var (
) )
type ProxyRepository interface { type ProxyRepository interface {
Create(ctx context.Context, proxy *model.Proxy) error Create(ctx context.Context, proxy *Proxy) error
GetByID(ctx context.Context, id int64) (*model.Proxy, error) GetByID(ctx context.Context, id int64) (*Proxy, error)
Update(ctx context.Context, proxy *model.Proxy) error Update(ctx context.Context, proxy *Proxy) error
Delete(ctx context.Context, id int64) error Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) List(ctx context.Context, params pagination.PaginationParams) ([]Proxy, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error)
ListActive(ctx context.Context) ([]model.Proxy, error) ListActive(ctx context.Context) ([]Proxy, error)
ListActiveWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error) ListActiveWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error)
ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error)
CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error)
...@@ -62,16 +61,16 @@ func NewProxyService(proxyRepo ProxyRepository) *ProxyService { ...@@ -62,16 +61,16 @@ func NewProxyService(proxyRepo ProxyRepository) *ProxyService {
} }
// Create 创建代理 // Create 创建代理
func (s *ProxyService) Create(ctx context.Context, req CreateProxyRequest) (*model.Proxy, error) { func (s *ProxyService) Create(ctx context.Context, req CreateProxyRequest) (*Proxy, error) {
// 创建代理 // 创建代理
proxy := &model.Proxy{ proxy := &Proxy{
Name: req.Name, Name: req.Name,
Protocol: req.Protocol, Protocol: req.Protocol,
Host: req.Host, Host: req.Host,
Port: req.Port, Port: req.Port,
Username: req.Username, Username: req.Username,
Password: req.Password, Password: req.Password,
Status: model.StatusActive, Status: StatusActive,
} }
if err := s.proxyRepo.Create(ctx, proxy); err != nil { if err := s.proxyRepo.Create(ctx, proxy); err != nil {
...@@ -82,7 +81,7 @@ func (s *ProxyService) Create(ctx context.Context, req CreateProxyRequest) (*mod ...@@ -82,7 +81,7 @@ func (s *ProxyService) Create(ctx context.Context, req CreateProxyRequest) (*mod
} }
// GetByID 根据ID获取代理 // GetByID 根据ID获取代理
func (s *ProxyService) GetByID(ctx context.Context, id int64) (*model.Proxy, error) { func (s *ProxyService) GetByID(ctx context.Context, id int64) (*Proxy, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id) proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, fmt.Errorf("get proxy: %w", err) return nil, fmt.Errorf("get proxy: %w", err)
...@@ -91,7 +90,7 @@ func (s *ProxyService) GetByID(ctx context.Context, id int64) (*model.Proxy, err ...@@ -91,7 +90,7 @@ func (s *ProxyService) GetByID(ctx context.Context, id int64) (*model.Proxy, err
} }
// List 获取代理列表 // List 获取代理列表
func (s *ProxyService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) { func (s *ProxyService) List(ctx context.Context, params pagination.PaginationParams) ([]Proxy, *pagination.PaginationResult, error) {
proxies, pagination, err := s.proxyRepo.List(ctx, params) proxies, pagination, err := s.proxyRepo.List(ctx, params)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("list proxies: %w", err) return nil, nil, fmt.Errorf("list proxies: %w", err)
...@@ -100,7 +99,7 @@ func (s *ProxyService) List(ctx context.Context, params pagination.PaginationPar ...@@ -100,7 +99,7 @@ func (s *ProxyService) List(ctx context.Context, params pagination.PaginationPar
} }
// ListActive 获取活跃代理列表 // ListActive 获取活跃代理列表
func (s *ProxyService) ListActive(ctx context.Context) ([]model.Proxy, error) { func (s *ProxyService) ListActive(ctx context.Context) ([]Proxy, error) {
proxies, err := s.proxyRepo.ListActive(ctx) proxies, err := s.proxyRepo.ListActive(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("list active proxies: %w", err) return nil, fmt.Errorf("list active proxies: %w", err)
...@@ -109,7 +108,7 @@ func (s *ProxyService) ListActive(ctx context.Context) ([]model.Proxy, error) { ...@@ -109,7 +108,7 @@ func (s *ProxyService) ListActive(ctx context.Context) ([]model.Proxy, error) {
} }
// Update 更新代理 // Update 更新代理
func (s *ProxyService) Update(ctx context.Context, id int64, req UpdateProxyRequest) (*model.Proxy, error) { func (s *ProxyService) Update(ctx context.Context, id int64, req UpdateProxyRequest) (*Proxy, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id) proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, fmt.Errorf("get proxy: %w", err) return nil, fmt.Errorf("get proxy: %w", err)
......
...@@ -8,7 +8,6 @@ import ( ...@@ -8,7 +8,6 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/model"
) )
// RateLimitService 处理限流和过载状态管理 // RateLimitService 处理限流和过载状态管理
...@@ -27,7 +26,7 @@ func NewRateLimitService(accountRepo AccountRepository, cfg *config.Config) *Rat ...@@ -27,7 +26,7 @@ func NewRateLimitService(accountRepo AccountRepository, cfg *config.Config) *Rat
// HandleUpstreamError 处理上游错误响应,标记账号状态 // HandleUpstreamError 处理上游错误响应,标记账号状态
// 返回是否应该停止该账号的调度 // 返回是否应该停止该账号的调度
func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *model.Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) { func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) {
// apikey 类型账号:检查自定义错误码配置 // apikey 类型账号:检查自定义错误码配置
// 如果启用且错误码不在列表中,则不处理(不停止调度、不标记限流/过载) // 如果启用且错误码不在列表中,则不处理(不停止调度、不标记限流/过载)
if !account.ShouldHandleErrorCode(statusCode) { if !account.ShouldHandleErrorCode(statusCode) {
...@@ -60,7 +59,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *mod ...@@ -60,7 +59,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *mod
} }
// handleAuthError 处理认证类错误(401/403),停止账号调度 // handleAuthError 处理认证类错误(401/403),停止账号调度
func (s *RateLimitService) handleAuthError(ctx context.Context, account *model.Account, errorMsg string) { func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account, errorMsg string) {
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil { if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
log.Printf("SetError failed for account %d: %v", account.ID, err) log.Printf("SetError failed for account %d: %v", account.ID, err)
return return
...@@ -70,7 +69,7 @@ func (s *RateLimitService) handleAuthError(ctx context.Context, account *model.A ...@@ -70,7 +69,7 @@ func (s *RateLimitService) handleAuthError(ctx context.Context, account *model.A
// handle429 处理429限流错误 // handle429 处理429限流错误
// 解析响应头获取重置时间,标记账号为限流状态 // 解析响应头获取重置时间,标记账号为限流状态
func (s *RateLimitService) handle429(ctx context.Context, account *model.Account, headers http.Header) { func (s *RateLimitService) handle429(ctx context.Context, account *Account, headers http.Header) {
// 解析重置时间戳 // 解析重置时间戳
resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset") resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset")
if resetTimestamp == "" { if resetTimestamp == "" {
...@@ -113,7 +112,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *model.Account ...@@ -113,7 +112,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *model.Account
// handle529 处理529过载错误 // handle529 处理529过载错误
// 根据配置设置过载冷却时间 // 根据配置设置过载冷却时间
func (s *RateLimitService) handle529(ctx context.Context, account *model.Account) { func (s *RateLimitService) handle529(ctx context.Context, account *Account) {
cooldownMinutes := s.cfg.RateLimit.OverloadCooldownMinutes cooldownMinutes := s.cfg.RateLimit.OverloadCooldownMinutes
if cooldownMinutes <= 0 { if cooldownMinutes <= 0 {
cooldownMinutes = 10 // 默认10分钟 cooldownMinutes = 10 // 默认10分钟
...@@ -129,7 +128,7 @@ func (s *RateLimitService) handle529(ctx context.Context, account *model.Account ...@@ -129,7 +128,7 @@ func (s *RateLimitService) handle529(ctx context.Context, account *model.Account
} }
// UpdateSessionWindow 从成功响应更新5h窗口状态 // UpdateSessionWindow 从成功响应更新5h窗口状态
func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *model.Account, headers http.Header) { func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Account, headers http.Header) {
status := headers.Get("anthropic-ratelimit-unified-5h-status") status := headers.Get("anthropic-ratelimit-unified-5h-status")
if status == "" { if status == "" {
return return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment