Commit 25a304c2 authored by Forest's avatar Forest
Browse files

test: 增加 repository 测试

parent 9d30ceae
//go:build integration
package repository
import (
"context"
"database/sql"
"fmt"
"log"
"os"
"os/exec"
"strconv"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
redisclient "github.com/redis/go-redis/v9"
tcpostgres "github.com/testcontainers/testcontainers-go/modules/postgres"
tcredis "github.com/testcontainers/testcontainers-go/modules/redis"
gormpostgres "gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
const (
redisImageTag = "redis:8.4-alpine"
postgresImageTag = "postgres:18.1-alpine3.23"
)
var (
integrationDB *gorm.DB
integrationRedis *redisclient.Client
redisNamespaceSeq uint64
)
func TestMain(m *testing.M) {
ctx := context.Background()
if err := timezone.Init("UTC"); err != nil {
log.Printf("failed to init timezone: %v", err)
os.Exit(1)
}
if !dockerIsAvailable(ctx) {
// In CI we expect Docker to be available so integration tests should fail loudly.
if os.Getenv("CI") != "" {
log.Printf("docker is not available (CI=true); failing integration tests")
os.Exit(1)
}
log.Printf("docker is not available; skipping integration tests (start Docker to enable)")
os.Exit(0)
}
postgresImage := selectDockerImage(ctx, postgresImageTag)
pgContainer, err := tcpostgres.Run(
ctx,
postgresImage,
tcpostgres.WithDatabase("sub2api_test"),
tcpostgres.WithUsername("postgres"),
tcpostgres.WithPassword("postgres"),
tcpostgres.BasicWaitStrategies(),
)
if err != nil {
log.Printf("failed to start postgres container: %v", err)
os.Exit(1)
}
defer func() { _ = pgContainer.Terminate(ctx) }()
redisContainer, err := tcredis.Run(
ctx,
redisImageTag,
)
if err != nil {
log.Printf("failed to start redis container: %v", err)
os.Exit(1)
}
defer func() { _ = redisContainer.Terminate(ctx) }()
dsn, err := pgContainer.ConnectionString(ctx, "sslmode=disable", "TimeZone=UTC")
if err != nil {
log.Printf("failed to get postgres dsn: %v", err)
os.Exit(1)
}
integrationDB, err = openGormWithRetry(ctx, dsn, 30*time.Second)
if err != nil {
log.Printf("failed to open gorm db: %v", err)
os.Exit(1)
}
if err := model.AutoMigrate(integrationDB); err != nil {
log.Printf("failed to automigrate db: %v", err)
os.Exit(1)
}
redisHost, err := redisContainer.Host(ctx)
if err != nil {
log.Printf("failed to get redis host: %v", err)
os.Exit(1)
}
redisPort, err := redisContainer.MappedPort(ctx, "6379/tcp")
if err != nil {
log.Printf("failed to get redis port: %v", err)
os.Exit(1)
}
integrationRedis = redisclient.NewClient(&redisclient.Options{
Addr: fmt.Sprintf("%s:%d", redisHost, redisPort.Int()),
DB: 0,
})
if err := integrationRedis.Ping(ctx).Err(); err != nil {
log.Printf("failed to ping redis: %v", err)
os.Exit(1)
}
code := m.Run()
_ = integrationRedis.Close()
os.Exit(code)
}
func dockerIsAvailable(ctx context.Context) bool {
cmd := exec.CommandContext(ctx, "docker", "info")
cmd.Env = os.Environ()
return cmd.Run() == nil
}
func selectDockerImage(ctx context.Context, preferred string) string {
if dockerImageExists(ctx, preferred) {
return preferred
}
return preferred
}
func dockerImageExists(ctx context.Context, image string) bool {
cmd := exec.CommandContext(ctx, "docker", "image", "inspect", image)
cmd.Env = os.Environ()
cmd.Stdout = nil
cmd.Stderr = nil
return cmd.Run() == nil
}
func openGormWithRetry(ctx context.Context, dsn string, timeout time.Duration) (*gorm.DB, error) {
deadline := time.Now().Add(timeout)
var lastErr error
for time.Now().Before(deadline) {
db, err := gorm.Open(gormpostgres.Open(dsn), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
lastErr = err
time.Sleep(250 * time.Millisecond)
continue
}
sqlDB, err := db.DB()
if err != nil {
lastErr = err
time.Sleep(250 * time.Millisecond)
continue
}
if err := pingWithTimeout(ctx, sqlDB, 2*time.Second); err != nil {
lastErr = err
time.Sleep(250 * time.Millisecond)
continue
}
return db, nil
}
return nil, fmt.Errorf("db not ready after %s: %w", timeout, lastErr)
}
func pingWithTimeout(ctx context.Context, db *sql.DB, timeout time.Duration) error {
pingCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
return db.PingContext(pingCtx)
}
func testTx(t *testing.T) *gorm.DB {
t.Helper()
tx := integrationDB.Begin()
require.NoError(t, tx.Error, "begin tx")
t.Cleanup(func() {
_ = tx.Rollback().Error
})
return tx
}
func testRedis(t *testing.T) *redisclient.Client {
t.Helper()
prefix := fmt.Sprintf(
"it:%s:%d:%d:",
sanitizeRedisNamespace(t.Name()),
time.Now().UnixNano(),
atomic.AddUint64(&redisNamespaceSeq, 1),
)
opts := *integrationRedis.Options()
rdb := redisclient.NewClient(&opts)
rdb.AddHook(prefixHook{prefix: prefix})
t.Cleanup(func() {
ctx := context.Background()
var cursor uint64
for {
keys, nextCursor, err := integrationRedis.Scan(ctx, cursor, prefix+"*", 500).Result()
require.NoError(t, err, "scan redis keys for cleanup")
if len(keys) > 0 {
require.NoError(t, integrationRedis.Unlink(ctx, keys...).Err(), "unlink redis keys for cleanup")
}
cursor = nextCursor
if cursor == 0 {
break
}
}
_ = rdb.Close()
})
return rdb
}
func assertTTLWithin(t *testing.T, ttl time.Duration, min, max time.Duration) {
t.Helper()
require.GreaterOrEqual(t, ttl, min, "ttl should be >= min")
require.LessOrEqual(t, ttl, max, "ttl should be <= max")
}
func sanitizeRedisNamespace(name string) string {
name = strings.ReplaceAll(name, "/", "_")
name = strings.ReplaceAll(name, " ", "_")
return name
}
type prefixHook struct {
prefix string
}
func (h prefixHook) DialHook(next redisclient.DialHook) redisclient.DialHook { return next }
func (h prefixHook) ProcessHook(next redisclient.ProcessHook) redisclient.ProcessHook {
return func(ctx context.Context, cmd redisclient.Cmder) error {
h.prefixCmd(cmd)
return next(ctx, cmd)
}
}
func (h prefixHook) ProcessPipelineHook(next redisclient.ProcessPipelineHook) redisclient.ProcessPipelineHook {
return func(ctx context.Context, cmds []redisclient.Cmder) error {
for _, cmd := range cmds {
h.prefixCmd(cmd)
}
return next(ctx, cmds)
}
}
func (h prefixHook) prefixCmd(cmd redisclient.Cmder) {
args := cmd.Args()
if len(args) < 2 {
return
}
prefixOne := func(i int) {
if i < 0 || i >= len(args) {
return
}
switch v := args[i].(type) {
case string:
if v != "" && !strings.HasPrefix(v, h.prefix) {
args[i] = h.prefix + v
}
case []byte:
s := string(v)
if s != "" && !strings.HasPrefix(s, h.prefix) {
args[i] = []byte(h.prefix + s)
}
}
}
switch strings.ToLower(cmd.Name()) {
case "get", "set", "setnx", "setex", "psetex", "incr", "decr", "incrby", "expire", "pexpire", "ttl", "pttl",
"hgetall", "hget", "hset", "hdel", "hincrbyfloat", "exists":
prefixOne(1)
case "del", "unlink":
for i := 1; i < len(args); i++ {
prefixOne(i)
}
case "eval", "evalsha", "eval_ro", "evalsha_ro":
if len(args) < 3 {
return
}
numKeys, err := strconv.Atoi(fmt.Sprint(args[2]))
if err != nil || numKeys <= 0 {
return
}
for i := 0; i < numKeys && 3+i < len(args); i++ {
prefixOne(3 + i)
}
case "scan":
for i := 2; i+1 < len(args); i++ {
if strings.EqualFold(fmt.Sprint(args[i]), "match") {
prefixOne(i + 1)
break
}
}
}
}
// IntegrationRedisSuite provides a base suite for Redis integration tests.
// Embedding suites should call SetupTest to initialize ctx and rdb.
type IntegrationRedisSuite struct {
suite.Suite
ctx context.Context
rdb *redisclient.Client
}
// SetupTest initializes ctx and rdb for each test method.
func (s *IntegrationRedisSuite) SetupTest() {
s.ctx = context.Background()
s.rdb = testRedis(s.T())
}
// RequireNoError is a convenience method wrapping require.NoError with s.T().
func (s *IntegrationRedisSuite) RequireNoError(err error, msgAndArgs ...any) {
s.T().Helper()
require.NoError(s.T(), err, msgAndArgs...)
}
// AssertTTLWithin asserts that ttl is within [min, max].
func (s *IntegrationRedisSuite) AssertTTLWithin(ttl, min, max time.Duration) {
s.T().Helper()
assertTTLWithin(s.T(), ttl, min, max)
}
// IntegrationDBSuite provides a base suite for DB (Gorm) integration tests.
// Embedding suites should call SetupTest to initialize ctx and db.
type IntegrationDBSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
}
// SetupTest initializes ctx and db for each test method.
func (s *IntegrationDBSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
}
// RequireNoError is a convenience method wrapping require.NoError with s.T().
func (s *IntegrationDBSuite) RequireNoError(err error, msgAndArgs ...any) {
s.T().Helper()
require.NoError(s.T(), err, msgAndArgs...)
}
......@@ -12,11 +12,13 @@ import (
"github.com/imroc/req/v3"
)
type openaiOAuthService struct{}
// NewOpenAIOAuthClient creates a new OpenAI OAuth client
func NewOpenAIOAuthClient() ports.OpenAIOAuthClient {
return &openaiOAuthService{}
return &openaiOAuthService{tokenURL: openai.TokenURL}
}
type openaiOAuthService struct {
tokenURL string
}
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
......@@ -39,7 +41,7 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
SetContext(ctx).
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(openai.TokenURL)
Post(s.tokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
......@@ -67,7 +69,7 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
SetContext(ctx).
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(openai.TokenURL)
Post(s.tokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
......
package repository
import (
"context"
"io"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type OpenAIOAuthServiceSuite struct {
suite.Suite
ctx context.Context
srv *httptest.Server
svc *openaiOAuthService
received chan url.Values
}
func (s *OpenAIOAuthServiceSuite) SetupTest() {
s.ctx = context.Background()
s.received = make(chan url.Values, 1)
}
func (s *OpenAIOAuthServiceSuite) TearDownTest() {
if s.srv != nil {
s.srv.Close()
s.srv = nil
}
}
func (s *OpenAIOAuthServiceSuite) setupServer(handler http.HandlerFunc) {
s.srv = httptest.NewServer(handler)
s.svc = &openaiOAuthService{tokenURL: s.srv.URL}
}
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_DefaultRedirectURI() {
errCh := make(chan string, 1)
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
errCh <- "method mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
if err := r.ParseForm(); err != nil {
errCh <- "ParseForm failed"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("grant_type"); got != "authorization_code" {
errCh <- "grant_type mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("client_id"); got != openai.ClientID {
errCh <- "client_id mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("code"); got != "code" {
errCh <- "code mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("redirect_uri"); got != openai.DefaultRedirectURI {
errCh <- "redirect_uri mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("code_verifier"); got != "ver" {
errCh <- "code_verifier mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at","refresh_token":"rt","token_type":"bearer","expires_in":3600}`)
}))
resp, err := s.svc.ExchangeCode(s.ctx, "code", "ver", "", "")
require.NoError(s.T(), err, "ExchangeCode")
select {
case msg := <-errCh:
require.Fail(s.T(), msg)
default:
}
require.Equal(s.T(), "at", resp.AccessToken)
require.Equal(s.T(), "rt", resp.RefreshToken)
}
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() {
errCh := make(chan string, 1)
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
errCh <- "ParseForm failed"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("grant_type"); got != "refresh_token" {
errCh <- "grant_type mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("refresh_token"); got != "rt" {
errCh <- "refresh_token mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("client_id"); got != openai.ClientID {
errCh <- "client_id mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("scope"); got != openai.RefreshScopes {
errCh <- "scope mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at2","refresh_token":"rt2","token_type":"bearer","expires_in":3600}`)
}))
resp, err := s.svc.RefreshToken(s.ctx, "rt", "")
require.NoError(s.T(), err, "RefreshToken")
select {
case msg := <-errCh:
require.Fail(s.T(), msg)
default:
}
require.Equal(s.T(), "at2", resp.AccessToken)
require.Equal(s.T(), "rt2", resp.RefreshToken)
}
func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
_, _ = io.WriteString(w, "bad")
}))
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "status 400")
require.ErrorContains(s.T(), err, "bad")
}
func (s *OpenAIOAuthServiceSuite) TestRequestError_ClosedServer() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
s.srv.Close()
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "request failed")
}
func (s *OpenAIOAuthServiceSuite) TestContextCancel() {
started := make(chan struct{})
block := make(chan struct{})
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
close(started)
<-block
}))
ctx, cancel := context.WithCancel(s.ctx)
done := make(chan error, 1)
go func() {
_, err := s.svc.ExchangeCode(ctx, "code", "ver", openai.DefaultRedirectURI, "")
done <- err
}()
<-started
cancel()
close(block)
err := <-done
require.Error(s.T(), err)
}
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UsesProvidedRedirectURI() {
want := "http://localhost:9999/cb"
errCh := make(chan string, 1)
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = r.ParseForm()
if got := r.PostForm.Get("redirect_uri"); got != want {
errCh <- "redirect_uri mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`)
}))
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", want, "")
require.NoError(s.T(), err, "ExchangeCode")
select {
case msg := <-errCh:
require.Fail(s.T(), msg)
default:
}
}
func (s *OpenAIOAuthServiceSuite) TestTokenURL_CanBeOverriddenWithQuery() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = r.ParseForm()
s.received <- r.PostForm
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`)
}))
s.svc.tokenURL = s.srv.URL + "?x=1"
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
require.NoError(s.T(), err, "ExchangeCode")
select {
case <-s.received:
default:
require.Fail(s.T(), "expected server to receive request")
}
}
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_SuccessButInvalidJSON() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = io.WriteString(w, "not-valid-json")
}))
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
require.Error(s.T(), err, "expected error for invalid JSON response")
}
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_NonSuccessStatus() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = io.WriteString(w, "unauthorized")
}))
_, err := s.svc.RefreshToken(s.ctx, "rt", "")
require.Error(s.T(), err, "expected error for non-2xx status")
require.ErrorContains(s.T(), err, "status 401")
}
func TestOpenAIOAuthServiceSuite(t *testing.T) {
suite.Run(t, new(OpenAIOAuthServiceSuite))
}
package repository
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type PricingServiceSuite struct {
suite.Suite
ctx context.Context
srv *httptest.Server
client *pricingRemoteClient
}
func (s *PricingServiceSuite) SetupTest() {
s.ctx = context.Background()
client, ok := NewPricingRemoteClient().(*pricingRemoteClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
}
func (s *PricingServiceSuite) TearDownTest() {
if s.srv != nil {
s.srv.Close()
s.srv = nil
}
}
func (s *PricingServiceSuite) setupServer(handler http.HandlerFunc) {
s.srv = httptest.NewServer(handler)
}
func (s *PricingServiceSuite) TestFetchPricingJSON_Success() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/ok" {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"ok":true}`))
return
}
w.WriteHeader(http.StatusInternalServerError)
}))
body, err := s.client.FetchPricingJSON(s.ctx, s.srv.URL+"/ok")
require.NoError(s.T(), err, "FetchPricingJSON")
require.Equal(s.T(), `{"ok":true}`, string(body), "body mismatch")
}
func (s *PricingServiceSuite) TestFetchPricingJSON_NonOKStatus() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
_, err := s.client.FetchPricingJSON(s.ctx, s.srv.URL+"/err")
require.Error(s.T(), err, "expected error for non-200 status")
}
func (s *PricingServiceSuite) TestFetchHashText_ParsesFields() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/hashfile":
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("abc123 model_prices.json\n"))
case "/hashonly":
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("def456\n"))
default:
w.WriteHeader(http.StatusNotFound)
}
}))
hash, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/hashfile")
require.NoError(s.T(), err, "FetchHashText")
require.Equal(s.T(), "abc123", hash, "hash mismatch")
hash2, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/hashonly")
require.NoError(s.T(), err, "FetchHashText")
require.Equal(s.T(), "def456", hash2, "hash mismatch")
}
func (s *PricingServiceSuite) TestFetchHashText_NonOKStatus() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
_, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/nope")
require.Error(s.T(), err, "expected error for non-200 status")
}
func (s *PricingServiceSuite) TestFetchPricingJSON_InvalidURL() {
_, err := s.client.FetchPricingJSON(s.ctx, "://invalid-url")
require.Error(s.T(), err, "expected error for invalid URL")
}
func (s *PricingServiceSuite) TestFetchHashText_EmptyBody() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
// empty body
}))
hash, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/empty")
require.NoError(s.T(), err, "FetchHashText empty body should not error")
require.Equal(s.T(), "", hash, "expected empty hash")
}
func (s *PricingServiceSuite) TestFetchHashText_WhitespaceOnly() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(" \n"))
}))
hash, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/ws")
require.NoError(s.T(), err, "FetchHashText whitespace body should not error")
require.Equal(s.T(), "", hash, "expected empty hash after trimming")
}
func (s *PricingServiceSuite) TestFetchPricingJSON_ContextCancel() {
started := make(chan struct{})
block := make(chan struct{})
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
close(started)
<-block
}))
ctx, cancel := context.WithCancel(s.ctx)
done := make(chan error, 1)
go func() {
_, err := s.client.FetchPricingJSON(ctx, s.srv.URL+"/block")
done <- err
}()
<-started
cancel()
close(block)
err := <-done
require.Error(s.T(), err)
}
func TestPricingServiceSuite(t *testing.T) {
suite.Run(t, new(PricingServiceSuite))
}
......@@ -16,10 +16,14 @@ import (
"golang.org/x/net/proxy"
)
type proxyProbeService struct{}
func NewProxyExitInfoProber() service.ProxyExitInfoProber {
return &proxyProbeService{}
return &proxyProbeService{ipInfoURL: defaultIPInfoURL}
}
const defaultIPInfoURL = "https://ipinfo.io/json"
type proxyProbeService struct {
ipInfoURL string
}
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
......@@ -34,7 +38,7 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
}
startTime := time.Now()
req, err := http.NewRequestWithContext(ctx, "GET", "https://ipinfo.io/json", nil)
req, err := http.NewRequestWithContext(ctx, "GET", s.ipInfoURL, nil)
if err != nil {
return nil, 0, fmt.Errorf("failed to create request: %w", err)
}
......
package repository
import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type ProxyProbeServiceSuite struct {
suite.Suite
ctx context.Context
proxySrv *httptest.Server
prober *proxyProbeService
}
func (s *ProxyProbeServiceSuite) SetupTest() {
s.ctx = context.Background()
s.prober = &proxyProbeService{ipInfoURL: "http://ipinfo.test/json"}
}
func (s *ProxyProbeServiceSuite) TearDownTest() {
if s.proxySrv != nil {
s.proxySrv.Close()
s.proxySrv = nil
}
}
func (s *ProxyProbeServiceSuite) setupProxyServer(handler http.HandlerFunc) {
s.proxySrv = httptest.NewServer(handler)
}
func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_InvalidURL() {
_, err := createProxyTransport("://bad")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "invalid proxy URL")
}
func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_UnsupportedScheme() {
_, err := createProxyTransport("ftp://127.0.0.1:1")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "unsupported proxy protocol")
}
func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_Socks5SetsDialer() {
tr, err := createProxyTransport("socks5://127.0.0.1:1080")
require.NoError(s.T(), err, "createProxyTransport")
require.NotNil(s.T(), tr.DialContext, "expected DialContext to be set for socks5")
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
seen := make(chan string, 1)
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
seen <- r.RequestURI
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"ip":"1.2.3.4","city":"c","region":"r","country":"cc"}`)
}))
info, latencyMs, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
require.NoError(s.T(), err, "ProbeProxy")
require.GreaterOrEqual(s.T(), latencyMs, int64(0), "unexpected latency")
require.Equal(s.T(), "1.2.3.4", info.IP)
require.Equal(s.T(), "c", info.City)
require.Equal(s.T(), "r", info.Region)
require.Equal(s.T(), "cc", info.Country)
// Verify proxy received the request
select {
case uri := <-seen:
require.Contains(s.T(), uri, "ipinfo.test", "expected request to go through proxy")
default:
require.Fail(s.T(), "expected proxy to receive request")
}
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_NonOKStatus() {
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusServiceUnavailable)
}))
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "status: 503")
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidJSON() {
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, "not-json")
}))
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "failed to parse response")
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidIPInfoURL() {
s.prober.ipInfoURL = "://invalid-url"
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
require.Error(s.T(), err, "expected error for invalid ipInfoURL")
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_ProxyServerClosed() {
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
s.proxySrv.Close()
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
require.Error(s.T(), err, "expected error when proxy server is closed")
}
func TestProxyProbeServiceSuite(t *testing.T) {
suite.Run(t, new(ProxyProbeServiceSuite))
}
//go:build integration
package repository
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
type ProxyRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *ProxyRepository
}
func (s *ProxyRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewProxyRepository(s.db)
}
func TestProxyRepoSuite(t *testing.T) {
suite.Run(t, new(ProxyRepoSuite))
}
// --- Create / GetByID / Update / Delete ---
func (s *ProxyRepoSuite) TestCreate() {
proxy := &model.Proxy{
Name: "test-create",
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Status: model.StatusActive,
}
err := s.repo.Create(s.ctx, proxy)
s.Require().NoError(err, "Create")
s.Require().NotZero(proxy.ID, "expected ID to be set")
got, err := s.repo.GetByID(s.ctx, proxy.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("test-create", got.Name)
}
func (s *ProxyRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
}
func (s *ProxyRepoSuite) TestUpdate() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "original"})
proxy.Name = "updated"
err := s.repo.Update(s.ctx, proxy)
s.Require().NoError(err, "Update")
got, err := s.repo.GetByID(s.ctx, proxy.ID)
s.Require().NoError(err, "GetByID after update")
s.Require().Equal("updated", got.Name)
}
func (s *ProxyRepoSuite) TestDelete() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "to-delete"})
err := s.repo.Delete(s.ctx, proxy.ID)
s.Require().NoError(err, "Delete")
_, err = s.repo.GetByID(s.ctx, proxy.ID)
s.Require().Error(err, "expected error after delete")
}
// --- List / ListWithFilters ---
func (s *ProxyRepoSuite) TestList() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2"})
proxies, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
s.Require().Len(proxies, 2)
s.Require().Equal(int64(2), page.Total)
}
func (s *ProxyRepoSuite) TestListWithFilters_Protocol() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1", Protocol: "http"})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2", Protocol: "socks5"})
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "socks5", "", "")
s.Require().NoError(err)
s.Require().Len(proxies, 1)
s.Require().Equal("socks5", proxies[0].Protocol)
}
func (s *ProxyRepoSuite) TestListWithFilters_Status() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1", Status: model.StatusActive})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2", Status: model.StatusDisabled})
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusDisabled, "")
s.Require().NoError(err)
s.Require().Len(proxies, 1)
s.Require().Equal(model.StatusDisabled, proxies[0].Status)
}
func (s *ProxyRepoSuite) TestListWithFilters_Search() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "production-proxy"})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "dev-proxy"})
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "prod")
s.Require().NoError(err)
s.Require().Len(proxies, 1)
s.Require().Contains(proxies[0].Name, "production")
}
// --- ListActive ---
func (s *ProxyRepoSuite) TestListActive() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "active1", Status: model.StatusActive})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "inactive1", Status: model.StatusDisabled})
proxies, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive")
s.Require().Len(proxies, 1)
s.Require().Equal("active1", proxies[0].Name)
}
// --- ExistsByHostPortAuth ---
func (s *ProxyRepoSuite) TestExistsByHostPortAuth() {
mustCreateProxy(s.T(), s.db, &model.Proxy{
Name: "p1",
Protocol: "http",
Host: "1.2.3.4",
Port: 8080,
Username: "user",
Password: "pass",
})
exists, err := s.repo.ExistsByHostPortAuth(s.ctx, "1.2.3.4", 8080, "user", "pass")
s.Require().NoError(err, "ExistsByHostPortAuth")
s.Require().True(exists)
notExists, err := s.repo.ExistsByHostPortAuth(s.ctx, "1.2.3.4", 8080, "wrong", "creds")
s.Require().NoError(err)
s.Require().False(notExists)
}
func (s *ProxyRepoSuite) TestExistsByHostPortAuth_NoAuth() {
mustCreateProxy(s.T(), s.db, &model.Proxy{
Name: "p-noauth",
Protocol: "http",
Host: "5.6.7.8",
Port: 8081,
Username: "",
Password: "",
})
exists, err := s.repo.ExistsByHostPortAuth(s.ctx, "5.6.7.8", 8081, "", "")
s.Require().NoError(err)
s.Require().True(exists)
}
// --- CountAccountsByProxyID ---
func (s *ProxyRepoSuite) TestCountAccountsByProxyID() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p-count"})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &proxy.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &proxy.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3"}) // no proxy
count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID)
s.Require().NoError(err, "CountAccountsByProxyID")
s.Require().Equal(int64(2), count)
}
func (s *ProxyRepoSuite) TestCountAccountsByProxyID_Zero() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p-zero"})
count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID)
s.Require().NoError(err)
s.Require().Zero(count)
}
// --- GetAccountCountsForProxies ---
func (s *ProxyRepoSuite) TestGetAccountCountsForProxies() {
p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"})
p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2"})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID})
counts, err := s.repo.GetAccountCountsForProxies(s.ctx)
s.Require().NoError(err, "GetAccountCountsForProxies")
s.Require().Equal(int64(2), counts[p1.ID])
s.Require().Equal(int64(1), counts[p2.ID])
}
func (s *ProxyRepoSuite) TestGetAccountCountsForProxies_Empty() {
counts, err := s.repo.GetAccountCountsForProxies(s.ctx)
s.Require().NoError(err)
s.Require().Empty(counts)
}
// --- ListActiveWithAccountCount ---
func (s *ProxyRepoSuite) TestListActiveWithAccountCount() {
base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{
Name: "p1",
Status: model.StatusActive,
CreatedAt: base.Add(-1 * time.Hour),
})
p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{
Name: "p2",
Status: model.StatusActive,
CreatedAt: base,
})
mustCreateProxy(s.T(), s.db, &model.Proxy{
Name: "p3-inactive",
Status: model.StatusDisabled,
})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID})
withCounts, err := s.repo.ListActiveWithAccountCount(s.ctx)
s.Require().NoError(err, "ListActiveWithAccountCount")
s.Require().Len(withCounts, 2, "expected 2 active proxies")
// Sorted by created_at DESC, so p2 first
s.Require().Equal(p2.ID, withCounts[0].ID)
s.Require().Equal(int64(1), withCounts[0].AccountCount)
s.Require().Equal(p1.ID, withCounts[1].ID)
s.Require().Equal(int64(2), withCounts[1].AccountCount)
}
// --- Combined original test ---
func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() {
p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{
Name: "p1",
Protocol: "http",
Host: "1.2.3.4",
Port: 8080,
Username: "u",
Password: "p",
CreatedAt: time.Now().Add(-1 * time.Hour),
UpdatedAt: time.Now().Add(-1 * time.Hour),
})
p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{
Name: "p2",
Protocol: "http",
Host: "5.6.7.8",
Port: 8081,
Username: "",
Password: "",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
})
exists, err := s.repo.ExistsByHostPortAuth(s.ctx, "1.2.3.4", 8080, "u", "p")
s.Require().NoError(err, "ExistsByHostPortAuth")
s.Require().True(exists, "expected proxy to exist")
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID})
count1, err := s.repo.CountAccountsByProxyID(s.ctx, p1.ID)
s.Require().NoError(err, "CountAccountsByProxyID")
s.Require().Equal(int64(2), count1, "expected 2 accounts for p1")
counts, err := s.repo.GetAccountCountsForProxies(s.ctx)
s.Require().NoError(err, "GetAccountCountsForProxies")
s.Require().Equal(int64(2), counts[p1.ID])
s.Require().Equal(int64(1), counts[p2.ID])
withCounts, err := s.repo.ListActiveWithAccountCount(s.ctx)
s.Require().NoError(err, "ListActiveWithAccountCount")
s.Require().Len(withCounts, 2, "expected 2 proxies")
for _, pc := range withCounts {
switch pc.ID {
case p1.ID:
s.Require().Equal(int64(2), pc.AccountCount, "p1 count mismatch")
case p2.ID:
s.Require().Equal(int64(1), pc.AccountCount, "p2 count mismatch")
default:
s.Require().Fail("unexpected proxy id", pc.ID)
}
}
}
This diff is collapsed.
This diff is collapsed.
......@@ -16,6 +16,7 @@ const turnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/sitev
type turnstileVerifier struct {
httpClient *http.Client
verifyURL string
}
func NewTurnstileVerifier() service.TurnstileVerifier {
......@@ -23,6 +24,7 @@ func NewTurnstileVerifier() service.TurnstileVerifier {
httpClient: &http.Client{
Timeout: 10 * time.Second,
},
verifyURL: turnstileVerifyURL,
}
}
......@@ -34,7 +36,7 @@ func (v *turnstileVerifier) VerifyToken(ctx context.Context, secretKey, token, r
formData.Set("remoteip", remoteIP)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, turnstileVerifyURL, strings.NewReader(formData.Encode()))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, v.verifyURL, strings.NewReader(formData.Encode()))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment