Commit 642842c2 authored by shaw's avatar shaw
Browse files

First commit

parent 569f4882
package service
import (
"context"
"errors"
"fmt"
"sub2api/internal/model"
"sub2api/internal/repository"
"gorm.io/gorm"
)
var (
ErrProxyNotFound = errors.New("proxy not found")
)
// CreateProxyRequest 创建代理请求
type CreateProxyRequest struct {
Name string `json:"name"`
Protocol string `json:"protocol"`
Host string `json:"host"`
Port int `json:"port"`
Username string `json:"username"`
Password string `json:"password"`
}
// UpdateProxyRequest 更新代理请求
type UpdateProxyRequest struct {
Name *string `json:"name"`
Protocol *string `json:"protocol"`
Host *string `json:"host"`
Port *int `json:"port"`
Username *string `json:"username"`
Password *string `json:"password"`
Status *string `json:"status"`
}
// ProxyService 代理管理服务
type ProxyService struct {
proxyRepo *repository.ProxyRepository
}
// NewProxyService 创建代理服务实例
func NewProxyService(proxyRepo *repository.ProxyRepository) *ProxyService {
return &ProxyService{
proxyRepo: proxyRepo,
}
}
// Create 创建代理
func (s *ProxyService) Create(ctx context.Context, req CreateProxyRequest) (*model.Proxy, error) {
// 创建代理
proxy := &model.Proxy{
Name: req.Name,
Protocol: req.Protocol,
Host: req.Host,
Port: req.Port,
Username: req.Username,
Password: req.Password,
Status: model.StatusActive,
}
if err := s.proxyRepo.Create(ctx, proxy); err != nil {
return nil, fmt.Errorf("create proxy: %w", err)
}
return proxy, nil
}
// GetByID 根据ID获取代理
func (s *ProxyService) GetByID(ctx context.Context, id int64) (*model.Proxy, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrProxyNotFound
}
return nil, fmt.Errorf("get proxy: %w", err)
}
return proxy, nil
}
// List 获取代理列表
func (s *ProxyService) List(ctx context.Context, params repository.PaginationParams) ([]model.Proxy, *repository.PaginationResult, error) {
proxies, pagination, err := s.proxyRepo.List(ctx, params)
if err != nil {
return nil, nil, fmt.Errorf("list proxies: %w", err)
}
return proxies, pagination, nil
}
// ListActive 获取活跃代理列表
func (s *ProxyService) ListActive(ctx context.Context) ([]model.Proxy, error) {
proxies, err := s.proxyRepo.ListActive(ctx)
if err != nil {
return nil, fmt.Errorf("list active proxies: %w", err)
}
return proxies, nil
}
// Update 更新代理
func (s *ProxyService) Update(ctx context.Context, id int64, req UpdateProxyRequest) (*model.Proxy, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrProxyNotFound
}
return nil, fmt.Errorf("get proxy: %w", err)
}
// 更新字段
if req.Name != nil {
proxy.Name = *req.Name
}
if req.Protocol != nil {
proxy.Protocol = *req.Protocol
}
if req.Host != nil {
proxy.Host = *req.Host
}
if req.Port != nil {
proxy.Port = *req.Port
}
if req.Username != nil {
proxy.Username = *req.Username
}
if req.Password != nil {
proxy.Password = *req.Password
}
if req.Status != nil {
proxy.Status = *req.Status
}
if err := s.proxyRepo.Update(ctx, proxy); err != nil {
return nil, fmt.Errorf("update proxy: %w", err)
}
return proxy, nil
}
// Delete 删除代理
func (s *ProxyService) Delete(ctx context.Context, id int64) error {
// 检查代理是否存在
_, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrProxyNotFound
}
return fmt.Errorf("get proxy: %w", err)
}
if err := s.proxyRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete proxy: %w", err)
}
return nil
}
// TestConnection 测试代理连接(需要实现具体测试逻辑)
func (s *ProxyService) TestConnection(ctx context.Context, id int64) error {
proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrProxyNotFound
}
return fmt.Errorf("get proxy: %w", err)
}
// TODO: 实现代理连接测试逻辑
// 可以尝试通过代理发送测试请求
_ = proxy
return nil
}
// GetURL 获取代理URL
func (s *ProxyService) GetURL(ctx context.Context, id int64) (string, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", ErrProxyNotFound
}
return "", fmt.Errorf("get proxy: %w", err)
}
return proxy.URL(), nil
}
package service
import (
"context"
"log"
"net/http"
"strconv"
"time"
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/repository"
)
// RateLimitService 处理限流和过载状态管理
type RateLimitService struct {
repos *repository.Repositories
cfg *config.Config
}
// NewRateLimitService 创建RateLimitService实例
func NewRateLimitService(repos *repository.Repositories, cfg *config.Config) *RateLimitService {
return &RateLimitService{
repos: repos,
cfg: cfg,
}
}
// HandleUpstreamError 处理上游错误响应,标记账号状态
// 返回是否应该停止该账号的调度
func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *model.Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) {
// apikey 类型账号:检查自定义错误码配置
// 如果启用且错误码不在列表中,则不处理(不停止调度、不标记限流/过载)
if !account.ShouldHandleErrorCode(statusCode) {
log.Printf("Account %d: error %d skipped (not in custom error codes)", account.ID, statusCode)
return false
}
switch statusCode {
case 401:
// 认证失败:停止调度,记录错误
s.handleAuthError(ctx, account, "Authentication failed (401): invalid or expired credentials")
return true
case 403:
// 禁止访问:停止调度,记录错误
s.handleAuthError(ctx, account, "Access forbidden (403): account may be suspended or lack permissions")
return true
case 429:
s.handle429(ctx, account, headers)
return false
case 529:
s.handle529(ctx, account)
return false
default:
// 其他5xx错误:记录但不停止调度
if statusCode >= 500 {
log.Printf("Account %d received upstream error %d", account.ID, statusCode)
}
return false
}
}
// handleAuthError 处理认证类错误(401/403),停止账号调度
func (s *RateLimitService) handleAuthError(ctx context.Context, account *model.Account, errorMsg string) {
if err := s.repos.Account.SetError(ctx, account.ID, errorMsg); err != nil {
log.Printf("SetError failed for account %d: %v", account.ID, err)
return
}
log.Printf("Account %d disabled due to auth error: %s", account.ID, errorMsg)
}
// handle429 处理429限流错误
// 解析响应头获取重置时间,标记账号为限流状态
func (s *RateLimitService) handle429(ctx context.Context, account *model.Account, headers http.Header) {
// 解析重置时间戳
resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset")
if resetTimestamp == "" {
// 没有重置时间,使用默认5分钟
resetAt := time.Now().Add(5 * time.Minute)
if err := s.repos.Account.SetRateLimited(ctx, account.ID, resetAt); err != nil {
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
}
return
}
// 解析Unix时间戳
ts, err := strconv.ParseInt(resetTimestamp, 10, 64)
if err != nil {
log.Printf("Parse reset timestamp failed: %v", err)
resetAt := time.Now().Add(5 * time.Minute)
if err := s.repos.Account.SetRateLimited(ctx, account.ID, resetAt); err != nil {
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
}
return
}
resetAt := time.Unix(ts, 0)
// 标记限流状态
if err := s.repos.Account.SetRateLimited(ctx, account.ID, resetAt); err != nil {
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
return
}
// 根据重置时间反推5h窗口
windowEnd := resetAt
windowStart := resetAt.Add(-5 * time.Hour)
if err := s.repos.Account.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil {
log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err)
}
log.Printf("Account %d rate limited until %v", account.ID, resetAt)
}
// handle529 处理529过载错误
// 根据配置设置过载冷却时间
func (s *RateLimitService) handle529(ctx context.Context, account *model.Account) {
cooldownMinutes := s.cfg.RateLimit.OverloadCooldownMinutes
if cooldownMinutes <= 0 {
cooldownMinutes = 10 // 默认10分钟
}
until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute)
if err := s.repos.Account.SetOverloaded(ctx, account.ID, until); err != nil {
log.Printf("SetOverloaded failed for account %d: %v", account.ID, err)
return
}
log.Printf("Account %d overloaded until %v", account.ID, until)
}
// UpdateSessionWindow 从成功响应更新5h窗口状态
func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *model.Account, headers http.Header) {
status := headers.Get("anthropic-ratelimit-unified-5h-status")
if status == "" {
return
}
// 检查是否需要初始化时间窗口
// 对于 Setup Token 账号,首次成功请求时需要预测时间窗口
var windowStart, windowEnd *time.Time
needInitWindow := account.SessionWindowEnd == nil || time.Now().After(*account.SessionWindowEnd)
if needInitWindow && (status == "allowed" || status == "allowed_warning") {
// 预测时间窗口:从当前时间的整点开始,+5小时为结束
// 例如:现在是 14:30,窗口为 14:00 ~ 19:00
now := time.Now()
start := time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), 0, 0, 0, now.Location())
end := start.Add(5 * time.Hour)
windowStart = &start
windowEnd = &end
log.Printf("Account %d: initializing 5h window from %v to %v (status: %s)", account.ID, start, end, status)
}
if err := s.repos.Account.UpdateSessionWindow(ctx, account.ID, windowStart, windowEnd, status); err != nil {
log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err)
}
// 如果状态为allowed且之前有限流,说明窗口已重置,清除限流状态
if status == "allowed" && account.IsRateLimited() {
if err := s.repos.Account.ClearRateLimit(ctx, account.ID); err != nil {
log.Printf("ClearRateLimit failed for account %d: %v", account.ID, err)
}
}
}
// ClearRateLimit 清除账号的限流状态
func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64) error {
return s.repos.Account.ClearRateLimit(ctx, accountID)
}
package service
import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"strings"
"sub2api/internal/model"
"sub2api/internal/repository"
"time"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
)
var (
ErrRedeemCodeNotFound = errors.New("redeem code not found")
ErrRedeemCodeUsed = errors.New("redeem code already used")
ErrRedeemCodeInvalid = errors.New("invalid redeem code")
ErrInsufficientBalance = errors.New("insufficient balance")
ErrRedeemRateLimited = errors.New("too many failed attempts, please try again later")
ErrRedeemCodeLocked = errors.New("redeem code is being processed, please try again")
)
const (
redeemRateLimitKeyPrefix = "redeem:rate_limit:"
redeemLockKeyPrefix = "redeem:lock:"
redeemMaxErrorsPerHour = 20
redeemRateLimitDuration = time.Hour
redeemLockDuration = 10 * time.Second // 锁超时时间,防止死锁
)
// GenerateCodesRequest 生成兑换码请求
type GenerateCodesRequest struct {
Count int `json:"count"`
Value float64 `json:"value"`
Type string `json:"type"`
}
// RedeemCodeResponse 兑换码响应
type RedeemCodeResponse struct {
Code string `json:"code"`
Value float64 `json:"value"`
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
}
// RedeemService 兑换码服务
type RedeemService struct {
redeemRepo *repository.RedeemCodeRepository
userRepo *repository.UserRepository
subscriptionService *SubscriptionService
rdb *redis.Client
billingCacheService *BillingCacheService
}
// NewRedeemService 创建兑换码服务实例
func NewRedeemService(redeemRepo *repository.RedeemCodeRepository, userRepo *repository.UserRepository, subscriptionService *SubscriptionService, rdb *redis.Client) *RedeemService {
return &RedeemService{
redeemRepo: redeemRepo,
userRepo: userRepo,
subscriptionService: subscriptionService,
rdb: rdb,
}
}
// SetBillingCacheService 设置计费缓存服务(用于缓存失效)
func (s *RedeemService) SetBillingCacheService(billingCacheService *BillingCacheService) {
s.billingCacheService = billingCacheService
}
// GenerateRandomCode 生成随机兑换码
func (s *RedeemService) GenerateRandomCode() (string, error) {
// 生成16字节随机数据
bytes := make([]byte, 16)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("generate random bytes: %w", err)
}
// 转换为十六进制字符串
code := hex.EncodeToString(bytes)
// 格式化为 XXXX-XXXX-XXXX-XXXX 格式
parts := []string{
strings.ToUpper(code[0:8]),
strings.ToUpper(code[8:16]),
strings.ToUpper(code[16:24]),
strings.ToUpper(code[24:32]),
}
return strings.Join(parts, "-"), nil
}
// GenerateCodes 批量生成兑换码
func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequest) ([]model.RedeemCode, error) {
if req.Count <= 0 {
return nil, errors.New("count must be greater than 0")
}
if req.Value <= 0 {
return nil, errors.New("value must be greater than 0")
}
if req.Count > 1000 {
return nil, errors.New("cannot generate more than 1000 codes at once")
}
codeType := req.Type
if codeType == "" {
codeType = model.RedeemTypeBalance
}
codes := make([]model.RedeemCode, 0, req.Count)
for i := 0; i < req.Count; i++ {
code, err := s.GenerateRandomCode()
if err != nil {
return nil, fmt.Errorf("generate code: %w", err)
}
codes = append(codes, model.RedeemCode{
Code: code,
Type: codeType,
Value: req.Value,
Status: model.StatusUnused,
})
}
// 批量插入
if err := s.redeemRepo.CreateBatch(ctx, codes); err != nil {
return nil, fmt.Errorf("create batch codes: %w", err)
}
return codes, nil
}
// checkRedeemRateLimit 检查用户兑换错误次数是否超限
func (s *RedeemService) checkRedeemRateLimit(ctx context.Context, userID int64) error {
if s.rdb == nil {
return nil
}
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
count, err := s.rdb.Get(ctx, key).Int()
if err != nil && !errors.Is(err, redis.Nil) {
// Redis 出错时不阻止用户操作
return nil
}
if count >= redeemMaxErrorsPerHour {
return ErrRedeemRateLimited
}
return nil
}
// incrementRedeemErrorCount 增加用户兑换错误计数
func (s *RedeemService) incrementRedeemErrorCount(ctx context.Context, userID int64) {
if s.rdb == nil {
return
}
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
pipe := s.rdb.Pipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, redeemRateLimitDuration)
_, _ = pipe.Exec(ctx)
}
// acquireRedeemLock 尝试获取兑换码的分布式锁
// 返回 true 表示获取成功,false 表示锁已被占用
func (s *RedeemService) acquireRedeemLock(ctx context.Context, code string) bool {
if s.rdb == nil {
return true // 无 Redis 时降级为不加锁
}
key := redeemLockKeyPrefix + code
ok, err := s.rdb.SetNX(ctx, key, "1", redeemLockDuration).Result()
if err != nil {
// Redis 出错时不阻止操作,依赖数据库层面的状态检查
return true
}
return ok
}
// releaseRedeemLock 释放兑换码的分布式锁
func (s *RedeemService) releaseRedeemLock(ctx context.Context, code string) {
if s.rdb == nil {
return
}
key := redeemLockKeyPrefix + code
s.rdb.Del(ctx, key)
}
// Redeem 使用兑换码
func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (*model.RedeemCode, error) {
// 检查限流
if err := s.checkRedeemRateLimit(ctx, userID); err != nil {
return nil, err
}
// 获取分布式锁,防止同一兑换码并发使用
if !s.acquireRedeemLock(ctx, code) {
return nil, ErrRedeemCodeLocked
}
defer s.releaseRedeemLock(ctx, code)
// 查找兑换码
redeemCode, err := s.redeemRepo.GetByCode(ctx, code)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
s.incrementRedeemErrorCount(ctx, userID)
return nil, ErrRedeemCodeNotFound
}
return nil, fmt.Errorf("get redeem code: %w", err)
}
// 检查兑换码状态
if !redeemCode.CanUse() {
s.incrementRedeemErrorCount(ctx, userID)
return nil, ErrRedeemCodeUsed
}
// 验证兑换码类型的前置条件
if redeemCode.Type == model.RedeemTypeSubscription && redeemCode.GroupID == nil {
return nil, errors.New("invalid subscription redeem code: missing group_id")
}
// 获取用户信息
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err)
}
_ = user // 使用变量避免未使用错误
// 【关键】先标记兑换码为已使用,确保并发安全
// 利用数据库乐观锁(WHERE status = 'unused')保证原子性
if err := s.redeemRepo.Use(ctx, redeemCode.ID, userID); err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
// 兑换码已被其他请求使用
return nil, ErrRedeemCodeUsed
}
return nil, fmt.Errorf("mark code as used: %w", err)
}
// 执行兑换逻辑(兑换码已被锁定,此时可安全操作)
switch redeemCode.Type {
case model.RedeemTypeBalance:
// 增加用户余额
if err := s.userRepo.UpdateBalance(ctx, userID, redeemCode.Value); err != nil {
return nil, fmt.Errorf("update user balance: %w", err)
}
// 失效余额缓存
if s.billingCacheService != nil {
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
}()
}
case model.RedeemTypeConcurrency:
// 增加用户并发数
if err := s.userRepo.UpdateConcurrency(ctx, userID, int(redeemCode.Value)); err != nil {
return nil, fmt.Errorf("update user concurrency: %w", err)
}
case model.RedeemTypeSubscription:
validityDays := redeemCode.ValidityDays
if validityDays <= 0 {
validityDays = 30
}
_, _, err := s.subscriptionService.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{
UserID: userID,
GroupID: *redeemCode.GroupID,
ValidityDays: validityDays,
AssignedBy: 0, // 系统分配
Notes: fmt.Sprintf("通过兑换码 %s 兑换", redeemCode.Code),
})
if err != nil {
return nil, fmt.Errorf("assign or extend subscription: %w", err)
}
// 失效订阅缓存
if s.billingCacheService != nil {
groupID := *redeemCode.GroupID
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
}()
}
default:
return nil, fmt.Errorf("unsupported redeem type: %s", redeemCode.Type)
}
// 重新获取更新后的兑换码
redeemCode, err = s.redeemRepo.GetByID(ctx, redeemCode.ID)
if err != nil {
return nil, fmt.Errorf("get updated redeem code: %w", err)
}
return redeemCode, nil
}
// GetByID 根据ID获取兑换码
func (s *RedeemService) GetByID(ctx context.Context, id int64) (*model.RedeemCode, error) {
code, err := s.redeemRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrRedeemCodeNotFound
}
return nil, fmt.Errorf("get redeem code: %w", err)
}
return code, nil
}
// GetByCode 根据Code获取兑换码
func (s *RedeemService) GetByCode(ctx context.Context, code string) (*model.RedeemCode, error) {
redeemCode, err := s.redeemRepo.GetByCode(ctx, code)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrRedeemCodeNotFound
}
return nil, fmt.Errorf("get redeem code: %w", err)
}
return redeemCode, nil
}
// List 获取兑换码列表(管理员功能)
func (s *RedeemService) List(ctx context.Context, params repository.PaginationParams) ([]model.RedeemCode, *repository.PaginationResult, error) {
codes, pagination, err := s.redeemRepo.List(ctx, params)
if err != nil {
return nil, nil, fmt.Errorf("list redeem codes: %w", err)
}
return codes, pagination, nil
}
// Delete 删除兑换码(管理员功能)
func (s *RedeemService) Delete(ctx context.Context, id int64) error {
// 检查兑换码是否存在
code, err := s.redeemRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrRedeemCodeNotFound
}
return fmt.Errorf("get redeem code: %w", err)
}
// 不允许删除已使用的兑换码
if code.IsUsed() {
return errors.New("cannot delete used redeem code")
}
if err := s.redeemRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete redeem code: %w", err)
}
return nil
}
// GetStats 获取兑换码统计信息
func (s *RedeemService) GetStats(ctx context.Context) (map[string]interface{}, error) {
// TODO: 实现统计逻辑
// 统计未使用、已使用的兑换码数量
// 统计总面值等
stats := map[string]interface{}{
"total_codes": 0,
"unused_codes": 0,
"used_codes": 0,
"total_value": 0.0,
}
return stats, nil
}
// GetUserHistory 获取用户的兑换历史
func (s *RedeemService) GetUserHistory(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error) {
codes, err := s.redeemRepo.ListByUser(ctx, userID, limit)
if err != nil {
return nil, fmt.Errorf("get user redeem history: %w", err)
}
return codes, nil
}
package service
import (
"sub2api/internal/config"
"sub2api/internal/repository"
"github.com/redis/go-redis/v9"
)
// Services 服务集合容器
type Services struct {
Auth *AuthService
User *UserService
ApiKey *ApiKeyService
Group *GroupService
Account *AccountService
Proxy *ProxyService
Redeem *RedeemService
Usage *UsageService
Pricing *PricingService
Billing *BillingService
BillingCache *BillingCacheService
Admin AdminService
Gateway *GatewayService
OAuth *OAuthService
RateLimit *RateLimitService
AccountUsage *AccountUsageService
AccountTest *AccountTestService
Setting *SettingService
Email *EmailService
EmailQueue *EmailQueueService
Turnstile *TurnstileService
Subscription *SubscriptionService
Concurrency *ConcurrencyService
Identity *IdentityService
}
// NewServices 创建所有服务实例
func NewServices(repos *repository.Repositories, rdb *redis.Client, cfg *config.Config) *Services {
// 初始化价格服务
pricingService := NewPricingService(cfg)
if err := pricingService.Initialize(); err != nil {
// 价格服务初始化失败不应阻止启动,使用回退价格
println("[Service] Warning: Pricing service initialization failed:", err.Error())
}
// 初始化计费服务(依赖价格服务)
billingService := NewBillingService(cfg, pricingService)
// 初始化其他服务
authService := NewAuthService(repos.User, cfg)
userService := NewUserService(repos.User, cfg)
apiKeyService := NewApiKeyService(repos.ApiKey, repos.User, repos.Group, repos.UserSubscription, rdb, cfg)
groupService := NewGroupService(repos.Group)
accountService := NewAccountService(repos.Account, repos.Group)
proxyService := NewProxyService(repos.Proxy)
usageService := NewUsageService(repos.UsageLog, repos.User)
// 初始化订阅服务 (RedeemService 依赖)
subscriptionService := NewSubscriptionService(repos)
// 初始化兑换服务 (依赖订阅服务)
redeemService := NewRedeemService(repos.RedeemCode, repos.User, subscriptionService, rdb)
// 初始化Admin服务
adminService := NewAdminService(repos)
// 初始化OAuth服务(GatewayService依赖)
oauthService := NewOAuthService(repos.Proxy)
// 初始化限流服务
rateLimitService := NewRateLimitService(repos, cfg)
// 初始化计费缓存服务
billingCacheService := NewBillingCacheService(rdb, repos.User, repos.UserSubscription)
// 初始化账号使用量服务
accountUsageService := NewAccountUsageService(repos, oauthService)
// 初始化账号测试服务
accountTestService := NewAccountTestService(repos, oauthService)
// 初始化身份指纹服务
identityService := NewIdentityService(rdb)
// 初始化Gateway服务
gatewayService := NewGatewayService(repos, rdb, cfg, oauthService, billingService, rateLimitService, billingCacheService, identityService)
// 初始化设置服务
settingService := NewSettingService(repos.Setting, cfg)
emailService := NewEmailService(repos.Setting, rdb)
// 初始化邮件队列服务
emailQueueService := NewEmailQueueService(emailService, 3)
// 初始化Turnstile服务
turnstileService := NewTurnstileService(settingService)
// 设置Auth服务的依赖(用于注册开关和邮件验证)
authService.SetSettingService(settingService)
authService.SetEmailService(emailService)
authService.SetTurnstileService(turnstileService)
authService.SetEmailQueueService(emailQueueService)
// 初始化并发控制服务
concurrencyService := NewConcurrencyService(rdb)
// 注入计费缓存服务到需要失效缓存的服务
redeemService.SetBillingCacheService(billingCacheService)
subscriptionService.SetBillingCacheService(billingCacheService)
SetAdminServiceBillingCache(adminService, billingCacheService)
return &Services{
Auth: authService,
User: userService,
ApiKey: apiKeyService,
Group: groupService,
Account: accountService,
Proxy: proxyService,
Redeem: redeemService,
Usage: usageService,
Pricing: pricingService,
Billing: billingService,
BillingCache: billingCacheService,
Admin: adminService,
Gateway: gatewayService,
OAuth: oauthService,
RateLimit: rateLimitService,
AccountUsage: accountUsageService,
AccountTest: accountTestService,
Setting: settingService,
Email: emailService,
EmailQueue: emailQueueService,
Turnstile: turnstileService,
Subscription: subscriptionService,
Concurrency: concurrencyService,
Identity: identityService,
}
}
package service
import (
"context"
"errors"
"fmt"
"strconv"
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/repository"
"gorm.io/gorm"
)
var (
ErrRegistrationDisabled = errors.New("registration is currently disabled")
)
// SettingService 系统设置服务
type SettingService struct {
settingRepo *repository.SettingRepository
cfg *config.Config
}
// NewSettingService 创建系统设置服务实例
func NewSettingService(settingRepo *repository.SettingRepository, cfg *config.Config) *SettingService {
return &SettingService{
settingRepo: settingRepo,
cfg: cfg,
}
}
// GetAllSettings 获取所有系统设置
func (s *SettingService) GetAllSettings(ctx context.Context) (*model.SystemSettings, error) {
settings, err := s.settingRepo.GetAll(ctx)
if err != nil {
return nil, fmt.Errorf("get all settings: %w", err)
}
return s.parseSettings(settings), nil
}
// GetPublicSettings 获取公开设置(无需登录)
func (s *SettingService) GetPublicSettings(ctx context.Context) (*model.PublicSettings, error) {
keys := []string{
model.SettingKeyRegistrationEnabled,
model.SettingKeyEmailVerifyEnabled,
model.SettingKeyTurnstileEnabled,
model.SettingKeyTurnstileSiteKey,
model.SettingKeySiteName,
model.SettingKeySiteLogo,
model.SettingKeySiteSubtitle,
model.SettingKeyApiBaseUrl,
model.SettingKeyContactInfo,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
if err != nil {
return nil, fmt.Errorf("get public settings: %w", err)
}
return &model.PublicSettings{
RegistrationEnabled: settings[model.SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[model.SettingKeyEmailVerifyEnabled] == "true",
TurnstileEnabled: settings[model.SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[model.SettingKeyTurnstileSiteKey],
SiteName: s.getStringOrDefault(settings, model.SettingKeySiteName, "Sub2API"),
SiteLogo: settings[model.SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, model.SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
ApiBaseUrl: settings[model.SettingKeyApiBaseUrl],
ContactInfo: settings[model.SettingKeyContactInfo],
}, nil
}
// UpdateSettings 更新系统设置
func (s *SettingService) UpdateSettings(ctx context.Context, settings *model.SystemSettings) error {
updates := make(map[string]string)
// 注册设置
updates[model.SettingKeyRegistrationEnabled] = strconv.FormatBool(settings.RegistrationEnabled)
updates[model.SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
// 邮件服务设置(只有非空才更新密码)
updates[model.SettingKeySmtpHost] = settings.SmtpHost
updates[model.SettingKeySmtpPort] = strconv.Itoa(settings.SmtpPort)
updates[model.SettingKeySmtpUsername] = settings.SmtpUsername
if settings.SmtpPassword != "" {
updates[model.SettingKeySmtpPassword] = settings.SmtpPassword
}
updates[model.SettingKeySmtpFrom] = settings.SmtpFrom
updates[model.SettingKeySmtpFromName] = settings.SmtpFromName
updates[model.SettingKeySmtpUseTLS] = strconv.FormatBool(settings.SmtpUseTLS)
// Cloudflare Turnstile 设置(只有非空才更新密钥)
updates[model.SettingKeyTurnstileEnabled] = strconv.FormatBool(settings.TurnstileEnabled)
updates[model.SettingKeyTurnstileSiteKey] = settings.TurnstileSiteKey
if settings.TurnstileSecretKey != "" {
updates[model.SettingKeyTurnstileSecretKey] = settings.TurnstileSecretKey
}
// OEM设置
updates[model.SettingKeySiteName] = settings.SiteName
updates[model.SettingKeySiteLogo] = settings.SiteLogo
updates[model.SettingKeySiteSubtitle] = settings.SiteSubtitle
updates[model.SettingKeyApiBaseUrl] = settings.ApiBaseUrl
updates[model.SettingKeyContactInfo] = settings.ContactInfo
// 默认配置
updates[model.SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
updates[model.SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
return s.settingRepo.SetMultiple(ctx, updates)
}
// IsRegistrationEnabled 检查是否开放注册
func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyRegistrationEnabled)
if err != nil {
// 默认开放注册
return true
}
return value == "true"
}
// IsEmailVerifyEnabled 检查是否开启邮件验证
func (s *SettingService) IsEmailVerifyEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyEmailVerifyEnabled)
if err != nil {
return false
}
return value == "true"
}
// GetSiteName 获取网站名称
func (s *SettingService) GetSiteName(ctx context.Context) string {
value, err := s.settingRepo.GetValue(ctx, model.SettingKeySiteName)
if err != nil || value == "" {
return "Sub2API"
}
return value
}
// GetDefaultConcurrency 获取默认并发量
func (s *SettingService) GetDefaultConcurrency(ctx context.Context) int {
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyDefaultConcurrency)
if err != nil {
return s.cfg.Default.UserConcurrency
}
if v, err := strconv.Atoi(value); err == nil && v > 0 {
return v
}
return s.cfg.Default.UserConcurrency
}
// GetDefaultBalance 获取默认余额
func (s *SettingService) GetDefaultBalance(ctx context.Context) float64 {
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyDefaultBalance)
if err != nil {
return s.cfg.Default.UserBalance
}
if v, err := strconv.ParseFloat(value, 64); err == nil && v >= 0 {
return v
}
return s.cfg.Default.UserBalance
}
// InitializeDefaultSettings 初始化默认设置
func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// 检查是否已有设置
_, err := s.settingRepo.GetValue(ctx, model.SettingKeyRegistrationEnabled)
if err == nil {
// 已有设置,不需要初始化
return nil
}
if !errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("check existing settings: %w", err)
}
// 初始化默认设置
defaults := map[string]string{
model.SettingKeyRegistrationEnabled: "true",
model.SettingKeyEmailVerifyEnabled: "false",
model.SettingKeySiteName: "Sub2API",
model.SettingKeySiteLogo: "",
model.SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
model.SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
model.SettingKeySmtpPort: "587",
model.SettingKeySmtpUseTLS: "false",
}
return s.settingRepo.SetMultiple(ctx, defaults)
}
// parseSettings 解析设置到结构体
func (s *SettingService) parseSettings(settings map[string]string) *model.SystemSettings {
result := &model.SystemSettings{
RegistrationEnabled: settings[model.SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[model.SettingKeyEmailVerifyEnabled] == "true",
SmtpHost: settings[model.SettingKeySmtpHost],
SmtpUsername: settings[model.SettingKeySmtpUsername],
SmtpFrom: settings[model.SettingKeySmtpFrom],
SmtpFromName: settings[model.SettingKeySmtpFromName],
SmtpUseTLS: settings[model.SettingKeySmtpUseTLS] == "true",
TurnstileEnabled: settings[model.SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[model.SettingKeyTurnstileSiteKey],
SiteName: s.getStringOrDefault(settings, model.SettingKeySiteName, "Sub2API"),
SiteLogo: settings[model.SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, model.SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
ApiBaseUrl: settings[model.SettingKeyApiBaseUrl],
ContactInfo: settings[model.SettingKeyContactInfo],
}
// 解析整数类型
if port, err := strconv.Atoi(settings[model.SettingKeySmtpPort]); err == nil {
result.SmtpPort = port
} else {
result.SmtpPort = 587
}
if concurrency, err := strconv.Atoi(settings[model.SettingKeyDefaultConcurrency]); err == nil {
result.DefaultConcurrency = concurrency
} else {
result.DefaultConcurrency = s.cfg.Default.UserConcurrency
}
// 解析浮点数类型
if balance, err := strconv.ParseFloat(settings[model.SettingKeyDefaultBalance], 64); err == nil {
result.DefaultBalance = balance
} else {
result.DefaultBalance = s.cfg.Default.UserBalance
}
// 敏感信息直接返回,方便测试连接时使用
result.SmtpPassword = settings[model.SettingKeySmtpPassword]
result.TurnstileSecretKey = settings[model.SettingKeyTurnstileSecretKey]
return result
}
// getStringOrDefault 获取字符串值或默认值
func (s *SettingService) getStringOrDefault(settings map[string]string, key, defaultValue string) string {
if value, ok := settings[key]; ok && value != "" {
return value
}
return defaultValue
}
// IsTurnstileEnabled 检查是否启用 Turnstile 验证
func (s *SettingService) IsTurnstileEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyTurnstileEnabled)
if err != nil {
return false
}
return value == "true"
}
// GetTurnstileSecretKey 获取 Turnstile Secret Key
func (s *SettingService) GetTurnstileSecretKey(ctx context.Context) string {
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyTurnstileSecretKey)
if err != nil {
return ""
}
return value
}
package service
import (
"context"
"errors"
"fmt"
"time"
"sub2api/internal/model"
"sub2api/internal/repository"
)
var (
ErrSubscriptionNotFound = errors.New("subscription not found")
ErrSubscriptionExpired = errors.New("subscription has expired")
ErrSubscriptionSuspended = errors.New("subscription is suspended")
ErrSubscriptionAlreadyExists = errors.New("subscription already exists for this user and group")
ErrGroupNotSubscriptionType = errors.New("group is not a subscription type")
ErrDailyLimitExceeded = errors.New("daily usage limit exceeded")
ErrWeeklyLimitExceeded = errors.New("weekly usage limit exceeded")
ErrMonthlyLimitExceeded = errors.New("monthly usage limit exceeded")
)
// SubscriptionService 订阅服务
type SubscriptionService struct {
repos *repository.Repositories
billingCacheService *BillingCacheService
}
// NewSubscriptionService 创建订阅服务
func NewSubscriptionService(repos *repository.Repositories) *SubscriptionService {
return &SubscriptionService{repos: repos}
}
// SetBillingCacheService 设置计费缓存服务(用于缓存失效)
func (s *SubscriptionService) SetBillingCacheService(billingCacheService *BillingCacheService) {
s.billingCacheService = billingCacheService
}
// AssignSubscriptionInput 分配订阅输入
type AssignSubscriptionInput struct {
UserID int64
GroupID int64
ValidityDays int
AssignedBy int64
Notes string
}
// AssignSubscription 分配订阅给用户(不允许重复分配)
func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, error) {
// 检查分组是否存在且为订阅类型
group, err := s.repos.Group.GetByID(ctx, input.GroupID)
if err != nil {
return nil, fmt.Errorf("group not found: %w", err)
}
if !group.IsSubscriptionType() {
return nil, ErrGroupNotSubscriptionType
}
// 检查是否已存在订阅
exists, err := s.repos.UserSubscription.ExistsByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
if err != nil {
return nil, err
}
if exists {
return nil, ErrSubscriptionAlreadyExists
}
sub, err := s.createSubscription(ctx, input)
if err != nil {
return nil, err
}
// 失效订阅缓存
if s.billingCacheService != nil {
userID, groupID := input.UserID, input.GroupID
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
}()
}
return sub, nil
}
// AssignOrExtendSubscription 分配或续期订阅(用于兑换码等场景)
// 如果用户已有同分组的订阅:
// - 未过期:从当前过期时间累加天数
// - 已过期:从当前时间开始计算新的过期时间,并激活订阅
// 如果没有订阅:创建新订阅
func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, bool, error) {
// 检查分组是否存在且为订阅类型
group, err := s.repos.Group.GetByID(ctx, input.GroupID)
if err != nil {
return nil, false, fmt.Errorf("group not found: %w", err)
}
if !group.IsSubscriptionType() {
return nil, false, ErrGroupNotSubscriptionType
}
// 查询是否已有订阅
existingSub, err := s.repos.UserSubscription.GetByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
if err != nil {
// 不存在记录是正常情况,其他错误需要返回
existingSub = nil
}
validityDays := input.ValidityDays
if validityDays <= 0 {
validityDays = 30
}
// 已有订阅,执行续期
if existingSub != nil {
now := time.Now()
var newExpiresAt time.Time
if existingSub.ExpiresAt.After(now) {
// 未过期:从当前过期时间累加
newExpiresAt = existingSub.ExpiresAt.AddDate(0, 0, validityDays)
} else {
// 已过期:从当前时间开始计算
newExpiresAt = now.AddDate(0, 0, validityDays)
}
// 更新过期时间
if err := s.repos.UserSubscription.ExtendExpiry(ctx, existingSub.ID, newExpiresAt); err != nil {
return nil, false, fmt.Errorf("extend subscription: %w", err)
}
// 如果订阅已过期或被暂停,恢复为active状态
if existingSub.Status != model.SubscriptionStatusActive {
if err := s.repos.UserSubscription.UpdateStatus(ctx, existingSub.ID, model.SubscriptionStatusActive); err != nil {
return nil, false, fmt.Errorf("update subscription status: %w", err)
}
}
// 追加备注
if input.Notes != "" {
newNotes := existingSub.Notes
if newNotes != "" {
newNotes += "\n"
}
newNotes += input.Notes
if err := s.repos.UserSubscription.UpdateNotes(ctx, existingSub.ID, newNotes); err != nil {
// 备注更新失败不影响主流程
}
}
// 失效订阅缓存
if s.billingCacheService != nil {
userID, groupID := input.UserID, input.GroupID
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
}()
}
// 返回更新后的订阅
sub, err := s.repos.UserSubscription.GetByID(ctx, existingSub.ID)
return sub, true, err // true 表示是续期
}
// 没有订阅,创建新订阅
sub, err := s.createSubscription(ctx, input)
if err != nil {
return nil, false, err
}
// 失效订阅缓存
if s.billingCacheService != nil {
userID, groupID := input.UserID, input.GroupID
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
}()
}
return sub, false, nil // false 表示是新建
}
// createSubscription 创建新订阅(内部方法)
func (s *SubscriptionService) createSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, error) {
validityDays := input.ValidityDays
if validityDays <= 0 {
validityDays = 30
}
now := time.Now()
sub := &model.UserSubscription{
UserID: input.UserID,
GroupID: input.GroupID,
StartsAt: now,
ExpiresAt: now.AddDate(0, 0, validityDays),
Status: model.SubscriptionStatusActive,
AssignedAt: now,
Notes: input.Notes,
CreatedAt: now,
UpdatedAt: now,
}
// 只有当 AssignedBy > 0 时才设置(0 表示系统分配,如兑换码)
if input.AssignedBy > 0 {
sub.AssignedBy = &input.AssignedBy
}
if err := s.repos.UserSubscription.Create(ctx, sub); err != nil {
return nil, err
}
// 重新获取完整订阅信息(包含关联)
return s.repos.UserSubscription.GetByID(ctx, sub.ID)
}
// BulkAssignSubscriptionInput 批量分配订阅输入
type BulkAssignSubscriptionInput struct {
UserIDs []int64
GroupID int64
ValidityDays int
AssignedBy int64
Notes string
}
// BulkAssignResult 批量分配结果
type BulkAssignResult struct {
SuccessCount int
FailedCount int
Subscriptions []model.UserSubscription
Errors []string
}
// BulkAssignSubscription 批量分配订阅
func (s *SubscriptionService) BulkAssignSubscription(ctx context.Context, input *BulkAssignSubscriptionInput) (*BulkAssignResult, error) {
result := &BulkAssignResult{
Subscriptions: make([]model.UserSubscription, 0),
Errors: make([]string, 0),
}
for _, userID := range input.UserIDs {
sub, err := s.AssignSubscription(ctx, &AssignSubscriptionInput{
UserID: userID,
GroupID: input.GroupID,
ValidityDays: input.ValidityDays,
AssignedBy: input.AssignedBy,
Notes: input.Notes,
})
if err != nil {
result.FailedCount++
result.Errors = append(result.Errors, fmt.Sprintf("user %d: %v", userID, err))
} else {
result.SuccessCount++
result.Subscriptions = append(result.Subscriptions, *sub)
}
}
return result, nil
}
// RevokeSubscription 撤销订阅
func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscriptionID int64) error {
// 先获取订阅信息用于失效缓存
sub, err := s.repos.UserSubscription.GetByID(ctx, subscriptionID)
if err != nil {
return err
}
if err := s.repos.UserSubscription.Delete(ctx, subscriptionID); err != nil {
return err
}
// 失效订阅缓存
if s.billingCacheService != nil {
userID, groupID := sub.UserID, sub.GroupID
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
}()
}
return nil
}
// ExtendSubscription 延长订阅
func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscriptionID int64, days int) (*model.UserSubscription, error) {
sub, err := s.repos.UserSubscription.GetByID(ctx, subscriptionID)
if err != nil {
return nil, ErrSubscriptionNotFound
}
// 计算新的过期时间
newExpiresAt := sub.ExpiresAt.AddDate(0, 0, days)
if err := s.repos.UserSubscription.ExtendExpiry(ctx, subscriptionID, newExpiresAt); err != nil {
return nil, err
}
// 如果订阅已过期,恢复为active状态
if sub.Status == model.SubscriptionStatusExpired {
if err := s.repos.UserSubscription.UpdateStatus(ctx, subscriptionID, model.SubscriptionStatusActive); err != nil {
return nil, err
}
}
// 失效订阅缓存
if s.billingCacheService != nil {
userID, groupID := sub.UserID, sub.GroupID
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
}()
}
return s.repos.UserSubscription.GetByID(ctx, subscriptionID)
}
// GetByID 根据ID获取订阅
func (s *SubscriptionService) GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) {
return s.repos.UserSubscription.GetByID(ctx, id)
}
// GetActiveSubscription 获取用户对特定分组的有效订阅
func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
sub, err := s.repos.UserSubscription.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
if err != nil {
return nil, ErrSubscriptionNotFound
}
return sub, nil
}
// ListUserSubscriptions 获取用户的所有订阅
func (s *SubscriptionService) ListUserSubscriptions(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
return s.repos.UserSubscription.ListByUserID(ctx, userID)
}
// ListActiveUserSubscriptions 获取用户的所有有效订阅
func (s *SubscriptionService) ListActiveUserSubscriptions(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
return s.repos.UserSubscription.ListActiveByUserID(ctx, userID)
}
// ListGroupSubscriptions 获取分组的所有订阅
func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupID int64, page, pageSize int) ([]model.UserSubscription, *repository.PaginationResult, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize}
return s.repos.UserSubscription.ListByGroupID(ctx, groupID, params)
}
// List 获取所有订阅(分页,支持筛选)
func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status string) ([]model.UserSubscription, *repository.PaginationResult, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize}
return s.repos.UserSubscription.List(ctx, params, userID, groupID, status)
}
// CheckAndActivateWindow 检查并激活窗口(首次使用时)
func (s *SubscriptionService) CheckAndActivateWindow(ctx context.Context, sub *model.UserSubscription) error {
if sub.IsWindowActivated() {
return nil
}
now := time.Now()
return s.repos.UserSubscription.ActivateWindows(ctx, sub.ID, now)
}
// CheckAndResetWindows 检查并重置过期的窗口
func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *model.UserSubscription) error {
now := time.Now()
// 日窗口重置(24小时)
if sub.NeedsDailyReset() {
if err := s.repos.UserSubscription.ResetDailyUsage(ctx, sub.ID, now); err != nil {
return err
}
sub.DailyWindowStart = &now
sub.DailyUsageUSD = 0
}
// 周窗口重置(7天)
if sub.NeedsWeeklyReset() {
if err := s.repos.UserSubscription.ResetWeeklyUsage(ctx, sub.ID, now); err != nil {
return err
}
sub.WeeklyWindowStart = &now
sub.WeeklyUsageUSD = 0
}
// 月窗口重置(30天)
if sub.NeedsMonthlyReset() {
if err := s.repos.UserSubscription.ResetMonthlyUsage(ctx, sub.ID, now); err != nil {
return err
}
sub.MonthlyWindowStart = &now
sub.MonthlyUsageUSD = 0
}
return nil
}
// CheckUsageLimits 检查使用限额(返回错误如果超限)
func (s *SubscriptionService) CheckUsageLimits(ctx context.Context, sub *model.UserSubscription, group *model.Group, additionalCost float64) error {
if !sub.CheckDailyLimit(group, additionalCost) {
return ErrDailyLimitExceeded
}
if !sub.CheckWeeklyLimit(group, additionalCost) {
return ErrWeeklyLimitExceeded
}
if !sub.CheckMonthlyLimit(group, additionalCost) {
return ErrMonthlyLimitExceeded
}
return nil
}
// RecordUsage 记录使用量到订阅
func (s *SubscriptionService) RecordUsage(ctx context.Context, subscriptionID int64, costUSD float64) error {
return s.repos.UserSubscription.IncrementUsage(ctx, subscriptionID, costUSD)
}
// SubscriptionProgress 订阅进度
type SubscriptionProgress struct {
ID int64 `json:"id"`
GroupName string `json:"group_name"`
ExpiresAt time.Time `json:"expires_at"`
ExpiresInDays int `json:"expires_in_days"`
Daily *UsageWindowProgress `json:"daily,omitempty"`
Weekly *UsageWindowProgress `json:"weekly,omitempty"`
Monthly *UsageWindowProgress `json:"monthly,omitempty"`
}
// UsageWindowProgress 使用窗口进度
type UsageWindowProgress struct {
LimitUSD float64 `json:"limit_usd"`
UsedUSD float64 `json:"used_usd"`
RemainingUSD float64 `json:"remaining_usd"`
Percentage float64 `json:"percentage"`
WindowStart time.Time `json:"window_start"`
ResetsAt time.Time `json:"resets_at"`
ResetsInSeconds int64 `json:"resets_in_seconds"`
}
// GetSubscriptionProgress 获取订阅使用进度
func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subscriptionID int64) (*SubscriptionProgress, error) {
sub, err := s.repos.UserSubscription.GetByID(ctx, subscriptionID)
if err != nil {
return nil, ErrSubscriptionNotFound
}
group := sub.Group
if group == nil {
group, err = s.repos.Group.GetByID(ctx, sub.GroupID)
if err != nil {
return nil, err
}
}
progress := &SubscriptionProgress{
ID: sub.ID,
GroupName: group.Name,
ExpiresAt: sub.ExpiresAt,
ExpiresInDays: sub.DaysRemaining(),
}
// 日进度
if group.HasDailyLimit() && sub.DailyWindowStart != nil {
limit := *group.DailyLimitUSD
resetsAt := sub.DailyWindowStart.Add(24 * time.Hour)
progress.Daily = &UsageWindowProgress{
LimitUSD: limit,
UsedUSD: sub.DailyUsageUSD,
RemainingUSD: limit - sub.DailyUsageUSD,
Percentage: (sub.DailyUsageUSD / limit) * 100,
WindowStart: *sub.DailyWindowStart,
ResetsAt: resetsAt,
ResetsInSeconds: int64(time.Until(resetsAt).Seconds()),
}
if progress.Daily.RemainingUSD < 0 {
progress.Daily.RemainingUSD = 0
}
if progress.Daily.Percentage > 100 {
progress.Daily.Percentage = 100
}
if progress.Daily.ResetsInSeconds < 0 {
progress.Daily.ResetsInSeconds = 0
}
}
// 周进度
if group.HasWeeklyLimit() && sub.WeeklyWindowStart != nil {
limit := *group.WeeklyLimitUSD
resetsAt := sub.WeeklyWindowStart.Add(7 * 24 * time.Hour)
progress.Weekly = &UsageWindowProgress{
LimitUSD: limit,
UsedUSD: sub.WeeklyUsageUSD,
RemainingUSD: limit - sub.WeeklyUsageUSD,
Percentage: (sub.WeeklyUsageUSD / limit) * 100,
WindowStart: *sub.WeeklyWindowStart,
ResetsAt: resetsAt,
ResetsInSeconds: int64(time.Until(resetsAt).Seconds()),
}
if progress.Weekly.RemainingUSD < 0 {
progress.Weekly.RemainingUSD = 0
}
if progress.Weekly.Percentage > 100 {
progress.Weekly.Percentage = 100
}
if progress.Weekly.ResetsInSeconds < 0 {
progress.Weekly.ResetsInSeconds = 0
}
}
// 月进度
if group.HasMonthlyLimit() && sub.MonthlyWindowStart != nil {
limit := *group.MonthlyLimitUSD
resetsAt := sub.MonthlyWindowStart.Add(30 * 24 * time.Hour)
progress.Monthly = &UsageWindowProgress{
LimitUSD: limit,
UsedUSD: sub.MonthlyUsageUSD,
RemainingUSD: limit - sub.MonthlyUsageUSD,
Percentage: (sub.MonthlyUsageUSD / limit) * 100,
WindowStart: *sub.MonthlyWindowStart,
ResetsAt: resetsAt,
ResetsInSeconds: int64(time.Until(resetsAt).Seconds()),
}
if progress.Monthly.RemainingUSD < 0 {
progress.Monthly.RemainingUSD = 0
}
if progress.Monthly.Percentage > 100 {
progress.Monthly.Percentage = 100
}
if progress.Monthly.ResetsInSeconds < 0 {
progress.Monthly.ResetsInSeconds = 0
}
}
return progress, nil
}
// GetUserSubscriptionsWithProgress 获取用户所有订阅及进度
func (s *SubscriptionService) GetUserSubscriptionsWithProgress(ctx context.Context, userID int64) ([]SubscriptionProgress, error) {
subs, err := s.repos.UserSubscription.ListActiveByUserID(ctx, userID)
if err != nil {
return nil, err
}
progresses := make([]SubscriptionProgress, 0, len(subs))
for _, sub := range subs {
progress, err := s.GetSubscriptionProgress(ctx, sub.ID)
if err != nil {
continue
}
progresses = append(progresses, *progress)
}
return progresses, nil
}
// UpdateExpiredSubscriptions 更新过期订阅状态(定时任务调用)
func (s *SubscriptionService) UpdateExpiredSubscriptions(ctx context.Context) (int64, error) {
return s.repos.UserSubscription.BatchUpdateExpiredStatus(ctx)
}
// ValidateSubscription 验证订阅是否有效
func (s *SubscriptionService) ValidateSubscription(ctx context.Context, sub *model.UserSubscription) error {
if sub.Status == model.SubscriptionStatusExpired {
return ErrSubscriptionExpired
}
if sub.Status == model.SubscriptionStatusSuspended {
return ErrSubscriptionSuspended
}
if sub.IsExpired() {
// 更新状态
_ = s.repos.UserSubscription.UpdateStatus(ctx, sub.ID, model.SubscriptionStatusExpired)
return ErrSubscriptionExpired
}
return nil
}
package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"net/url"
"strings"
"time"
)
var (
ErrTurnstileVerificationFailed = errors.New("turnstile verification failed")
ErrTurnstileNotConfigured = errors.New("turnstile not configured")
)
const turnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
// TurnstileService Turnstile 验证服务
type TurnstileService struct {
settingService *SettingService
httpClient *http.Client
}
// TurnstileVerifyResponse Cloudflare Turnstile 验证响应
type TurnstileVerifyResponse struct {
Success bool `json:"success"`
ChallengeTS string `json:"challenge_ts"`
Hostname string `json:"hostname"`
ErrorCodes []string `json:"error-codes"`
Action string `json:"action"`
CData string `json:"cdata"`
}
// NewTurnstileService 创建 Turnstile 服务实例
func NewTurnstileService(settingService *SettingService) *TurnstileService {
return &TurnstileService{
settingService: settingService,
httpClient: &http.Client{
Timeout: 10 * time.Second,
},
}
}
// VerifyToken 验证 Turnstile token
func (s *TurnstileService) VerifyToken(ctx context.Context, token string, remoteIP string) error {
// 检查是否启用 Turnstile
if !s.settingService.IsTurnstileEnabled(ctx) {
log.Println("[Turnstile] Disabled, skipping verification")
return nil
}
// 获取 Secret Key
secretKey := s.settingService.GetTurnstileSecretKey(ctx)
if secretKey == "" {
log.Println("[Turnstile] Secret key not configured")
return ErrTurnstileNotConfigured
}
// 如果 token 为空,返回错误
if token == "" {
log.Println("[Turnstile] Token is empty")
return ErrTurnstileVerificationFailed
}
// 构建请求
formData := url.Values{}
formData.Set("secret", secretKey)
formData.Set("response", token)
if remoteIP != "" {
formData.Set("remoteip", remoteIP)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, turnstileVerifyURL, strings.NewReader(formData.Encode()))
if err != nil {
return fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
// 发送请求
log.Printf("[Turnstile] Verifying token for IP: %s", remoteIP)
resp, err := s.httpClient.Do(req)
if err != nil {
log.Printf("[Turnstile] Request failed: %v", err)
return fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
// 解析响应
var result TurnstileVerifyResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
log.Printf("[Turnstile] Failed to decode response: %v", err)
return fmt.Errorf("decode response: %w", err)
}
if !result.Success {
log.Printf("[Turnstile] Verification failed, error codes: %v", result.ErrorCodes)
return ErrTurnstileVerificationFailed
}
log.Println("[Turnstile] Verification successful")
return nil
}
// IsEnabled 检查 Turnstile 是否启用
func (s *TurnstileService) IsEnabled(ctx context.Context) bool {
return s.settingService.IsTurnstileEnabled(ctx)
}
package service
import (
"archive/tar"
"bufio"
"compress/gzip"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
"github.com/redis/go-redis/v9"
)
const (
updateCacheKey = "update_check_cache"
updateCacheTTL = 1200 // 20 minutes
githubRepo = "Wei-Shaw/sub2api"
// Security: allowed download domains for updates
allowedDownloadHost = "github.com"
allowedAssetHost = "objects.githubusercontent.com"
// Security: max download size (500MB)
maxDownloadSize = 500 * 1024 * 1024
)
// UpdateService handles software updates
type UpdateService struct {
rdb *redis.Client
currentVersion string
buildType string // "source" for manual builds, "release" for CI builds
}
// NewUpdateService creates a new UpdateService
func NewUpdateService(rdb *redis.Client, version, buildType string) *UpdateService {
return &UpdateService{
rdb: rdb,
currentVersion: version,
buildType: buildType,
}
}
// UpdateInfo contains update information
type UpdateInfo struct {
CurrentVersion string `json:"current_version"`
LatestVersion string `json:"latest_version"`
HasUpdate bool `json:"has_update"`
ReleaseInfo *ReleaseInfo `json:"release_info,omitempty"`
Cached bool `json:"cached"`
Warning string `json:"warning,omitempty"`
BuildType string `json:"build_type"` // "source" or "release"
}
// ReleaseInfo contains GitHub release details
type ReleaseInfo struct {
Name string `json:"name"`
Body string `json:"body"`
PublishedAt string `json:"published_at"`
HtmlURL string `json:"html_url"`
Assets []Asset `json:"assets,omitempty"`
}
// Asset represents a release asset
type Asset struct {
Name string `json:"name"`
DownloadURL string `json:"download_url"`
Size int64 `json:"size"`
}
// GitHubRelease represents GitHub API response
type GitHubRelease struct {
TagName string `json:"tag_name"`
Name string `json:"name"`
Body string `json:"body"`
PublishedAt string `json:"published_at"`
HtmlUrl string `json:"html_url"`
Assets []GitHubAsset `json:"assets"`
}
type GitHubAsset struct {
Name string `json:"name"`
BrowserDownloadUrl string `json:"browser_download_url"`
Size int64 `json:"size"`
}
// CheckUpdate checks for available updates
func (s *UpdateService) CheckUpdate(ctx context.Context, force bool) (*UpdateInfo, error) {
// Try cache first
if !force {
if cached, err := s.getFromCache(ctx); err == nil && cached != nil {
return cached, nil
}
}
// Fetch from GitHub
info, err := s.fetchLatestRelease(ctx)
if err != nil {
// Return cached on error
if cached, cacheErr := s.getFromCache(ctx); cacheErr == nil && cached != nil {
cached.Warning = "Using cached data: " + err.Error()
return cached, nil
}
return &UpdateInfo{
CurrentVersion: s.currentVersion,
LatestVersion: s.currentVersion,
HasUpdate: false,
Warning: err.Error(),
BuildType: s.buildType,
}, nil
}
// Cache result
s.saveToCache(ctx, info)
return info, nil
}
// PerformUpdate downloads and applies the update
func (s *UpdateService) PerformUpdate(ctx context.Context) error {
info, err := s.CheckUpdate(ctx, true)
if err != nil {
return err
}
if !info.HasUpdate {
return fmt.Errorf("no update available")
}
// Find matching archive and checksum for current platform
archiveName := s.getArchiveName()
var downloadURL string
var checksumURL string
for _, asset := range info.ReleaseInfo.Assets {
if strings.Contains(asset.Name, archiveName) && !strings.HasSuffix(asset.Name, ".txt") {
downloadURL = asset.DownloadURL
}
if asset.Name == "checksums.txt" {
checksumURL = asset.DownloadURL
}
}
if downloadURL == "" {
return fmt.Errorf("no compatible release found for %s/%s", runtime.GOOS, runtime.GOARCH)
}
// SECURITY: Validate download URL is from trusted domain
if err := validateDownloadURL(downloadURL); err != nil {
return fmt.Errorf("invalid download URL: %w", err)
}
if checksumURL != "" {
if err := validateDownloadURL(checksumURL); err != nil {
return fmt.Errorf("invalid checksum URL: %w", err)
}
}
// Get current executable path
exePath, err := os.Executable()
if err != nil {
return fmt.Errorf("failed to get executable path: %w", err)
}
exePath, err = filepath.EvalSymlinks(exePath)
if err != nil {
return fmt.Errorf("failed to resolve symlinks: %w", err)
}
// Create temp directory for extraction
tempDir, err := os.MkdirTemp("", "sub2api-update-*")
if err != nil {
return fmt.Errorf("failed to create temp dir: %w", err)
}
defer os.RemoveAll(tempDir)
// Download archive
archivePath := filepath.Join(tempDir, filepath.Base(downloadURL))
if err := s.downloadFile(ctx, downloadURL, archivePath); err != nil {
return fmt.Errorf("download failed: %w", err)
}
// Verify checksum if available
if checksumURL != "" {
if err := s.verifyChecksum(ctx, archivePath, checksumURL); err != nil {
return fmt.Errorf("checksum verification failed: %w", err)
}
}
// Extract binary from archive
newBinaryPath := filepath.Join(tempDir, "sub2api")
if err := s.extractBinary(archivePath, newBinaryPath); err != nil {
return fmt.Errorf("extraction failed: %w", err)
}
// Backup current binary
backupFile := exePath + ".backup"
if err := os.Rename(exePath, backupFile); err != nil {
return fmt.Errorf("backup failed: %w", err)
}
// Replace with new binary
if err := copyFile(newBinaryPath, exePath); err != nil {
os.Rename(backupFile, exePath)
return fmt.Errorf("replace failed: %w", err)
}
// Make executable
if err := os.Chmod(exePath, 0755); err != nil {
return fmt.Errorf("chmod failed: %w", err)
}
return nil
}
// Rollback restores the previous version
func (s *UpdateService) Rollback() error {
exePath, err := os.Executable()
if err != nil {
return fmt.Errorf("failed to get executable path: %w", err)
}
exePath, err = filepath.EvalSymlinks(exePath)
if err != nil {
return fmt.Errorf("failed to resolve symlinks: %w", err)
}
backupFile := exePath + ".backup"
if _, err := os.Stat(backupFile); os.IsNotExist(err) {
return fmt.Errorf("no backup found")
}
// Replace current with backup
if err := os.Rename(backupFile, exePath); err != nil {
return fmt.Errorf("rollback failed: %w", err)
}
return nil
}
// RestartService triggers a service restart via systemd
func (s *UpdateService) RestartService() error {
if runtime.GOOS != "linux" {
return fmt.Errorf("systemd restart only available on Linux")
}
// Try direct systemctl first (works if running as root or with proper permissions)
cmd := exec.Command("systemctl", "restart", "sub2api")
if err := cmd.Run(); err != nil {
// Try with sudo (requires NOPASSWD sudoers entry)
sudoCmd := exec.Command("sudo", "systemctl", "restart", "sub2api")
if sudoErr := sudoCmd.Run(); sudoErr != nil {
return fmt.Errorf("systemctl restart failed: %w (sudo also failed: %v)", err, sudoErr)
}
}
return nil
}
func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, error) {
url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", githubRepo)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, err
}
req.Header.Set("Accept", "application/vnd.github.v3+json")
req.Header.Set("User-Agent", "Sub2API-Updater")
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
return &UpdateInfo{
CurrentVersion: s.currentVersion,
LatestVersion: s.currentVersion,
HasUpdate: false,
Warning: "No releases found",
BuildType: s.buildType,
}, nil
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("GitHub API returned %d", resp.StatusCode)
}
var release GitHubRelease
if err := json.NewDecoder(resp.Body).Decode(&release); err != nil {
return nil, err
}
latestVersion := strings.TrimPrefix(release.TagName, "v")
assets := make([]Asset, len(release.Assets))
for i, a := range release.Assets {
assets[i] = Asset{
Name: a.Name,
DownloadURL: a.BrowserDownloadUrl,
Size: a.Size,
}
}
return &UpdateInfo{
CurrentVersion: s.currentVersion,
LatestVersion: latestVersion,
HasUpdate: compareVersions(s.currentVersion, latestVersion) < 0,
ReleaseInfo: &ReleaseInfo{
Name: release.Name,
Body: release.Body,
PublishedAt: release.PublishedAt,
HtmlURL: release.HtmlUrl,
Assets: assets,
},
Cached: false,
BuildType: s.buildType,
}, nil
}
func (s *UpdateService) downloadFile(ctx context.Context, downloadURL, dest string) error {
req, err := http.NewRequestWithContext(ctx, "GET", downloadURL, nil)
if err != nil {
return err
}
client := &http.Client{Timeout: 10 * time.Minute}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("download returned %d", resp.StatusCode)
}
// SECURITY: Check Content-Length if available
if resp.ContentLength > maxDownloadSize {
return fmt.Errorf("file too large: %d bytes (max %d)", resp.ContentLength, maxDownloadSize)
}
out, err := os.Create(dest)
if err != nil {
return err
}
defer out.Close()
// SECURITY: Use LimitReader to enforce max download size even if Content-Length is missing/wrong
limited := io.LimitReader(resp.Body, maxDownloadSize+1)
written, err := io.Copy(out, limited)
if err != nil {
return err
}
// Check if we hit the limit (downloaded more than maxDownloadSize)
if written > maxDownloadSize {
os.Remove(dest) // Clean up partial file
return fmt.Errorf("download exceeded maximum size of %d bytes", maxDownloadSize)
}
return nil
}
func (s *UpdateService) getArchiveName() string {
osName := runtime.GOOS
arch := runtime.GOARCH
return fmt.Sprintf("%s_%s", osName, arch)
}
// validateDownloadURL checks if the URL is from an allowed domain
// SECURITY: This prevents SSRF and ensures downloads only come from trusted GitHub domains
func validateDownloadURL(rawURL string) error {
parsedURL, err := url.Parse(rawURL)
if err != nil {
return fmt.Errorf("invalid URL: %w", err)
}
// Must be HTTPS
if parsedURL.Scheme != "https" {
return fmt.Errorf("only HTTPS URLs are allowed")
}
// Check against allowed hosts
host := parsedURL.Host
// GitHub release URLs can be from github.com or objects.githubusercontent.com
if host != allowedDownloadHost &&
!strings.HasSuffix(host, "."+allowedDownloadHost) &&
host != allowedAssetHost &&
!strings.HasSuffix(host, "."+allowedAssetHost) {
return fmt.Errorf("download from untrusted host: %s", host)
}
return nil
}
func (s *UpdateService) verifyChecksum(ctx context.Context, filePath, checksumURL string) error {
// Download checksums file
req, err := http.NewRequestWithContext(ctx, "GET", checksumURL, nil)
if err != nil {
return err
}
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("failed to download checksums: %d", resp.StatusCode)
}
// Calculate file hash
f, err := os.Open(filePath)
if err != nil {
return err
}
defer f.Close()
h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
return err
}
actualHash := hex.EncodeToString(h.Sum(nil))
// Find expected hash in checksums file
fileName := filepath.Base(filePath)
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
parts := strings.Fields(line)
if len(parts) == 2 && parts[1] == fileName {
if parts[0] == actualHash {
return nil
}
return fmt.Errorf("checksum mismatch: expected %s, got %s", parts[0], actualHash)
}
}
return fmt.Errorf("checksum not found for %s", fileName)
}
func (s *UpdateService) extractBinary(archivePath, destPath string) error {
f, err := os.Open(archivePath)
if err != nil {
return err
}
defer f.Close()
var reader io.Reader = f
// Handle gzip compression
if strings.HasSuffix(archivePath, ".gz") || strings.HasSuffix(archivePath, ".tar.gz") || strings.HasSuffix(archivePath, ".tgz") {
gzr, err := gzip.NewReader(f)
if err != nil {
return err
}
defer gzr.Close()
reader = gzr
}
// Handle tar archive
if strings.Contains(archivePath, ".tar") {
tr := tar.NewReader(reader)
for {
hdr, err := tr.Next()
if err == io.EOF {
break
}
if err != nil {
return err
}
// SECURITY: Prevent Zip Slip / Path Traversal attack
// Only allow files with safe base names, no directory traversal
baseName := filepath.Base(hdr.Name)
// Check for path traversal attempts
if strings.Contains(hdr.Name, "..") {
return fmt.Errorf("path traversal attempt detected: %s", hdr.Name)
}
// Validate the entry is a regular file
if hdr.Typeflag != tar.TypeReg {
continue // Skip directories and special files
}
// Only extract the specific binary we need
if baseName == "sub2api" || baseName == "sub2api.exe" {
// Additional security: limit file size (max 500MB)
const maxBinarySize = 500 * 1024 * 1024
if hdr.Size > maxBinarySize {
return fmt.Errorf("binary too large: %d bytes (max %d)", hdr.Size, maxBinarySize)
}
out, err := os.Create(destPath)
if err != nil {
return err
}
// Use LimitReader to prevent decompression bombs
limited := io.LimitReader(tr, maxBinarySize)
if _, err := io.Copy(out, limited); err != nil {
out.Close()
return err
}
out.Close()
return nil
}
}
return fmt.Errorf("binary not found in archive")
}
// Direct copy for non-tar files (with size limit)
const maxBinarySize = 500 * 1024 * 1024
out, err := os.Create(destPath)
if err != nil {
return err
}
defer out.Close()
limited := io.LimitReader(reader, maxBinarySize)
_, err = io.Copy(out, limited)
return err
}
func copyFile(src, dst string) error {
in, err := os.Open(src)
if err != nil {
return err
}
defer in.Close()
out, err := os.Create(dst)
if err != nil {
return err
}
defer out.Close()
_, err = io.Copy(out, in)
return err
}
func (s *UpdateService) getFromCache(ctx context.Context) (*UpdateInfo, error) {
data, err := s.rdb.Get(ctx, updateCacheKey).Result()
if err != nil {
return nil, err
}
var cached struct {
Latest string `json:"latest"`
ReleaseInfo *ReleaseInfo `json:"release_info"`
Timestamp int64 `json:"timestamp"`
}
if err := json.Unmarshal([]byte(data), &cached); err != nil {
return nil, err
}
if time.Now().Unix()-cached.Timestamp > updateCacheTTL {
return nil, fmt.Errorf("cache expired")
}
return &UpdateInfo{
CurrentVersion: s.currentVersion,
LatestVersion: cached.Latest,
HasUpdate: compareVersions(s.currentVersion, cached.Latest) < 0,
ReleaseInfo: cached.ReleaseInfo,
Cached: true,
BuildType: s.buildType,
}, nil
}
func (s *UpdateService) saveToCache(ctx context.Context, info *UpdateInfo) {
cacheData := struct {
Latest string `json:"latest"`
ReleaseInfo *ReleaseInfo `json:"release_info"`
Timestamp int64 `json:"timestamp"`
}{
Latest: info.LatestVersion,
ReleaseInfo: info.ReleaseInfo,
Timestamp: time.Now().Unix(),
}
data, _ := json.Marshal(cacheData)
s.rdb.Set(ctx, updateCacheKey, data, time.Duration(updateCacheTTL)*time.Second)
}
// compareVersions compares two semantic versions
func compareVersions(current, latest string) int {
currentParts := parseVersion(current)
latestParts := parseVersion(latest)
for i := 0; i < 3; i++ {
if currentParts[i] < latestParts[i] {
return -1
}
if currentParts[i] > latestParts[i] {
return 1
}
}
return 0
}
func parseVersion(v string) [3]int {
v = strings.TrimPrefix(v, "v")
parts := strings.Split(v, ".")
result := [3]int{0, 0, 0}
for i := 0; i < len(parts) && i < 3; i++ {
fmt.Sscanf(parts[i], "%d", &result[i])
}
return result
}
package service
import (
"context"
"errors"
"fmt"
"sub2api/internal/model"
"sub2api/internal/repository"
"time"
"gorm.io/gorm"
)
var (
ErrUsageLogNotFound = errors.New("usage log not found")
)
// CreateUsageLogRequest 创建使用日志请求
type CreateUsageLogRequest struct {
UserID int64 `json:"user_id"`
ApiKeyID int64 `json:"api_key_id"`
AccountID int64 `json:"account_id"`
RequestID string `json:"request_id"`
Model string `json:"model"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheCreationTokens int `json:"cache_creation_tokens"`
CacheReadTokens int `json:"cache_read_tokens"`
CacheCreation5mTokens int `json:"cache_creation_5m_tokens"`
CacheCreation1hTokens int `json:"cache_creation_1h_tokens"`
InputCost float64 `json:"input_cost"`
OutputCost float64 `json:"output_cost"`
CacheCreationCost float64 `json:"cache_creation_cost"`
CacheReadCost float64 `json:"cache_read_cost"`
TotalCost float64 `json:"total_cost"`
ActualCost float64 `json:"actual_cost"`
RateMultiplier float64 `json:"rate_multiplier"`
Stream bool `json:"stream"`
DurationMs *int `json:"duration_ms"`
}
// UsageStats 使用统计
type UsageStats struct {
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheTokens int64 `json:"total_cache_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
AverageDurationMs float64 `json:"average_duration_ms"`
}
// UsageService 使用统计服务
type UsageService struct {
usageRepo *repository.UsageLogRepository
userRepo *repository.UserRepository
}
// NewUsageService 创建使用统计服务实例
func NewUsageService(usageRepo *repository.UsageLogRepository, userRepo *repository.UserRepository) *UsageService {
return &UsageService{
usageRepo: usageRepo,
userRepo: userRepo,
}
}
// Create 创建使用日志
func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*model.UsageLog, error) {
// 验证用户存在
_, err := s.userRepo.GetByID(ctx, req.UserID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err)
}
// 创建使用日志
usageLog := &model.UsageLog{
UserID: req.UserID,
ApiKeyID: req.ApiKeyID,
AccountID: req.AccountID,
RequestID: req.RequestID,
Model: req.Model,
InputTokens: req.InputTokens,
OutputTokens: req.OutputTokens,
CacheCreationTokens: req.CacheCreationTokens,
CacheReadTokens: req.CacheReadTokens,
CacheCreation5mTokens: req.CacheCreation5mTokens,
CacheCreation1hTokens: req.CacheCreation1hTokens,
InputCost: req.InputCost,
OutputCost: req.OutputCost,
CacheCreationCost: req.CacheCreationCost,
CacheReadCost: req.CacheReadCost,
TotalCost: req.TotalCost,
ActualCost: req.ActualCost,
RateMultiplier: req.RateMultiplier,
Stream: req.Stream,
DurationMs: req.DurationMs,
}
if err := s.usageRepo.Create(ctx, usageLog); err != nil {
return nil, fmt.Errorf("create usage log: %w", err)
}
// 扣除用户余额
if req.ActualCost > 0 {
if err := s.userRepo.UpdateBalance(ctx, req.UserID, -req.ActualCost); err != nil {
return nil, fmt.Errorf("update user balance: %w", err)
}
}
return usageLog, nil
}
// GetByID 根据ID获取使用日志
func (s *UsageService) GetByID(ctx context.Context, id int64) (*model.UsageLog, error) {
log, err := s.usageRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUsageLogNotFound
}
return nil, fmt.Errorf("get usage log: %w", err)
}
return log, nil
}
// ListByUser 获取用户的使用日志列表
func (s *UsageService) ListByUser(ctx context.Context, userID int64, params repository.PaginationParams) ([]model.UsageLog, *repository.PaginationResult, error) {
logs, pagination, err := s.usageRepo.ListByUser(ctx, userID, params)
if err != nil {
return nil, nil, fmt.Errorf("list usage logs: %w", err)
}
return logs, pagination, nil
}
// ListByApiKey 获取API Key的使用日志列表
func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params repository.PaginationParams) ([]model.UsageLog, *repository.PaginationResult, error) {
logs, pagination, err := s.usageRepo.ListByApiKey(ctx, apiKeyID, params)
if err != nil {
return nil, nil, fmt.Errorf("list usage logs: %w", err)
}
return logs, pagination, nil
}
// ListByAccount 获取账号的使用日志列表
func (s *UsageService) ListByAccount(ctx context.Context, accountID int64, params repository.PaginationParams) ([]model.UsageLog, *repository.PaginationResult, error) {
logs, pagination, err := s.usageRepo.ListByAccount(ctx, accountID, params)
if err != nil {
return nil, nil, fmt.Errorf("list usage logs: %w", err)
}
return logs, pagination, nil
}
// GetStatsByUser 获取用户的使用统计
func (s *UsageService) GetStatsByUser(ctx context.Context, userID int64, startTime, endTime time.Time) (*UsageStats, error) {
logs, _, err := s.usageRepo.ListByUserAndTimeRange(ctx, userID, startTime, endTime)
if err != nil {
return nil, fmt.Errorf("list usage logs: %w", err)
}
return s.calculateStats(logs), nil
}
// GetStatsByApiKey 获取API Key的使用统计
func (s *UsageService) GetStatsByApiKey(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*UsageStats, error) {
logs, _, err := s.usageRepo.ListByApiKeyAndTimeRange(ctx, apiKeyID, startTime, endTime)
if err != nil {
return nil, fmt.Errorf("list usage logs: %w", err)
}
return s.calculateStats(logs), nil
}
// GetStatsByAccount 获取账号的使用统计
func (s *UsageService) GetStatsByAccount(ctx context.Context, accountID int64, startTime, endTime time.Time) (*UsageStats, error) {
logs, _, err := s.usageRepo.ListByAccountAndTimeRange(ctx, accountID, startTime, endTime)
if err != nil {
return nil, fmt.Errorf("list usage logs: %w", err)
}
return s.calculateStats(logs), nil
}
// GetStatsByModel 获取模型的使用统计
func (s *UsageService) GetStatsByModel(ctx context.Context, modelName string, startTime, endTime time.Time) (*UsageStats, error) {
logs, _, err := s.usageRepo.ListByModelAndTimeRange(ctx, modelName, startTime, endTime)
if err != nil {
return nil, fmt.Errorf("list usage logs: %w", err)
}
return s.calculateStats(logs), nil
}
// GetDailyStats 获取每日使用统计(最近N天)
func (s *UsageService) GetDailyStats(ctx context.Context, userID int64, days int) ([]map[string]interface{}, error) {
endTime := time.Now()
startTime := endTime.AddDate(0, 0, -days)
logs, _, err := s.usageRepo.ListByUserAndTimeRange(ctx, userID, startTime, endTime)
if err != nil {
return nil, fmt.Errorf("list usage logs: %w", err)
}
// 按日期分组统计
dailyStats := make(map[string]*UsageStats)
for _, log := range logs {
dateKey := log.CreatedAt.Format("2006-01-02")
if _, exists := dailyStats[dateKey]; !exists {
dailyStats[dateKey] = &UsageStats{}
}
stats := dailyStats[dateKey]
stats.TotalRequests++
stats.TotalInputTokens += int64(log.InputTokens)
stats.TotalOutputTokens += int64(log.OutputTokens)
stats.TotalCacheTokens += int64(log.CacheCreationTokens + log.CacheReadTokens)
stats.TotalTokens += int64(log.TotalTokens())
stats.TotalCost += log.TotalCost
stats.TotalActualCost += log.ActualCost
if log.DurationMs != nil {
stats.AverageDurationMs += float64(*log.DurationMs)
}
}
// 计算平均值并转换为数组
result := make([]map[string]interface{}, 0, len(dailyStats))
for date, stats := range dailyStats {
if stats.TotalRequests > 0 {
stats.AverageDurationMs /= float64(stats.TotalRequests)
}
result = append(result, map[string]interface{}{
"date": date,
"total_requests": stats.TotalRequests,
"total_input_tokens": stats.TotalInputTokens,
"total_output_tokens": stats.TotalOutputTokens,
"total_cache_tokens": stats.TotalCacheTokens,
"total_tokens": stats.TotalTokens,
"total_cost": stats.TotalCost,
"total_actual_cost": stats.TotalActualCost,
"average_duration_ms": stats.AverageDurationMs,
})
}
return result, nil
}
// calculateStats 计算统计数据
func (s *UsageService) calculateStats(logs []model.UsageLog) *UsageStats {
stats := &UsageStats{}
for _, log := range logs {
stats.TotalRequests++
stats.TotalInputTokens += int64(log.InputTokens)
stats.TotalOutputTokens += int64(log.OutputTokens)
stats.TotalCacheTokens += int64(log.CacheCreationTokens + log.CacheReadTokens)
stats.TotalTokens += int64(log.TotalTokens())
stats.TotalCost += log.TotalCost
stats.TotalActualCost += log.ActualCost
if log.DurationMs != nil {
stats.AverageDurationMs += float64(*log.DurationMs)
}
}
// 计算平均持续时间
if stats.TotalRequests > 0 {
stats.AverageDurationMs /= float64(stats.TotalRequests)
}
return stats
}
// Delete 删除使用日志(管理员功能,谨慎使用)
func (s *UsageService) Delete(ctx context.Context, id int64) error {
if err := s.usageRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete usage log: %w", err)
}
return nil
}
package service
import (
"context"
"errors"
"fmt"
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/repository"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
var (
ErrUserNotFound = errors.New("user not found")
ErrPasswordIncorrect = errors.New("current password is incorrect")
ErrInsufficientPerms = errors.New("insufficient permissions")
)
// UpdateProfileRequest 更新用户资料请求
type UpdateProfileRequest struct {
Email *string `json:"email"`
Concurrency *int `json:"concurrency"`
}
// ChangePasswordRequest 修改密码请求
type ChangePasswordRequest struct {
CurrentPassword string `json:"current_password"`
NewPassword string `json:"new_password"`
}
// UserService 用户服务
type UserService struct {
userRepo *repository.UserRepository
cfg *config.Config
}
// NewUserService 创建用户服务实例
func NewUserService(userRepo *repository.UserRepository, cfg *config.Config) *UserService {
return &UserService{
userRepo: userRepo,
cfg: cfg,
}
}
// GetProfile 获取用户资料
func (s *UserService) GetProfile(ctx context.Context, userID int64) (*model.User, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err)
}
return user, nil
}
// UpdateProfile 更新用户资料
func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req UpdateProfileRequest) (*model.User, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err)
}
// 更新字段
if req.Email != nil {
// 检查新邮箱是否已被使用
exists, err := s.userRepo.ExistsByEmail(ctx, *req.Email)
if err != nil {
return nil, fmt.Errorf("check email exists: %w", err)
}
if exists && *req.Email != user.Email {
return nil, ErrEmailExists
}
user.Email = *req.Email
}
if req.Concurrency != nil {
user.Concurrency = *req.Concurrency
}
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, fmt.Errorf("update user: %w", err)
}
return user, nil
}
// ChangePassword 修改密码
func (s *UserService) ChangePassword(ctx context.Context, userID int64, req ChangePasswordRequest) error {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrUserNotFound
}
return fmt.Errorf("get user: %w", err)
}
// 验证当前密码
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.CurrentPassword)); err != nil {
return ErrPasswordIncorrect
}
// 生成新密码哈希
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.NewPassword), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("hash password: %w", err)
}
user.PasswordHash = string(hashedPassword)
if err := s.userRepo.Update(ctx, user); err != nil {
return fmt.Errorf("update user: %w", err)
}
return nil
}
// GetByID 根据ID获取用户(管理员功能)
func (s *UserService) GetByID(ctx context.Context, id int64) (*model.User, error) {
user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err)
}
return user, nil
}
// List 获取用户列表(管理员功能)
func (s *UserService) List(ctx context.Context, params repository.PaginationParams) ([]model.User, *repository.PaginationResult, error) {
users, pagination, err := s.userRepo.List(ctx, params)
if err != nil {
return nil, nil, fmt.Errorf("list users: %w", err)
}
return users, pagination, nil
}
// UpdateBalance 更新用户余额(管理员功能)
func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount float64) error {
if err := s.userRepo.UpdateBalance(ctx, userID, amount); err != nil {
return fmt.Errorf("update balance: %w", err)
}
return nil
}
// UpdateStatus 更新用户状态(管理员功能)
func (s *UserService) UpdateStatus(ctx context.Context, userID int64, status string) error {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrUserNotFound
}
return fmt.Errorf("get user: %w", err)
}
user.Status = status
if err := s.userRepo.Update(ctx, user); err != nil {
return fmt.Errorf("update user: %w", err)
}
return nil
}
// Delete 删除用户(管理员功能)
func (s *UserService) Delete(ctx context.Context, userID int64) error {
if err := s.userRepo.Delete(ctx, userID); err != nil {
return fmt.Errorf("delete user: %w", err)
}
return nil
}
package setup
import (
"bufio"
"fmt"
"net/mail"
"os"
"regexp"
"strconv"
"strings"
"golang.org/x/term"
)
// CLI input validation functions (matching Web API validation)
func cliValidateHostname(host string) bool {
validHost := regexp.MustCompile(`^[a-zA-Z0-9.\-:]+$`)
return validHost.MatchString(host) && len(host) <= 253
}
func cliValidateDBName(name string) bool {
validName := regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9_]*$`)
return validName.MatchString(name) && len(name) <= 63
}
func cliValidateUsername(name string) bool {
validName := regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
return validName.MatchString(name) && len(name) <= 63
}
func cliValidateEmail(email string) bool {
_, err := mail.ParseAddress(email)
return err == nil && len(email) <= 254
}
func cliValidatePort(port int) bool {
return port > 0 && port <= 65535
}
func cliValidateSSLMode(mode string) bool {
validModes := map[string]bool{
"disable": true, "require": true, "verify-ca": true, "verify-full": true,
}
return validModes[mode]
}
// RunCLI runs the CLI setup wizard
func RunCLI() error {
reader := bufio.NewReader(os.Stdin)
fmt.Println()
fmt.Println("╔═══════════════════════════════════════════╗")
fmt.Println("║ Sub2API Installation Wizard ║")
fmt.Println("╚═══════════════════════════════════════════╝")
fmt.Println()
cfg := &SetupConfig{
Server: ServerConfig{
Host: "0.0.0.0",
Port: 8080,
Mode: "release",
},
JWT: JWTConfig{
ExpireHour: 24,
},
}
// Database configuration with validation
fmt.Println("── Database Configuration ──")
for {
cfg.Database.Host = promptString(reader, "PostgreSQL Host", "localhost")
if cliValidateHostname(cfg.Database.Host) {
break
}
fmt.Println(" Invalid hostname format. Use alphanumeric, dots, hyphens only.")
}
for {
cfg.Database.Port = promptInt(reader, "PostgreSQL Port", 5432)
if cliValidatePort(cfg.Database.Port) {
break
}
fmt.Println(" Invalid port. Must be between 1 and 65535.")
}
for {
cfg.Database.User = promptString(reader, "PostgreSQL User", "postgres")
if cliValidateUsername(cfg.Database.User) {
break
}
fmt.Println(" Invalid username. Use alphanumeric and underscores only.")
}
cfg.Database.Password = promptPassword("PostgreSQL Password")
for {
cfg.Database.DBName = promptString(reader, "Database Name", "sub2api")
if cliValidateDBName(cfg.Database.DBName) {
break
}
fmt.Println(" Invalid database name. Start with letter, use alphanumeric and underscores.")
}
for {
cfg.Database.SSLMode = promptString(reader, "SSL Mode", "disable")
if cliValidateSSLMode(cfg.Database.SSLMode) {
break
}
fmt.Println(" Invalid SSL mode. Use: disable, require, verify-ca, or verify-full.")
}
fmt.Println()
fmt.Print("Testing database connection... ")
if err := TestDatabaseConnection(&cfg.Database); err != nil {
fmt.Println("FAILED")
return fmt.Errorf("database connection failed: %w", err)
}
fmt.Println("OK")
// Redis configuration with validation
fmt.Println()
fmt.Println("── Redis Configuration ──")
for {
cfg.Redis.Host = promptString(reader, "Redis Host", "localhost")
if cliValidateHostname(cfg.Redis.Host) {
break
}
fmt.Println(" Invalid hostname format. Use alphanumeric, dots, hyphens only.")
}
for {
cfg.Redis.Port = promptInt(reader, "Redis Port", 6379)
if cliValidatePort(cfg.Redis.Port) {
break
}
fmt.Println(" Invalid port. Must be between 1 and 65535.")
}
cfg.Redis.Password = promptPassword("Redis Password (optional)")
for {
cfg.Redis.DB = promptInt(reader, "Redis DB", 0)
if cfg.Redis.DB >= 0 && cfg.Redis.DB <= 15 {
break
}
fmt.Println(" Invalid Redis DB. Must be between 0 and 15.")
}
fmt.Println()
fmt.Print("Testing Redis connection... ")
if err := TestRedisConnection(&cfg.Redis); err != nil {
fmt.Println("FAILED")
return fmt.Errorf("redis connection failed: %w", err)
}
fmt.Println("OK")
// Admin configuration with validation
fmt.Println()
fmt.Println("── Admin Account ──")
for {
cfg.Admin.Email = promptString(reader, "Admin Email", "admin@example.com")
if cliValidateEmail(cfg.Admin.Email) {
break
}
fmt.Println(" Invalid email format.")
}
for {
cfg.Admin.Password = promptPassword("Admin Password")
// SECURITY: Match Web API requirement of 8 characters minimum
if len(cfg.Admin.Password) < 8 {
fmt.Println(" Password must be at least 8 characters")
continue
}
if len(cfg.Admin.Password) > 128 {
fmt.Println(" Password must be at most 128 characters")
continue
}
confirm := promptPassword("Confirm Password")
if cfg.Admin.Password != confirm {
fmt.Println(" Passwords do not match")
continue
}
break
}
// Server configuration with validation
fmt.Println()
fmt.Println("── Server Configuration ──")
for {
cfg.Server.Port = promptInt(reader, "Server Port", 8080)
if cliValidatePort(cfg.Server.Port) {
break
}
fmt.Println(" Invalid port. Must be between 1 and 65535.")
}
// Confirm and install
fmt.Println()
fmt.Println("── Configuration Summary ──")
fmt.Printf("Database: %s@%s:%d/%s\n", cfg.Database.User, cfg.Database.Host, cfg.Database.Port, cfg.Database.DBName)
fmt.Printf("Redis: %s:%d\n", cfg.Redis.Host, cfg.Redis.Port)
fmt.Printf("Admin: %s\n", cfg.Admin.Email)
fmt.Printf("Server: :%d\n", cfg.Server.Port)
fmt.Println()
if !promptConfirm(reader, "Proceed with installation?") {
fmt.Println("Installation cancelled")
return nil
}
fmt.Println()
fmt.Print("Installing... ")
if err := Install(cfg); err != nil {
fmt.Println("FAILED")
return err
}
fmt.Println("OK")
fmt.Println()
fmt.Println("╔═══════════════════════════════════════════╗")
fmt.Println("║ Installation Complete! ║")
fmt.Println("╚═══════════════════════════════════════════╝")
fmt.Println()
fmt.Println("Start the server with:")
fmt.Println(" ./sub2api")
fmt.Println()
fmt.Printf("Admin panel: http://localhost:%d\n", cfg.Server.Port)
fmt.Println()
return nil
}
func promptString(reader *bufio.Reader, prompt, defaultVal string) string {
if defaultVal != "" {
fmt.Printf(" %s [%s]: ", prompt, defaultVal)
} else {
fmt.Printf(" %s: ", prompt)
}
input, _ := reader.ReadString('\n')
input = strings.TrimSpace(input)
if input == "" {
return defaultVal
}
return input
}
func promptInt(reader *bufio.Reader, prompt string, defaultVal int) int {
fmt.Printf(" %s [%d]: ", prompt, defaultVal)
input, _ := reader.ReadString('\n')
input = strings.TrimSpace(input)
if input == "" {
return defaultVal
}
val, err := strconv.Atoi(input)
if err != nil {
return defaultVal
}
return val
}
func promptPassword(prompt string) string {
fmt.Printf(" %s: ", prompt)
// Try to read password without echo
if term.IsTerminal(int(os.Stdin.Fd())) {
password, err := term.ReadPassword(int(os.Stdin.Fd()))
fmt.Println()
if err == nil {
return string(password)
}
}
// Fallback to regular input
reader := bufio.NewReader(os.Stdin)
input, _ := reader.ReadString('\n')
return strings.TrimSpace(input)
}
func promptConfirm(reader *bufio.Reader, prompt string) bool {
fmt.Printf("%s [y/N]: ", prompt)
input, _ := reader.ReadString('\n')
input = strings.TrimSpace(strings.ToLower(input))
return input == "y" || input == "yes"
}
package setup
import (
"fmt"
"net/http"
"net/mail"
"regexp"
"strings"
"sync"
"sub2api/internal/pkg/response"
"github.com/gin-gonic/gin"
)
// installMutex prevents concurrent installation attempts (TOCTOU protection)
var installMutex sync.Mutex
// RegisterRoutes registers setup wizard routes
func RegisterRoutes(r *gin.Engine) {
setup := r.Group("/setup")
{
// Status endpoint is always accessible (read-only)
setup.GET("/status", getStatus)
// All modification endpoints are protected by setupGuard
protected := setup.Group("")
protected.Use(setupGuard())
{
protected.POST("/test-db", testDatabase)
protected.POST("/test-redis", testRedis)
protected.POST("/install", install)
}
}
}
// SetupStatus represents the current setup state
type SetupStatus struct {
NeedsSetup bool `json:"needs_setup"`
Step string `json:"step"`
}
// getStatus returns the current setup status
func getStatus(c *gin.Context) {
response.Success(c, SetupStatus{
NeedsSetup: NeedsSetup(),
Step: "welcome",
})
}
// setupGuard middleware ensures setup endpoints are only accessible during setup mode
func setupGuard() gin.HandlerFunc {
return func(c *gin.Context) {
if !NeedsSetup() {
response.Error(c, http.StatusForbidden, "Setup is not allowed: system is already installed")
c.Abort()
return
}
c.Next()
}
}
// validateHostname checks if a hostname/IP is safe (no injection characters)
func validateHostname(host string) bool {
// Allow only alphanumeric, dots, hyphens, and colons (for IPv6)
validHost := regexp.MustCompile(`^[a-zA-Z0-9.\-:]+$`)
return validHost.MatchString(host) && len(host) <= 253
}
// validateDBName checks if database name is safe
func validateDBName(name string) bool {
// Allow only alphanumeric and underscores, starting with letter
validName := regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9_]*$`)
return validName.MatchString(name) && len(name) <= 63
}
// validateUsername checks if username is safe
func validateUsername(name string) bool {
// Allow only alphanumeric and underscores
validName := regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
return validName.MatchString(name) && len(name) <= 63
}
// validateEmail checks if email format is valid
func validateEmail(email string) bool {
_, err := mail.ParseAddress(email)
return err == nil && len(email) <= 254
}
// validatePassword checks password strength
func validatePassword(password string) error {
if len(password) < 8 {
return fmt.Errorf("password must be at least 8 characters")
}
if len(password) > 128 {
return fmt.Errorf("password must be at most 128 characters")
}
return nil
}
// validatePort checks if port is in valid range
func validatePort(port int) bool {
return port > 0 && port <= 65535
}
// validateSSLMode checks if SSL mode is valid
func validateSSLMode(mode string) bool {
validModes := map[string]bool{
"disable": true, "require": true, "verify-ca": true, "verify-full": true,
}
return validModes[mode]
}
// TestDatabaseRequest represents database test request
type TestDatabaseRequest struct {
Host string `json:"host" binding:"required"`
Port int `json:"port" binding:"required"`
User string `json:"user" binding:"required"`
Password string `json:"password"`
DBName string `json:"dbname" binding:"required"`
SSLMode string `json:"sslmode"`
}
// testDatabase tests database connection
func testDatabase(c *gin.Context) {
var req TestDatabaseRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.Error(c, http.StatusBadRequest, "Invalid request: "+err.Error())
return
}
// Security: Validate all inputs to prevent injection attacks
if !validateHostname(req.Host) {
response.Error(c, http.StatusBadRequest, "Invalid hostname format")
return
}
if !validatePort(req.Port) {
response.Error(c, http.StatusBadRequest, "Invalid port number")
return
}
if !validateUsername(req.User) {
response.Error(c, http.StatusBadRequest, "Invalid username format")
return
}
if !validateDBName(req.DBName) {
response.Error(c, http.StatusBadRequest, "Invalid database name format")
return
}
if req.SSLMode == "" {
req.SSLMode = "disable"
}
if !validateSSLMode(req.SSLMode) {
response.Error(c, http.StatusBadRequest, "Invalid SSL mode")
return
}
cfg := &DatabaseConfig{
Host: req.Host,
Port: req.Port,
User: req.User,
Password: req.Password,
DBName: req.DBName,
SSLMode: req.SSLMode,
}
if err := TestDatabaseConnection(cfg); err != nil {
response.Error(c, http.StatusBadRequest, "Connection failed: "+err.Error())
return
}
response.Success(c, gin.H{"message": "Connection successful"})
}
// TestRedisRequest represents Redis test request
type TestRedisRequest struct {
Host string `json:"host" binding:"required"`
Port int `json:"port" binding:"required"`
Password string `json:"password"`
DB int `json:"db"`
}
// testRedis tests Redis connection
func testRedis(c *gin.Context) {
var req TestRedisRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.Error(c, http.StatusBadRequest, "Invalid request: "+err.Error())
return
}
// Security: Validate inputs
if !validateHostname(req.Host) {
response.Error(c, http.StatusBadRequest, "Invalid hostname format")
return
}
if !validatePort(req.Port) {
response.Error(c, http.StatusBadRequest, "Invalid port number")
return
}
if req.DB < 0 || req.DB > 15 {
response.Error(c, http.StatusBadRequest, "Invalid Redis database number (0-15)")
return
}
cfg := &RedisConfig{
Host: req.Host,
Port: req.Port,
Password: req.Password,
DB: req.DB,
}
if err := TestRedisConnection(cfg); err != nil {
response.Error(c, http.StatusBadRequest, "Connection failed: "+err.Error())
return
}
response.Success(c, gin.H{"message": "Connection successful"})
}
// InstallRequest represents installation request
type InstallRequest struct {
Database DatabaseConfig `json:"database" binding:"required"`
Redis RedisConfig `json:"redis" binding:"required"`
Admin AdminConfig `json:"admin" binding:"required"`
Server ServerConfig `json:"server"`
}
// install performs the installation
func install(c *gin.Context) {
// TOCTOU Protection: Acquire mutex to prevent concurrent installation
installMutex.Lock()
defer installMutex.Unlock()
// Double-check after acquiring lock
if !NeedsSetup() {
response.Error(c, http.StatusForbidden, "Setup is not allowed: system is already installed")
return
}
var req InstallRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.Error(c, http.StatusBadRequest, "Invalid request: "+err.Error())
return
}
// ========== COMPREHENSIVE INPUT VALIDATION ==========
// Database validation
if !validateHostname(req.Database.Host) {
response.Error(c, http.StatusBadRequest, "Invalid database hostname")
return
}
if !validatePort(req.Database.Port) {
response.Error(c, http.StatusBadRequest, "Invalid database port")
return
}
if !validateUsername(req.Database.User) {
response.Error(c, http.StatusBadRequest, "Invalid database username")
return
}
if !validateDBName(req.Database.DBName) {
response.Error(c, http.StatusBadRequest, "Invalid database name")
return
}
// Redis validation
if !validateHostname(req.Redis.Host) {
response.Error(c, http.StatusBadRequest, "Invalid Redis hostname")
return
}
if !validatePort(req.Redis.Port) {
response.Error(c, http.StatusBadRequest, "Invalid Redis port")
return
}
if req.Redis.DB < 0 || req.Redis.DB > 15 {
response.Error(c, http.StatusBadRequest, "Invalid Redis database number")
return
}
// Admin validation
if !validateEmail(req.Admin.Email) {
response.Error(c, http.StatusBadRequest, "Invalid admin email format")
return
}
if err := validatePassword(req.Admin.Password); err != nil {
response.Error(c, http.StatusBadRequest, err.Error())
return
}
// Server validation
if req.Server.Port != 0 && !validatePort(req.Server.Port) {
response.Error(c, http.StatusBadRequest, "Invalid server port")
return
}
// ========== SET DEFAULTS ==========
if req.Database.SSLMode == "" {
req.Database.SSLMode = "disable"
}
if !validateSSLMode(req.Database.SSLMode) {
response.Error(c, http.StatusBadRequest, "Invalid SSL mode")
return
}
if req.Server.Host == "" {
req.Server.Host = "0.0.0.0"
}
if req.Server.Port == 0 {
req.Server.Port = 8080
}
if req.Server.Mode == "" {
req.Server.Mode = "release"
}
// Validate server mode
if req.Server.Mode != "release" && req.Server.Mode != "debug" {
response.Error(c, http.StatusBadRequest, "Invalid server mode (must be 'release' or 'debug')")
return
}
// Trim whitespace from string inputs
req.Admin.Email = strings.TrimSpace(req.Admin.Email)
req.Database.Host = strings.TrimSpace(req.Database.Host)
req.Database.User = strings.TrimSpace(req.Database.User)
req.Database.DBName = strings.TrimSpace(req.Database.DBName)
req.Redis.Host = strings.TrimSpace(req.Redis.Host)
cfg := &SetupConfig{
Database: req.Database,
Redis: req.Redis,
Admin: req.Admin,
Server: req.Server,
JWT: JWTConfig{
ExpireHour: 24,
},
}
if err := Install(cfg); err != nil {
response.Error(c, http.StatusInternalServerError, "Installation failed: "+err.Error())
return
}
response.Success(c, gin.H{
"message": "Installation completed successfully",
"restart": true,
})
}
package setup
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"log"
"os"
"strconv"
"time"
"github.com/redis/go-redis/v9"
"golang.org/x/crypto/bcrypt"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gopkg.in/yaml.v3"
)
// Config paths
const (
ConfigFile = "config.yaml"
EnvFile = ".env"
)
// SetupConfig holds the setup configuration
type SetupConfig struct {
Database DatabaseConfig `json:"database" yaml:"database"`
Redis RedisConfig `json:"redis" yaml:"redis"`
Admin AdminConfig `json:"admin" yaml:"-"` // Not stored in config file
Server ServerConfig `json:"server" yaml:"server"`
JWT JWTConfig `json:"jwt" yaml:"jwt"`
Timezone string `json:"timezone" yaml:"timezone"` // e.g. "Asia/Shanghai", "UTC"
}
type DatabaseConfig struct {
Host string `json:"host" yaml:"host"`
Port int `json:"port" yaml:"port"`
User string `json:"user" yaml:"user"`
Password string `json:"password" yaml:"password"`
DBName string `json:"dbname" yaml:"dbname"`
SSLMode string `json:"sslmode" yaml:"sslmode"`
}
type RedisConfig struct {
Host string `json:"host" yaml:"host"`
Port int `json:"port" yaml:"port"`
Password string `json:"password" yaml:"password"`
DB int `json:"db" yaml:"db"`
}
type AdminConfig struct {
Email string `json:"email"`
Password string `json:"password"`
}
type ServerConfig struct {
Host string `json:"host" yaml:"host"`
Port int `json:"port" yaml:"port"`
Mode string `json:"mode" yaml:"mode"`
}
type JWTConfig struct {
Secret string `json:"secret" yaml:"secret"`
ExpireHour int `json:"expire_hour" yaml:"expire_hour"`
}
// NeedsSetup checks if the system needs initial setup
// Uses multiple checks to prevent attackers from forcing re-setup by deleting config
func NeedsSetup() bool {
// Check 1: Config file must not exist
if _, err := os.Stat(ConfigFile); !os.IsNotExist(err) {
return false // Config exists, no setup needed
}
// Check 2: Installation lock file (harder to bypass)
lockFile := ".installed"
if _, err := os.Stat(lockFile); !os.IsNotExist(err) {
return false // Lock file exists, already installed
}
return true
}
// TestDatabaseConnection tests the database connection
func TestDatabaseConnection(cfg *DatabaseConfig) error {
dsn := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
cfg.Host, cfg.Port, cfg.User, cfg.Password, cfg.DBName, cfg.SSLMode,
)
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
if err != nil {
return fmt.Errorf("failed to connect: %w", err)
}
sqlDB, err := db.DB()
if err != nil {
return fmt.Errorf("failed to get db instance: %w", err)
}
defer sqlDB.Close()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := sqlDB.PingContext(ctx); err != nil {
return fmt.Errorf("ping failed: %w", err)
}
return nil
}
// TestRedisConnection tests the Redis connection
func TestRedisConnection(cfg *RedisConfig) error {
rdb := redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
Password: cfg.Password,
DB: cfg.DB,
})
defer rdb.Close()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := rdb.Ping(ctx).Err(); err != nil {
return fmt.Errorf("ping failed: %w", err)
}
return nil
}
// Install performs the installation with the given configuration
func Install(cfg *SetupConfig) error {
// Security check: prevent re-installation if already installed
if !NeedsSetup() {
return fmt.Errorf("system is already installed, re-installation is not allowed")
}
// Generate JWT secret if not provided
if cfg.JWT.Secret == "" {
cfg.JWT.Secret = generateSecret(32)
}
// Test connections
if err := TestDatabaseConnection(&cfg.Database); err != nil {
return fmt.Errorf("database connection failed: %w", err)
}
if err := TestRedisConnection(&cfg.Redis); err != nil {
return fmt.Errorf("redis connection failed: %w", err)
}
// Initialize database
if err := initializeDatabase(cfg); err != nil {
return fmt.Errorf("database initialization failed: %w", err)
}
// Create admin user
if err := createAdminUser(cfg); err != nil {
return fmt.Errorf("admin user creation failed: %w", err)
}
// Write config file
if err := writeConfigFile(cfg); err != nil {
return fmt.Errorf("config file creation failed: %w", err)
}
// Create installation lock file to prevent re-setup attacks
if err := createInstallLock(); err != nil {
return fmt.Errorf("failed to create install lock: %w", err)
}
return nil
}
// createInstallLock creates a lock file to prevent re-installation attacks
func createInstallLock() error {
lockFile := ".installed"
content := fmt.Sprintf("installed_at=%s\n", time.Now().UTC().Format(time.RFC3339))
return os.WriteFile(lockFile, []byte(content), 0400) // Read-only for owner
}
func initializeDatabase(cfg *SetupConfig) error {
dsn := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
cfg.Database.Host, cfg.Database.Port, cfg.Database.User,
cfg.Database.Password, cfg.Database.DBName, cfg.Database.SSLMode,
)
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
if err != nil {
return err
}
sqlDB, err := db.DB()
if err != nil {
return err
}
defer sqlDB.Close()
// Run auto-migration for all models
return db.AutoMigrate(
&User{},
&Group{},
&APIKey{},
&Account{},
&Proxy{},
&RedeemCode{},
&UsageLog{},
&UserSubscription{},
&Setting{},
)
}
func createAdminUser(cfg *SetupConfig) error {
dsn := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
cfg.Database.Host, cfg.Database.Port, cfg.Database.User,
cfg.Database.Password, cfg.Database.DBName, cfg.Database.SSLMode,
)
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
if err != nil {
return err
}
sqlDB, err := db.DB()
if err != nil {
return err
}
defer sqlDB.Close()
// Check if admin already exists
var count int64
db.Model(&User{}).Where("role = ?", "admin").Count(&count)
if count > 0 {
return nil // Admin already exists
}
// Hash password
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(cfg.Admin.Password), bcrypt.DefaultCost)
if err != nil {
return err
}
// Create admin user
admin := &User{
Email: cfg.Admin.Email,
PasswordHash: string(hashedPassword),
Role: "admin",
Status: "active",
Balance: 0,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
return db.Create(admin).Error
}
func writeConfigFile(cfg *SetupConfig) error {
// Ensure timezone has a default value
tz := cfg.Timezone
if tz == "" {
tz = "Asia/Shanghai"
}
// Prepare config for YAML (exclude sensitive data and admin config)
yamlConfig := struct {
Server ServerConfig `yaml:"server"`
Database DatabaseConfig `yaml:"database"`
Redis RedisConfig `yaml:"redis"`
JWT struct {
Secret string `yaml:"secret"`
ExpireHour int `yaml:"expire_hour"`
} `yaml:"jwt"`
Default struct {
GroupID uint `yaml:"group_id"`
} `yaml:"default"`
RateLimit struct {
RequestsPerMinute int `yaml:"requests_per_minute"`
BurstSize int `yaml:"burst_size"`
} `yaml:"rate_limit"`
Timezone string `yaml:"timezone"`
}{
Server: cfg.Server,
Database: cfg.Database,
Redis: cfg.Redis,
JWT: struct {
Secret string `yaml:"secret"`
ExpireHour int `yaml:"expire_hour"`
}{
Secret: cfg.JWT.Secret,
ExpireHour: cfg.JWT.ExpireHour,
},
Default: struct {
GroupID uint `yaml:"group_id"`
}{
GroupID: 1,
},
RateLimit: struct {
RequestsPerMinute int `yaml:"requests_per_minute"`
BurstSize int `yaml:"burst_size"`
}{
RequestsPerMinute: 60,
BurstSize: 10,
},
Timezone: tz,
}
data, err := yaml.Marshal(&yamlConfig)
if err != nil {
return err
}
return os.WriteFile(ConfigFile, data, 0600)
}
func generateSecret(length int) string {
bytes := make([]byte, length)
rand.Read(bytes)
return hex.EncodeToString(bytes)
}
// Minimal model definitions for migration (to avoid circular import)
type User struct {
ID uint `gorm:"primaryKey"`
Email string `gorm:"uniqueIndex;not null"`
PasswordHash string `gorm:"not null"`
Role string `gorm:"default:user"`
Status string `gorm:"default:active"`
Balance float64 `gorm:"default:0"`
CreatedAt time.Time
UpdatedAt time.Time
}
type Group struct {
ID uint `gorm:"primaryKey"`
Name string `gorm:"uniqueIndex;not null"`
Description string `gorm:"type:text"`
RateMultiplier float64 `gorm:"default:1.0"`
IsExclusive bool `gorm:"default:false"`
Priority int `gorm:"default:0"`
Status string `gorm:"default:active"`
CreatedAt time.Time
UpdatedAt time.Time
}
type APIKey struct {
ID uint `gorm:"primaryKey"`
UserID uint `gorm:"index;not null"`
Key string `gorm:"uniqueIndex;not null"`
Name string
GroupID *uint
Status string `gorm:"default:active"`
CreatedAt time.Time
UpdatedAt time.Time
}
type Account struct {
ID uint `gorm:"primaryKey"`
Platform string `gorm:"not null"`
Type string `gorm:"not null"`
Credentials string `gorm:"type:text"`
Status string `gorm:"default:active"`
Priority int `gorm:"default:0"`
ProxyID *uint
CreatedAt time.Time
UpdatedAt time.Time
}
type Proxy struct {
ID uint `gorm:"primaryKey"`
Name string `gorm:"not null"`
Protocol string `gorm:"not null"`
Host string `gorm:"not null"`
Port int `gorm:"not null"`
Username string
Password string
Status string `gorm:"default:active"`
CreatedAt time.Time
UpdatedAt time.Time
}
type RedeemCode struct {
ID uint `gorm:"primaryKey"`
Code string `gorm:"uniqueIndex;not null"`
Value float64 `gorm:"not null"`
Status string `gorm:"default:unused"`
UsedBy *uint
UsedAt *time.Time
ExpiresAt *time.Time
CreatedAt time.Time
}
type UsageLog struct {
ID uint `gorm:"primaryKey"`
UserID uint `gorm:"index"`
APIKeyID uint `gorm:"index"`
AccountID *uint `gorm:"index"`
Model string `gorm:"index"`
InputTokens int
OutputTokens int
Cost float64
CreatedAt time.Time
}
type UserSubscription struct {
ID uint `gorm:"primaryKey"`
UserID uint `gorm:"index;not null"`
GroupID uint `gorm:"index;not null"`
Quota int64
Used int64 `gorm:"default:0"`
Status string
ExpiresAt *time.Time
CreatedAt time.Time
UpdatedAt time.Time
}
type Setting struct {
ID uint `gorm:"primaryKey"`
Key string `gorm:"uniqueIndex;not null"`
Value string `gorm:"type:text"`
CreatedAt time.Time
UpdatedAt time.Time
}
func (User) TableName() string { return "users" }
func (Group) TableName() string { return "groups" }
func (APIKey) TableName() string { return "api_keys" }
func (Account) TableName() string { return "accounts" }
func (Proxy) TableName() string { return "proxies" }
func (RedeemCode) TableName() string { return "redeem_codes" }
func (UsageLog) TableName() string { return "usage_logs" }
func (UserSubscription) TableName() string { return "user_subscriptions" }
func (Setting) TableName() string { return "settings" }
// =============================================================================
// Auto Setup for Docker Deployment
// =============================================================================
// AutoSetupEnabled checks if auto setup is enabled via environment variable
func AutoSetupEnabled() bool {
val := os.Getenv("AUTO_SETUP")
return val == "true" || val == "1" || val == "yes"
}
// getEnvOrDefault gets environment variable or returns default value
func getEnvOrDefault(key, defaultValue string) string {
if val := os.Getenv(key); val != "" {
return val
}
return defaultValue
}
// getEnvIntOrDefault gets environment variable as int or returns default value
func getEnvIntOrDefault(key string, defaultValue int) int {
if val := os.Getenv(key); val != "" {
if i, err := strconv.Atoi(val); err == nil {
return i
}
}
return defaultValue
}
// AutoSetupFromEnv performs automatic setup using environment variables
// This is designed for Docker deployment where all config is passed via env vars
func AutoSetupFromEnv() error {
log.Println("Auto setup enabled, configuring from environment variables...")
// Get timezone from TZ or TIMEZONE env var (TZ is standard for Docker)
tz := getEnvOrDefault("TZ", "")
if tz == "" {
tz = getEnvOrDefault("TIMEZONE", "Asia/Shanghai")
}
// Build config from environment variables
cfg := &SetupConfig{
Database: DatabaseConfig{
Host: getEnvOrDefault("DATABASE_HOST", "localhost"),
Port: getEnvIntOrDefault("DATABASE_PORT", 5432),
User: getEnvOrDefault("DATABASE_USER", "postgres"),
Password: getEnvOrDefault("DATABASE_PASSWORD", ""),
DBName: getEnvOrDefault("DATABASE_DBNAME", "sub2api"),
SSLMode: getEnvOrDefault("DATABASE_SSLMODE", "disable"),
},
Redis: RedisConfig{
Host: getEnvOrDefault("REDIS_HOST", "localhost"),
Port: getEnvIntOrDefault("REDIS_PORT", 6379),
Password: getEnvOrDefault("REDIS_PASSWORD", ""),
DB: getEnvIntOrDefault("REDIS_DB", 0),
},
Admin: AdminConfig{
Email: getEnvOrDefault("ADMIN_EMAIL", "admin@sub2api.local"),
Password: getEnvOrDefault("ADMIN_PASSWORD", ""),
},
Server: ServerConfig{
Host: getEnvOrDefault("SERVER_HOST", "0.0.0.0"),
Port: getEnvIntOrDefault("SERVER_PORT", 8080),
Mode: getEnvOrDefault("SERVER_MODE", "release"),
},
JWT: JWTConfig{
Secret: getEnvOrDefault("JWT_SECRET", ""),
ExpireHour: getEnvIntOrDefault("JWT_EXPIRE_HOUR", 24),
},
Timezone: tz,
}
// Generate JWT secret if not provided
if cfg.JWT.Secret == "" {
cfg.JWT.Secret = generateSecret(32)
log.Println("Generated JWT secret automatically")
}
// Generate admin password if not provided
if cfg.Admin.Password == "" {
cfg.Admin.Password = generateSecret(16)
log.Printf("Generated admin password: %s", cfg.Admin.Password)
log.Println("IMPORTANT: Save this password! It will not be shown again.")
}
// Test database connection
log.Println("Testing database connection...")
if err := TestDatabaseConnection(&cfg.Database); err != nil {
return fmt.Errorf("database connection failed: %w", err)
}
log.Println("Database connection successful")
// Test Redis connection
log.Println("Testing Redis connection...")
if err := TestRedisConnection(&cfg.Redis); err != nil {
return fmt.Errorf("redis connection failed: %w", err)
}
log.Println("Redis connection successful")
// Initialize database
log.Println("Initializing database...")
if err := initializeDatabase(cfg); err != nil {
return fmt.Errorf("database initialization failed: %w", err)
}
log.Println("Database initialized successfully")
// Create admin user
log.Println("Creating admin user...")
if err := createAdminUser(cfg); err != nil {
return fmt.Errorf("admin user creation failed: %w", err)
}
log.Printf("Admin user created: %s", cfg.Admin.Email)
// Write config file
log.Println("Writing configuration file...")
if err := writeConfigFile(cfg); err != nil {
return fmt.Errorf("config file creation failed: %w", err)
}
log.Println("Configuration file created")
// Create installation lock file
if err := createInstallLock(); err != nil {
return fmt.Errorf("failed to create install lock: %w", err)
}
log.Println("Installation lock created")
log.Println("Auto setup completed successfully!")
return nil
}
package web
import (
"embed"
"io"
"io/fs"
"net/http"
"strings"
"github.com/gin-gonic/gin"
)
//go:embed dist/*
var frontendFS embed.FS
// ServeEmbeddedFrontend returns a Gin handler that serves embedded frontend assets
// and handles SPA routing by falling back to index.html for non-API routes.
func ServeEmbeddedFrontend() gin.HandlerFunc {
distFS, err := fs.Sub(frontendFS, "dist")
if err != nil {
panic("failed to get dist subdirectory: " + err.Error())
}
fileServer := http.FileServer(http.FS(distFS))
return func(c *gin.Context) {
path := c.Request.URL.Path
// Skip API and gateway routes
if strings.HasPrefix(path, "/api/") ||
strings.HasPrefix(path, "/v1/") ||
strings.HasPrefix(path, "/setup/") ||
path == "/health" {
c.Next()
return
}
// Try to serve static file
cleanPath := strings.TrimPrefix(path, "/")
if cleanPath == "" {
cleanPath = "index.html"
}
if file, err := distFS.Open(cleanPath); err == nil {
file.Close()
fileServer.ServeHTTP(c.Writer, c.Request)
c.Abort()
return
}
// SPA fallback: serve index.html for all other routes
serveIndexHTML(c, distFS)
}
}
func serveIndexHTML(c *gin.Context, fsys fs.FS) {
file, err := fsys.Open("index.html")
if err != nil {
c.String(http.StatusNotFound, "Frontend not found")
c.Abort()
return
}
defer file.Close()
content, err := io.ReadAll(file)
if err != nil {
c.String(http.StatusInternalServerError, "Failed to read index.html")
c.Abort()
return
}
c.Data(http.StatusOK, "text/html; charset=utf-8", content)
c.Abort()
}
// HasEmbeddedFrontend checks if frontend assets are embedded
func HasEmbeddedFrontend() bool {
_, err := frontendFS.ReadFile("dist/index.html")
return err == nil
}
-- Sub2API 初始化数据库迁移脚本
-- PostgreSQL 15+
-- 1. proxies 代理IP表(无外键依赖)
CREATE TABLE IF NOT EXISTS proxies (
id BIGSERIAL PRIMARY KEY,
name VARCHAR(100) NOT NULL,
protocol VARCHAR(20) NOT NULL, -- http/https/socks5
host VARCHAR(255) NOT NULL,
port INT NOT NULL,
username VARCHAR(100),
password VARCHAR(100),
status VARCHAR(20) NOT NULL DEFAULT 'active', -- active/disabled
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
deleted_at TIMESTAMPTZ
);
CREATE INDEX IF NOT EXISTS idx_proxies_status ON proxies(status);
CREATE INDEX IF NOT EXISTS idx_proxies_deleted_at ON proxies(deleted_at);
-- 2. groups 分组表(无外键依赖)
CREATE TABLE IF NOT EXISTS groups (
id BIGSERIAL PRIMARY KEY,
name VARCHAR(100) NOT NULL UNIQUE,
description TEXT,
rate_multiplier DECIMAL(10, 4) NOT NULL DEFAULT 1.0, -- 费率倍率
is_exclusive BOOLEAN NOT NULL DEFAULT FALSE, -- 是否专属分组
status VARCHAR(20) NOT NULL DEFAULT 'active', -- active/disabled
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
deleted_at TIMESTAMPTZ
);
CREATE INDEX IF NOT EXISTS idx_groups_name ON groups(name);
CREATE INDEX IF NOT EXISTS idx_groups_status ON groups(status);
CREATE INDEX IF NOT EXISTS idx_groups_is_exclusive ON groups(is_exclusive);
CREATE INDEX IF NOT EXISTS idx_groups_deleted_at ON groups(deleted_at);
-- 3. users 用户表(无外键依赖)
CREATE TABLE IF NOT EXISTS users (
id BIGSERIAL PRIMARY KEY,
email VARCHAR(255) NOT NULL UNIQUE,
password_hash VARCHAR(255) NOT NULL,
role VARCHAR(20) NOT NULL DEFAULT 'user', -- admin/user
balance DECIMAL(20, 8) NOT NULL DEFAULT 0, -- 余额(可为负数)
concurrency INT NOT NULL DEFAULT 5, -- 并发数限制
status VARCHAR(20) NOT NULL DEFAULT 'active', -- active/disabled
allowed_groups BIGINT[] DEFAULT NULL, -- 允许绑定的分组ID列表
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
deleted_at TIMESTAMPTZ
);
CREATE INDEX IF NOT EXISTS idx_users_email ON users(email);
CREATE INDEX IF NOT EXISTS idx_users_status ON users(status);
CREATE INDEX IF NOT EXISTS idx_users_deleted_at ON users(deleted_at);
-- 4. accounts 上游账号表(依赖proxies)
CREATE TABLE IF NOT EXISTS accounts (
id BIGSERIAL PRIMARY KEY,
name VARCHAR(100) NOT NULL,
platform VARCHAR(50) NOT NULL, -- anthropic/openai/gemini
type VARCHAR(20) NOT NULL, -- oauth/apikey
credentials JSONB NOT NULL DEFAULT '{}', -- 凭证信息(加密存储)
extra JSONB NOT NULL DEFAULT '{}', -- 扩展信息
proxy_id BIGINT REFERENCES proxies(id) ON DELETE SET NULL,
concurrency INT NOT NULL DEFAULT 3, -- 账号并发限制
priority INT NOT NULL DEFAULT 50, -- 调度优先级(1-100,越小越高)
status VARCHAR(20) NOT NULL DEFAULT 'active', -- active/disabled/error
error_message TEXT,
last_used_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
deleted_at TIMESTAMPTZ
);
CREATE INDEX IF NOT EXISTS idx_accounts_platform ON accounts(platform);
CREATE INDEX IF NOT EXISTS idx_accounts_type ON accounts(type);
CREATE INDEX IF NOT EXISTS idx_accounts_status ON accounts(status);
CREATE INDEX IF NOT EXISTS idx_accounts_proxy_id ON accounts(proxy_id);
CREATE INDEX IF NOT EXISTS idx_accounts_priority ON accounts(priority);
CREATE INDEX IF NOT EXISTS idx_accounts_last_used_at ON accounts(last_used_at);
CREATE INDEX IF NOT EXISTS idx_accounts_deleted_at ON accounts(deleted_at);
-- 5. api_keys API密钥表(依赖users, groups)
CREATE TABLE IF NOT EXISTS api_keys (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
key VARCHAR(64) NOT NULL UNIQUE, -- sk-xxx格式
name VARCHAR(100) NOT NULL,
group_id BIGINT REFERENCES groups(id) ON DELETE SET NULL,
status VARCHAR(20) NOT NULL DEFAULT 'active', -- active/disabled
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
deleted_at TIMESTAMPTZ
);
CREATE INDEX IF NOT EXISTS idx_api_keys_key ON api_keys(key);
CREATE INDEX IF NOT EXISTS idx_api_keys_user_id ON api_keys(user_id);
CREATE INDEX IF NOT EXISTS idx_api_keys_group_id ON api_keys(group_id);
CREATE INDEX IF NOT EXISTS idx_api_keys_status ON api_keys(status);
CREATE INDEX IF NOT EXISTS idx_api_keys_deleted_at ON api_keys(deleted_at);
-- 6. account_groups 账号-分组关联表(依赖accounts, groups)
CREATE TABLE IF NOT EXISTS account_groups (
account_id BIGINT NOT NULL REFERENCES accounts(id) ON DELETE CASCADE,
group_id BIGINT NOT NULL REFERENCES groups(id) ON DELETE CASCADE,
priority INT NOT NULL DEFAULT 50, -- 分组内优先级
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
PRIMARY KEY (account_id, group_id)
);
CREATE INDEX IF NOT EXISTS idx_account_groups_group_id ON account_groups(group_id);
CREATE INDEX IF NOT EXISTS idx_account_groups_priority ON account_groups(priority);
-- 7. redeem_codes 卡密表(依赖users)
CREATE TABLE IF NOT EXISTS redeem_codes (
id BIGSERIAL PRIMARY KEY,
code VARCHAR(32) NOT NULL UNIQUE, -- 兑换码
type VARCHAR(20) NOT NULL DEFAULT 'balance', -- balance
value DECIMAL(20, 8) NOT NULL, -- 面值(USD)
status VARCHAR(20) NOT NULL DEFAULT 'unused', -- unused/used
used_by BIGINT REFERENCES users(id) ON DELETE SET NULL,
used_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_redeem_codes_code ON redeem_codes(code);
CREATE INDEX IF NOT EXISTS idx_redeem_codes_status ON redeem_codes(status);
CREATE INDEX IF NOT EXISTS idx_redeem_codes_used_by ON redeem_codes(used_by);
-- 8. usage_logs 使用记录表(依赖users, api_keys, accounts)
CREATE TABLE IF NOT EXISTS usage_logs (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
api_key_id BIGINT NOT NULL REFERENCES api_keys(id) ON DELETE CASCADE,
account_id BIGINT NOT NULL REFERENCES accounts(id) ON DELETE CASCADE,
request_id VARCHAR(64),
model VARCHAR(100) NOT NULL,
-- Token使用量(4类)
input_tokens INT NOT NULL DEFAULT 0,
output_tokens INT NOT NULL DEFAULT 0,
cache_creation_tokens INT NOT NULL DEFAULT 0,
cache_read_tokens INT NOT NULL DEFAULT 0,
-- 详细的缓存创建分类
cache_creation_5m_tokens INT NOT NULL DEFAULT 0,
cache_creation_1h_tokens INT NOT NULL DEFAULT 0,
-- 费用(USD)
input_cost DECIMAL(20, 10) NOT NULL DEFAULT 0,
output_cost DECIMAL(20, 10) NOT NULL DEFAULT 0,
cache_creation_cost DECIMAL(20, 10) NOT NULL DEFAULT 0,
cache_read_cost DECIMAL(20, 10) NOT NULL DEFAULT 0,
total_cost DECIMAL(20, 10) NOT NULL DEFAULT 0, -- 原始总费用
actual_cost DECIMAL(20, 10) NOT NULL DEFAULT 0, -- 实际扣除费用
-- 元数据
stream BOOLEAN NOT NULL DEFAULT FALSE,
duration_ms INT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_usage_logs_user_id ON usage_logs(user_id);
CREATE INDEX IF NOT EXISTS idx_usage_logs_api_key_id ON usage_logs(api_key_id);
CREATE INDEX IF NOT EXISTS idx_usage_logs_account_id ON usage_logs(account_id);
CREATE INDEX IF NOT EXISTS idx_usage_logs_model ON usage_logs(model);
CREATE INDEX IF NOT EXISTS idx_usage_logs_created_at ON usage_logs(created_at);
CREATE INDEX IF NOT EXISTS idx_usage_logs_user_created ON usage_logs(user_id, created_at);
-- 插入默认管理员用户
-- 密码: admin123 (bcrypt hash)
INSERT INTO users (email, password_hash, role, balance, concurrency, status)
VALUES ('admin@sub2api.com', '$2a$10$N9qo8uLOickgx2ZMRZoMye.IjJbDdJeCo0U2bBPJj9lS/5LqD.C.C', 'admin', 0, 10, 'active')
ON CONFLICT (email) DO NOTHING;
-- 插入默认分组
INSERT INTO groups (name, description, rate_multiplier, is_exclusive, status)
VALUES ('default', '默认分组', 1.0, false, 'active')
ON CONFLICT (name) DO NOTHING;
-- Sub2API 账号类型迁移脚本
-- 将 'official' 类型账号迁移为 'oauth' 或 'setup-token'
-- 根据 credentials->>'scope' 字段判断:
-- - 包含 'user:profile' 的是 'oauth' 类型
-- - 只有 'user:inference' 的是 'setup-token' 类型
-- 1. 将包含 profile scope 的 official 账号迁移为 oauth
UPDATE accounts
SET type = 'oauth',
updated_at = NOW()
WHERE type = 'official'
AND credentials->>'scope' LIKE '%user:profile%';
-- 2. 将只有 inference scope 的 official 账号迁移为 setup-token
UPDATE accounts
SET type = 'setup-token',
updated_at = NOW()
WHERE type = 'official'
AND (
credentials->>'scope' = 'user:inference'
OR credentials->>'scope' NOT LIKE '%user:profile%'
);
-- 3. 处理没有 scope 字段的旧账号(默认为 oauth)
UPDATE accounts
SET type = 'oauth',
updated_at = NOW()
WHERE type = 'official'
AND (credentials->>'scope' IS NULL OR credentials->>'scope' = '');
-- 4. 验证迁移结果(查询是否还有 official 类型账号)
-- SELECT COUNT(*) FROM accounts WHERE type = 'official';
-- 如果结果为 0,说明迁移成功
-- Sub2API 订阅功能迁移脚本
-- 添加订阅分组和用户订阅功能
-- 1. 扩展 groups 表添加订阅相关字段
ALTER TABLE groups ADD COLUMN IF NOT EXISTS platform VARCHAR(50) NOT NULL DEFAULT 'anthropic';
ALTER TABLE groups ADD COLUMN IF NOT EXISTS subscription_type VARCHAR(20) NOT NULL DEFAULT 'standard';
ALTER TABLE groups ADD COLUMN IF NOT EXISTS daily_limit_usd DECIMAL(20, 8) DEFAULT NULL;
ALTER TABLE groups ADD COLUMN IF NOT EXISTS weekly_limit_usd DECIMAL(20, 8) DEFAULT NULL;
ALTER TABLE groups ADD COLUMN IF NOT EXISTS monthly_limit_usd DECIMAL(20, 8) DEFAULT NULL;
ALTER TABLE groups ADD COLUMN IF NOT EXISTS default_validity_days INT NOT NULL DEFAULT 30;
-- 添加索引
CREATE INDEX IF NOT EXISTS idx_groups_platform ON groups(platform);
CREATE INDEX IF NOT EXISTS idx_groups_subscription_type ON groups(subscription_type);
-- 2. 创建 user_subscriptions 用户订阅表
CREATE TABLE IF NOT EXISTS user_subscriptions (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
group_id BIGINT NOT NULL REFERENCES groups(id) ON DELETE CASCADE,
-- 订阅有效期
starts_at TIMESTAMPTZ NOT NULL,
expires_at TIMESTAMPTZ NOT NULL,
status VARCHAR(20) NOT NULL DEFAULT 'active', -- active/expired/suspended
-- 滑动窗口起始时间(NULL=未激活)
daily_window_start TIMESTAMPTZ,
weekly_window_start TIMESTAMPTZ,
monthly_window_start TIMESTAMPTZ,
-- 当前窗口已用额度(USD,基于 total_cost 计算)
daily_usage_usd DECIMAL(20, 10) NOT NULL DEFAULT 0,
weekly_usage_usd DECIMAL(20, 10) NOT NULL DEFAULT 0,
monthly_usage_usd DECIMAL(20, 10) NOT NULL DEFAULT 0,
-- 管理员分配信息
assigned_by BIGINT REFERENCES users(id) ON DELETE SET NULL,
assigned_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
notes TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
-- 唯一约束:每个用户对每个分组只能有一个订阅
UNIQUE(user_id, group_id)
);
-- user_subscriptions 索引
CREATE INDEX IF NOT EXISTS idx_user_subscriptions_user_id ON user_subscriptions(user_id);
CREATE INDEX IF NOT EXISTS idx_user_subscriptions_group_id ON user_subscriptions(group_id);
CREATE INDEX IF NOT EXISTS idx_user_subscriptions_status ON user_subscriptions(status);
CREATE INDEX IF NOT EXISTS idx_user_subscriptions_expires_at ON user_subscriptions(expires_at);
CREATE INDEX IF NOT EXISTS idx_user_subscriptions_assigned_by ON user_subscriptions(assigned_by);
-- 3. 扩展 usage_logs 表添加分组和订阅关联
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS group_id BIGINT REFERENCES groups(id) ON DELETE SET NULL;
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS subscription_id BIGINT REFERENCES user_subscriptions(id) ON DELETE SET NULL;
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS rate_multiplier DECIMAL(10, 4) NOT NULL DEFAULT 1;
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS first_token_ms INT;
-- usage_logs 新索引
CREATE INDEX IF NOT EXISTS idx_usage_logs_group_id ON usage_logs(group_id);
CREATE INDEX IF NOT EXISTS idx_usage_logs_subscription_id ON usage_logs(subscription_id);
CREATE INDEX IF NOT EXISTS idx_usage_logs_sub_created ON usage_logs(subscription_id, created_at);
# Model Pricing Data
This directory contains a local copy of the mirrored model pricing data as a fallback mechanism.
## Source
The original file is maintained by the LiteLLM project and mirrored into the `price-mirror` branch of this repository via GitHub Actions:
- Mirror branch (configurable via `PRICE_MIRROR_REPO`): https://raw.githubusercontent.com/<your-repo>/price-mirror/model_prices_and_context_window.json
- Upstream source: https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json
## Purpose
This local copy serves as a fallback when the remote file cannot be downloaded due to:
- Network restrictions
- Firewall rules
- DNS resolution issues
- GitHub being blocked in certain regions
- Docker container network limitations
## Update Process
The pricingService will:
1. First attempt to download the latest version from GitHub
2. If download fails, use this local copy as fallback
3. Log a warning when using the fallback file
## Manual Update
To manually update this file with the latest pricing data (if automation is unavailable):
```bash
curl -s https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json -o model_prices_and_context_window.json
```
## File Format
The file contains JSON data with model pricing information including:
- Model names and identifiers
- Input/output token costs
- Context window sizes
- Model capabilities
Last updated: 2025-08-10
This source diff could not be displayed because it is too large. You can view the blob instead.
# =============================================================================
# Sub2API Docker Environment Configuration
# =============================================================================
# Copy this file to .env and modify as needed:
# cp .env.example .env
# nano .env
#
# Then start with: docker-compose up -d
# =============================================================================
# -----------------------------------------------------------------------------
# Server Configuration
# -----------------------------------------------------------------------------
# Bind address for host port mapping
BIND_HOST=0.0.0.0
# Server port (exposed on host)
SERVER_PORT=8080
# Server mode: release or debug
SERVER_MODE=release
# Timezone
TZ=Asia/Shanghai
# -----------------------------------------------------------------------------
# PostgreSQL Configuration (REQUIRED)
# -----------------------------------------------------------------------------
POSTGRES_USER=sub2api
POSTGRES_PASSWORD=change_this_secure_password
POSTGRES_DB=sub2api
# -----------------------------------------------------------------------------
# Redis Configuration
# -----------------------------------------------------------------------------
# Leave empty for no password (default for local development)
REDIS_PASSWORD=
REDIS_DB=0
# -----------------------------------------------------------------------------
# Admin Account
# -----------------------------------------------------------------------------
# Email for the admin account
ADMIN_EMAIL=admin@sub2api.local
# Password for admin account
# Leave empty to auto-generate (will be shown in logs on first run)
ADMIN_PASSWORD=
# -----------------------------------------------------------------------------
# JWT Configuration
# -----------------------------------------------------------------------------
# Leave empty to auto-generate (recommended)
JWT_SECRET=
JWT_EXPIRE_HOUR=24
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment