Unverified Commit bb500b7b authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #9 from NepetaLemon/refactor/add-http-service-ports

refactor(backend): service http ports
parents 5c2e7ae2 cceada7d
...@@ -10,7 +10,6 @@ import ( ...@@ -10,7 +10,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"net/http"
"net/url" "net/url"
"os" "os"
"path/filepath" "path/filepath"
...@@ -34,17 +33,26 @@ const ( ...@@ -34,17 +33,26 @@ const (
maxDownloadSize = 500 * 1024 * 1024 maxDownloadSize = 500 * 1024 * 1024
) )
// GitHubReleaseClient 获取 GitHub release 信息的接口
type GitHubReleaseClient interface {
FetchLatestRelease(ctx context.Context, repo string) (*GitHubRelease, error)
DownloadFile(ctx context.Context, url, dest string, maxSize int64) error
FetchChecksumFile(ctx context.Context, url string) ([]byte, error)
}
// UpdateService handles software updates // UpdateService handles software updates
type UpdateService struct { type UpdateService struct {
cache ports.UpdateCache cache ports.UpdateCache
githubClient GitHubReleaseClient
currentVersion string currentVersion string
buildType string // "source" for manual builds, "release" for CI builds buildType string // "source" for manual builds, "release" for CI builds
} }
// NewUpdateService creates a new UpdateService // NewUpdateService creates a new UpdateService
func NewUpdateService(cache ports.UpdateCache, version, buildType string) *UpdateService { func NewUpdateService(cache ports.UpdateCache, githubClient GitHubReleaseClient, version, buildType string) *UpdateService {
return &UpdateService{ return &UpdateService{
cache: cache, cache: cache,
githubClient: githubClient,
currentVersion: version, currentVersion: version,
buildType: buildType, buildType: buildType,
} }
...@@ -260,42 +268,11 @@ func (s *UpdateService) Rollback() error { ...@@ -260,42 +268,11 @@ func (s *UpdateService) Rollback() error {
return nil return nil
} }
func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, error) { func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, error) {
url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", githubRepo) release, err := s.githubClient.FetchLatestRelease(ctx, githubRepo)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
req.Header.Set("Accept", "application/vnd.github.v3+json")
req.Header.Set("User-Agent", "Sub2API-Updater")
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
return &UpdateInfo{
CurrentVersion: s.currentVersion,
LatestVersion: s.currentVersion,
HasUpdate: false,
Warning: "No releases found",
BuildType: s.buildType,
}, nil
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("GitHub API returned %d", resp.StatusCode)
}
var release GitHubRelease
if err := json.NewDecoder(resp.Body).Decode(&release); err != nil {
return nil, err
}
latestVersion := strings.TrimPrefix(release.TagName, "v") latestVersion := strings.TrimPrefix(release.TagName, "v")
...@@ -325,47 +302,7 @@ func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, er ...@@ -325,47 +302,7 @@ func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, er
} }
func (s *UpdateService) downloadFile(ctx context.Context, downloadURL, dest string) error { func (s *UpdateService) downloadFile(ctx context.Context, downloadURL, dest string) error {
req, err := http.NewRequestWithContext(ctx, "GET", downloadURL, nil) return s.githubClient.DownloadFile(ctx, downloadURL, dest, maxDownloadSize)
if err != nil {
return err
}
client := &http.Client{Timeout: 10 * time.Minute}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("download returned %d", resp.StatusCode)
}
// SECURITY: Check Content-Length if available
if resp.ContentLength > maxDownloadSize {
return fmt.Errorf("file too large: %d bytes (max %d)", resp.ContentLength, maxDownloadSize)
}
out, err := os.Create(dest)
if err != nil {
return err
}
defer out.Close()
// SECURITY: Use LimitReader to enforce max download size even if Content-Length is missing/wrong
limited := io.LimitReader(resp.Body, maxDownloadSize+1)
written, err := io.Copy(out, limited)
if err != nil {
return err
}
// Check if we hit the limit (downloaded more than maxDownloadSize)
if written > maxDownloadSize {
os.Remove(dest) // Clean up partial file
return fmt.Errorf("download exceeded maximum size of %d bytes", maxDownloadSize)
}
return nil
} }
func (s *UpdateService) getArchiveName() string { func (s *UpdateService) getArchiveName() string {
...@@ -402,20 +339,9 @@ func validateDownloadURL(rawURL string) error { ...@@ -402,20 +339,9 @@ func validateDownloadURL(rawURL string) error {
func (s *UpdateService) verifyChecksum(ctx context.Context, filePath, checksumURL string) error { func (s *UpdateService) verifyChecksum(ctx context.Context, filePath, checksumURL string) error {
// Download checksums file // Download checksums file
req, err := http.NewRequestWithContext(ctx, "GET", checksumURL, nil) checksumData, err := s.githubClient.FetchChecksumFile(ctx, checksumURL)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to download checksums: %w", err)
}
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("failed to download checksums: %d", resp.StatusCode)
} }
// Calculate file hash // Calculate file hash
...@@ -433,7 +359,7 @@ func (s *UpdateService) verifyChecksum(ctx context.Context, filePath, checksumUR ...@@ -433,7 +359,7 @@ func (s *UpdateService) verifyChecksum(ctx context.Context, filePath, checksumUR
// Find expected hash in checksums file // Find expected hash in checksums file
fileName := filepath.Base(filePath) fileName := filepath.Base(filePath)
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(strings.NewReader(string(checksumData)))
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
parts := strings.Fields(line) parts := strings.Fields(line)
......
...@@ -2,13 +2,20 @@ package service ...@@ -2,13 +2,20 @@ package service
import ( import (
"sub2api/internal/config" "sub2api/internal/config"
"sub2api/internal/service/ports"
"github.com/google/wire" "github.com/google/wire"
) )
// BuildInfo contains build information
type BuildInfo struct {
Version string
BuildType string
}
// ProvidePricingService creates and initializes PricingService // ProvidePricingService creates and initializes PricingService
func ProvidePricingService(cfg *config.Config) (*PricingService, error) { func ProvidePricingService(cfg *config.Config, remoteClient PricingRemoteClient) (*PricingService, error) {
svc := NewPricingService(cfg) svc := NewPricingService(cfg, remoteClient)
if err := svc.Initialize(); err != nil { if err := svc.Initialize(); err != nil {
// 价格服务初始化失败不应阻止启动,使用回退价格 // 价格服务初始化失败不应阻止启动,使用回退价格
println("[Service] Warning: Pricing service initialization failed:", err.Error()) println("[Service] Warning: Pricing service initialization failed:", err.Error())
...@@ -16,6 +23,11 @@ func ProvidePricingService(cfg *config.Config) (*PricingService, error) { ...@@ -16,6 +23,11 @@ func ProvidePricingService(cfg *config.Config) (*PricingService, error) {
return svc, nil return svc, nil
} }
// ProvideUpdateService creates UpdateService with BuildInfo
func ProvideUpdateService(cache ports.UpdateCache, githubClient GitHubReleaseClient, buildInfo BuildInfo) *UpdateService {
return NewUpdateService(cache, githubClient, buildInfo.Version, buildInfo.BuildType)
}
// ProvideEmailQueueService creates EmailQueueService with default worker count // ProvideEmailQueueService creates EmailQueueService with default worker count
func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService { func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService {
return NewEmailQueueService(emailService, 3) return NewEmailQueueService(emailService, 3)
...@@ -48,6 +60,7 @@ var ProviderSet = wire.NewSet( ...@@ -48,6 +60,7 @@ var ProviderSet = wire.NewSet(
NewSubscriptionService, NewSubscriptionService,
NewConcurrencyService, NewConcurrencyService,
NewIdentityService, NewIdentityService,
ProvideUpdateService,
// Provide the Services container struct // Provide the Services container struct
wire.Struct(new(Services), "*"), wire.Struct(new(Services), "*"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment