Unverified Commit bb500b7b authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #9 from NepetaLemon/refactor/add-http-service-ports

refactor(backend): service http ports
parents 5c2e7ae2 cceada7d
......@@ -10,7 +10,6 @@ import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
......@@ -34,17 +33,26 @@ const (
maxDownloadSize = 500 * 1024 * 1024
)
// GitHubReleaseClient 获取 GitHub release 信息的接口
type GitHubReleaseClient interface {
FetchLatestRelease(ctx context.Context, repo string) (*GitHubRelease, error)
DownloadFile(ctx context.Context, url, dest string, maxSize int64) error
FetchChecksumFile(ctx context.Context, url string) ([]byte, error)
}
// UpdateService handles software updates
type UpdateService struct {
cache ports.UpdateCache
githubClient GitHubReleaseClient
currentVersion string
buildType string // "source" for manual builds, "release" for CI builds
}
// NewUpdateService creates a new UpdateService
func NewUpdateService(cache ports.UpdateCache, version, buildType string) *UpdateService {
func NewUpdateService(cache ports.UpdateCache, githubClient GitHubReleaseClient, version, buildType string) *UpdateService {
return &UpdateService{
cache: cache,
githubClient: githubClient,
currentVersion: version,
buildType: buildType,
}
......@@ -260,42 +268,11 @@ func (s *UpdateService) Rollback() error {
return nil
}
func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, error) {
url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", githubRepo)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
release, err := s.githubClient.FetchLatestRelease(ctx, githubRepo)
if err != nil {
return nil, err
}
req.Header.Set("Accept", "application/vnd.github.v3+json")
req.Header.Set("User-Agent", "Sub2API-Updater")
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
return &UpdateInfo{
CurrentVersion: s.currentVersion,
LatestVersion: s.currentVersion,
HasUpdate: false,
Warning: "No releases found",
BuildType: s.buildType,
}, nil
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("GitHub API returned %d", resp.StatusCode)
}
var release GitHubRelease
if err := json.NewDecoder(resp.Body).Decode(&release); err != nil {
return nil, err
}
latestVersion := strings.TrimPrefix(release.TagName, "v")
......@@ -325,47 +302,7 @@ func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, er
}
func (s *UpdateService) downloadFile(ctx context.Context, downloadURL, dest string) error {
req, err := http.NewRequestWithContext(ctx, "GET", downloadURL, nil)
if err != nil {
return err
}
client := &http.Client{Timeout: 10 * time.Minute}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("download returned %d", resp.StatusCode)
}
// SECURITY: Check Content-Length if available
if resp.ContentLength > maxDownloadSize {
return fmt.Errorf("file too large: %d bytes (max %d)", resp.ContentLength, maxDownloadSize)
}
out, err := os.Create(dest)
if err != nil {
return err
}
defer out.Close()
// SECURITY: Use LimitReader to enforce max download size even if Content-Length is missing/wrong
limited := io.LimitReader(resp.Body, maxDownloadSize+1)
written, err := io.Copy(out, limited)
if err != nil {
return err
}
// Check if we hit the limit (downloaded more than maxDownloadSize)
if written > maxDownloadSize {
os.Remove(dest) // Clean up partial file
return fmt.Errorf("download exceeded maximum size of %d bytes", maxDownloadSize)
}
return nil
return s.githubClient.DownloadFile(ctx, downloadURL, dest, maxDownloadSize)
}
func (s *UpdateService) getArchiveName() string {
......@@ -402,20 +339,9 @@ func validateDownloadURL(rawURL string) error {
func (s *UpdateService) verifyChecksum(ctx context.Context, filePath, checksumURL string) error {
// Download checksums file
req, err := http.NewRequestWithContext(ctx, "GET", checksumURL, nil)
checksumData, err := s.githubClient.FetchChecksumFile(ctx, checksumURL)
if err != nil {
return err
}
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("failed to download checksums: %d", resp.StatusCode)
return fmt.Errorf("failed to download checksums: %w", err)
}
// Calculate file hash
......@@ -433,7 +359,7 @@ func (s *UpdateService) verifyChecksum(ctx context.Context, filePath, checksumUR
// Find expected hash in checksums file
fileName := filepath.Base(filePath)
scanner := bufio.NewScanner(resp.Body)
scanner := bufio.NewScanner(strings.NewReader(string(checksumData)))
for scanner.Scan() {
line := scanner.Text()
parts := strings.Fields(line)
......
......@@ -2,13 +2,20 @@ package service
import (
"sub2api/internal/config"
"sub2api/internal/service/ports"
"github.com/google/wire"
)
// BuildInfo contains build information
type BuildInfo struct {
Version string
BuildType string
}
// ProvidePricingService creates and initializes PricingService
func ProvidePricingService(cfg *config.Config) (*PricingService, error) {
svc := NewPricingService(cfg)
func ProvidePricingService(cfg *config.Config, remoteClient PricingRemoteClient) (*PricingService, error) {
svc := NewPricingService(cfg, remoteClient)
if err := svc.Initialize(); err != nil {
// 价格服务初始化失败不应阻止启动,使用回退价格
println("[Service] Warning: Pricing service initialization failed:", err.Error())
......@@ -16,6 +23,11 @@ func ProvidePricingService(cfg *config.Config) (*PricingService, error) {
return svc, nil
}
// ProvideUpdateService creates UpdateService with BuildInfo
func ProvideUpdateService(cache ports.UpdateCache, githubClient GitHubReleaseClient, buildInfo BuildInfo) *UpdateService {
return NewUpdateService(cache, githubClient, buildInfo.Version, buildInfo.BuildType)
}
// ProvideEmailQueueService creates EmailQueueService with default worker count
func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService {
return NewEmailQueueService(emailService, 3)
......@@ -48,6 +60,7 @@ var ProviderSet = wire.NewSet(
NewSubscriptionService,
NewConcurrencyService,
NewIdentityService,
ProvideUpdateService,
// Provide the Services container struct
wire.Struct(new(Services), "*"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment