Commit 7331220e authored by Edric Li's avatar Edric Li
Browse files

Merge remote-tracking branch 'upstream/main'

# Conflicts:
#	frontend/src/components/account/CreateAccountModal.vue
parents fb86002e 4f13c8de
package repository
import (
"context"
"fmt"
"os"
"testing"
"time"
"github.com/redis/go-redis/v9"
)
// 基准测试用 TTL 配置
const benchSlotTTLMinutes = 15
var benchSlotTTL = time.Duration(benchSlotTTLMinutes) * time.Minute
// BenchmarkAccountConcurrency 用于对比 SCAN 与有序集合的计数性能。
func BenchmarkAccountConcurrency(b *testing.B) {
rdb := newBenchmarkRedisClient(b)
defer func() {
_ = rdb.Close()
}()
cache, _ := NewConcurrencyCache(rdb, benchSlotTTLMinutes, int(benchSlotTTL.Seconds())).(*concurrencyCache)
ctx := context.Background()
for _, size := range []int{10, 100, 1000} {
size := size
b.Run(fmt.Sprintf("zset/slots=%d", size), func(b *testing.B) {
accountID := time.Now().UnixNano()
key := accountSlotKey(accountID)
b.StopTimer()
members := make([]redis.Z, 0, size)
now := float64(time.Now().Unix())
for i := 0; i < size; i++ {
members = append(members, redis.Z{
Score: now,
Member: fmt.Sprintf("req_%d", i),
})
}
if err := rdb.ZAdd(ctx, key, members...).Err(); err != nil {
b.Fatalf("初始化有序集合失败: %v", err)
}
if err := rdb.Expire(ctx, key, benchSlotTTL).Err(); err != nil {
b.Fatalf("设置有序集合 TTL 失败: %v", err)
}
b.StartTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
if _, err := cache.GetAccountConcurrency(ctx, accountID); err != nil {
b.Fatalf("获取并发数量失败: %v", err)
}
}
b.StopTimer()
if err := rdb.Del(ctx, key).Err(); err != nil {
b.Fatalf("清理有序集合失败: %v", err)
}
})
b.Run(fmt.Sprintf("scan/slots=%d", size), func(b *testing.B) {
accountID := time.Now().UnixNano()
pattern := fmt.Sprintf("%s%d:*", accountSlotKeyPrefix, accountID)
keys := make([]string, 0, size)
b.StopTimer()
pipe := rdb.Pipeline()
for i := 0; i < size; i++ {
key := fmt.Sprintf("%s%d:req_%d", accountSlotKeyPrefix, accountID, i)
keys = append(keys, key)
pipe.Set(ctx, key, "1", benchSlotTTL)
}
if _, err := pipe.Exec(ctx); err != nil {
b.Fatalf("初始化扫描键失败: %v", err)
}
b.StartTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
if _, err := scanSlotCount(ctx, rdb, pattern); err != nil {
b.Fatalf("SCAN 计数失败: %v", err)
}
}
b.StopTimer()
if err := rdb.Del(ctx, keys...).Err(); err != nil {
b.Fatalf("清理扫描键失败: %v", err)
}
})
}
}
func scanSlotCount(ctx context.Context, rdb *redis.Client, pattern string) (int, error) {
var cursor uint64
count := 0
for {
keys, nextCursor, err := rdb.Scan(ctx, cursor, pattern, 100).Result()
if err != nil {
return 0, err
}
count += len(keys)
if nextCursor == 0 {
break
}
cursor = nextCursor
}
return count, nil
}
func newBenchmarkRedisClient(b *testing.B) *redis.Client {
b.Helper()
redisURL := os.Getenv("TEST_REDIS_URL")
if redisURL == "" {
b.Skip("未设置 TEST_REDIS_URL,跳过 Redis 基准测试")
}
opt, err := redis.ParseURL(redisURL)
if err != nil {
b.Fatalf("解析 TEST_REDIS_URL 失败: %v", err)
}
client := redis.NewClient(opt)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
if err := client.Ping(ctx).Err(); err != nil {
b.Fatalf("Redis 连接失败: %v", err)
}
return client
}
...@@ -14,6 +14,12 @@ import ( ...@@ -14,6 +14,12 @@ import (
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
) )
// 测试用 TTL 配置(15 分钟,与默认值一致)
const testSlotTTLMinutes = 15
// 测试用 TTL Duration,用于 TTL 断言
var testSlotTTL = time.Duration(testSlotTTLMinutes) * time.Minute
type ConcurrencyCacheSuite struct { type ConcurrencyCacheSuite struct {
IntegrationRedisSuite IntegrationRedisSuite
cache service.ConcurrencyCache cache service.ConcurrencyCache
...@@ -21,7 +27,7 @@ type ConcurrencyCacheSuite struct { ...@@ -21,7 +27,7 @@ type ConcurrencyCacheSuite struct {
func (s *ConcurrencyCacheSuite) SetupTest() { func (s *ConcurrencyCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest() s.IntegrationRedisSuite.SetupTest()
s.cache = NewConcurrencyCache(s.rdb) s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes, int(testSlotTTL.Seconds()))
} }
func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() { func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() {
...@@ -54,7 +60,7 @@ func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() { ...@@ -54,7 +60,7 @@ func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() {
func (s *ConcurrencyCacheSuite) TestAccountSlot_TTL() { func (s *ConcurrencyCacheSuite) TestAccountSlot_TTL() {
accountID := int64(11) accountID := int64(11)
reqID := "req_ttl_test" reqID := "req_ttl_test"
slotKey := fmt.Sprintf("%s%d:%s", accountSlotKeyPrefix, accountID, reqID) slotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, reqID) ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, reqID)
require.NoError(s.T(), err, "AcquireAccountSlot") require.NoError(s.T(), err, "AcquireAccountSlot")
...@@ -62,7 +68,7 @@ func (s *ConcurrencyCacheSuite) TestAccountSlot_TTL() { ...@@ -62,7 +68,7 @@ func (s *ConcurrencyCacheSuite) TestAccountSlot_TTL() {
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result() ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
require.NoError(s.T(), err, "TTL") require.NoError(s.T(), err, "TTL")
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL) s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
} }
func (s *ConcurrencyCacheSuite) TestAccountSlot_DuplicateReqID() { func (s *ConcurrencyCacheSuite) TestAccountSlot_DuplicateReqID() {
...@@ -139,7 +145,7 @@ func (s *ConcurrencyCacheSuite) TestUserSlot_AcquireAndRelease() { ...@@ -139,7 +145,7 @@ func (s *ConcurrencyCacheSuite) TestUserSlot_AcquireAndRelease() {
func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() { func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() {
userID := int64(200) userID := int64(200)
reqID := "req_ttl_test" reqID := "req_ttl_test"
slotKey := fmt.Sprintf("%s%d:%s", userSlotKeyPrefix, userID, reqID) slotKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 5, reqID) ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 5, reqID)
require.NoError(s.T(), err, "AcquireUserSlot") require.NoError(s.T(), err, "AcquireUserSlot")
...@@ -147,7 +153,7 @@ func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() { ...@@ -147,7 +153,7 @@ func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() {
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result() ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
require.NoError(s.T(), err, "TTL") require.NoError(s.T(), err, "TTL")
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL) s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
} }
func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() { func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
...@@ -168,7 +174,7 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() { ...@@ -168,7 +174,7 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result() ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
require.NoError(s.T(), err, "TTL waitKey") require.NoError(s.T(), err, "TTL waitKey")
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL) s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount") require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
...@@ -212,6 +218,48 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() { ...@@ -212,6 +218,48 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count") require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count")
} }
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() {
accountID := int64(30)
waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
ok, err := s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
require.NoError(s.T(), err, "IncrementAccountWaitCount 1")
require.True(s.T(), ok)
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
require.NoError(s.T(), err, "IncrementAccountWaitCount 2")
require.True(s.T(), ok)
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
require.NoError(s.T(), err, "IncrementAccountWaitCount 3")
require.False(s.T(), ok, "expected account wait increment over max to fail")
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
require.NoError(s.T(), err, "TTL account waitKey")
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount")
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
require.Equal(s.T(), 1, val, "expected account wait count 1")
}
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_DecrementNoNegative() {
accountID := int64(301)
waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount on non-existent key")
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative account wait count after decrement on empty")
}
func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() { func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() {
// When no slots exist, GetAccountConcurrency should return 0 // When no slots exist, GetAccountConcurrency should return 0
cur, err := s.cache.GetAccountConcurrency(s.ctx, 999) cur, err := s.cache.GetAccountConcurrency(s.ctx, 999)
...@@ -226,6 +274,139 @@ func (s *ConcurrencyCacheSuite) TestGetUserConcurrency_Missing() { ...@@ -226,6 +274,139 @@ func (s *ConcurrencyCacheSuite) TestGetUserConcurrency_Missing() {
require.Equal(s.T(), 0, cur) require.Equal(s.T(), 0, cur)
} }
func (s *ConcurrencyCacheSuite) TestGetAccountsLoadBatch() {
s.T().Skip("TODO: Fix this test - CurrentConcurrency returns 0 instead of expected value in CI")
// Setup: Create accounts with different load states
account1 := int64(100)
account2 := int64(101)
account3 := int64(102)
// Account 1: 2/3 slots used, 1 waiting
ok, err := s.cache.AcquireAccountSlot(s.ctx, account1, 3, "req1")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, account1, 3, "req2")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, account1, 5)
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Account 2: 1/2 slots used, 0 waiting
ok, err = s.cache.AcquireAccountSlot(s.ctx, account2, 2, "req3")
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Account 3: 0/1 slots used, 0 waiting (idle)
// Query batch load
accounts := []service.AccountWithConcurrency{
{ID: account1, MaxConcurrency: 3},
{ID: account2, MaxConcurrency: 2},
{ID: account3, MaxConcurrency: 1},
}
loadMap, err := s.cache.GetAccountsLoadBatch(s.ctx, accounts)
require.NoError(s.T(), err)
require.Len(s.T(), loadMap, 3)
// Verify account1: (2 + 1) / 3 = 100%
load1 := loadMap[account1]
require.NotNil(s.T(), load1)
require.Equal(s.T(), account1, load1.AccountID)
require.Equal(s.T(), 2, load1.CurrentConcurrency)
require.Equal(s.T(), 1, load1.WaitingCount)
require.Equal(s.T(), 100, load1.LoadRate)
// Verify account2: (1 + 0) / 2 = 50%
load2 := loadMap[account2]
require.NotNil(s.T(), load2)
require.Equal(s.T(), account2, load2.AccountID)
require.Equal(s.T(), 1, load2.CurrentConcurrency)
require.Equal(s.T(), 0, load2.WaitingCount)
require.Equal(s.T(), 50, load2.LoadRate)
// Verify account3: (0 + 0) / 1 = 0%
load3 := loadMap[account3]
require.NotNil(s.T(), load3)
require.Equal(s.T(), account3, load3.AccountID)
require.Equal(s.T(), 0, load3.CurrentConcurrency)
require.Equal(s.T(), 0, load3.WaitingCount)
require.Equal(s.T(), 0, load3.LoadRate)
}
func (s *ConcurrencyCacheSuite) TestGetAccountsLoadBatch_Empty() {
// Test with empty account list
loadMap, err := s.cache.GetAccountsLoadBatch(s.ctx, []service.AccountWithConcurrency{})
require.NoError(s.T(), err)
require.Empty(s.T(), loadMap)
}
func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots() {
accountID := int64(200)
slotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
// Acquire 3 slots
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req1")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req2")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req3")
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Verify 3 slots exist
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 3, cur)
// Manually set old timestamps for req1 and req2 (simulate expired slots)
now := time.Now().Unix()
expiredTime := now - int64(testSlotTTL.Seconds()) - 10 // 10 seconds past TTL
err = s.rdb.ZAdd(s.ctx, slotKey, redis.Z{Score: float64(expiredTime), Member: "req1"}).Err()
require.NoError(s.T(), err)
err = s.rdb.ZAdd(s.ctx, slotKey, redis.Z{Score: float64(expiredTime), Member: "req2"}).Err()
require.NoError(s.T(), err)
// Run cleanup
err = s.cache.CleanupExpiredAccountSlots(s.ctx, accountID)
require.NoError(s.T(), err)
// Verify only 1 slot remains (req3)
cur, err = s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 1, cur)
// Verify req3 still exists
members, err := s.rdb.ZRange(s.ctx, slotKey, 0, -1).Result()
require.NoError(s.T(), err)
require.Len(s.T(), members, 1)
require.Equal(s.T(), "req3", members[0])
}
func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots_NoExpired() {
accountID := int64(201)
// Acquire 2 fresh slots
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req1")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req2")
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Run cleanup (should not remove anything)
err = s.cache.CleanupExpiredAccountSlots(s.ctx, accountID)
require.NoError(s.T(), err)
// Verify both slots still exist
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 2, cur)
}
func TestConcurrencyCacheSuite(t *testing.T) { func TestConcurrencyCacheSuite(t *testing.T) {
suite.Run(t, new(ConcurrencyCacheSuite)) suite.Run(t, new(ConcurrencyCacheSuite))
} }
package repository
import (
"database/sql"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
)
type dbPoolSettings struct {
MaxOpenConns int
MaxIdleConns int
ConnMaxLifetime time.Duration
ConnMaxIdleTime time.Duration
}
func buildDBPoolSettings(cfg *config.Config) dbPoolSettings {
return dbPoolSettings{
MaxOpenConns: cfg.Database.MaxOpenConns,
MaxIdleConns: cfg.Database.MaxIdleConns,
ConnMaxLifetime: time.Duration(cfg.Database.ConnMaxLifetimeMinutes) * time.Minute,
ConnMaxIdleTime: time.Duration(cfg.Database.ConnMaxIdleTimeMinutes) * time.Minute,
}
}
func applyDBPoolSettings(db *sql.DB, cfg *config.Config) {
settings := buildDBPoolSettings(cfg)
db.SetMaxOpenConns(settings.MaxOpenConns)
db.SetMaxIdleConns(settings.MaxIdleConns)
db.SetConnMaxLifetime(settings.ConnMaxLifetime)
db.SetConnMaxIdleTime(settings.ConnMaxIdleTime)
}
package repository
import (
"database/sql"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
_ "github.com/lib/pq"
)
func TestBuildDBPoolSettings(t *testing.T) {
cfg := &config.Config{
Database: config.DatabaseConfig{
MaxOpenConns: 50,
MaxIdleConns: 10,
ConnMaxLifetimeMinutes: 30,
ConnMaxIdleTimeMinutes: 5,
},
}
settings := buildDBPoolSettings(cfg)
require.Equal(t, 50, settings.MaxOpenConns)
require.Equal(t, 10, settings.MaxIdleConns)
require.Equal(t, 30*time.Minute, settings.ConnMaxLifetime)
require.Equal(t, 5*time.Minute, settings.ConnMaxIdleTime)
}
func TestApplyDBPoolSettings(t *testing.T) {
cfg := &config.Config{
Database: config.DatabaseConfig{
MaxOpenConns: 40,
MaxIdleConns: 8,
ConnMaxLifetimeMinutes: 15,
ConnMaxIdleTimeMinutes: 3,
},
}
db, err := sql.Open("postgres", "host=127.0.0.1 port=5432 user=postgres sslmode=disable")
require.NoError(t, err)
t.Cleanup(func() {
_ = db.Close()
})
applyDBPoolSettings(db, cfg)
stats := db.Stats()
require.Equal(t, 40, stats.MaxOpenConnections)
}
// Package infrastructure 提供应用程序的基础设施层组件。 // Package infrastructure 提供应用程序的基础设施层组件。
// 包括数据库连接初始化、ORM 客户端管理、Redis 连接、数据库迁移等核心功能。 // 包括数据库连接初始化、ORM 客户端管理、Redis 连接、数据库迁移等核心功能。
package infrastructure package repository
import ( import (
"context" "context"
...@@ -51,6 +51,7 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) { ...@@ -51,6 +51,7 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) {
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
applyDBPoolSettings(drv.DB(), cfg)
// 确保数据库 schema 已准备就绪。 // 确保数据库 schema 已准备就绪。
// SQL 迁移文件是 schema 的权威来源(source of truth)。 // SQL 迁移文件是 schema 的权威来源(source of truth)。
......
package repository package repository
import ( import (
"context"
"database/sql" "database/sql"
"errors" "errors"
"strings" "strings"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/lib/pq" "github.com/lib/pq"
) )
// clientFromContext 从 context 中获取事务 client,如果不存在则返回默认 client。
//
// 这个辅助函数支持 repository 方法在事务上下文中工作:
// - 如果 context 中存在事务(通过 ent.NewTxContext 设置),返回事务的 client
// - 否则返回传入的默认 client
//
// 使用示例:
//
// func (r *someRepo) SomeMethod(ctx context.Context) error {
// client := clientFromContext(ctx, r.client)
// return client.SomeEntity.Create().Save(ctx)
// }
func clientFromContext(ctx context.Context, defaultClient *dbent.Client) *dbent.Client {
if tx := dbent.TxFromContext(ctx); tx != nil {
return tx.Client()
}
return defaultClient
}
// translatePersistenceError 将数据库层错误翻译为业务层错误。 // translatePersistenceError 将数据库层错误翻译为业务层错误。
// //
// 这是 Repository 层的核心错误处理函数,确保数据库细节不会泄露到业务层。 // 这是 Repository 层的核心错误处理函数,确保数据库细节不会泄露到业务层。
......
...@@ -109,9 +109,8 @@ func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refresh ...@@ -109,9 +109,8 @@ func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refresh
} }
func createGeminiReqClient(proxyURL string) *req.Client { func createGeminiReqClient(proxyURL string) *req.Client {
client := req.C().SetTimeout(60 * time.Second) return getSharedReqClient(reqClientOptions{
if proxyURL != "" { ProxyURL: proxyURL,
client.SetProxyURL(proxyURL) Timeout: 60 * time.Second,
} })
return client
} }
...@@ -76,11 +76,10 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken ...@@ -76,11 +76,10 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken
} }
func createGeminiCliReqClient(proxyURL string) *req.Client { func createGeminiCliReqClient(proxyURL string) *req.Client {
client := req.C().SetTimeout(30 * time.Second) return getSharedReqClient(reqClientOptions{
if proxyURL != "" { ProxyURL: proxyURL,
client.SetProxyURL(proxyURL) Timeout: 30 * time.Second,
} })
return client
} }
func defaultLoadCodeAssistRequest() *geminicli.LoadCodeAssistRequest { func defaultLoadCodeAssistRequest() *geminicli.LoadCodeAssistRequest {
......
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"os" "os"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
) )
...@@ -17,10 +18,14 @@ type githubReleaseClient struct { ...@@ -17,10 +18,14 @@ type githubReleaseClient struct {
} }
func NewGitHubReleaseClient() service.GitHubReleaseClient { func NewGitHubReleaseClient() service.GitHubReleaseClient {
sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 30 * time.Second,
})
if err != nil {
sharedClient = &http.Client{Timeout: 30 * time.Second}
}
return &githubReleaseClient{ return &githubReleaseClient{
httpClient: &http.Client{ httpClient: sharedClient,
Timeout: 30 * time.Second,
},
} }
} }
...@@ -58,8 +63,13 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string ...@@ -58,8 +63,13 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
return err return err
} }
client := &http.Client{Timeout: 10 * time.Minute} downloadClient, err := httpclient.GetClient(httpclient.Options{
resp, err := client.Do(req) Timeout: 10 * time.Minute,
})
if err != nil {
downloadClient = &http.Client{Timeout: 10 * time.Minute}
}
resp, err := downloadClient.Do(req)
if err != nil { if err != nil {
return err return err
} }
......
...@@ -42,7 +42,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er ...@@ -42,7 +42,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetSubscriptionType(groupIn.SubscriptionType). SetSubscriptionType(groupIn.SubscriptionType).
SetNillableDailyLimitUsd(groupIn.DailyLimitUSD). SetNillableDailyLimitUsd(groupIn.DailyLimitUSD).
SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD). SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD).
SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD) SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD).
SetDefaultValidityDays(groupIn.DefaultValidityDays)
created, err := builder.Save(ctx) created, err := builder.Save(ctx)
if err == nil { if err == nil {
...@@ -79,6 +80,7 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er ...@@ -79,6 +80,7 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetNillableDailyLimitUsd(groupIn.DailyLimitUSD). SetNillableDailyLimitUsd(groupIn.DailyLimitUSD).
SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD). SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD).
SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD). SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD).
SetDefaultValidityDays(groupIn.DefaultValidityDays).
Save(ctx) Save(ctx)
if err != nil { if err != nil {
return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists) return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists)
...@@ -89,7 +91,7 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er ...@@ -89,7 +91,7 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
func (r *groupRepository) Delete(ctx context.Context, id int64) error { func (r *groupRepository) Delete(ctx context.Context, id int64) error {
_, err := r.client.Group.Delete().Where(group.IDEQ(id)).Exec(ctx) _, err := r.client.Group.Delete().Where(group.IDEQ(id)).Exec(ctx)
return err return translatePersistenceError(err, service.ErrGroupNotFound, nil)
} }
func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) { func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
...@@ -239,8 +241,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, ...@@ -239,8 +241,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
// err 为 dbent.ErrTxStarted 时,复用当前 client 参与同一事务。 // err 为 dbent.ErrTxStarted 时,复用当前 client 参与同一事务。
// Lock the group row to avoid concurrent writes while we cascade. // Lock the group row to avoid concurrent writes while we cascade.
// 这里使用 exec.QueryContext 手动扫描,确保同一事务内加锁并能区分未找到与其他错误。 // 这里使用 exec.QueryContext 手动扫描,确保同一事务内加锁并能区分"未找到"与其他错误。
rows, err := exec.QueryContext(ctx, "SELECT id FROM groups WHERE id = $1 FOR UPDATE", id) rows, err := exec.QueryContext(ctx, "SELECT id FROM groups WHERE id = $1 AND deleted_at IS NULL FOR UPDATE", id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -263,7 +265,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, ...@@ -263,7 +265,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
var affectedUserIDs []int64 var affectedUserIDs []int64
if groupSvc.IsSubscriptionType() { if groupSvc.IsSubscriptionType() {
rows, err := exec.QueryContext(ctx, "SELECT user_id FROM user_subscriptions WHERE group_id = $1", id) // 只查询未软删除的订阅,避免通知已取消订阅的用户
rows, err := exec.QueryContext(ctx, "SELECT user_id FROM user_subscriptions WHERE group_id = $1 AND deleted_at IS NULL", id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -282,7 +285,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, ...@@ -282,7 +285,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
return nil, err return nil, err
} }
if _, err := exec.ExecContext(ctx, "DELETE FROM user_subscriptions WHERE group_id = $1", id); err != nil { // 软删除订阅:设置 deleted_at 而非硬删除
if _, err := exec.ExecContext(ctx, "UPDATE user_subscriptions SET deleted_at = NOW() WHERE group_id = $1 AND deleted_at IS NULL", id); err != nil {
return nil, err return nil, err
} }
} }
...@@ -297,18 +301,11 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, ...@@ -297,18 +301,11 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
return nil, err return nil, err
} }
// 3. Remove the group id from users.allowed_groups array (legacy representation). // 3. Remove the group id from user_allowed_groups join table.
// Phase 1 compatibility: also delete from user_allowed_groups join table when present. // Legacy users.allowed_groups 列已弃用,不再同步。
if _, err := exec.ExecContext(ctx, "DELETE FROM user_allowed_groups WHERE group_id = $1", id); err != nil { if _, err := exec.ExecContext(ctx, "DELETE FROM user_allowed_groups WHERE group_id = $1", id); err != nil {
return nil, err return nil, err
} }
if _, err := exec.ExecContext(
ctx,
"UPDATE users SET allowed_groups = array_remove(allowed_groups, $1) WHERE $1 = ANY(allowed_groups)",
id,
); err != nil {
return nil, err
}
// 4. Delete account_groups join rows. // 4. Delete account_groups join rows.
if _, err := exec.ExecContext(ctx, "DELETE FROM account_groups WHERE group_id = $1", id); err != nil { if _, err := exec.ExecContext(ctx, "DELETE FROM account_groups WHERE group_id = $1", id); err != nil {
......
...@@ -478,3 +478,58 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() { ...@@ -478,3 +478,58 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
count, _ := s.repo.GetAccountCount(s.ctx, g.ID) count, _ := s.repo.GetAccountCount(s.ctx, g.ID)
s.Require().Zero(count) s.Require().Zero(count)
} }
// --- 软删除过滤测试 ---
func (s *GroupRepoSuite) TestDelete_SoftDelete_NotVisibleInList() {
group := &service.Group{
Name: "to-soft-delete",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, group))
// 获取删除前的列表数量
listBefore, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 100})
s.Require().NoError(err)
beforeCount := len(listBefore)
// 软删除
err = s.repo.Delete(s.ctx, group.ID)
s.Require().NoError(err, "Delete (soft delete)")
// 验证列表中不再包含软删除的 group
listAfter, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 100})
s.Require().NoError(err)
s.Require().Len(listAfter, beforeCount-1, "soft deleted group should not appear in list")
// 验证 GetByID 也无法找到
_, err = s.repo.GetByID(s.ctx, group.ID)
s.Require().Error(err)
s.Require().ErrorIs(err, service.ErrGroupNotFound)
}
func (s *GroupRepoSuite) TestDelete_SoftDeletedGroup_lockForUpdate() {
group := &service.Group{
Name: "lock-soft-delete",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, group))
// 软删除
err := s.repo.Delete(s.ctx, group.ID)
s.Require().NoError(err)
// 验证软删除的 group 在 GetByID 时返回 ErrGroupNotFound
// 这证明 lockForUpdate 的 deleted_at IS NULL 过滤正在工作
_, err = s.repo.GetByID(s.ctx, group.ID)
s.Require().Error(err, "should fail to get soft-deleted group")
s.Require().ErrorIs(err, service.ErrGroupNotFound)
}
package repository package repository
import ( import (
"errors"
"fmt"
"io"
"net"
"net/http" "net/http"
"net/url" "net/url"
"strings"
"sync"
"sync/atomic"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
) )
// httpUpstreamService is a generic HTTP upstream service that can be used for // 默认配置常量
// making requests to any HTTP API (Claude, OpenAI, etc.) with optional proxy support. // 这些值在配置文件未指定时作为回退默认值使用
const (
// directProxyKey: 无代理时的缓存键标识
directProxyKey = "direct"
// defaultMaxIdleConns: 默认最大空闲连接总数
// HTTP/2 场景下,单连接可多路复用,240 足以支撑高并发
defaultMaxIdleConns = 240
// defaultMaxIdleConnsPerHost: 默认每主机最大空闲连接数
defaultMaxIdleConnsPerHost = 120
// defaultMaxConnsPerHost: 默认每主机最大连接数(含活跃连接)
// 达到上限后新请求会等待,而非无限创建连接
defaultMaxConnsPerHost = 240
// defaultIdleConnTimeout: 默认空闲连接超时时间(5分钟)
// 超时后连接会被关闭,释放系统资源
defaultIdleConnTimeout = 300 * time.Second
// defaultResponseHeaderTimeout: 默认等待响应头超时时间(5分钟)
// LLM 请求可能排队较久,需要较长超时
defaultResponseHeaderTimeout = 300 * time.Second
// defaultMaxUpstreamClients: 默认最大客户端缓存数量
// 超出后会淘汰最久未使用的客户端
defaultMaxUpstreamClients = 5000
// defaultClientIdleTTLSeconds: 默认客户端空闲回收阈值(15分钟)
defaultClientIdleTTLSeconds = 900
)
var errUpstreamClientLimitReached = errors.New("upstream client cache limit reached")
// poolSettings 连接池配置参数
// 封装 Transport 所需的各项连接池参数
type poolSettings struct {
maxIdleConns int // 最大空闲连接总数
maxIdleConnsPerHost int // 每主机最大空闲连接数
maxConnsPerHost int // 每主机最大连接数(含活跃)
idleConnTimeout time.Duration // 空闲连接超时时间
responseHeaderTimeout time.Duration // 等待响应头超时时间
}
// upstreamClientEntry 上游客户端缓存条目
// 记录客户端实例及其元数据,用于连接池管理和淘汰策略
type upstreamClientEntry struct {
client *http.Client // HTTP 客户端实例
proxyKey string // 代理标识(用于检测代理变更)
poolKey string // 连接池配置标识(用于检测配置变更)
lastUsed int64 // 最后使用时间戳(纳秒),用于 LRU 淘汰
inFlight int64 // 当前进行中的请求数,>0 时不可淘汰
}
// httpUpstreamService 通用 HTTP 上游服务
// 用于向任意 HTTP API(Claude、OpenAI 等)发送请求,支持可选代理
//
// 架构设计:
// - 根据隔离策略(proxy/account/account_proxy)缓存客户端实例
// - 每个客户端拥有独立的 Transport 连接池
// - 支持 LRU + 空闲时间双重淘汰策略
//
// 性能优化:
// 1. 根据隔离策略缓存客户端实例,避免频繁创建 http.Client
// 2. 复用 Transport 连接池,减少 TCP 握手和 TLS 协商开销
// 3. 支持账号级隔离与空闲回收,降低连接层关联风险
// 4. 达到最大连接数后等待可用连接,而非无限创建
// 5. 仅回收空闲客户端,避免中断活跃请求
// 6. HTTP/2 多路复用,连接上限不等于并发请求上限
// 7. 代理变更时清空旧连接池,避免复用错误代理
// 8. 账号并发数与连接池上限对应(账号隔离策略下)
type httpUpstreamService struct { type httpUpstreamService struct {
defaultClient *http.Client cfg *config.Config // 全局配置
cfg *config.Config mu sync.RWMutex // 保护 clients map 的读写锁
clients map[string]*upstreamClientEntry // 客户端缓存池,key 由隔离策略决定
} }
// NewHTTPUpstream creates a new generic HTTP upstream service // NewHTTPUpstream 创建通用 HTTP 上游服务
// 使用配置中的连接池参数构建 Transport
//
// 参数:
// - cfg: 全局配置,包含连接池参数和隔离策略
//
// 返回:
// - service.HTTPUpstream 接口实现
func NewHTTPUpstream(cfg *config.Config) service.HTTPUpstream { func NewHTTPUpstream(cfg *config.Config) service.HTTPUpstream {
responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second return &httpUpstreamService{
if responseHeaderTimeout == 0 { cfg: cfg,
responseHeaderTimeout = 300 * time.Second clients: make(map[string]*upstreamClientEntry),
} }
}
transport := &http.Transport{ // Do 执行 HTTP 请求
MaxIdleConns: 100, // 根据隔离策略获取或创建客户端,并跟踪请求生命周期
MaxIdleConnsPerHost: 10, //
IdleConnTimeout: 90 * time.Second, // 参数:
ResponseHeaderTimeout: responseHeaderTimeout, // - req: HTTP 请求对象
// - proxyURL: 代理地址,空字符串表示直连
// - accountID: 账户 ID,用于账户级隔离
// - accountConcurrency: 账户并发限制,用于动态调整连接池大小
//
// 返回:
// - *http.Response: HTTP 响应(Body 已包装,关闭时自动更新计数)
// - error: 请求错误
//
// 注意:
// - 调用方必须关闭 resp.Body,否则会导致 inFlight 计数泄漏
// - inFlight > 0 的客户端不会被淘汰,确保活跃请求不被中断
func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
// 获取或创建对应的客户端,并标记请求占用
entry, err := s.acquireClient(proxyURL, accountID, accountConcurrency)
if err != nil {
return nil, err
} }
return &httpUpstreamService{ // 执行请求
defaultClient: &http.Client{Transport: transport}, resp, err := entry.client.Do(req)
cfg: cfg, if err != nil {
// 请求失败,立即减少计数
atomic.AddInt64(&entry.inFlight, -1)
atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano())
return nil, err
} }
// 包装响应体,在关闭时自动减少计数并更新时间戳
// 这确保了流式响应(如 SSE)在完全读取前不会被淘汰
resp.Body = wrapTrackedBody(resp.Body, func() {
atomic.AddInt64(&entry.inFlight, -1)
atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano())
})
return resp, nil
} }
func (s *httpUpstreamService) Do(req *http.Request, proxyURL string) (*http.Response, error) { // acquireClient 获取或创建客户端,并标记为进行中请求
if proxyURL == "" { // 用于请求路径,避免在获取后被淘汰
return s.defaultClient.Do(req) func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) {
return s.getClientEntry(proxyURL, accountID, accountConcurrency, true, true)
}
// getOrCreateClient 获取或创建客户端
// 根据隔离策略和参数决定缓存键,处理代理变更和配置变更
//
// 参数:
// - proxyURL: 代理地址
// - accountID: 账户 ID
// - accountConcurrency: 账户并发限制
//
// 返回:
// - *upstreamClientEntry: 客户端缓存条目
//
// 隔离策略说明:
// - proxy: 按代理地址隔离,同一代理共享客户端
// - account: 按账户隔离,同一账户共享客户端(代理变更时重建)
// - account_proxy: 按账户+代理组合隔离,最细粒度
func (s *httpUpstreamService) getOrCreateClient(proxyURL string, accountID int64, accountConcurrency int) *upstreamClientEntry {
entry, _ := s.getClientEntry(proxyURL, accountID, accountConcurrency, false, false)
return entry
}
// getClientEntry 获取或创建客户端条目
// markInFlight=true 时会标记进行中请求,用于请求路径防止被淘汰
// enforceLimit=true 时会限制客户端数量,超限且无法淘汰时返回错误
func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, accountConcurrency int, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) {
// 获取隔离模式
isolation := s.getIsolationMode()
// 标准化代理 URL 并解析
proxyKey, parsedProxy := normalizeProxyURL(proxyURL)
// 构建缓存键(根据隔离策略不同)
cacheKey := buildCacheKey(isolation, proxyKey, accountID)
// 构建连接池配置键(用于检测配置变更)
poolKey := s.buildPoolKey(isolation, accountConcurrency)
now := time.Now()
nowUnix := now.UnixNano()
// 读锁快速路径:命中缓存直接返回,减少锁竞争
s.mu.RLock()
if entry, ok := s.clients[cacheKey]; ok && s.shouldReuseEntry(entry, isolation, proxyKey, poolKey) {
atomic.StoreInt64(&entry.lastUsed, nowUnix)
if markInFlight {
atomic.AddInt64(&entry.inFlight, 1)
}
s.mu.RUnlock()
return entry, nil
}
s.mu.RUnlock()
// 写锁慢路径:创建或重建客户端
s.mu.Lock()
if entry, ok := s.clients[cacheKey]; ok {
if s.shouldReuseEntry(entry, isolation, proxyKey, poolKey) {
atomic.StoreInt64(&entry.lastUsed, nowUnix)
if markInFlight {
atomic.AddInt64(&entry.inFlight, 1)
}
s.mu.Unlock()
return entry, nil
}
s.removeClientLocked(cacheKey, entry)
}
// 超出缓存上限时尝试淘汰,无法淘汰则拒绝新建
if enforceLimit && s.maxUpstreamClients() > 0 {
s.evictIdleLocked(now)
if len(s.clients) >= s.maxUpstreamClients() {
if !s.evictOldestIdleLocked() {
s.mu.Unlock()
return nil, errUpstreamClientLimitReached
}
}
}
// 缓存未命中或需要重建,创建新客户端
settings := s.resolvePoolSettings(isolation, accountConcurrency)
client := &http.Client{Transport: buildUpstreamTransport(settings, parsedProxy)}
entry := &upstreamClientEntry{
client: client,
proxyKey: proxyKey,
poolKey: poolKey,
}
atomic.StoreInt64(&entry.lastUsed, nowUnix)
if markInFlight {
atomic.StoreInt64(&entry.inFlight, 1)
}
s.clients[cacheKey] = entry
// 执行淘汰策略:先淘汰空闲超时的,再淘汰超出数量限制的
s.evictIdleLocked(now)
s.evictOverLimitLocked()
s.mu.Unlock()
return entry, nil
}
// shouldReuseEntry 判断缓存条目是否可复用
// 若代理或连接池配置发生变化,则需要重建客户端
func (s *httpUpstreamService) shouldReuseEntry(entry *upstreamClientEntry, isolation, proxyKey, poolKey string) bool {
if entry == nil {
return false
}
if isolation == config.ConnectionPoolIsolationAccount && entry.proxyKey != proxyKey {
return false
}
if entry.poolKey != poolKey {
return false
}
return true
}
// removeClientLocked 移除客户端(需持有锁)
// 从缓存中删除并关闭空闲连接
//
// 参数:
// - key: 缓存键
// - entry: 客户端条目
func (s *httpUpstreamService) removeClientLocked(key string, entry *upstreamClientEntry) {
delete(s.clients, key)
if entry != nil && entry.client != nil {
// 关闭空闲连接,释放系统资源
// 注意:这不会中断活跃连接
entry.client.CloseIdleConnections()
}
}
// evictIdleLocked 淘汰空闲超时的客户端(需持有锁)
// 遍历所有客户端,移除超过 TTL 且无活跃请求的条目
//
// 参数:
// - now: 当前时间
func (s *httpUpstreamService) evictIdleLocked(now time.Time) {
ttl := s.clientIdleTTL()
if ttl <= 0 {
return
}
// 计算淘汰截止时间
cutoff := now.Add(-ttl).UnixNano()
for key, entry := range s.clients {
// 跳过有活跃请求的客户端
if atomic.LoadInt64(&entry.inFlight) != 0 {
continue
}
// 淘汰超时的空闲客户端
if atomic.LoadInt64(&entry.lastUsed) <= cutoff {
s.removeClientLocked(key, entry)
}
}
}
// evictOldestIdleLocked 淘汰最久未使用且无活跃请求的客户端(需持有锁)
func (s *httpUpstreamService) evictOldestIdleLocked() bool {
var (
oldestKey string
oldestEntry *upstreamClientEntry
oldestTime int64
)
// 查找最久未使用且无活跃请求的客户端
for key, entry := range s.clients {
// 跳过有活跃请求的客户端
if atomic.LoadInt64(&entry.inFlight) != 0 {
continue
}
lastUsed := atomic.LoadInt64(&entry.lastUsed)
if oldestEntry == nil || lastUsed < oldestTime {
oldestKey = key
oldestEntry = entry
oldestTime = lastUsed
}
}
// 所有客户端都有活跃请求,无法淘汰
if oldestEntry == nil {
return false
}
s.removeClientLocked(oldestKey, oldestEntry)
return true
}
// evictOverLimitLocked 淘汰超出数量限制的客户端(需持有锁)
// 使用 LRU 策略,优先淘汰最久未使用且无活跃请求的客户端
func (s *httpUpstreamService) evictOverLimitLocked() bool {
maxClients := s.maxUpstreamClients()
if maxClients <= 0 {
return false
}
evicted := false
// 循环淘汰直到满足数量限制
for len(s.clients) > maxClients {
if !s.evictOldestIdleLocked() {
return evicted
}
evicted = true
}
return evicted
}
// getIsolationMode 获取连接池隔离模式
// 从配置中读取,无效值回退到 account_proxy 模式
//
// 返回:
// - string: 隔离模式(proxy/account/account_proxy)
func (s *httpUpstreamService) getIsolationMode() string {
if s.cfg == nil {
return config.ConnectionPoolIsolationAccountProxy
}
mode := strings.ToLower(strings.TrimSpace(s.cfg.Gateway.ConnectionPoolIsolation))
if mode == "" {
return config.ConnectionPoolIsolationAccountProxy
}
switch mode {
case config.ConnectionPoolIsolationProxy, config.ConnectionPoolIsolationAccount, config.ConnectionPoolIsolationAccountProxy:
return mode
default:
return config.ConnectionPoolIsolationAccountProxy
}
}
// maxUpstreamClients 获取最大客户端缓存数量
// 从配置中读取,无效值使用默认值
func (s *httpUpstreamService) maxUpstreamClients() int {
if s.cfg == nil {
return defaultMaxUpstreamClients
}
if s.cfg.Gateway.MaxUpstreamClients > 0 {
return s.cfg.Gateway.MaxUpstreamClients
}
return defaultMaxUpstreamClients
}
// clientIdleTTL 获取客户端空闲回收阈值
// 从配置中读取,无效值使用默认值
func (s *httpUpstreamService) clientIdleTTL() time.Duration {
if s.cfg == nil {
return time.Duration(defaultClientIdleTTLSeconds) * time.Second
}
if s.cfg.Gateway.ClientIdleTTLSeconds > 0 {
return time.Duration(s.cfg.Gateway.ClientIdleTTLSeconds) * time.Second
}
return time.Duration(defaultClientIdleTTLSeconds) * time.Second
}
// resolvePoolSettings 解析连接池配置
// 根据隔离策略和账户并发数动态调整连接池参数
//
// 参数:
// - isolation: 隔离模式
// - accountConcurrency: 账户并发限制
//
// 返回:
// - poolSettings: 连接池配置
//
// 说明:
// - 账户隔离模式下,连接池大小与账户并发数对应
// - 这确保了单账户不会占用过多连接资源
func (s *httpUpstreamService) resolvePoolSettings(isolation string, accountConcurrency int) poolSettings {
settings := defaultPoolSettings(s.cfg)
// 账户隔离模式下,根据账户并发数调整连接池大小
if (isolation == config.ConnectionPoolIsolationAccount || isolation == config.ConnectionPoolIsolationAccountProxy) && accountConcurrency > 0 {
settings.maxIdleConns = accountConcurrency
settings.maxIdleConnsPerHost = accountConcurrency
settings.maxConnsPerHost = accountConcurrency
} }
client := s.createProxyClient(proxyURL) return settings
return client.Do(req)
} }
func (s *httpUpstreamService) createProxyClient(proxyURL string) *http.Client { // buildPoolKey 构建连接池配置键
parsedURL, err := url.Parse(proxyURL) // 用于检测配置变更,配置变更时需要重建客户端
//
// 参数:
// - isolation: 隔离模式
// - accountConcurrency: 账户并发限制
//
// 返回:
// - string: 配置键
func (s *httpUpstreamService) buildPoolKey(isolation string, accountConcurrency int) string {
if isolation == config.ConnectionPoolIsolationAccount || isolation == config.ConnectionPoolIsolationAccountProxy {
if accountConcurrency > 0 {
return fmt.Sprintf("account:%d", accountConcurrency)
}
}
return "default"
}
// buildCacheKey 构建客户端缓存键
// 根据隔离策略决定缓存键的组成
//
// 参数:
// - isolation: 隔离模式
// - proxyKey: 代理标识
// - accountID: 账户 ID
//
// 返回:
// - string: 缓存键
//
// 缓存键格式:
// - proxy 模式: "proxy:{proxyKey}"
// - account 模式: "account:{accountID}"
// - account_proxy 模式: "account:{accountID}|proxy:{proxyKey}"
func buildCacheKey(isolation, proxyKey string, accountID int64) string {
switch isolation {
case config.ConnectionPoolIsolationAccount:
return fmt.Sprintf("account:%d", accountID)
case config.ConnectionPoolIsolationAccountProxy:
return fmt.Sprintf("account:%d|proxy:%s", accountID, proxyKey)
default:
return fmt.Sprintf("proxy:%s", proxyKey)
}
}
// normalizeProxyURL 标准化代理 URL
// 处理空值和解析错误,返回标准化的键和解析后的 URL
//
// 参数:
// - raw: 原始代理 URL 字符串
//
// 返回:
// - string: 标准化的代理键(空或解析失败返回 "direct")
// - *url.URL: 解析后的 URL(空或解析失败返回 nil)
func normalizeProxyURL(raw string) (string, *url.URL) {
proxyURL := strings.TrimSpace(raw)
if proxyURL == "" {
return directProxyKey, nil
}
parsed, err := url.Parse(proxyURL)
if err != nil { if err != nil {
return s.defaultClient return directProxyKey, nil
}
parsed.Scheme = strings.ToLower(parsed.Scheme)
parsed.Host = strings.ToLower(parsed.Host)
parsed.Path = ""
parsed.RawPath = ""
parsed.RawQuery = ""
parsed.Fragment = ""
parsed.ForceQuery = false
if hostname := parsed.Hostname(); hostname != "" {
port := parsed.Port()
if (parsed.Scheme == "http" && port == "80") || (parsed.Scheme == "https" && port == "443") {
port = ""
}
hostname = strings.ToLower(hostname)
if port != "" {
parsed.Host = net.JoinHostPort(hostname, port)
} else {
parsed.Host = hostname
}
}
return parsed.String(), parsed
}
// defaultPoolSettings 获取默认连接池配置
// 从全局配置中读取,无效值使用常量默认值
//
// 参数:
// - cfg: 全局配置
//
// 返回:
// - poolSettings: 连接池配置
func defaultPoolSettings(cfg *config.Config) poolSettings {
maxIdleConns := defaultMaxIdleConns
maxIdleConnsPerHost := defaultMaxIdleConnsPerHost
maxConnsPerHost := defaultMaxConnsPerHost
idleConnTimeout := defaultIdleConnTimeout
responseHeaderTimeout := defaultResponseHeaderTimeout
if cfg != nil {
if cfg.Gateway.MaxIdleConns > 0 {
maxIdleConns = cfg.Gateway.MaxIdleConns
}
if cfg.Gateway.MaxIdleConnsPerHost > 0 {
maxIdleConnsPerHost = cfg.Gateway.MaxIdleConnsPerHost
}
if cfg.Gateway.MaxConnsPerHost >= 0 {
maxConnsPerHost = cfg.Gateway.MaxConnsPerHost
}
if cfg.Gateway.IdleConnTimeoutSeconds > 0 {
idleConnTimeout = time.Duration(cfg.Gateway.IdleConnTimeoutSeconds) * time.Second
}
if cfg.Gateway.ResponseHeaderTimeout > 0 {
responseHeaderTimeout = time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second
}
} }
responseHeaderTimeout := time.Duration(s.cfg.Gateway.ResponseHeaderTimeout) * time.Second return poolSettings{
if responseHeaderTimeout == 0 { maxIdleConns: maxIdleConns,
responseHeaderTimeout = 300 * time.Second maxIdleConnsPerHost: maxIdleConnsPerHost,
maxConnsPerHost: maxConnsPerHost,
idleConnTimeout: idleConnTimeout,
responseHeaderTimeout: responseHeaderTimeout,
} }
}
// buildUpstreamTransport 构建上游请求的 Transport
// 使用配置文件中的连接池参数,支持生产环境调优
//
// 参数:
// - settings: 连接池配置
// - proxyURL: 代理 URL(nil 表示直连)
//
// 返回:
// - *http.Transport: 配置好的 Transport 实例
//
// Transport 参数说明:
// - MaxIdleConns: 所有主机的最大空闲连接总数
// - MaxIdleConnsPerHost: 每主机最大空闲连接数(影响连接复用率)
// - MaxConnsPerHost: 每主机最大连接数(达到后新请求等待)
// - IdleConnTimeout: 空闲连接超时(超时后关闭)
// - ResponseHeaderTimeout: 等待响应头超时(不影响流式传输)
func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) *http.Transport {
transport := &http.Transport{ transport := &http.Transport{
Proxy: http.ProxyURL(parsedURL), MaxIdleConns: settings.maxIdleConns,
MaxIdleConns: 100, MaxIdleConnsPerHost: settings.maxIdleConnsPerHost,
MaxIdleConnsPerHost: 10, MaxConnsPerHost: settings.maxConnsPerHost,
IdleConnTimeout: 90 * time.Second, IdleConnTimeout: settings.idleConnTimeout,
ResponseHeaderTimeout: responseHeaderTimeout, ResponseHeaderTimeout: settings.responseHeaderTimeout,
} }
if proxyURL != nil {
transport.Proxy = http.ProxyURL(proxyURL)
}
return transport
}
return &http.Client{Transport: transport} // trackedBody 带跟踪功能的响应体包装器
// 在 Close 时执行回调,用于更新请求计数
type trackedBody struct {
io.ReadCloser // 原始响应体
once sync.Once
onClose func() // 关闭时的回调函数
}
// Close 关闭响应体并执行回调
// 使用 sync.Once 确保回调只执行一次
func (b *trackedBody) Close() error {
err := b.ReadCloser.Close()
if b.onClose != nil {
b.once.Do(b.onClose)
}
return err
}
// wrapTrackedBody 包装响应体以跟踪关闭事件
// 用于在响应体关闭时更新 inFlight 计数
//
// 参数:
// - body: 原始响应体
// - onClose: 关闭时的回调函数
//
// 返回:
// - io.ReadCloser: 包装后的响应体
func wrapTrackedBody(body io.ReadCloser, onClose func()) io.ReadCloser {
if body == nil {
return body
}
return &trackedBody{ReadCloser: body, onClose: onClose}
} }
package repository
import (
"net/http"
"net/url"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
)
// httpClientSink 用于防止编译器优化掉基准测试中的赋值操作
// 这是 Go 基准测试的常见模式,确保测试结果准确
var httpClientSink *http.Client
// BenchmarkHTTPUpstreamProxyClient 对比重复创建与复用代理客户端的开销
//
// 测试目的:
// - 验证连接池复用相比每次新建的性能提升
// - 量化内存分配差异
//
// 预期结果:
// - "复用" 子测试应显著快于 "新建"
// - "复用" 子测试应零内存分配
func BenchmarkHTTPUpstreamProxyClient(b *testing.B) {
// 创建测试配置
cfg := &config.Config{
Gateway: config.GatewayConfig{ResponseHeaderTimeout: 300},
}
upstream := NewHTTPUpstream(cfg)
svc, ok := upstream.(*httpUpstreamService)
if !ok {
b.Fatalf("类型断言失败,无法获取 httpUpstreamService")
}
proxyURL := "http://127.0.0.1:8080"
b.ReportAllocs() // 报告内存分配统计
// 子测试:每次新建客户端
// 模拟未优化前的行为,每次请求都创建新的 http.Client
b.Run("新建", func(b *testing.B) {
parsedProxy, err := url.Parse(proxyURL)
if err != nil {
b.Fatalf("解析代理地址失败: %v", err)
}
settings := defaultPoolSettings(cfg)
for i := 0; i < b.N; i++ {
// 每次迭代都创建新客户端,包含 Transport 分配
httpClientSink = &http.Client{
Transport: buildUpstreamTransport(settings, parsedProxy),
}
}
})
// 子测试:复用已缓存的客户端
// 模拟优化后的行为,从缓存获取客户端
b.Run("复用", func(b *testing.B) {
// 预热:确保客户端已缓存
entry := svc.getOrCreateClient(proxyURL, 1, 1)
client := entry.client
b.ResetTimer() // 重置计时器,排除预热时间
for i := 0; i < b.N; i++ {
// 直接使用缓存的客户端,无内存分配
httpClientSink = client
}
})
}
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"sync/atomic"
"testing" "testing"
"time" "time"
...@@ -12,45 +13,86 @@ import ( ...@@ -12,45 +13,86 @@ import (
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
) )
// HTTPUpstreamSuite HTTP 上游服务测试套件
// 使用 testify/suite 组织测试,支持 SetupTest 初始化
type HTTPUpstreamSuite struct { type HTTPUpstreamSuite struct {
suite.Suite suite.Suite
cfg *config.Config cfg *config.Config // 测试用配置
} }
// SetupTest 每个测试用例执行前的初始化
// 创建空配置,各测试用例可按需覆盖
func (s *HTTPUpstreamSuite) SetupTest() { func (s *HTTPUpstreamSuite) SetupTest() {
s.cfg = &config.Config{} s.cfg = &config.Config{}
} }
func (s *HTTPUpstreamSuite) TestDefaultResponseHeaderTimeout() { // newService 创建测试用的 httpUpstreamService 实例
// 返回具体类型以便访问内部状态进行断言
func (s *HTTPUpstreamSuite) newService() *httpUpstreamService {
up := NewHTTPUpstream(s.cfg) up := NewHTTPUpstream(s.cfg)
svc, ok := up.(*httpUpstreamService) svc, ok := up.(*httpUpstreamService)
require.True(s.T(), ok, "expected *httpUpstreamService") require.True(s.T(), ok, "expected *httpUpstreamService")
transport, ok := svc.defaultClient.Transport.(*http.Transport) return svc
}
// TestDefaultResponseHeaderTimeout 测试默认响应头超时配置
// 验证未配置时使用 300 秒默认值
func (s *HTTPUpstreamSuite) TestDefaultResponseHeaderTimeout() {
svc := s.newService()
entry := svc.getOrCreateClient("", 0, 0)
transport, ok := entry.client.Transport.(*http.Transport)
require.True(s.T(), ok, "expected *http.Transport") require.True(s.T(), ok, "expected *http.Transport")
require.Equal(s.T(), 300*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch") require.Equal(s.T(), 300*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch")
} }
// TestCustomResponseHeaderTimeout 测试自定义响应头超时配置
// 验证配置值能正确应用到 Transport
func (s *HTTPUpstreamSuite) TestCustomResponseHeaderTimeout() { func (s *HTTPUpstreamSuite) TestCustomResponseHeaderTimeout() {
s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 7} s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 7}
up := NewHTTPUpstream(s.cfg) svc := s.newService()
svc, ok := up.(*httpUpstreamService) entry := svc.getOrCreateClient("", 0, 0)
require.True(s.T(), ok, "expected *httpUpstreamService") transport, ok := entry.client.Transport.(*http.Transport)
transport, ok := svc.defaultClient.Transport.(*http.Transport)
require.True(s.T(), ok, "expected *http.Transport") require.True(s.T(), ok, "expected *http.Transport")
require.Equal(s.T(), 7*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch") require.Equal(s.T(), 7*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch")
} }
func (s *HTTPUpstreamSuite) TestCreateProxyClient_InvalidURLFallsBackToDefault() { // TestGetOrCreateClient_InvalidURLFallsBackToDirect 测试无效代理 URL 回退
s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 5} // 验证解析失败时回退到直连模式
up := NewHTTPUpstream(s.cfg) func (s *HTTPUpstreamSuite) TestGetOrCreateClient_InvalidURLFallsBackToDirect() {
svc, ok := up.(*httpUpstreamService) svc := s.newService()
require.True(s.T(), ok, "expected *httpUpstreamService") entry := svc.getOrCreateClient("://bad-proxy-url", 1, 1)
require.Equal(s.T(), directProxyKey, entry.proxyKey, "expected direct proxy fallback")
}
// TestNormalizeProxyURL_Canonicalizes 测试代理 URL 规范化
// 验证等价地址能够映射到同一缓存键
func (s *HTTPUpstreamSuite) TestNormalizeProxyURL_Canonicalizes() {
key1, _ := normalizeProxyURL("http://proxy.local:8080")
key2, _ := normalizeProxyURL("http://proxy.local:8080/")
require.Equal(s.T(), key1, key2, "expected normalized proxy keys to match")
}
// TestAcquireClient_OverLimitReturnsError 测试连接池缓存上限保护
// 验证超限且无可淘汰条目时返回错误
func (s *HTTPUpstreamSuite) TestAcquireClient_OverLimitReturnsError() {
s.cfg.Gateway = config.GatewayConfig{
ConnectionPoolIsolation: config.ConnectionPoolIsolationAccountProxy,
MaxUpstreamClients: 1,
}
svc := s.newService()
entry1, err := svc.acquireClient("http://proxy-a:8080", 1, 1)
require.NoError(s.T(), err, "expected first acquire to succeed")
require.NotNil(s.T(), entry1, "expected entry")
got := svc.createProxyClient("://bad-proxy-url") entry2, err := svc.acquireClient("http://proxy-b:8080", 2, 1)
require.Equal(s.T(), svc.defaultClient, got, "expected defaultClient fallback") require.Error(s.T(), err, "expected error when cache limit reached")
require.Nil(s.T(), entry2, "expected nil entry when cache limit reached")
} }
// TestDo_WithoutProxy_GoesDirect 测试无代理时直连
// 验证空代理 URL 时请求直接发送到目标服务器
func (s *HTTPUpstreamSuite) TestDo_WithoutProxy_GoesDirect() { func (s *HTTPUpstreamSuite) TestDo_WithoutProxy_GoesDirect() {
// 创建模拟上游服务器
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "direct") _, _ = io.WriteString(w, "direct")
})) }))
...@@ -60,17 +102,21 @@ func (s *HTTPUpstreamSuite) TestDo_WithoutProxy_GoesDirect() { ...@@ -60,17 +102,21 @@ func (s *HTTPUpstreamSuite) TestDo_WithoutProxy_GoesDirect() {
req, err := http.NewRequest(http.MethodGet, upstream.URL+"/x", nil) req, err := http.NewRequest(http.MethodGet, upstream.URL+"/x", nil)
require.NoError(s.T(), err, "NewRequest") require.NoError(s.T(), err, "NewRequest")
resp, err := up.Do(req, "") resp, err := up.Do(req, "", 1, 1)
require.NoError(s.T(), err, "Do") require.NoError(s.T(), err, "Do")
defer func() { _ = resp.Body.Close() }() defer func() { _ = resp.Body.Close() }()
b, _ := io.ReadAll(resp.Body) b, _ := io.ReadAll(resp.Body)
require.Equal(s.T(), "direct", string(b), "unexpected body") require.Equal(s.T(), "direct", string(b), "unexpected body")
} }
// TestDo_WithHTTPProxy_UsesProxy 测试 HTTP 代理功能
// 验证请求通过代理服务器转发,使用绝对 URI 格式
func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() { func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() {
// 用于接收代理请求的通道
seen := make(chan string, 1) seen := make(chan string, 1)
// 创建模拟代理服务器
proxySrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { proxySrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
seen <- r.RequestURI seen <- r.RequestURI // 记录请求 URI
_, _ = io.WriteString(w, "proxied") _, _ = io.WriteString(w, "proxied")
})) }))
s.T().Cleanup(proxySrv.Close) s.T().Cleanup(proxySrv.Close)
...@@ -78,14 +124,16 @@ func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() { ...@@ -78,14 +124,16 @@ func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() {
s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 1} s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 1}
up := NewHTTPUpstream(s.cfg) up := NewHTTPUpstream(s.cfg)
// 发送请求到外部地址,应通过代理
req, err := http.NewRequest(http.MethodGet, "http://example.com/test", nil) req, err := http.NewRequest(http.MethodGet, "http://example.com/test", nil)
require.NoError(s.T(), err, "NewRequest") require.NoError(s.T(), err, "NewRequest")
resp, err := up.Do(req, proxySrv.URL) resp, err := up.Do(req, proxySrv.URL, 1, 1)
require.NoError(s.T(), err, "Do") require.NoError(s.T(), err, "Do")
defer func() { _ = resp.Body.Close() }() defer func() { _ = resp.Body.Close() }()
b, _ := io.ReadAll(resp.Body) b, _ := io.ReadAll(resp.Body)
require.Equal(s.T(), "proxied", string(b), "unexpected body") require.Equal(s.T(), "proxied", string(b), "unexpected body")
// 验证代理收到的是绝对 URI 格式(HTTP 代理规范要求)
select { select {
case uri := <-seen: case uri := <-seen:
require.Equal(s.T(), "http://example.com/test", uri, "expected absolute-form request URI") require.Equal(s.T(), "http://example.com/test", uri, "expected absolute-form request URI")
...@@ -94,6 +142,8 @@ func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() { ...@@ -94,6 +142,8 @@ func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() {
} }
} }
// TestDo_EmptyProxy_UsesDirect 测试空代理字符串
// 验证空字符串代理等同于直连
func (s *HTTPUpstreamSuite) TestDo_EmptyProxy_UsesDirect() { func (s *HTTPUpstreamSuite) TestDo_EmptyProxy_UsesDirect() {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "direct-empty") _, _ = io.WriteString(w, "direct-empty")
...@@ -103,13 +153,134 @@ func (s *HTTPUpstreamSuite) TestDo_EmptyProxy_UsesDirect() { ...@@ -103,13 +153,134 @@ func (s *HTTPUpstreamSuite) TestDo_EmptyProxy_UsesDirect() {
up := NewHTTPUpstream(s.cfg) up := NewHTTPUpstream(s.cfg)
req, err := http.NewRequest(http.MethodGet, upstream.URL+"/y", nil) req, err := http.NewRequest(http.MethodGet, upstream.URL+"/y", nil)
require.NoError(s.T(), err, "NewRequest") require.NoError(s.T(), err, "NewRequest")
resp, err := up.Do(req, "") resp, err := up.Do(req, "", 1, 1)
require.NoError(s.T(), err, "Do with empty proxy") require.NoError(s.T(), err, "Do with empty proxy")
defer func() { _ = resp.Body.Close() }() defer func() { _ = resp.Body.Close() }()
b, _ := io.ReadAll(resp.Body) b, _ := io.ReadAll(resp.Body)
require.Equal(s.T(), "direct-empty", string(b)) require.Equal(s.T(), "direct-empty", string(b))
} }
// TestAccountIsolation_DifferentAccounts 测试账户隔离模式
// 验证不同账户使用独立的连接池
func (s *HTTPUpstreamSuite) TestAccountIsolation_DifferentAccounts() {
s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount}
svc := s.newService()
// 同一代理,不同账户
entry1 := svc.getOrCreateClient("http://proxy.local:8080", 1, 3)
entry2 := svc.getOrCreateClient("http://proxy.local:8080", 2, 3)
require.NotSame(s.T(), entry1, entry2, "不同账号不应共享连接池")
require.Equal(s.T(), 2, len(svc.clients), "账号隔离应缓存两个客户端")
}
// TestAccountProxyIsolation_DifferentProxy 测试账户+代理组合隔离模式
// 验证同一账户使用不同代理时创建独立连接池
func (s *HTTPUpstreamSuite) TestAccountProxyIsolation_DifferentProxy() {
s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccountProxy}
svc := s.newService()
// 同一账户,不同代理
entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 3)
entry2 := svc.getOrCreateClient("http://proxy-b:8080", 1, 3)
require.NotSame(s.T(), entry1, entry2, "账号+代理隔离应区分不同代理")
require.Equal(s.T(), 2, len(svc.clients), "账号+代理隔离应缓存两个客户端")
}
// TestAccountModeProxyChangeClearsPool 测试账户模式下代理变更
// 验证账户切换代理时清理旧连接池,避免复用错误代理
func (s *HTTPUpstreamSuite) TestAccountModeProxyChangeClearsPool() {
s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount}
svc := s.newService()
// 同一账户,先后使用不同代理
entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 3)
entry2 := svc.getOrCreateClient("http://proxy-b:8080", 1, 3)
require.NotSame(s.T(), entry1, entry2, "账号切换代理应创建新连接池")
require.Equal(s.T(), 1, len(svc.clients), "账号模式下应仅保留一个连接池")
require.False(s.T(), hasEntry(svc, entry1), "旧连接池应被清理")
}
// TestAccountConcurrencyOverridesPoolSettings 测试账户并发数覆盖连接池配置
// 验证账户隔离模式下,连接池大小与账户并发数对应
func (s *HTTPUpstreamSuite) TestAccountConcurrencyOverridesPoolSettings() {
s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount}
svc := s.newService()
// 账户并发数为 12
entry := svc.getOrCreateClient("", 1, 12)
transport, ok := entry.client.Transport.(*http.Transport)
require.True(s.T(), ok, "expected *http.Transport")
// 连接池参数应与并发数一致
require.Equal(s.T(), 12, transport.MaxConnsPerHost, "MaxConnsPerHost mismatch")
require.Equal(s.T(), 12, transport.MaxIdleConns, "MaxIdleConns mismatch")
require.Equal(s.T(), 12, transport.MaxIdleConnsPerHost, "MaxIdleConnsPerHost mismatch")
}
// TestAccountConcurrencyFallbackToDefault 测试账户并发数为 0 时回退到默认配置
// 验证未指定并发数时使用全局配置值
func (s *HTTPUpstreamSuite) TestAccountConcurrencyFallbackToDefault() {
s.cfg.Gateway = config.GatewayConfig{
ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount,
MaxIdleConns: 77,
MaxIdleConnsPerHost: 55,
MaxConnsPerHost: 66,
}
svc := s.newService()
// 账户并发数为 0,应使用全局配置
entry := svc.getOrCreateClient("", 1, 0)
transport, ok := entry.client.Transport.(*http.Transport)
require.True(s.T(), ok, "expected *http.Transport")
require.Equal(s.T(), 66, transport.MaxConnsPerHost, "MaxConnsPerHost fallback mismatch")
require.Equal(s.T(), 77, transport.MaxIdleConns, "MaxIdleConns fallback mismatch")
require.Equal(s.T(), 55, transport.MaxIdleConnsPerHost, "MaxIdleConnsPerHost fallback mismatch")
}
// TestEvictOverLimitRemovesOldestIdle 测试超出数量限制时的 LRU 淘汰
// 验证优先淘汰最久未使用的空闲客户端
func (s *HTTPUpstreamSuite) TestEvictOverLimitRemovesOldestIdle() {
s.cfg.Gateway = config.GatewayConfig{
ConnectionPoolIsolation: config.ConnectionPoolIsolationAccountProxy,
MaxUpstreamClients: 2, // 最多缓存 2 个客户端
}
svc := s.newService()
// 创建两个客户端,设置不同的最后使用时间
entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 1)
entry2 := svc.getOrCreateClient("http://proxy-b:8080", 2, 1)
atomic.StoreInt64(&entry1.lastUsed, time.Now().Add(-2*time.Hour).UnixNano()) // 最久
atomic.StoreInt64(&entry2.lastUsed, time.Now().Add(-time.Hour).UnixNano())
// 创建第三个客户端,触发淘汰
_ = svc.getOrCreateClient("http://proxy-c:8080", 3, 1)
require.LessOrEqual(s.T(), len(svc.clients), 2, "应保持在缓存上限内")
require.False(s.T(), hasEntry(svc, entry1), "最久未使用的连接池应被清理")
}
// TestIdleTTLDoesNotEvictActive 测试活跃请求保护
// 验证有进行中请求的客户端不会被空闲超时淘汰
func (s *HTTPUpstreamSuite) TestIdleTTLDoesNotEvictActive() {
s.cfg.Gateway = config.GatewayConfig{
ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount,
ClientIdleTTLSeconds: 1, // 1 秒空闲超时
}
svc := s.newService()
entry1 := svc.getOrCreateClient("", 1, 1)
// 设置为很久之前使用,但有活跃请求
atomic.StoreInt64(&entry1.lastUsed, time.Now().Add(-2*time.Minute).UnixNano())
atomic.StoreInt64(&entry1.inFlight, 1) // 模拟有活跃请求
// 创建新客户端,触发淘汰检查
_ = svc.getOrCreateClient("", 2, 1)
require.True(s.T(), hasEntry(svc, entry1), "有活跃请求时不应回收")
}
// TestHTTPUpstreamSuite 运行测试套件
func TestHTTPUpstreamSuite(t *testing.T) { func TestHTTPUpstreamSuite(t *testing.T) {
suite.Run(t, new(HTTPUpstreamSuite)) suite.Run(t, new(HTTPUpstreamSuite))
} }
// hasEntry 检查客户端是否存在于缓存中
// 辅助函数,用于验证淘汰逻辑
func hasEntry(svc *httpUpstreamService, target *upstreamClientEntry) bool {
for _, entry := range svc.clients {
if entry == target {
return true
}
}
return false
}
...@@ -17,7 +17,6 @@ import ( ...@@ -17,7 +17,6 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
_ "github.com/Wei-Shaw/sub2api/ent/runtime" _ "github.com/Wei-Shaw/sub2api/ent/runtime"
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
...@@ -97,7 +96,7 @@ func TestMain(m *testing.M) { ...@@ -97,7 +96,7 @@ func TestMain(m *testing.M) {
log.Printf("failed to open sql db: %v", err) log.Printf("failed to open sql db: %v", err)
os.Exit(1) os.Exit(1)
} }
if err := infrastructure.ApplyMigrations(ctx, integrationDB); err != nil { if err := ApplyMigrations(ctx, integrationDB); err != nil {
log.Printf("failed to apply db migrations: %v", err) log.Printf("failed to apply db migrations: %v", err)
os.Exit(1) os.Exit(1)
} }
...@@ -330,7 +329,8 @@ func (h prefixHook) prefixCmd(cmd redisclient.Cmder) { ...@@ -330,7 +329,8 @@ func (h prefixHook) prefixCmd(cmd redisclient.Cmder) {
switch strings.ToLower(cmd.Name()) { switch strings.ToLower(cmd.Name()) {
case "get", "set", "setnx", "setex", "psetex", "incr", "decr", "incrby", "expire", "pexpire", "ttl", "pttl", case "get", "set", "setnx", "setex", "psetex", "incr", "decr", "incrby", "expire", "pexpire", "ttl", "pttl",
"hgetall", "hget", "hset", "hdel", "hincrbyfloat", "exists": "hgetall", "hget", "hset", "hdel", "hincrbyfloat", "exists",
"zadd", "zcard", "zrange", "zrangebyscore", "zrem", "zremrangebyscore", "zrevrange", "zrevrangebyscore", "zscore":
prefixOne(1) prefixOne(1)
case "del", "unlink": case "del", "unlink":
for i := 1; i < len(args); i++ { for i := 1; i < len(args); i++ {
......
package infrastructure package repository
import ( import (
"context" "context"
...@@ -127,7 +127,15 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error { ...@@ -127,7 +127,15 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
if existing != checksum { if existing != checksum {
// 校验和不匹配意味着迁移文件在应用后被修改,这是危险的。 // 校验和不匹配意味着迁移文件在应用后被修改,这是危险的。
// 正确的做法是创建新的迁移文件来进行变更。 // 正确的做法是创建新的迁移文件来进行变更。
return fmt.Errorf("migration %s checksum mismatch (db=%s file=%s)", name, existing, checksum) return fmt.Errorf(
"migration %s checksum mismatch (db=%s file=%s)\n"+
"This means the migration file was modified after being applied to the database.\n"+
"Solutions:\n"+
" 1. Revert to original: git log --oneline -- migrations/%s && git checkout <commit> -- migrations/%s\n"+
" 2. For new changes, create a new migration file instead of modifying existing ones\n"+
"Note: Modifying applied migrations breaks the immutability principle and can cause inconsistencies across environments",
name, existing, checksum, name, name,
)
} }
continue // 迁移已应用且校验和匹配,跳过 continue // 迁移已应用且校验和匹配,跳过
} }
......
...@@ -7,7 +7,6 @@ import ( ...@@ -7,7 +7,6 @@ import (
"database/sql" "database/sql"
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
...@@ -15,7 +14,7 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) { ...@@ -15,7 +14,7 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
tx := testTx(t) tx := testTx(t)
// Re-apply migrations to verify idempotency (no errors, no duplicate rows). // Re-apply migrations to verify idempotency (no errors, no duplicate rows).
require.NoError(t, infrastructure.ApplyMigrations(context.Background(), integrationDB)) require.NoError(t, ApplyMigrations(context.Background(), integrationDB))
// schema_migrations should have at least the current migration set. // schema_migrations should have at least the current migration set.
var applied int var applied int
...@@ -53,6 +52,20 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) { ...@@ -53,6 +52,20 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
var uagRegclass sql.NullString var uagRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.user_allowed_groups')").Scan(&uagRegclass)) require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.user_allowed_groups')").Scan(&uagRegclass))
require.True(t, uagRegclass.Valid, "expected user_allowed_groups table to exist") require.True(t, uagRegclass.Valid, "expected user_allowed_groups table to exist")
// user_subscriptions: deleted_at for soft delete support (migration 012)
requireColumn(t, tx, "user_subscriptions", "deleted_at", "timestamp with time zone", 0, true)
// orphan_allowed_groups_audit table should exist (migration 013)
var orphanAuditRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.orphan_allowed_groups_audit')").Scan(&orphanAuditRegclass))
require.True(t, orphanAuditRegclass.Valid, "expected orphan_allowed_groups_audit table to exist")
// account_groups: created_at should be timestamptz
requireColumn(t, tx, "account_groups", "created_at", "timestamp with time zone", 0, false)
// user_allowed_groups: created_at should be timestamptz
requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false)
} }
func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) { func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) {
......
...@@ -82,12 +82,8 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro ...@@ -82,12 +82,8 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
} }
func createOpenAIReqClient(proxyURL string) *req.Client { func createOpenAIReqClient(proxyURL string) *req.Client {
client := req.C(). return getSharedReqClient(reqClientOptions{
SetTimeout(60 * time.Second) ProxyURL: proxyURL,
Timeout: 60 * time.Second,
if proxyURL != "" { })
client.SetProxyURL(proxyURL)
}
return client
} }
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
) )
...@@ -16,10 +17,14 @@ type pricingRemoteClient struct { ...@@ -16,10 +17,14 @@ type pricingRemoteClient struct {
} }
func NewPricingRemoteClient() service.PricingRemoteClient { func NewPricingRemoteClient() service.PricingRemoteClient {
sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 30 * time.Second,
})
if err != nil {
sharedClient = &http.Client{Timeout: 30 * time.Second}
}
return &pricingRemoteClient{ return &pricingRemoteClient{
httpClient: &http.Client{ httpClient: sharedClient,
Timeout: 30 * time.Second,
},
} }
} }
......
...@@ -2,18 +2,14 @@ package repository ...@@ -2,18 +2,14 @@ package repository
import ( import (
"context" "context"
"crypto/tls"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"net/url"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"golang.org/x/net/proxy"
) )
func NewProxyExitInfoProber() service.ProxyExitInfoProber { func NewProxyExitInfoProber() service.ProxyExitInfoProber {
...@@ -27,14 +23,14 @@ type proxyProbeService struct { ...@@ -27,14 +23,14 @@ type proxyProbeService struct {
} }
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) { func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
transport, err := createProxyTransport(proxyURL) client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL,
Timeout: 15 * time.Second,
InsecureSkipVerify: true,
ProxyStrict: true,
})
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("failed to create proxy transport: %w", err) return nil, 0, fmt.Errorf("failed to create proxy client: %w", err)
}
client := &http.Client{
Transport: transport,
Timeout: 15 * time.Second,
} }
startTime := time.Now() startTime := time.Now()
...@@ -78,31 +74,3 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s ...@@ -78,31 +74,3 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
Country: ipInfo.Country, Country: ipInfo.Country,
}, latencyMs, nil }, latencyMs, nil
} }
func createProxyTransport(proxyURL string) (*http.Transport, error) {
parsedURL, err := url.Parse(proxyURL)
if err != nil {
return nil, fmt.Errorf("invalid proxy URL: %w", err)
}
transport := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
switch parsedURL.Scheme {
case "http", "https":
transport.Proxy = http.ProxyURL(parsedURL)
case "socks5":
dialer, err := proxy.FromURL(parsedURL, proxy.Direct)
if err != nil {
return nil, fmt.Errorf("failed to create socks5 dialer: %w", err)
}
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
}
default:
return nil, fmt.Errorf("unsupported proxy protocol: %s", parsedURL.Scheme)
}
return transport, nil
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment