Unverified Commit c6b3de11 authored by NepetaLemon's avatar NepetaLemon Committed by GitHub
Browse files

ci(backend): 添加 github actions (#10)

## 变更内容

### CI/CD
- 添加 GitHub Actions 工作流(test + golangci-lint)
- 添加 golangci-lint 配置,启用 errcheck/govet/staticcheck/unused/depguard
- 通过 depguard 强制 service 层不能直接导入 repository

### 错误处理修复
- 修复 CSV 写入、SSE 流式输出、随机数生成等未处理的错误
- GenerateRedeemCode() 现在返回 error

### 资源泄露修复
- 统一使用 defer func() { _ = xxx.Close() }() 模式

### 代码清理
- 移除未使用的常量
- 简化 nil map 检查
- 统一代码格式
parent f1325e9a
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"log"
"time" "time"
"sub2api/internal/model" "sub2api/internal/model"
...@@ -309,7 +310,9 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda ...@@ -309,7 +310,9 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
go func() { go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
s.billingCacheService.InvalidateUserBalance(cacheCtx, id) if err := s.billingCacheService.InvalidateUserBalance(cacheCtx, id); err != nil {
log.Printf("invalidate user balance cache failed: user_id=%d err=%v", id, err)
}
}() }()
} }
} }
...@@ -317,8 +320,13 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda ...@@ -317,8 +320,13 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
// Create adjustment records for balance/concurrency changes // Create adjustment records for balance/concurrency changes
balanceDiff := user.Balance - oldBalance balanceDiff := user.Balance - oldBalance
if balanceDiff != 0 { if balanceDiff != 0 {
code, err := model.GenerateRedeemCode()
if err != nil {
log.Printf("failed to generate adjustment redeem code: %v", err)
return user, nil
}
adjustmentRecord := &model.RedeemCode{ adjustmentRecord := &model.RedeemCode{
Code: model.GenerateRedeemCode(), Code: code,
Type: model.AdjustmentTypeAdminBalance, Type: model.AdjustmentTypeAdminBalance,
Value: balanceDiff, Value: balanceDiff,
Status: model.StatusUsed, Status: model.StatusUsed,
...@@ -327,15 +335,19 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda ...@@ -327,15 +335,19 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
now := time.Now() now := time.Now()
adjustmentRecord.UsedAt = &now adjustmentRecord.UsedAt = &now
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil { if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
// Log error but don't fail the update log.Printf("failed to create balance adjustment redeem code: %v", err)
// The user update has already succeeded
} }
} }
concurrencyDiff := user.Concurrency - oldConcurrency concurrencyDiff := user.Concurrency - oldConcurrency
if concurrencyDiff != 0 { if concurrencyDiff != 0 {
code, err := model.GenerateRedeemCode()
if err != nil {
log.Printf("failed to generate adjustment redeem code: %v", err)
return user, nil
}
adjustmentRecord := &model.RedeemCode{ adjustmentRecord := &model.RedeemCode{
Code: model.GenerateRedeemCode(), Code: code,
Type: model.AdjustmentTypeAdminConcurrency, Type: model.AdjustmentTypeAdminConcurrency,
Value: float64(concurrencyDiff), Value: float64(concurrencyDiff),
Status: model.StatusUsed, Status: model.StatusUsed,
...@@ -344,8 +356,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda ...@@ -344,8 +356,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
now := time.Now() now := time.Now()
adjustmentRecord.UsedAt = &now adjustmentRecord.UsedAt = &now
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil { if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
// Log error but don't fail the update log.Printf("failed to create concurrency adjustment redeem code: %v", err)
// The user update has already succeeded
} }
} }
...@@ -388,7 +399,9 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, ...@@ -388,7 +399,9 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
go func() { go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
s.billingCacheService.InvalidateUserBalance(cacheCtx, userID) if err := s.billingCacheService.InvalidateUserBalance(cacheCtx, userID); err != nil {
log.Printf("invalidate user balance cache failed: user_id=%d err=%v", userID, err)
}
}() }()
} }
...@@ -579,7 +592,9 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error { ...@@ -579,7 +592,9 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
cacheCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) cacheCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel() defer cancel()
for _, userID := range affectedUserIDs { for _, userID := range affectedUserIDs {
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID) if err := s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID); err != nil {
log.Printf("invalidate subscription cache failed: user_id=%d group_id=%d err=%v", userID, groupID, err)
}
} }
}() }()
} }
...@@ -646,10 +661,10 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U ...@@ -646,10 +661,10 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
if input.Type != "" { if input.Type != "" {
account.Type = input.Type account.Type = input.Type
} }
if input.Credentials != nil && len(input.Credentials) > 0 { if len(input.Credentials) > 0 {
account.Credentials = model.JSONB(input.Credentials) account.Credentials = model.JSONB(input.Credentials)
} }
if input.Extra != nil && len(input.Extra) > 0 { if len(input.Extra) > 0 {
account.Extra = model.JSONB(input.Extra) account.Extra = model.JSONB(input.Extra)
} }
if input.ProxyID != nil { if input.ProxyID != nil {
...@@ -831,8 +846,12 @@ func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *Gener ...@@ -831,8 +846,12 @@ func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *Gener
codes := make([]model.RedeemCode, 0, input.Count) codes := make([]model.RedeemCode, 0, input.Count)
for i := 0; i < input.Count; i++ { for i := 0; i < input.Count; i++ {
codeValue, err := model.GenerateRedeemCode()
if err != nil {
return nil, err
}
code := model.RedeemCode{ code := model.RedeemCode{
Code: model.GenerateRedeemCode(), Code: codeValue,
Type: input.Type, Type: input.Type,
Value: input.Value, Value: input.Value,
Status: model.StatusUnused, Status: model.StatusUnused,
......
...@@ -100,10 +100,13 @@ func (s *ApiKeyService) ValidateCustomKey(key string) error { ...@@ -100,10 +100,13 @@ func (s *ApiKeyService) ValidateCustomKey(key string) error {
// 检查字符:只允许字母、数字、下划线、连字符 // 检查字符:只允许字母、数字、下划线、连字符
for _, c := range key { for _, c := range key {
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || if (c >= 'a' && c <= 'z') ||
(c >= '0' && c <= '9') || c == '_' || c == '-') { (c >= 'A' && c <= 'Z') ||
return ErrApiKeyInvalidChars (c >= '0' && c <= '9') ||
c == '_' || c == '-' {
continue
} }
return ErrApiKeyInvalidChars
} }
return nil return nil
......
...@@ -9,12 +9,6 @@ import ( ...@@ -9,12 +9,6 @@ import (
) )
const ( const (
// Wait polling interval
waitPollInterval = 100 * time.Millisecond
// Default max wait time
defaultMaxWait = 60 * time.Second
// Default extra wait slots beyond concurrency limit // Default extra wait slots beyond concurrency limit
defaultExtraWaitSlots = 20 defaultExtraWaitSlots = 20
) )
...@@ -31,7 +25,7 @@ func NewConcurrencyService(cache ports.ConcurrencyCache) *ConcurrencyService { ...@@ -31,7 +25,7 @@ func NewConcurrencyService(cache ports.ConcurrencyCache) *ConcurrencyService {
// AcquireResult represents the result of acquiring a concurrency slot // AcquireResult represents the result of acquiring a concurrency slot
type AcquireResult struct { type AcquireResult struct {
Acquired bool Acquired bool
ReleaseFunc func() // Must be called when done (typically via defer) ReleaseFunc func() // Must be called when done (typically via defer)
} }
...@@ -54,7 +48,7 @@ func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID i ...@@ -54,7 +48,7 @@ func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID i
if acquired { if acquired {
return &AcquireResult{ return &AcquireResult{
Acquired: true, Acquired: true,
ReleaseFunc: func() { ReleaseFunc: func() {
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
...@@ -90,7 +84,7 @@ func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64, ...@@ -90,7 +84,7 @@ func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64,
if acquired { if acquired {
return &AcquireResult{ return &AcquireResult{
Acquired: true, Acquired: true,
ReleaseFunc: func() { ReleaseFunc: func() {
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
......
...@@ -133,13 +133,13 @@ func (s *EmailService) sendMailTLS(addr string, auth smtp.Auth, from, to string, ...@@ -133,13 +133,13 @@ func (s *EmailService) sendMailTLS(addr string, auth smtp.Auth, from, to string,
if err != nil { if err != nil {
return fmt.Errorf("tls dial: %w", err) return fmt.Errorf("tls dial: %w", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
client, err := smtp.NewClient(conn, host) client, err := smtp.NewClient(conn, host)
if err != nil { if err != nil {
return fmt.Errorf("new smtp client: %w", err) return fmt.Errorf("new smtp client: %w", err)
} }
defer client.Close() defer func() { _ = client.Close() }()
if err = client.Auth(auth); err != nil { if err = client.Auth(auth); err != nil {
return fmt.Errorf("smtp auth: %w", err) return fmt.Errorf("smtp auth: %w", err)
...@@ -303,13 +303,13 @@ func (s *EmailService) TestSmtpConnectionWithConfig(config *SmtpConfig) error { ...@@ -303,13 +303,13 @@ func (s *EmailService) TestSmtpConnectionWithConfig(config *SmtpConfig) error {
if err != nil { if err != nil {
return fmt.Errorf("tls connection failed: %w", err) return fmt.Errorf("tls connection failed: %w", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
client, err := smtp.NewClient(conn, config.Host) client, err := smtp.NewClient(conn, config.Host)
if err != nil { if err != nil {
return fmt.Errorf("smtp client creation failed: %w", err) return fmt.Errorf("smtp client creation failed: %w", err)
} }
defer client.Close() defer func() { _ = client.Close() }()
auth := smtp.PlainAuth("", config.Username, config.Password, config.Host) auth := smtp.PlainAuth("", config.Username, config.Password, config.Host)
if err = client.Auth(auth); err != nil { if err = client.Auth(auth); err != nil {
...@@ -324,7 +324,7 @@ func (s *EmailService) TestSmtpConnectionWithConfig(config *SmtpConfig) error { ...@@ -324,7 +324,7 @@ func (s *EmailService) TestSmtpConnectionWithConfig(config *SmtpConfig) error {
if err != nil { if err != nil {
return fmt.Errorf("smtp connection failed: %w", err) return fmt.Errorf("smtp connection failed: %w", err)
} }
defer client.Close() defer func() { _ = client.Close() }()
auth := smtp.PlainAuth("", config.Username, config.Password, config.Host) auth := smtp.PlainAuth("", config.Username, config.Password, config.Host)
if err = client.Auth(auth); err != nil { if err = client.Auth(auth); err != nil {
......
...@@ -281,7 +281,9 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int ...@@ -281,7 +281,9 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
// 同时检查模型支持 // 同时检查模型支持
if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) { if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
// 续期粘性会话 // 续期粘性会话
s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL) if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
return account, nil return account, nil
} }
} }
...@@ -331,7 +333,9 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int ...@@ -331,7 +333,9 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
// 4. 建立粘性绑定 // 4. 建立粘性绑定
if sessionHash != "" { if sessionHash != "" {
s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL) if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil {
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
}
} }
return selected, nil return selected, nil
...@@ -411,7 +415,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m ...@@ -411,7 +415,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
if err != nil { if err != nil {
return nil, fmt.Errorf("upstream request failed: %w", err) return nil, fmt.Errorf("upstream request failed: %w", err)
} }
defer resp.Body.Close() defer func() { _ = resp.Body.Close() }()
// 处理错误响应(包括401,由后台TokenRefreshService维护token有效性) // 处理错误响应(包括401,由后台TokenRefreshService维护token有效性)
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
...@@ -678,7 +682,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http ...@@ -678,7 +682,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
} }
// 转发行 // 转发行
fmt.Fprintf(w, "%s\n", line) if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush() flusher.Flush()
// 解析usage数据 // 解析usage数据
...@@ -985,7 +991,9 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, ...@@ -985,7 +991,9 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed") s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
return fmt.Errorf("upstream request failed: %w", err) return fmt.Errorf("upstream request failed: %w", err)
} }
defer resp.Body.Close() defer func() {
_ = resp.Body.Close()
}()
// 读取响应体 // 读取响应体
respBody, err := io.ReadAll(resp.Body) respBody, err := io.ReadAll(resp.Body)
......
...@@ -15,7 +15,6 @@ import ( ...@@ -15,7 +15,6 @@ import (
"time" "time"
) )
// 预编译正则表达式(避免每次调用重新编译) // 预编译正则表达式(避免每次调用重新编译)
var ( var (
// 匹配 user_id 格式: user_{64位hex}_account__session_{uuid} // 匹配 user_id 格式: user_{64位hex}_account__session_{uuid}
......
...@@ -254,7 +254,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) ( ...@@ -254,7 +254,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
go func() { go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
s.billingCacheService.InvalidateUserBalance(cacheCtx, userID) _ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
}() }()
} }
...@@ -285,7 +285,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) ( ...@@ -285,7 +285,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
go func() { go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID) _ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
}() }()
} }
......
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"log"
"time" "time"
"sub2api/internal/model" "sub2api/internal/model"
...@@ -78,7 +79,7 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass ...@@ -78,7 +79,7 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass
go func() { go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID) _ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
}() }()
} }
...@@ -146,7 +147,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in ...@@ -146,7 +147,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
} }
newNotes += input.Notes newNotes += input.Notes
if err := s.userSubRepo.UpdateNotes(ctx, existingSub.ID, newNotes); err != nil { if err := s.userSubRepo.UpdateNotes(ctx, existingSub.ID, newNotes); err != nil {
// 备注更新失败不影响主流程 log.Printf("update subscription notes failed: sub_id=%d err=%v", existingSub.ID, err)
} }
} }
...@@ -156,7 +157,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in ...@@ -156,7 +157,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
go func() { go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID) _ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
}() }()
} }
...@@ -177,7 +178,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in ...@@ -177,7 +178,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
go func() { go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID) _ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
}() }()
} }
...@@ -278,7 +279,7 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti ...@@ -278,7 +279,7 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti
go func() { go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID) _ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
}() }()
} }
...@@ -311,7 +312,7 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti ...@@ -311,7 +312,7 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti
go func() { go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID) _ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
}() }()
} }
......
...@@ -12,8 +12,6 @@ var ( ...@@ -12,8 +12,6 @@ var (
ErrTurnstileNotConfigured = errors.New("turnstile not configured") ErrTurnstileNotConfigured = errors.New("turnstile not configured")
) )
const turnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
// TurnstileVerifier 验证 Turnstile token 的接口 // TurnstileVerifier 验证 Turnstile token 的接口
type TurnstileVerifier interface { type TurnstileVerifier interface {
VerifyToken(ctx context.Context, secretKey, token, remoteIP string) (*TurnstileVerifyResponse, error) VerifyToken(ctx context.Context, secretKey, token, remoteIP string) (*TurnstileVerifyResponse, error)
......
...@@ -14,6 +14,7 @@ import ( ...@@ -14,6 +14,7 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strconv"
"strings" "strings"
"time" "time"
...@@ -190,7 +191,7 @@ func (s *UpdateService) PerformUpdate(ctx context.Context) error { ...@@ -190,7 +191,7 @@ func (s *UpdateService) PerformUpdate(ctx context.Context) error {
if err != nil { if err != nil {
return fmt.Errorf("failed to create temp dir: %w", err) return fmt.Errorf("failed to create temp dir: %w", err)
} }
defer os.RemoveAll(tempDir) defer func() { _ = os.RemoveAll(tempDir) }()
// Download archive // Download archive
archivePath := filepath.Join(tempDir, filepath.Base(downloadURL)) archivePath := filepath.Join(tempDir, filepath.Base(downloadURL))
...@@ -223,7 +224,7 @@ func (s *UpdateService) PerformUpdate(ctx context.Context) error { ...@@ -223,7 +224,7 @@ func (s *UpdateService) PerformUpdate(ctx context.Context) error {
backupPath := exePath + ".backup" backupPath := exePath + ".backup"
// Remove old backup if exists // Remove old backup if exists
os.Remove(backupPath) _ = os.Remove(backupPath)
// Step 1: Move current binary to backup // Step 1: Move current binary to backup
if err := os.Rename(exePath, backupPath); err != nil { if err := os.Rename(exePath, backupPath); err != nil {
...@@ -349,7 +350,7 @@ func (s *UpdateService) verifyChecksum(ctx context.Context, filePath, checksumUR ...@@ -349,7 +350,7 @@ func (s *UpdateService) verifyChecksum(ctx context.Context, filePath, checksumUR
if err != nil { if err != nil {
return err return err
} }
defer f.Close() defer func() { _ = f.Close() }()
h := sha256.New() h := sha256.New()
if _, err := io.Copy(h, f); err != nil { if _, err := io.Copy(h, f); err != nil {
...@@ -379,7 +380,7 @@ func (s *UpdateService) extractBinary(archivePath, destPath string) error { ...@@ -379,7 +380,7 @@ func (s *UpdateService) extractBinary(archivePath, destPath string) error {
if err != nil { if err != nil {
return err return err
} }
defer f.Close() defer func() { _ = f.Close() }()
var reader io.Reader = f var reader io.Reader = f
...@@ -389,7 +390,7 @@ func (s *UpdateService) extractBinary(archivePath, destPath string) error { ...@@ -389,7 +390,7 @@ func (s *UpdateService) extractBinary(archivePath, destPath string) error {
if err != nil { if err != nil {
return err return err
} }
defer gzr.Close() defer func() { _ = gzr.Close() }()
reader = gzr reader = gzr
} }
...@@ -435,10 +436,12 @@ func (s *UpdateService) extractBinary(archivePath, destPath string) error { ...@@ -435,10 +436,12 @@ func (s *UpdateService) extractBinary(archivePath, destPath string) error {
// Use LimitReader to prevent decompression bombs // Use LimitReader to prevent decompression bombs
limited := io.LimitReader(tr, maxBinarySize) limited := io.LimitReader(tr, maxBinarySize)
if _, err := io.Copy(out, limited); err != nil { if _, err := io.Copy(out, limited); err != nil {
out.Close() _ = out.Close()
return err
}
if err := out.Close(); err != nil {
return err return err
} }
out.Close()
return nil return nil
} }
} }
...@@ -451,11 +454,13 @@ func (s *UpdateService) extractBinary(archivePath, destPath string) error { ...@@ -451,11 +454,13 @@ func (s *UpdateService) extractBinary(archivePath, destPath string) error {
if err != nil { if err != nil {
return err return err
} }
defer out.Close()
limited := io.LimitReader(reader, maxBinarySize) limited := io.LimitReader(reader, maxBinarySize)
_, err = io.Copy(out, limited) if _, err := io.Copy(out, limited); err != nil {
return err _ = out.Close()
return err
}
return out.Close()
} }
func (s *UpdateService) getFromCache(ctx context.Context) (*UpdateInfo, error) { func (s *UpdateService) getFromCache(ctx context.Context) (*UpdateInfo, error) {
...@@ -499,7 +504,7 @@ func (s *UpdateService) saveToCache(ctx context.Context, info *UpdateInfo) { ...@@ -499,7 +504,7 @@ func (s *UpdateService) saveToCache(ctx context.Context, info *UpdateInfo) {
} }
data, _ := json.Marshal(cacheData) data, _ := json.Marshal(cacheData)
s.cache.SetUpdateInfo(ctx, string(data), time.Duration(updateCacheTTL)*time.Second) _ = s.cache.SetUpdateInfo(ctx, string(data), time.Duration(updateCacheTTL)*time.Second)
} }
// compareVersions compares two semantic versions // compareVersions compares two semantic versions
...@@ -523,7 +528,9 @@ func parseVersion(v string) [3]int { ...@@ -523,7 +528,9 @@ func parseVersion(v string) [3]int {
parts := strings.Split(v, ".") parts := strings.Split(v, ".")
result := [3]int{0, 0, 0} result := [3]int{0, 0, 0}
for i := 0; i < len(parts) && i < 3; i++ { for i := 0; i < len(parts) && i < 3; i++ {
fmt.Sscanf(parts[i], "%d", &result[i]) if parsed, err := strconv.Atoi(parts[i]); err == nil {
result[i] = parsed
}
} }
return result return result
} }
...@@ -352,4 +352,3 @@ func install(c *gin.Context) { ...@@ -352,4 +352,3 @@ func install(c *gin.Context) {
"restart": true, "restart": true,
}) })
} }
...@@ -14,9 +14,9 @@ import ( ...@@ -14,9 +14,9 @@ import (
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"gopkg.in/yaml.v3"
"gorm.io/driver/postgres" "gorm.io/driver/postgres"
"gorm.io/gorm" "gorm.io/gorm"
"gopkg.in/yaml.v3"
) )
// Config paths // Config paths
...@@ -101,7 +101,14 @@ func TestDatabaseConnection(cfg *DatabaseConfig) error { ...@@ -101,7 +101,14 @@ func TestDatabaseConnection(cfg *DatabaseConfig) error {
if err != nil { if err != nil {
return fmt.Errorf("failed to get db instance: %w", err) return fmt.Errorf("failed to get db instance: %w", err)
} }
defer sqlDB.Close() defer func() {
if sqlDB == nil {
return
}
if err := sqlDB.Close(); err != nil {
log.Printf("failed to close postgres connection: %v", err)
}
}()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
...@@ -129,7 +136,10 @@ func TestDatabaseConnection(cfg *DatabaseConfig) error { ...@@ -129,7 +136,10 @@ func TestDatabaseConnection(cfg *DatabaseConfig) error {
} }
// Now connect to the target database to verify // Now connect to the target database to verify
sqlDB.Close() if err := sqlDB.Close(); err != nil {
log.Printf("failed to close postgres connection: %v", err)
}
sqlDB = nil
targetDSN := fmt.Sprintf( targetDSN := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", "host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
...@@ -145,7 +155,11 @@ func TestDatabaseConnection(cfg *DatabaseConfig) error { ...@@ -145,7 +155,11 @@ func TestDatabaseConnection(cfg *DatabaseConfig) error {
if err != nil { if err != nil {
return fmt.Errorf("failed to get target db instance: %w", err) return fmt.Errorf("failed to get target db instance: %w", err)
} }
defer targetSqlDB.Close() defer func() {
if err := targetSqlDB.Close(); err != nil {
log.Printf("failed to close postgres connection: %v", err)
}
}()
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second) ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel2() defer cancel2()
...@@ -164,7 +178,11 @@ func TestRedisConnection(cfg *RedisConfig) error { ...@@ -164,7 +178,11 @@ func TestRedisConnection(cfg *RedisConfig) error {
Password: cfg.Password, Password: cfg.Password,
DB: cfg.DB, DB: cfg.DB,
}) })
defer rdb.Close() defer func() {
if err := rdb.Close(); err != nil {
log.Printf("failed to close redis client: %v", err)
}
}()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
...@@ -185,7 +203,11 @@ func Install(cfg *SetupConfig) error { ...@@ -185,7 +203,11 @@ func Install(cfg *SetupConfig) error {
// Generate JWT secret if not provided // Generate JWT secret if not provided
if cfg.JWT.Secret == "" { if cfg.JWT.Secret == "" {
cfg.JWT.Secret = generateSecret(32) secret, err := generateSecret(32)
if err != nil {
return fmt.Errorf("failed to generate jwt secret: %w", err)
}
cfg.JWT.Secret = secret
} }
// Test connections // Test connections
...@@ -243,7 +265,11 @@ func initializeDatabase(cfg *SetupConfig) error { ...@@ -243,7 +265,11 @@ func initializeDatabase(cfg *SetupConfig) error {
if err != nil { if err != nil {
return err return err
} }
defer sqlDB.Close() defer func() {
if err := sqlDB.Close(); err != nil {
log.Printf("failed to close postgres connection: %v", err)
}
}()
// 使用 model 包的 AutoMigrate,确保模型定义统一 // 使用 model 包的 AutoMigrate,确保模型定义统一
return model.AutoMigrate(db) return model.AutoMigrate(db)
...@@ -265,7 +291,11 @@ func createAdminUser(cfg *SetupConfig) error { ...@@ -265,7 +291,11 @@ func createAdminUser(cfg *SetupConfig) error {
if err != nil { if err != nil {
return err return err
} }
defer sqlDB.Close() defer func() {
if err := sqlDB.Close(); err != nil {
log.Printf("failed to close postgres connection: %v", err)
}
}()
// Check if admin already exists // Check if admin already exists
var count int64 var count int64
...@@ -352,10 +382,12 @@ func writeConfigFile(cfg *SetupConfig) error { ...@@ -352,10 +382,12 @@ func writeConfigFile(cfg *SetupConfig) error {
return os.WriteFile(ConfigFile, data, 0600) return os.WriteFile(ConfigFile, data, 0600)
} }
func generateSecret(length int) string { func generateSecret(length int) (string, error) {
bytes := make([]byte, length) bytes := make([]byte, length)
rand.Read(bytes) if _, err := rand.Read(bytes); err != nil {
return hex.EncodeToString(bytes) return "", err
}
return hex.EncodeToString(bytes), nil
} }
// ============================================================================= // =============================================================================
...@@ -431,13 +463,21 @@ func AutoSetupFromEnv() error { ...@@ -431,13 +463,21 @@ func AutoSetupFromEnv() error {
// Generate JWT secret if not provided // Generate JWT secret if not provided
if cfg.JWT.Secret == "" { if cfg.JWT.Secret == "" {
cfg.JWT.Secret = generateSecret(32) secret, err := generateSecret(32)
if err != nil {
return fmt.Errorf("failed to generate jwt secret: %w", err)
}
cfg.JWT.Secret = secret
log.Println("Generated JWT secret automatically") log.Println("Generated JWT secret automatically")
} }
// Generate admin password if not provided // Generate admin password if not provided
if cfg.Admin.Password == "" { if cfg.Admin.Password == "" {
cfg.Admin.Password = generateSecret(16) password, err := generateSecret(16)
if err != nil {
return fmt.Errorf("failed to generate admin password: %w", err)
}
cfg.Admin.Password = password
log.Printf("Generated admin password: %s", cfg.Admin.Password) log.Printf("Generated admin password: %s", cfg.Admin.Password)
log.Println("IMPORTANT: Save this password! It will not be shown again.") log.Println("IMPORTANT: Save this password! It will not be shown again.")
} }
......
...@@ -41,7 +41,7 @@ func ServeEmbeddedFrontend() gin.HandlerFunc { ...@@ -41,7 +41,7 @@ func ServeEmbeddedFrontend() gin.HandlerFunc {
} }
if file, err := distFS.Open(cleanPath); err == nil { if file, err := distFS.Open(cleanPath); err == nil {
file.Close() _ = file.Close()
fileServer.ServeHTTP(c.Writer, c.Request) fileServer.ServeHTTP(c.Writer, c.Request)
c.Abort() c.Abort()
return return
...@@ -59,7 +59,7 @@ func serveIndexHTML(c *gin.Context, fsys fs.FS) { ...@@ -59,7 +59,7 @@ func serveIndexHTML(c *gin.Context, fsys fs.FS) {
c.Abort() c.Abort()
return return
} }
defer file.Close() defer func() { _ = file.Close() }()
content, err := io.ReadAll(file) content, err := io.ReadAll(file)
if err != nil { if err != nil {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment