Commit 7331220e authored by Edric Li's avatar Edric Li
Browse files

Merge remote-tracking branch 'upstream/main'

# Conflicts:
#	frontend/src/components/account/CreateAccountModal.vue
parents fb86002e 4f13c8de
package repository
import (
"context"
"fmt"
"os"
"testing"
"time"
"github.com/redis/go-redis/v9"
)
// 基准测试用 TTL 配置
const benchSlotTTLMinutes = 15
var benchSlotTTL = time.Duration(benchSlotTTLMinutes) * time.Minute
// BenchmarkAccountConcurrency 用于对比 SCAN 与有序集合的计数性能。
func BenchmarkAccountConcurrency(b *testing.B) {
rdb := newBenchmarkRedisClient(b)
defer func() {
_ = rdb.Close()
}()
cache, _ := NewConcurrencyCache(rdb, benchSlotTTLMinutes, int(benchSlotTTL.Seconds())).(*concurrencyCache)
ctx := context.Background()
for _, size := range []int{10, 100, 1000} {
size := size
b.Run(fmt.Sprintf("zset/slots=%d", size), func(b *testing.B) {
accountID := time.Now().UnixNano()
key := accountSlotKey(accountID)
b.StopTimer()
members := make([]redis.Z, 0, size)
now := float64(time.Now().Unix())
for i := 0; i < size; i++ {
members = append(members, redis.Z{
Score: now,
Member: fmt.Sprintf("req_%d", i),
})
}
if err := rdb.ZAdd(ctx, key, members...).Err(); err != nil {
b.Fatalf("初始化有序集合失败: %v", err)
}
if err := rdb.Expire(ctx, key, benchSlotTTL).Err(); err != nil {
b.Fatalf("设置有序集合 TTL 失败: %v", err)
}
b.StartTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
if _, err := cache.GetAccountConcurrency(ctx, accountID); err != nil {
b.Fatalf("获取并发数量失败: %v", err)
}
}
b.StopTimer()
if err := rdb.Del(ctx, key).Err(); err != nil {
b.Fatalf("清理有序集合失败: %v", err)
}
})
b.Run(fmt.Sprintf("scan/slots=%d", size), func(b *testing.B) {
accountID := time.Now().UnixNano()
pattern := fmt.Sprintf("%s%d:*", accountSlotKeyPrefix, accountID)
keys := make([]string, 0, size)
b.StopTimer()
pipe := rdb.Pipeline()
for i := 0; i < size; i++ {
key := fmt.Sprintf("%s%d:req_%d", accountSlotKeyPrefix, accountID, i)
keys = append(keys, key)
pipe.Set(ctx, key, "1", benchSlotTTL)
}
if _, err := pipe.Exec(ctx); err != nil {
b.Fatalf("初始化扫描键失败: %v", err)
}
b.StartTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
if _, err := scanSlotCount(ctx, rdb, pattern); err != nil {
b.Fatalf("SCAN 计数失败: %v", err)
}
}
b.StopTimer()
if err := rdb.Del(ctx, keys...).Err(); err != nil {
b.Fatalf("清理扫描键失败: %v", err)
}
})
}
}
func scanSlotCount(ctx context.Context, rdb *redis.Client, pattern string) (int, error) {
var cursor uint64
count := 0
for {
keys, nextCursor, err := rdb.Scan(ctx, cursor, pattern, 100).Result()
if err != nil {
return 0, err
}
count += len(keys)
if nextCursor == 0 {
break
}
cursor = nextCursor
}
return count, nil
}
func newBenchmarkRedisClient(b *testing.B) *redis.Client {
b.Helper()
redisURL := os.Getenv("TEST_REDIS_URL")
if redisURL == "" {
b.Skip("未设置 TEST_REDIS_URL,跳过 Redis 基准测试")
}
opt, err := redis.ParseURL(redisURL)
if err != nil {
b.Fatalf("解析 TEST_REDIS_URL 失败: %v", err)
}
client := redis.NewClient(opt)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
if err := client.Ping(ctx).Err(); err != nil {
b.Fatalf("Redis 连接失败: %v", err)
}
return client
}
......@@ -14,6 +14,12 @@ import (
"github.com/stretchr/testify/suite"
)
// 测试用 TTL 配置(15 分钟,与默认值一致)
const testSlotTTLMinutes = 15
// 测试用 TTL Duration,用于 TTL 断言
var testSlotTTL = time.Duration(testSlotTTLMinutes) * time.Minute
type ConcurrencyCacheSuite struct {
IntegrationRedisSuite
cache service.ConcurrencyCache
......@@ -21,7 +27,7 @@ type ConcurrencyCacheSuite struct {
func (s *ConcurrencyCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewConcurrencyCache(s.rdb)
s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes, int(testSlotTTL.Seconds()))
}
func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() {
......@@ -54,7 +60,7 @@ func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() {
func (s *ConcurrencyCacheSuite) TestAccountSlot_TTL() {
accountID := int64(11)
reqID := "req_ttl_test"
slotKey := fmt.Sprintf("%s%d:%s", accountSlotKeyPrefix, accountID, reqID)
slotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, reqID)
require.NoError(s.T(), err, "AcquireAccountSlot")
......@@ -62,7 +68,7 @@ func (s *ConcurrencyCacheSuite) TestAccountSlot_TTL() {
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
require.NoError(s.T(), err, "TTL")
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL)
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
}
func (s *ConcurrencyCacheSuite) TestAccountSlot_DuplicateReqID() {
......@@ -139,7 +145,7 @@ func (s *ConcurrencyCacheSuite) TestUserSlot_AcquireAndRelease() {
func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() {
userID := int64(200)
reqID := "req_ttl_test"
slotKey := fmt.Sprintf("%s%d:%s", userSlotKeyPrefix, userID, reqID)
slotKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 5, reqID)
require.NoError(s.T(), err, "AcquireUserSlot")
......@@ -147,7 +153,7 @@ func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() {
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
require.NoError(s.T(), err, "TTL")
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL)
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
}
func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
......@@ -168,7 +174,7 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
require.NoError(s.T(), err, "TTL waitKey")
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL)
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
......@@ -212,6 +218,48 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count")
}
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() {
accountID := int64(30)
waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
ok, err := s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
require.NoError(s.T(), err, "IncrementAccountWaitCount 1")
require.True(s.T(), ok)
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
require.NoError(s.T(), err, "IncrementAccountWaitCount 2")
require.True(s.T(), ok)
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
require.NoError(s.T(), err, "IncrementAccountWaitCount 3")
require.False(s.T(), ok, "expected account wait increment over max to fail")
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
require.NoError(s.T(), err, "TTL account waitKey")
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount")
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
require.Equal(s.T(), 1, val, "expected account wait count 1")
}
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_DecrementNoNegative() {
accountID := int64(301)
waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount on non-existent key")
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative account wait count after decrement on empty")
}
func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() {
// When no slots exist, GetAccountConcurrency should return 0
cur, err := s.cache.GetAccountConcurrency(s.ctx, 999)
......@@ -226,6 +274,139 @@ func (s *ConcurrencyCacheSuite) TestGetUserConcurrency_Missing() {
require.Equal(s.T(), 0, cur)
}
func (s *ConcurrencyCacheSuite) TestGetAccountsLoadBatch() {
s.T().Skip("TODO: Fix this test - CurrentConcurrency returns 0 instead of expected value in CI")
// Setup: Create accounts with different load states
account1 := int64(100)
account2 := int64(101)
account3 := int64(102)
// Account 1: 2/3 slots used, 1 waiting
ok, err := s.cache.AcquireAccountSlot(s.ctx, account1, 3, "req1")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, account1, 3, "req2")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, account1, 5)
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Account 2: 1/2 slots used, 0 waiting
ok, err = s.cache.AcquireAccountSlot(s.ctx, account2, 2, "req3")
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Account 3: 0/1 slots used, 0 waiting (idle)
// Query batch load
accounts := []service.AccountWithConcurrency{
{ID: account1, MaxConcurrency: 3},
{ID: account2, MaxConcurrency: 2},
{ID: account3, MaxConcurrency: 1},
}
loadMap, err := s.cache.GetAccountsLoadBatch(s.ctx, accounts)
require.NoError(s.T(), err)
require.Len(s.T(), loadMap, 3)
// Verify account1: (2 + 1) / 3 = 100%
load1 := loadMap[account1]
require.NotNil(s.T(), load1)
require.Equal(s.T(), account1, load1.AccountID)
require.Equal(s.T(), 2, load1.CurrentConcurrency)
require.Equal(s.T(), 1, load1.WaitingCount)
require.Equal(s.T(), 100, load1.LoadRate)
// Verify account2: (1 + 0) / 2 = 50%
load2 := loadMap[account2]
require.NotNil(s.T(), load2)
require.Equal(s.T(), account2, load2.AccountID)
require.Equal(s.T(), 1, load2.CurrentConcurrency)
require.Equal(s.T(), 0, load2.WaitingCount)
require.Equal(s.T(), 50, load2.LoadRate)
// Verify account3: (0 + 0) / 1 = 0%
load3 := loadMap[account3]
require.NotNil(s.T(), load3)
require.Equal(s.T(), account3, load3.AccountID)
require.Equal(s.T(), 0, load3.CurrentConcurrency)
require.Equal(s.T(), 0, load3.WaitingCount)
require.Equal(s.T(), 0, load3.LoadRate)
}
func (s *ConcurrencyCacheSuite) TestGetAccountsLoadBatch_Empty() {
// Test with empty account list
loadMap, err := s.cache.GetAccountsLoadBatch(s.ctx, []service.AccountWithConcurrency{})
require.NoError(s.T(), err)
require.Empty(s.T(), loadMap)
}
func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots() {
accountID := int64(200)
slotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
// Acquire 3 slots
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req1")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req2")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req3")
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Verify 3 slots exist
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 3, cur)
// Manually set old timestamps for req1 and req2 (simulate expired slots)
now := time.Now().Unix()
expiredTime := now - int64(testSlotTTL.Seconds()) - 10 // 10 seconds past TTL
err = s.rdb.ZAdd(s.ctx, slotKey, redis.Z{Score: float64(expiredTime), Member: "req1"}).Err()
require.NoError(s.T(), err)
err = s.rdb.ZAdd(s.ctx, slotKey, redis.Z{Score: float64(expiredTime), Member: "req2"}).Err()
require.NoError(s.T(), err)
// Run cleanup
err = s.cache.CleanupExpiredAccountSlots(s.ctx, accountID)
require.NoError(s.T(), err)
// Verify only 1 slot remains (req3)
cur, err = s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 1, cur)
// Verify req3 still exists
members, err := s.rdb.ZRange(s.ctx, slotKey, 0, -1).Result()
require.NoError(s.T(), err)
require.Len(s.T(), members, 1)
require.Equal(s.T(), "req3", members[0])
}
func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots_NoExpired() {
accountID := int64(201)
// Acquire 2 fresh slots
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req1")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req2")
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Run cleanup (should not remove anything)
err = s.cache.CleanupExpiredAccountSlots(s.ctx, accountID)
require.NoError(s.T(), err)
// Verify both slots still exist
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 2, cur)
}
func TestConcurrencyCacheSuite(t *testing.T) {
suite.Run(t, new(ConcurrencyCacheSuite))
}
package repository
import (
"database/sql"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
)
type dbPoolSettings struct {
MaxOpenConns int
MaxIdleConns int
ConnMaxLifetime time.Duration
ConnMaxIdleTime time.Duration
}
func buildDBPoolSettings(cfg *config.Config) dbPoolSettings {
return dbPoolSettings{
MaxOpenConns: cfg.Database.MaxOpenConns,
MaxIdleConns: cfg.Database.MaxIdleConns,
ConnMaxLifetime: time.Duration(cfg.Database.ConnMaxLifetimeMinutes) * time.Minute,
ConnMaxIdleTime: time.Duration(cfg.Database.ConnMaxIdleTimeMinutes) * time.Minute,
}
}
func applyDBPoolSettings(db *sql.DB, cfg *config.Config) {
settings := buildDBPoolSettings(cfg)
db.SetMaxOpenConns(settings.MaxOpenConns)
db.SetMaxIdleConns(settings.MaxIdleConns)
db.SetConnMaxLifetime(settings.ConnMaxLifetime)
db.SetConnMaxIdleTime(settings.ConnMaxIdleTime)
}
package repository
import (
"database/sql"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
_ "github.com/lib/pq"
)
func TestBuildDBPoolSettings(t *testing.T) {
cfg := &config.Config{
Database: config.DatabaseConfig{
MaxOpenConns: 50,
MaxIdleConns: 10,
ConnMaxLifetimeMinutes: 30,
ConnMaxIdleTimeMinutes: 5,
},
}
settings := buildDBPoolSettings(cfg)
require.Equal(t, 50, settings.MaxOpenConns)
require.Equal(t, 10, settings.MaxIdleConns)
require.Equal(t, 30*time.Minute, settings.ConnMaxLifetime)
require.Equal(t, 5*time.Minute, settings.ConnMaxIdleTime)
}
func TestApplyDBPoolSettings(t *testing.T) {
cfg := &config.Config{
Database: config.DatabaseConfig{
MaxOpenConns: 40,
MaxIdleConns: 8,
ConnMaxLifetimeMinutes: 15,
ConnMaxIdleTimeMinutes: 3,
},
}
db, err := sql.Open("postgres", "host=127.0.0.1 port=5432 user=postgres sslmode=disable")
require.NoError(t, err)
t.Cleanup(func() {
_ = db.Close()
})
applyDBPoolSettings(db, cfg)
stats := db.Stats()
require.Equal(t, 40, stats.MaxOpenConnections)
}
// Package infrastructure 提供应用程序的基础设施层组件。
// 包括数据库连接初始化、ORM 客户端管理、Redis 连接、数据库迁移等核心功能。
package infrastructure
package repository
import (
"context"
......@@ -51,6 +51,7 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) {
if err != nil {
return nil, nil, err
}
applyDBPoolSettings(drv.DB(), cfg)
// 确保数据库 schema 已准备就绪。
// SQL 迁移文件是 schema 的权威来源(source of truth)。
......
package repository
import (
"context"
"database/sql"
"errors"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/lib/pq"
)
// clientFromContext 从 context 中获取事务 client,如果不存在则返回默认 client。
//
// 这个辅助函数支持 repository 方法在事务上下文中工作:
// - 如果 context 中存在事务(通过 ent.NewTxContext 设置),返回事务的 client
// - 否则返回传入的默认 client
//
// 使用示例:
//
// func (r *someRepo) SomeMethod(ctx context.Context) error {
// client := clientFromContext(ctx, r.client)
// return client.SomeEntity.Create().Save(ctx)
// }
func clientFromContext(ctx context.Context, defaultClient *dbent.Client) *dbent.Client {
if tx := dbent.TxFromContext(ctx); tx != nil {
return tx.Client()
}
return defaultClient
}
// translatePersistenceError 将数据库层错误翻译为业务层错误。
//
// 这是 Repository 层的核心错误处理函数,确保数据库细节不会泄露到业务层。
......
......@@ -109,9 +109,8 @@ func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refresh
}
func createGeminiReqClient(proxyURL string) *req.Client {
client := req.C().SetTimeout(60 * time.Second)
if proxyURL != "" {
client.SetProxyURL(proxyURL)
}
return client
return getSharedReqClient(reqClientOptions{
ProxyURL: proxyURL,
Timeout: 60 * time.Second,
})
}
......@@ -76,11 +76,10 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken
}
func createGeminiCliReqClient(proxyURL string) *req.Client {
client := req.C().SetTimeout(30 * time.Second)
if proxyURL != "" {
client.SetProxyURL(proxyURL)
}
return client
return getSharedReqClient(reqClientOptions{
ProxyURL: proxyURL,
Timeout: 30 * time.Second,
})
}
func defaultLoadCodeAssistRequest() *geminicli.LoadCodeAssistRequest {
......
......@@ -9,6 +9,7 @@ import (
"os"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service"
)
......@@ -17,10 +18,14 @@ type githubReleaseClient struct {
}
func NewGitHubReleaseClient() service.GitHubReleaseClient {
sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 30 * time.Second,
})
if err != nil {
sharedClient = &http.Client{Timeout: 30 * time.Second}
}
return &githubReleaseClient{
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
httpClient: sharedClient,
}
}
......@@ -58,8 +63,13 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
return err
}
client := &http.Client{Timeout: 10 * time.Minute}
resp, err := client.Do(req)
downloadClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 10 * time.Minute,
})
if err != nil {
downloadClient = &http.Client{Timeout: 10 * time.Minute}
}
resp, err := downloadClient.Do(req)
if err != nil {
return err
}
......
......@@ -42,7 +42,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetSubscriptionType(groupIn.SubscriptionType).
SetNillableDailyLimitUsd(groupIn.DailyLimitUSD).
SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD).
SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD)
SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD).
SetDefaultValidityDays(groupIn.DefaultValidityDays)
created, err := builder.Save(ctx)
if err == nil {
......@@ -79,6 +80,7 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetNillableDailyLimitUsd(groupIn.DailyLimitUSD).
SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD).
SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD).
SetDefaultValidityDays(groupIn.DefaultValidityDays).
Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists)
......@@ -89,7 +91,7 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
func (r *groupRepository) Delete(ctx context.Context, id int64) error {
_, err := r.client.Group.Delete().Where(group.IDEQ(id)).Exec(ctx)
return err
return translatePersistenceError(err, service.ErrGroupNotFound, nil)
}
func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
......@@ -239,8 +241,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
// err 为 dbent.ErrTxStarted 时,复用当前 client 参与同一事务。
// Lock the group row to avoid concurrent writes while we cascade.
// 这里使用 exec.QueryContext 手动扫描,确保同一事务内加锁并能区分未找到与其他错误。
rows, err := exec.QueryContext(ctx, "SELECT id FROM groups WHERE id = $1 FOR UPDATE", id)
// 这里使用 exec.QueryContext 手动扫描,确保同一事务内加锁并能区分"未找到"与其他错误。
rows, err := exec.QueryContext(ctx, "SELECT id FROM groups WHERE id = $1 AND deleted_at IS NULL FOR UPDATE", id)
if err != nil {
return nil, err
}
......@@ -263,7 +265,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
var affectedUserIDs []int64
if groupSvc.IsSubscriptionType() {
rows, err := exec.QueryContext(ctx, "SELECT user_id FROM user_subscriptions WHERE group_id = $1", id)
// 只查询未软删除的订阅,避免通知已取消订阅的用户
rows, err := exec.QueryContext(ctx, "SELECT user_id FROM user_subscriptions WHERE group_id = $1 AND deleted_at IS NULL", id)
if err != nil {
return nil, err
}
......@@ -282,7 +285,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
return nil, err
}
if _, err := exec.ExecContext(ctx, "DELETE FROM user_subscriptions WHERE group_id = $1", id); err != nil {
// 软删除订阅:设置 deleted_at 而非硬删除
if _, err := exec.ExecContext(ctx, "UPDATE user_subscriptions SET deleted_at = NOW() WHERE group_id = $1 AND deleted_at IS NULL", id); err != nil {
return nil, err
}
}
......@@ -297,18 +301,11 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
return nil, err
}
// 3. Remove the group id from users.allowed_groups array (legacy representation).
// Phase 1 compatibility: also delete from user_allowed_groups join table when present.
// 3. Remove the group id from user_allowed_groups join table.
// Legacy users.allowed_groups 列已弃用,不再同步。
if _, err := exec.ExecContext(ctx, "DELETE FROM user_allowed_groups WHERE group_id = $1", id); err != nil {
return nil, err
}
if _, err := exec.ExecContext(
ctx,
"UPDATE users SET allowed_groups = array_remove(allowed_groups, $1) WHERE $1 = ANY(allowed_groups)",
id,
); err != nil {
return nil, err
}
// 4. Delete account_groups join rows.
if _, err := exec.ExecContext(ctx, "DELETE FROM account_groups WHERE group_id = $1", id); err != nil {
......
......@@ -478,3 +478,58 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
count, _ := s.repo.GetAccountCount(s.ctx, g.ID)
s.Require().Zero(count)
}
// --- 软删除过滤测试 ---
func (s *GroupRepoSuite) TestDelete_SoftDelete_NotVisibleInList() {
group := &service.Group{
Name: "to-soft-delete",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, group))
// 获取删除前的列表数量
listBefore, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 100})
s.Require().NoError(err)
beforeCount := len(listBefore)
// 软删除
err = s.repo.Delete(s.ctx, group.ID)
s.Require().NoError(err, "Delete (soft delete)")
// 验证列表中不再包含软删除的 group
listAfter, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 100})
s.Require().NoError(err)
s.Require().Len(listAfter, beforeCount-1, "soft deleted group should not appear in list")
// 验证 GetByID 也无法找到
_, err = s.repo.GetByID(s.ctx, group.ID)
s.Require().Error(err)
s.Require().ErrorIs(err, service.ErrGroupNotFound)
}
func (s *GroupRepoSuite) TestDelete_SoftDeletedGroup_lockForUpdate() {
group := &service.Group{
Name: "lock-soft-delete",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, group))
// 软删除
err := s.repo.Delete(s.ctx, group.ID)
s.Require().NoError(err)
// 验证软删除的 group 在 GetByID 时返回 ErrGroupNotFound
// 这证明 lockForUpdate 的 deleted_at IS NULL 过滤正在工作
_, err = s.repo.GetByID(s.ctx, group.ID)
s.Require().Error(err, "should fail to get soft-deleted group")
s.Require().ErrorIs(err, service.ErrGroupNotFound)
}
package repository
import (
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// httpUpstreamService is a generic HTTP upstream service that can be used for
// making requests to any HTTP API (Claude, OpenAI, etc.) with optional proxy support.
// 默认配置常量
// 这些值在配置文件未指定时作为回退默认值使用
const (
// directProxyKey: 无代理时的缓存键标识
directProxyKey = "direct"
// defaultMaxIdleConns: 默认最大空闲连接总数
// HTTP/2 场景下,单连接可多路复用,240 足以支撑高并发
defaultMaxIdleConns = 240
// defaultMaxIdleConnsPerHost: 默认每主机最大空闲连接数
defaultMaxIdleConnsPerHost = 120
// defaultMaxConnsPerHost: 默认每主机最大连接数(含活跃连接)
// 达到上限后新请求会等待,而非无限创建连接
defaultMaxConnsPerHost = 240
// defaultIdleConnTimeout: 默认空闲连接超时时间(5分钟)
// 超时后连接会被关闭,释放系统资源
defaultIdleConnTimeout = 300 * time.Second
// defaultResponseHeaderTimeout: 默认等待响应头超时时间(5分钟)
// LLM 请求可能排队较久,需要较长超时
defaultResponseHeaderTimeout = 300 * time.Second
// defaultMaxUpstreamClients: 默认最大客户端缓存数量
// 超出后会淘汰最久未使用的客户端
defaultMaxUpstreamClients = 5000
// defaultClientIdleTTLSeconds: 默认客户端空闲回收阈值(15分钟)
defaultClientIdleTTLSeconds = 900
)
var errUpstreamClientLimitReached = errors.New("upstream client cache limit reached")
// poolSettings 连接池配置参数
// 封装 Transport 所需的各项连接池参数
type poolSettings struct {
maxIdleConns int // 最大空闲连接总数
maxIdleConnsPerHost int // 每主机最大空闲连接数
maxConnsPerHost int // 每主机最大连接数(含活跃)
idleConnTimeout time.Duration // 空闲连接超时时间
responseHeaderTimeout time.Duration // 等待响应头超时时间
}
// upstreamClientEntry 上游客户端缓存条目
// 记录客户端实例及其元数据,用于连接池管理和淘汰策略
type upstreamClientEntry struct {
client *http.Client // HTTP 客户端实例
proxyKey string // 代理标识(用于检测代理变更)
poolKey string // 连接池配置标识(用于检测配置变更)
lastUsed int64 // 最后使用时间戳(纳秒),用于 LRU 淘汰
inFlight int64 // 当前进行中的请求数,>0 时不可淘汰
}
// httpUpstreamService 通用 HTTP 上游服务
// 用于向任意 HTTP API(Claude、OpenAI 等)发送请求,支持可选代理
//
// 架构设计:
// - 根据隔离策略(proxy/account/account_proxy)缓存客户端实例
// - 每个客户端拥有独立的 Transport 连接池
// - 支持 LRU + 空闲时间双重淘汰策略
//
// 性能优化:
// 1. 根据隔离策略缓存客户端实例,避免频繁创建 http.Client
// 2. 复用 Transport 连接池,减少 TCP 握手和 TLS 协商开销
// 3. 支持账号级隔离与空闲回收,降低连接层关联风险
// 4. 达到最大连接数后等待可用连接,而非无限创建
// 5. 仅回收空闲客户端,避免中断活跃请求
// 6. HTTP/2 多路复用,连接上限不等于并发请求上限
// 7. 代理变更时清空旧连接池,避免复用错误代理
// 8. 账号并发数与连接池上限对应(账号隔离策略下)
type httpUpstreamService struct {
defaultClient *http.Client
cfg *config.Config
cfg *config.Config // 全局配置
mu sync.RWMutex // 保护 clients map 的读写锁
clients map[string]*upstreamClientEntry // 客户端缓存池,key 由隔离策略决定
}
// NewHTTPUpstream creates a new generic HTTP upstream service
// NewHTTPUpstream 创建通用 HTTP 上游服务
// 使用配置中的连接池参数构建 Transport
//
// 参数:
// - cfg: 全局配置,包含连接池参数和隔离策略
//
// 返回:
// - service.HTTPUpstream 接口实现
func NewHTTPUpstream(cfg *config.Config) service.HTTPUpstream {
responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second
if responseHeaderTimeout == 0 {
responseHeaderTimeout = 300 * time.Second
return &httpUpstreamService{
cfg: cfg,
clients: make(map[string]*upstreamClientEntry),
}
}
transport := &http.Transport{
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
ResponseHeaderTimeout: responseHeaderTimeout,
// Do 执行 HTTP 请求
// 根据隔离策略获取或创建客户端,并跟踪请求生命周期
//
// 参数:
// - req: HTTP 请求对象
// - proxyURL: 代理地址,空字符串表示直连
// - accountID: 账户 ID,用于账户级隔离
// - accountConcurrency: 账户并发限制,用于动态调整连接池大小
//
// 返回:
// - *http.Response: HTTP 响应(Body 已包装,关闭时自动更新计数)
// - error: 请求错误
//
// 注意:
// - 调用方必须关闭 resp.Body,否则会导致 inFlight 计数泄漏
// - inFlight > 0 的客户端不会被淘汰,确保活跃请求不被中断
func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
// 获取或创建对应的客户端,并标记请求占用
entry, err := s.acquireClient(proxyURL, accountID, accountConcurrency)
if err != nil {
return nil, err
}
return &httpUpstreamService{
defaultClient: &http.Client{Transport: transport},
cfg: cfg,
// 执行请求
resp, err := entry.client.Do(req)
if err != nil {
// 请求失败,立即减少计数
atomic.AddInt64(&entry.inFlight, -1)
atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano())
return nil, err
}
// 包装响应体,在关闭时自动减少计数并更新时间戳
// 这确保了流式响应(如 SSE)在完全读取前不会被淘汰
resp.Body = wrapTrackedBody(resp.Body, func() {
atomic.AddInt64(&entry.inFlight, -1)
atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano())
})
return resp, nil
}
func (s *httpUpstreamService) Do(req *http.Request, proxyURL string) (*http.Response, error) {
if proxyURL == "" {
return s.defaultClient.Do(req)
// acquireClient 获取或创建客户端,并标记为进行中请求
// 用于请求路径,避免在获取后被淘汰
func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) {
return s.getClientEntry(proxyURL, accountID, accountConcurrency, true, true)
}
// getOrCreateClient 获取或创建客户端
// 根据隔离策略和参数决定缓存键,处理代理变更和配置变更
//
// 参数:
// - proxyURL: 代理地址
// - accountID: 账户 ID
// - accountConcurrency: 账户并发限制
//
// 返回:
// - *upstreamClientEntry: 客户端缓存条目
//
// 隔离策略说明:
// - proxy: 按代理地址隔离,同一代理共享客户端
// - account: 按账户隔离,同一账户共享客户端(代理变更时重建)
// - account_proxy: 按账户+代理组合隔离,最细粒度
func (s *httpUpstreamService) getOrCreateClient(proxyURL string, accountID int64, accountConcurrency int) *upstreamClientEntry {
entry, _ := s.getClientEntry(proxyURL, accountID, accountConcurrency, false, false)
return entry
}
// getClientEntry 获取或创建客户端条目
// markInFlight=true 时会标记进行中请求,用于请求路径防止被淘汰
// enforceLimit=true 时会限制客户端数量,超限且无法淘汰时返回错误
func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, accountConcurrency int, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) {
// 获取隔离模式
isolation := s.getIsolationMode()
// 标准化代理 URL 并解析
proxyKey, parsedProxy := normalizeProxyURL(proxyURL)
// 构建缓存键(根据隔离策略不同)
cacheKey := buildCacheKey(isolation, proxyKey, accountID)
// 构建连接池配置键(用于检测配置变更)
poolKey := s.buildPoolKey(isolation, accountConcurrency)
now := time.Now()
nowUnix := now.UnixNano()
// 读锁快速路径:命中缓存直接返回,减少锁竞争
s.mu.RLock()
if entry, ok := s.clients[cacheKey]; ok && s.shouldReuseEntry(entry, isolation, proxyKey, poolKey) {
atomic.StoreInt64(&entry.lastUsed, nowUnix)
if markInFlight {
atomic.AddInt64(&entry.inFlight, 1)
}
s.mu.RUnlock()
return entry, nil
}
s.mu.RUnlock()
// 写锁慢路径:创建或重建客户端
s.mu.Lock()
if entry, ok := s.clients[cacheKey]; ok {
if s.shouldReuseEntry(entry, isolation, proxyKey, poolKey) {
atomic.StoreInt64(&entry.lastUsed, nowUnix)
if markInFlight {
atomic.AddInt64(&entry.inFlight, 1)
}
s.mu.Unlock()
return entry, nil
}
s.removeClientLocked(cacheKey, entry)
}
// 超出缓存上限时尝试淘汰,无法淘汰则拒绝新建
if enforceLimit && s.maxUpstreamClients() > 0 {
s.evictIdleLocked(now)
if len(s.clients) >= s.maxUpstreamClients() {
if !s.evictOldestIdleLocked() {
s.mu.Unlock()
return nil, errUpstreamClientLimitReached
}
}
}
// 缓存未命中或需要重建,创建新客户端
settings := s.resolvePoolSettings(isolation, accountConcurrency)
client := &http.Client{Transport: buildUpstreamTransport(settings, parsedProxy)}
entry := &upstreamClientEntry{
client: client,
proxyKey: proxyKey,
poolKey: poolKey,
}
atomic.StoreInt64(&entry.lastUsed, nowUnix)
if markInFlight {
atomic.StoreInt64(&entry.inFlight, 1)
}
s.clients[cacheKey] = entry
// 执行淘汰策略:先淘汰空闲超时的,再淘汰超出数量限制的
s.evictIdleLocked(now)
s.evictOverLimitLocked()
s.mu.Unlock()
return entry, nil
}
// shouldReuseEntry 判断缓存条目是否可复用
// 若代理或连接池配置发生变化,则需要重建客户端
func (s *httpUpstreamService) shouldReuseEntry(entry *upstreamClientEntry, isolation, proxyKey, poolKey string) bool {
if entry == nil {
return false
}
if isolation == config.ConnectionPoolIsolationAccount && entry.proxyKey != proxyKey {
return false
}
if entry.poolKey != poolKey {
return false
}
return true
}
// removeClientLocked 移除客户端(需持有锁)
// 从缓存中删除并关闭空闲连接
//
// 参数:
// - key: 缓存键
// - entry: 客户端条目
func (s *httpUpstreamService) removeClientLocked(key string, entry *upstreamClientEntry) {
delete(s.clients, key)
if entry != nil && entry.client != nil {
// 关闭空闲连接,释放系统资源
// 注意:这不会中断活跃连接
entry.client.CloseIdleConnections()
}
}
// evictIdleLocked 淘汰空闲超时的客户端(需持有锁)
// 遍历所有客户端,移除超过 TTL 且无活跃请求的条目
//
// 参数:
// - now: 当前时间
func (s *httpUpstreamService) evictIdleLocked(now time.Time) {
ttl := s.clientIdleTTL()
if ttl <= 0 {
return
}
// 计算淘汰截止时间
cutoff := now.Add(-ttl).UnixNano()
for key, entry := range s.clients {
// 跳过有活跃请求的客户端
if atomic.LoadInt64(&entry.inFlight) != 0 {
continue
}
// 淘汰超时的空闲客户端
if atomic.LoadInt64(&entry.lastUsed) <= cutoff {
s.removeClientLocked(key, entry)
}
}
}
// evictOldestIdleLocked 淘汰最久未使用且无活跃请求的客户端(需持有锁)
func (s *httpUpstreamService) evictOldestIdleLocked() bool {
var (
oldestKey string
oldestEntry *upstreamClientEntry
oldestTime int64
)
// 查找最久未使用且无活跃请求的客户端
for key, entry := range s.clients {
// 跳过有活跃请求的客户端
if atomic.LoadInt64(&entry.inFlight) != 0 {
continue
}
lastUsed := atomic.LoadInt64(&entry.lastUsed)
if oldestEntry == nil || lastUsed < oldestTime {
oldestKey = key
oldestEntry = entry
oldestTime = lastUsed
}
}
// 所有客户端都有活跃请求,无法淘汰
if oldestEntry == nil {
return false
}
s.removeClientLocked(oldestKey, oldestEntry)
return true
}
// evictOverLimitLocked 淘汰超出数量限制的客户端(需持有锁)
// 使用 LRU 策略,优先淘汰最久未使用且无活跃请求的客户端
func (s *httpUpstreamService) evictOverLimitLocked() bool {
maxClients := s.maxUpstreamClients()
if maxClients <= 0 {
return false
}
evicted := false
// 循环淘汰直到满足数量限制
for len(s.clients) > maxClients {
if !s.evictOldestIdleLocked() {
return evicted
}
evicted = true
}
return evicted
}
// getIsolationMode 获取连接池隔离模式
// 从配置中读取,无效值回退到 account_proxy 模式
//
// 返回:
// - string: 隔离模式(proxy/account/account_proxy)
func (s *httpUpstreamService) getIsolationMode() string {
if s.cfg == nil {
return config.ConnectionPoolIsolationAccountProxy
}
mode := strings.ToLower(strings.TrimSpace(s.cfg.Gateway.ConnectionPoolIsolation))
if mode == "" {
return config.ConnectionPoolIsolationAccountProxy
}
switch mode {
case config.ConnectionPoolIsolationProxy, config.ConnectionPoolIsolationAccount, config.ConnectionPoolIsolationAccountProxy:
return mode
default:
return config.ConnectionPoolIsolationAccountProxy
}
}
// maxUpstreamClients 获取最大客户端缓存数量
// 从配置中读取,无效值使用默认值
func (s *httpUpstreamService) maxUpstreamClients() int {
if s.cfg == nil {
return defaultMaxUpstreamClients
}
if s.cfg.Gateway.MaxUpstreamClients > 0 {
return s.cfg.Gateway.MaxUpstreamClients
}
return defaultMaxUpstreamClients
}
// clientIdleTTL 获取客户端空闲回收阈值
// 从配置中读取,无效值使用默认值
func (s *httpUpstreamService) clientIdleTTL() time.Duration {
if s.cfg == nil {
return time.Duration(defaultClientIdleTTLSeconds) * time.Second
}
if s.cfg.Gateway.ClientIdleTTLSeconds > 0 {
return time.Duration(s.cfg.Gateway.ClientIdleTTLSeconds) * time.Second
}
return time.Duration(defaultClientIdleTTLSeconds) * time.Second
}
// resolvePoolSettings 解析连接池配置
// 根据隔离策略和账户并发数动态调整连接池参数
//
// 参数:
// - isolation: 隔离模式
// - accountConcurrency: 账户并发限制
//
// 返回:
// - poolSettings: 连接池配置
//
// 说明:
// - 账户隔离模式下,连接池大小与账户并发数对应
// - 这确保了单账户不会占用过多连接资源
func (s *httpUpstreamService) resolvePoolSettings(isolation string, accountConcurrency int) poolSettings {
settings := defaultPoolSettings(s.cfg)
// 账户隔离模式下,根据账户并发数调整连接池大小
if (isolation == config.ConnectionPoolIsolationAccount || isolation == config.ConnectionPoolIsolationAccountProxy) && accountConcurrency > 0 {
settings.maxIdleConns = accountConcurrency
settings.maxIdleConnsPerHost = accountConcurrency
settings.maxConnsPerHost = accountConcurrency
}
client := s.createProxyClient(proxyURL)
return client.Do(req)
return settings
}
func (s *httpUpstreamService) createProxyClient(proxyURL string) *http.Client {
parsedURL, err := url.Parse(proxyURL)
// buildPoolKey 构建连接池配置键
// 用于检测配置变更,配置变更时需要重建客户端
//
// 参数:
// - isolation: 隔离模式
// - accountConcurrency: 账户并发限制
//
// 返回:
// - string: 配置键
func (s *httpUpstreamService) buildPoolKey(isolation string, accountConcurrency int) string {
if isolation == config.ConnectionPoolIsolationAccount || isolation == config.ConnectionPoolIsolationAccountProxy {
if accountConcurrency > 0 {
return fmt.Sprintf("account:%d", accountConcurrency)
}
}
return "default"
}
// buildCacheKey 构建客户端缓存键
// 根据隔离策略决定缓存键的组成
//
// 参数:
// - isolation: 隔离模式
// - proxyKey: 代理标识
// - accountID: 账户 ID
//
// 返回:
// - string: 缓存键
//
// 缓存键格式:
// - proxy 模式: "proxy:{proxyKey}"
// - account 模式: "account:{accountID}"
// - account_proxy 模式: "account:{accountID}|proxy:{proxyKey}"
func buildCacheKey(isolation, proxyKey string, accountID int64) string {
switch isolation {
case config.ConnectionPoolIsolationAccount:
return fmt.Sprintf("account:%d", accountID)
case config.ConnectionPoolIsolationAccountProxy:
return fmt.Sprintf("account:%d|proxy:%s", accountID, proxyKey)
default:
return fmt.Sprintf("proxy:%s", proxyKey)
}
}
// normalizeProxyURL 标准化代理 URL
// 处理空值和解析错误,返回标准化的键和解析后的 URL
//
// 参数:
// - raw: 原始代理 URL 字符串
//
// 返回:
// - string: 标准化的代理键(空或解析失败返回 "direct")
// - *url.URL: 解析后的 URL(空或解析失败返回 nil)
func normalizeProxyURL(raw string) (string, *url.URL) {
proxyURL := strings.TrimSpace(raw)
if proxyURL == "" {
return directProxyKey, nil
}
parsed, err := url.Parse(proxyURL)
if err != nil {
return s.defaultClient
return directProxyKey, nil
}
parsed.Scheme = strings.ToLower(parsed.Scheme)
parsed.Host = strings.ToLower(parsed.Host)
parsed.Path = ""
parsed.RawPath = ""
parsed.RawQuery = ""
parsed.Fragment = ""
parsed.ForceQuery = false
if hostname := parsed.Hostname(); hostname != "" {
port := parsed.Port()
if (parsed.Scheme == "http" && port == "80") || (parsed.Scheme == "https" && port == "443") {
port = ""
}
hostname = strings.ToLower(hostname)
if port != "" {
parsed.Host = net.JoinHostPort(hostname, port)
} else {
parsed.Host = hostname
}
}
return parsed.String(), parsed
}
// defaultPoolSettings 获取默认连接池配置
// 从全局配置中读取,无效值使用常量默认值
//
// 参数:
// - cfg: 全局配置
//
// 返回:
// - poolSettings: 连接池配置
func defaultPoolSettings(cfg *config.Config) poolSettings {
maxIdleConns := defaultMaxIdleConns
maxIdleConnsPerHost := defaultMaxIdleConnsPerHost
maxConnsPerHost := defaultMaxConnsPerHost
idleConnTimeout := defaultIdleConnTimeout
responseHeaderTimeout := defaultResponseHeaderTimeout
if cfg != nil {
if cfg.Gateway.MaxIdleConns > 0 {
maxIdleConns = cfg.Gateway.MaxIdleConns
}
if cfg.Gateway.MaxIdleConnsPerHost > 0 {
maxIdleConnsPerHost = cfg.Gateway.MaxIdleConnsPerHost
}
if cfg.Gateway.MaxConnsPerHost >= 0 {
maxConnsPerHost = cfg.Gateway.MaxConnsPerHost
}
if cfg.Gateway.IdleConnTimeoutSeconds > 0 {
idleConnTimeout = time.Duration(cfg.Gateway.IdleConnTimeoutSeconds) * time.Second
}
if cfg.Gateway.ResponseHeaderTimeout > 0 {
responseHeaderTimeout = time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second
}
}
responseHeaderTimeout := time.Duration(s.cfg.Gateway.ResponseHeaderTimeout) * time.Second
if responseHeaderTimeout == 0 {
responseHeaderTimeout = 300 * time.Second
return poolSettings{
maxIdleConns: maxIdleConns,
maxIdleConnsPerHost: maxIdleConnsPerHost,
maxConnsPerHost: maxConnsPerHost,
idleConnTimeout: idleConnTimeout,
responseHeaderTimeout: responseHeaderTimeout,
}
}
// buildUpstreamTransport 构建上游请求的 Transport
// 使用配置文件中的连接池参数,支持生产环境调优
//
// 参数:
// - settings: 连接池配置
// - proxyURL: 代理 URL(nil 表示直连)
//
// 返回:
// - *http.Transport: 配置好的 Transport 实例
//
// Transport 参数说明:
// - MaxIdleConns: 所有主机的最大空闲连接总数
// - MaxIdleConnsPerHost: 每主机最大空闲连接数(影响连接复用率)
// - MaxConnsPerHost: 每主机最大连接数(达到后新请求等待)
// - IdleConnTimeout: 空闲连接超时(超时后关闭)
// - ResponseHeaderTimeout: 等待响应头超时(不影响流式传输)
func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) *http.Transport {
transport := &http.Transport{
Proxy: http.ProxyURL(parsedURL),
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
ResponseHeaderTimeout: responseHeaderTimeout,
MaxIdleConns: settings.maxIdleConns,
MaxIdleConnsPerHost: settings.maxIdleConnsPerHost,
MaxConnsPerHost: settings.maxConnsPerHost,
IdleConnTimeout: settings.idleConnTimeout,
ResponseHeaderTimeout: settings.responseHeaderTimeout,
}
if proxyURL != nil {
transport.Proxy = http.ProxyURL(proxyURL)
}
return transport
}
return &http.Client{Transport: transport}
// trackedBody 带跟踪功能的响应体包装器
// 在 Close 时执行回调,用于更新请求计数
type trackedBody struct {
io.ReadCloser // 原始响应体
once sync.Once
onClose func() // 关闭时的回调函数
}
// Close 关闭响应体并执行回调
// 使用 sync.Once 确保回调只执行一次
func (b *trackedBody) Close() error {
err := b.ReadCloser.Close()
if b.onClose != nil {
b.once.Do(b.onClose)
}
return err
}
// wrapTrackedBody 包装响应体以跟踪关闭事件
// 用于在响应体关闭时更新 inFlight 计数
//
// 参数:
// - body: 原始响应体
// - onClose: 关闭时的回调函数
//
// 返回:
// - io.ReadCloser: 包装后的响应体
func wrapTrackedBody(body io.ReadCloser, onClose func()) io.ReadCloser {
if body == nil {
return body
}
return &trackedBody{ReadCloser: body, onClose: onClose}
}
package repository
import (
"net/http"
"net/url"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
)
// httpClientSink 用于防止编译器优化掉基准测试中的赋值操作
// 这是 Go 基准测试的常见模式,确保测试结果准确
var httpClientSink *http.Client
// BenchmarkHTTPUpstreamProxyClient 对比重复创建与复用代理客户端的开销
//
// 测试目的:
// - 验证连接池复用相比每次新建的性能提升
// - 量化内存分配差异
//
// 预期结果:
// - "复用" 子测试应显著快于 "新建"
// - "复用" 子测试应零内存分配
func BenchmarkHTTPUpstreamProxyClient(b *testing.B) {
// 创建测试配置
cfg := &config.Config{
Gateway: config.GatewayConfig{ResponseHeaderTimeout: 300},
}
upstream := NewHTTPUpstream(cfg)
svc, ok := upstream.(*httpUpstreamService)
if !ok {
b.Fatalf("类型断言失败,无法获取 httpUpstreamService")
}
proxyURL := "http://127.0.0.1:8080"
b.ReportAllocs() // 报告内存分配统计
// 子测试:每次新建客户端
// 模拟未优化前的行为,每次请求都创建新的 http.Client
b.Run("新建", func(b *testing.B) {
parsedProxy, err := url.Parse(proxyURL)
if err != nil {
b.Fatalf("解析代理地址失败: %v", err)
}
settings := defaultPoolSettings(cfg)
for i := 0; i < b.N; i++ {
// 每次迭代都创建新客户端,包含 Transport 分配
httpClientSink = &http.Client{
Transport: buildUpstreamTransport(settings, parsedProxy),
}
}
})
// 子测试:复用已缓存的客户端
// 模拟优化后的行为,从缓存获取客户端
b.Run("复用", func(b *testing.B) {
// 预热:确保客户端已缓存
entry := svc.getOrCreateClient(proxyURL, 1, 1)
client := entry.client
b.ResetTimer() // 重置计时器,排除预热时间
for i := 0; i < b.N; i++ {
// 直接使用缓存的客户端,无内存分配
httpClientSink = client
}
})
}
......@@ -4,6 +4,7 @@ import (
"io"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
......@@ -12,45 +13,86 @@ import (
"github.com/stretchr/testify/suite"
)
// HTTPUpstreamSuite HTTP 上游服务测试套件
// 使用 testify/suite 组织测试,支持 SetupTest 初始化
type HTTPUpstreamSuite struct {
suite.Suite
cfg *config.Config
cfg *config.Config // 测试用配置
}
// SetupTest 每个测试用例执行前的初始化
// 创建空配置,各测试用例可按需覆盖
func (s *HTTPUpstreamSuite) SetupTest() {
s.cfg = &config.Config{}
}
func (s *HTTPUpstreamSuite) TestDefaultResponseHeaderTimeout() {
// newService 创建测试用的 httpUpstreamService 实例
// 返回具体类型以便访问内部状态进行断言
func (s *HTTPUpstreamSuite) newService() *httpUpstreamService {
up := NewHTTPUpstream(s.cfg)
svc, ok := up.(*httpUpstreamService)
require.True(s.T(), ok, "expected *httpUpstreamService")
transport, ok := svc.defaultClient.Transport.(*http.Transport)
return svc
}
// TestDefaultResponseHeaderTimeout 测试默认响应头超时配置
// 验证未配置时使用 300 秒默认值
func (s *HTTPUpstreamSuite) TestDefaultResponseHeaderTimeout() {
svc := s.newService()
entry := svc.getOrCreateClient("", 0, 0)
transport, ok := entry.client.Transport.(*http.Transport)
require.True(s.T(), ok, "expected *http.Transport")
require.Equal(s.T(), 300*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch")
}
// TestCustomResponseHeaderTimeout 测试自定义响应头超时配置
// 验证配置值能正确应用到 Transport
func (s *HTTPUpstreamSuite) TestCustomResponseHeaderTimeout() {
s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 7}
up := NewHTTPUpstream(s.cfg)
svc, ok := up.(*httpUpstreamService)
require.True(s.T(), ok, "expected *httpUpstreamService")
transport, ok := svc.defaultClient.Transport.(*http.Transport)
svc := s.newService()
entry := svc.getOrCreateClient("", 0, 0)
transport, ok := entry.client.Transport.(*http.Transport)
require.True(s.T(), ok, "expected *http.Transport")
require.Equal(s.T(), 7*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch")
}
func (s *HTTPUpstreamSuite) TestCreateProxyClient_InvalidURLFallsBackToDefault() {
s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 5}
up := NewHTTPUpstream(s.cfg)
svc, ok := up.(*httpUpstreamService)
require.True(s.T(), ok, "expected *httpUpstreamService")
// TestGetOrCreateClient_InvalidURLFallsBackToDirect 测试无效代理 URL 回退
// 验证解析失败时回退到直连模式
func (s *HTTPUpstreamSuite) TestGetOrCreateClient_InvalidURLFallsBackToDirect() {
svc := s.newService()
entry := svc.getOrCreateClient("://bad-proxy-url", 1, 1)
require.Equal(s.T(), directProxyKey, entry.proxyKey, "expected direct proxy fallback")
}
// TestNormalizeProxyURL_Canonicalizes 测试代理 URL 规范化
// 验证等价地址能够映射到同一缓存键
func (s *HTTPUpstreamSuite) TestNormalizeProxyURL_Canonicalizes() {
key1, _ := normalizeProxyURL("http://proxy.local:8080")
key2, _ := normalizeProxyURL("http://proxy.local:8080/")
require.Equal(s.T(), key1, key2, "expected normalized proxy keys to match")
}
// TestAcquireClient_OverLimitReturnsError 测试连接池缓存上限保护
// 验证超限且无可淘汰条目时返回错误
func (s *HTTPUpstreamSuite) TestAcquireClient_OverLimitReturnsError() {
s.cfg.Gateway = config.GatewayConfig{
ConnectionPoolIsolation: config.ConnectionPoolIsolationAccountProxy,
MaxUpstreamClients: 1,
}
svc := s.newService()
entry1, err := svc.acquireClient("http://proxy-a:8080", 1, 1)
require.NoError(s.T(), err, "expected first acquire to succeed")
require.NotNil(s.T(), entry1, "expected entry")
got := svc.createProxyClient("://bad-proxy-url")
require.Equal(s.T(), svc.defaultClient, got, "expected defaultClient fallback")
entry2, err := svc.acquireClient("http://proxy-b:8080", 2, 1)
require.Error(s.T(), err, "expected error when cache limit reached")
require.Nil(s.T(), entry2, "expected nil entry when cache limit reached")
}
// TestDo_WithoutProxy_GoesDirect 测试无代理时直连
// 验证空代理 URL 时请求直接发送到目标服务器
func (s *HTTPUpstreamSuite) TestDo_WithoutProxy_GoesDirect() {
// 创建模拟上游服务器
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "direct")
}))
......@@ -60,17 +102,21 @@ func (s *HTTPUpstreamSuite) TestDo_WithoutProxy_GoesDirect() {
req, err := http.NewRequest(http.MethodGet, upstream.URL+"/x", nil)
require.NoError(s.T(), err, "NewRequest")
resp, err := up.Do(req, "")
resp, err := up.Do(req, "", 1, 1)
require.NoError(s.T(), err, "Do")
defer func() { _ = resp.Body.Close() }()
b, _ := io.ReadAll(resp.Body)
require.Equal(s.T(), "direct", string(b), "unexpected body")
}
// TestDo_WithHTTPProxy_UsesProxy 测试 HTTP 代理功能
// 验证请求通过代理服务器转发,使用绝对 URI 格式
func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() {
// 用于接收代理请求的通道
seen := make(chan string, 1)
// 创建模拟代理服务器
proxySrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
seen <- r.RequestURI
seen <- r.RequestURI // 记录请求 URI
_, _ = io.WriteString(w, "proxied")
}))
s.T().Cleanup(proxySrv.Close)
......@@ -78,14 +124,16 @@ func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() {
s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 1}
up := NewHTTPUpstream(s.cfg)
// 发送请求到外部地址,应通过代理
req, err := http.NewRequest(http.MethodGet, "http://example.com/test", nil)
require.NoError(s.T(), err, "NewRequest")
resp, err := up.Do(req, proxySrv.URL)
resp, err := up.Do(req, proxySrv.URL, 1, 1)
require.NoError(s.T(), err, "Do")
defer func() { _ = resp.Body.Close() }()
b, _ := io.ReadAll(resp.Body)
require.Equal(s.T(), "proxied", string(b), "unexpected body")
// 验证代理收到的是绝对 URI 格式(HTTP 代理规范要求)
select {
case uri := <-seen:
require.Equal(s.T(), "http://example.com/test", uri, "expected absolute-form request URI")
......@@ -94,6 +142,8 @@ func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() {
}
}
// TestDo_EmptyProxy_UsesDirect 测试空代理字符串
// 验证空字符串代理等同于直连
func (s *HTTPUpstreamSuite) TestDo_EmptyProxy_UsesDirect() {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "direct-empty")
......@@ -103,13 +153,134 @@ func (s *HTTPUpstreamSuite) TestDo_EmptyProxy_UsesDirect() {
up := NewHTTPUpstream(s.cfg)
req, err := http.NewRequest(http.MethodGet, upstream.URL+"/y", nil)
require.NoError(s.T(), err, "NewRequest")
resp, err := up.Do(req, "")
resp, err := up.Do(req, "", 1, 1)
require.NoError(s.T(), err, "Do with empty proxy")
defer func() { _ = resp.Body.Close() }()
b, _ := io.ReadAll(resp.Body)
require.Equal(s.T(), "direct-empty", string(b))
}
// TestAccountIsolation_DifferentAccounts 测试账户隔离模式
// 验证不同账户使用独立的连接池
func (s *HTTPUpstreamSuite) TestAccountIsolation_DifferentAccounts() {
s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount}
svc := s.newService()
// 同一代理,不同账户
entry1 := svc.getOrCreateClient("http://proxy.local:8080", 1, 3)
entry2 := svc.getOrCreateClient("http://proxy.local:8080", 2, 3)
require.NotSame(s.T(), entry1, entry2, "不同账号不应共享连接池")
require.Equal(s.T(), 2, len(svc.clients), "账号隔离应缓存两个客户端")
}
// TestAccountProxyIsolation_DifferentProxy 测试账户+代理组合隔离模式
// 验证同一账户使用不同代理时创建独立连接池
func (s *HTTPUpstreamSuite) TestAccountProxyIsolation_DifferentProxy() {
s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccountProxy}
svc := s.newService()
// 同一账户,不同代理
entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 3)
entry2 := svc.getOrCreateClient("http://proxy-b:8080", 1, 3)
require.NotSame(s.T(), entry1, entry2, "账号+代理隔离应区分不同代理")
require.Equal(s.T(), 2, len(svc.clients), "账号+代理隔离应缓存两个客户端")
}
// TestAccountModeProxyChangeClearsPool 测试账户模式下代理变更
// 验证账户切换代理时清理旧连接池,避免复用错误代理
func (s *HTTPUpstreamSuite) TestAccountModeProxyChangeClearsPool() {
s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount}
svc := s.newService()
// 同一账户,先后使用不同代理
entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 3)
entry2 := svc.getOrCreateClient("http://proxy-b:8080", 1, 3)
require.NotSame(s.T(), entry1, entry2, "账号切换代理应创建新连接池")
require.Equal(s.T(), 1, len(svc.clients), "账号模式下应仅保留一个连接池")
require.False(s.T(), hasEntry(svc, entry1), "旧连接池应被清理")
}
// TestAccountConcurrencyOverridesPoolSettings 测试账户并发数覆盖连接池配置
// 验证账户隔离模式下,连接池大小与账户并发数对应
func (s *HTTPUpstreamSuite) TestAccountConcurrencyOverridesPoolSettings() {
s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount}
svc := s.newService()
// 账户并发数为 12
entry := svc.getOrCreateClient("", 1, 12)
transport, ok := entry.client.Transport.(*http.Transport)
require.True(s.T(), ok, "expected *http.Transport")
// 连接池参数应与并发数一致
require.Equal(s.T(), 12, transport.MaxConnsPerHost, "MaxConnsPerHost mismatch")
require.Equal(s.T(), 12, transport.MaxIdleConns, "MaxIdleConns mismatch")
require.Equal(s.T(), 12, transport.MaxIdleConnsPerHost, "MaxIdleConnsPerHost mismatch")
}
// TestAccountConcurrencyFallbackToDefault 测试账户并发数为 0 时回退到默认配置
// 验证未指定并发数时使用全局配置值
func (s *HTTPUpstreamSuite) TestAccountConcurrencyFallbackToDefault() {
s.cfg.Gateway = config.GatewayConfig{
ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount,
MaxIdleConns: 77,
MaxIdleConnsPerHost: 55,
MaxConnsPerHost: 66,
}
svc := s.newService()
// 账户并发数为 0,应使用全局配置
entry := svc.getOrCreateClient("", 1, 0)
transport, ok := entry.client.Transport.(*http.Transport)
require.True(s.T(), ok, "expected *http.Transport")
require.Equal(s.T(), 66, transport.MaxConnsPerHost, "MaxConnsPerHost fallback mismatch")
require.Equal(s.T(), 77, transport.MaxIdleConns, "MaxIdleConns fallback mismatch")
require.Equal(s.T(), 55, transport.MaxIdleConnsPerHost, "MaxIdleConnsPerHost fallback mismatch")
}
// TestEvictOverLimitRemovesOldestIdle 测试超出数量限制时的 LRU 淘汰
// 验证优先淘汰最久未使用的空闲客户端
func (s *HTTPUpstreamSuite) TestEvictOverLimitRemovesOldestIdle() {
s.cfg.Gateway = config.GatewayConfig{
ConnectionPoolIsolation: config.ConnectionPoolIsolationAccountProxy,
MaxUpstreamClients: 2, // 最多缓存 2 个客户端
}
svc := s.newService()
// 创建两个客户端,设置不同的最后使用时间
entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 1)
entry2 := svc.getOrCreateClient("http://proxy-b:8080", 2, 1)
atomic.StoreInt64(&entry1.lastUsed, time.Now().Add(-2*time.Hour).UnixNano()) // 最久
atomic.StoreInt64(&entry2.lastUsed, time.Now().Add(-time.Hour).UnixNano())
// 创建第三个客户端,触发淘汰
_ = svc.getOrCreateClient("http://proxy-c:8080", 3, 1)
require.LessOrEqual(s.T(), len(svc.clients), 2, "应保持在缓存上限内")
require.False(s.T(), hasEntry(svc, entry1), "最久未使用的连接池应被清理")
}
// TestIdleTTLDoesNotEvictActive 测试活跃请求保护
// 验证有进行中请求的客户端不会被空闲超时淘汰
func (s *HTTPUpstreamSuite) TestIdleTTLDoesNotEvictActive() {
s.cfg.Gateway = config.GatewayConfig{
ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount,
ClientIdleTTLSeconds: 1, // 1 秒空闲超时
}
svc := s.newService()
entry1 := svc.getOrCreateClient("", 1, 1)
// 设置为很久之前使用,但有活跃请求
atomic.StoreInt64(&entry1.lastUsed, time.Now().Add(-2*time.Minute).UnixNano())
atomic.StoreInt64(&entry1.inFlight, 1) // 模拟有活跃请求
// 创建新客户端,触发淘汰检查
_ = svc.getOrCreateClient("", 2, 1)
require.True(s.T(), hasEntry(svc, entry1), "有活跃请求时不应回收")
}
// TestHTTPUpstreamSuite 运行测试套件
func TestHTTPUpstreamSuite(t *testing.T) {
suite.Run(t, new(HTTPUpstreamSuite))
}
// hasEntry 检查客户端是否存在于缓存中
// 辅助函数,用于验证淘汰逻辑
func hasEntry(svc *httpUpstreamService, target *upstreamClientEntry) bool {
for _, entry := range svc.clients {
if entry == target {
return true
}
}
return false
}
......@@ -17,7 +17,6 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
_ "github.com/Wei-Shaw/sub2api/ent/runtime"
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
......@@ -97,7 +96,7 @@ func TestMain(m *testing.M) {
log.Printf("failed to open sql db: %v", err)
os.Exit(1)
}
if err := infrastructure.ApplyMigrations(ctx, integrationDB); err != nil {
if err := ApplyMigrations(ctx, integrationDB); err != nil {
log.Printf("failed to apply db migrations: %v", err)
os.Exit(1)
}
......@@ -330,7 +329,8 @@ func (h prefixHook) prefixCmd(cmd redisclient.Cmder) {
switch strings.ToLower(cmd.Name()) {
case "get", "set", "setnx", "setex", "psetex", "incr", "decr", "incrby", "expire", "pexpire", "ttl", "pttl",
"hgetall", "hget", "hset", "hdel", "hincrbyfloat", "exists":
"hgetall", "hget", "hset", "hdel", "hincrbyfloat", "exists",
"zadd", "zcard", "zrange", "zrangebyscore", "zrem", "zremrangebyscore", "zrevrange", "zrevrangebyscore", "zscore":
prefixOne(1)
case "del", "unlink":
for i := 1; i < len(args); i++ {
......
package infrastructure
package repository
import (
"context"
......@@ -127,7 +127,15 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
if existing != checksum {
// 校验和不匹配意味着迁移文件在应用后被修改,这是危险的。
// 正确的做法是创建新的迁移文件来进行变更。
return fmt.Errorf("migration %s checksum mismatch (db=%s file=%s)", name, existing, checksum)
return fmt.Errorf(
"migration %s checksum mismatch (db=%s file=%s)\n"+
"This means the migration file was modified after being applied to the database.\n"+
"Solutions:\n"+
" 1. Revert to original: git log --oneline -- migrations/%s && git checkout <commit> -- migrations/%s\n"+
" 2. For new changes, create a new migration file instead of modifying existing ones\n"+
"Note: Modifying applied migrations breaks the immutability principle and can cause inconsistencies across environments",
name, existing, checksum, name, name,
)
}
continue // 迁移已应用且校验和匹配,跳过
}
......
......@@ -7,7 +7,6 @@ import (
"database/sql"
"testing"
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
"github.com/stretchr/testify/require"
)
......@@ -15,7 +14,7 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
tx := testTx(t)
// Re-apply migrations to verify idempotency (no errors, no duplicate rows).
require.NoError(t, infrastructure.ApplyMigrations(context.Background(), integrationDB))
require.NoError(t, ApplyMigrations(context.Background(), integrationDB))
// schema_migrations should have at least the current migration set.
var applied int
......@@ -53,6 +52,20 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
var uagRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.user_allowed_groups')").Scan(&uagRegclass))
require.True(t, uagRegclass.Valid, "expected user_allowed_groups table to exist")
// user_subscriptions: deleted_at for soft delete support (migration 012)
requireColumn(t, tx, "user_subscriptions", "deleted_at", "timestamp with time zone", 0, true)
// orphan_allowed_groups_audit table should exist (migration 013)
var orphanAuditRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.orphan_allowed_groups_audit')").Scan(&orphanAuditRegclass))
require.True(t, orphanAuditRegclass.Valid, "expected orphan_allowed_groups_audit table to exist")
// account_groups: created_at should be timestamptz
requireColumn(t, tx, "account_groups", "created_at", "timestamp with time zone", 0, false)
// user_allowed_groups: created_at should be timestamptz
requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false)
}
func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) {
......
......@@ -82,12 +82,8 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
}
func createOpenAIReqClient(proxyURL string) *req.Client {
client := req.C().
SetTimeout(60 * time.Second)
if proxyURL != "" {
client.SetProxyURL(proxyURL)
}
return client
return getSharedReqClient(reqClientOptions{
ProxyURL: proxyURL,
Timeout: 60 * time.Second,
})
}
......@@ -8,6 +8,7 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service"
)
......@@ -16,10 +17,14 @@ type pricingRemoteClient struct {
}
func NewPricingRemoteClient() service.PricingRemoteClient {
sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 30 * time.Second,
})
if err != nil {
sharedClient = &http.Client{Timeout: 30 * time.Second}
}
return &pricingRemoteClient{
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
httpClient: sharedClient,
}
}
......
......@@ -2,18 +2,14 @@ package repository
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service"
"golang.org/x/net/proxy"
)
func NewProxyExitInfoProber() service.ProxyExitInfoProber {
......@@ -27,14 +23,14 @@ type proxyProbeService struct {
}
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
transport, err := createProxyTransport(proxyURL)
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL,
Timeout: 15 * time.Second,
InsecureSkipVerify: true,
ProxyStrict: true,
})
if err != nil {
return nil, 0, fmt.Errorf("failed to create proxy transport: %w", err)
}
client := &http.Client{
Transport: transport,
Timeout: 15 * time.Second,
return nil, 0, fmt.Errorf("failed to create proxy client: %w", err)
}
startTime := time.Now()
......@@ -78,31 +74,3 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
Country: ipInfo.Country,
}, latencyMs, nil
}
func createProxyTransport(proxyURL string) (*http.Transport, error) {
parsedURL, err := url.Parse(proxyURL)
if err != nil {
return nil, fmt.Errorf("invalid proxy URL: %w", err)
}
transport := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
switch parsedURL.Scheme {
case "http", "https":
transport.Proxy = http.ProxyURL(parsedURL)
case "socks5":
dialer, err := proxy.FromURL(parsedURL, proxy.Direct)
if err != nil {
return nil, fmt.Errorf("failed to create socks5 dialer: %w", err)
}
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
}
default:
return nil, fmt.Errorf("unsupported proxy protocol: %s", parsedURL.Scheme)
}
return transport, nil
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment