Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
7331220e
Commit
7331220e
authored
Jan 01, 2026
by
Edric Li
Browse files
Merge remote-tracking branch 'upstream/main'
# Conflicts: # frontend/src/components/account/CreateAccountModal.vue
parents
fb86002e
4f13c8de
Changes
215
Hide whitespace changes
Inline
Side-by-side
backend/internal/repository/concurrency_cache_benchmark_test.go
0 → 100644
View file @
7331220e
package
repository
import
(
"context"
"fmt"
"os"
"testing"
"time"
"github.com/redis/go-redis/v9"
)
// 基准测试用 TTL 配置
const
benchSlotTTLMinutes
=
15
var
benchSlotTTL
=
time
.
Duration
(
benchSlotTTLMinutes
)
*
time
.
Minute
// BenchmarkAccountConcurrency 用于对比 SCAN 与有序集合的计数性能。
func
BenchmarkAccountConcurrency
(
b
*
testing
.
B
)
{
rdb
:=
newBenchmarkRedisClient
(
b
)
defer
func
()
{
_
=
rdb
.
Close
()
}()
cache
,
_
:=
NewConcurrencyCache
(
rdb
,
benchSlotTTLMinutes
,
int
(
benchSlotTTL
.
Seconds
()))
.
(
*
concurrencyCache
)
ctx
:=
context
.
Background
()
for
_
,
size
:=
range
[]
int
{
10
,
100
,
1000
}
{
size
:=
size
b
.
Run
(
fmt
.
Sprintf
(
"zset/slots=%d"
,
size
),
func
(
b
*
testing
.
B
)
{
accountID
:=
time
.
Now
()
.
UnixNano
()
key
:=
accountSlotKey
(
accountID
)
b
.
StopTimer
()
members
:=
make
([]
redis
.
Z
,
0
,
size
)
now
:=
float64
(
time
.
Now
()
.
Unix
())
for
i
:=
0
;
i
<
size
;
i
++
{
members
=
append
(
members
,
redis
.
Z
{
Score
:
now
,
Member
:
fmt
.
Sprintf
(
"req_%d"
,
i
),
})
}
if
err
:=
rdb
.
ZAdd
(
ctx
,
key
,
members
...
)
.
Err
();
err
!=
nil
{
b
.
Fatalf
(
"初始化有序集合失败: %v"
,
err
)
}
if
err
:=
rdb
.
Expire
(
ctx
,
key
,
benchSlotTTL
)
.
Err
();
err
!=
nil
{
b
.
Fatalf
(
"设置有序集合 TTL 失败: %v"
,
err
)
}
b
.
StartTimer
()
b
.
ReportAllocs
()
for
i
:=
0
;
i
<
b
.
N
;
i
++
{
if
_
,
err
:=
cache
.
GetAccountConcurrency
(
ctx
,
accountID
);
err
!=
nil
{
b
.
Fatalf
(
"获取并发数量失败: %v"
,
err
)
}
}
b
.
StopTimer
()
if
err
:=
rdb
.
Del
(
ctx
,
key
)
.
Err
();
err
!=
nil
{
b
.
Fatalf
(
"清理有序集合失败: %v"
,
err
)
}
})
b
.
Run
(
fmt
.
Sprintf
(
"scan/slots=%d"
,
size
),
func
(
b
*
testing
.
B
)
{
accountID
:=
time
.
Now
()
.
UnixNano
()
pattern
:=
fmt
.
Sprintf
(
"%s%d:*"
,
accountSlotKeyPrefix
,
accountID
)
keys
:=
make
([]
string
,
0
,
size
)
b
.
StopTimer
()
pipe
:=
rdb
.
Pipeline
()
for
i
:=
0
;
i
<
size
;
i
++
{
key
:=
fmt
.
Sprintf
(
"%s%d:req_%d"
,
accountSlotKeyPrefix
,
accountID
,
i
)
keys
=
append
(
keys
,
key
)
pipe
.
Set
(
ctx
,
key
,
"1"
,
benchSlotTTL
)
}
if
_
,
err
:=
pipe
.
Exec
(
ctx
);
err
!=
nil
{
b
.
Fatalf
(
"初始化扫描键失败: %v"
,
err
)
}
b
.
StartTimer
()
b
.
ReportAllocs
()
for
i
:=
0
;
i
<
b
.
N
;
i
++
{
if
_
,
err
:=
scanSlotCount
(
ctx
,
rdb
,
pattern
);
err
!=
nil
{
b
.
Fatalf
(
"SCAN 计数失败: %v"
,
err
)
}
}
b
.
StopTimer
()
if
err
:=
rdb
.
Del
(
ctx
,
keys
...
)
.
Err
();
err
!=
nil
{
b
.
Fatalf
(
"清理扫描键失败: %v"
,
err
)
}
})
}
}
func
scanSlotCount
(
ctx
context
.
Context
,
rdb
*
redis
.
Client
,
pattern
string
)
(
int
,
error
)
{
var
cursor
uint64
count
:=
0
for
{
keys
,
nextCursor
,
err
:=
rdb
.
Scan
(
ctx
,
cursor
,
pattern
,
100
)
.
Result
()
if
err
!=
nil
{
return
0
,
err
}
count
+=
len
(
keys
)
if
nextCursor
==
0
{
break
}
cursor
=
nextCursor
}
return
count
,
nil
}
func
newBenchmarkRedisClient
(
b
*
testing
.
B
)
*
redis
.
Client
{
b
.
Helper
()
redisURL
:=
os
.
Getenv
(
"TEST_REDIS_URL"
)
if
redisURL
==
""
{
b
.
Skip
(
"未设置 TEST_REDIS_URL,跳过 Redis 基准测试"
)
}
opt
,
err
:=
redis
.
ParseURL
(
redisURL
)
if
err
!=
nil
{
b
.
Fatalf
(
"解析 TEST_REDIS_URL 失败: %v"
,
err
)
}
client
:=
redis
.
NewClient
(
opt
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
3
*
time
.
Second
)
defer
cancel
()
if
err
:=
client
.
Ping
(
ctx
)
.
Err
();
err
!=
nil
{
b
.
Fatalf
(
"Redis 连接失败: %v"
,
err
)
}
return
client
}
backend/internal/repository/concurrency_cache_integration_test.go
View file @
7331220e
...
...
@@ -14,6 +14,12 @@ import (
"github.com/stretchr/testify/suite"
)
// 测试用 TTL 配置(15 分钟,与默认值一致)
const
testSlotTTLMinutes
=
15
// 测试用 TTL Duration,用于 TTL 断言
var
testSlotTTL
=
time
.
Duration
(
testSlotTTLMinutes
)
*
time
.
Minute
type
ConcurrencyCacheSuite
struct
{
IntegrationRedisSuite
cache
service
.
ConcurrencyCache
...
...
@@ -21,7 +27,7 @@ type ConcurrencyCacheSuite struct {
func
(
s
*
ConcurrencyCacheSuite
)
SetupTest
()
{
s
.
IntegrationRedisSuite
.
SetupTest
()
s
.
cache
=
NewConcurrencyCache
(
s
.
rdb
)
s
.
cache
=
NewConcurrencyCache
(
s
.
rdb
,
testSlotTTLMinutes
,
int
(
testSlotTTL
.
Seconds
())
)
}
func
(
s
*
ConcurrencyCacheSuite
)
TestAccountSlot_AcquireAndRelease
()
{
...
...
@@ -54,7 +60,7 @@ func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() {
func
(
s
*
ConcurrencyCacheSuite
)
TestAccountSlot_TTL
()
{
accountID
:=
int64
(
11
)
reqID
:=
"req_ttl_test"
slotKey
:=
fmt
.
Sprintf
(
"%s%d
:%s
"
,
accountSlotKeyPrefix
,
accountID
,
reqID
)
slotKey
:=
fmt
.
Sprintf
(
"%s%d"
,
accountSlotKeyPrefix
,
accountID
)
ok
,
err
:=
s
.
cache
.
AcquireAccountSlot
(
s
.
ctx
,
accountID
,
5
,
reqID
)
require
.
NoError
(
s
.
T
(),
err
,
"AcquireAccountSlot"
)
...
...
@@ -62,7 +68,7 @@ func (s *ConcurrencyCacheSuite) TestAccountSlot_TTL() {
ttl
,
err
:=
s
.
rdb
.
TTL
(
s
.
ctx
,
slotKey
)
.
Result
()
require
.
NoError
(
s
.
T
(),
err
,
"TTL"
)
s
.
AssertTTLWithin
(
ttl
,
1
*
time
.
Second
,
s
lotTTL
)
s
.
AssertTTLWithin
(
ttl
,
1
*
time
.
Second
,
testS
lotTTL
)
}
func
(
s
*
ConcurrencyCacheSuite
)
TestAccountSlot_DuplicateReqID
()
{
...
...
@@ -139,7 +145,7 @@ func (s *ConcurrencyCacheSuite) TestUserSlot_AcquireAndRelease() {
func
(
s
*
ConcurrencyCacheSuite
)
TestUserSlot_TTL
()
{
userID
:=
int64
(
200
)
reqID
:=
"req_ttl_test"
slotKey
:=
fmt
.
Sprintf
(
"%s%d
:%s
"
,
userSlotKeyPrefix
,
userID
,
reqID
)
slotKey
:=
fmt
.
Sprintf
(
"%s%d"
,
userSlotKeyPrefix
,
userID
)
ok
,
err
:=
s
.
cache
.
AcquireUserSlot
(
s
.
ctx
,
userID
,
5
,
reqID
)
require
.
NoError
(
s
.
T
(),
err
,
"AcquireUserSlot"
)
...
...
@@ -147,7 +153,7 @@ func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() {
ttl
,
err
:=
s
.
rdb
.
TTL
(
s
.
ctx
,
slotKey
)
.
Result
()
require
.
NoError
(
s
.
T
(),
err
,
"TTL"
)
s
.
AssertTTLWithin
(
ttl
,
1
*
time
.
Second
,
s
lotTTL
)
s
.
AssertTTLWithin
(
ttl
,
1
*
time
.
Second
,
testS
lotTTL
)
}
func
(
s
*
ConcurrencyCacheSuite
)
TestWaitQueue_IncrementAndDecrement
()
{
...
...
@@ -168,7 +174,7 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
ttl
,
err
:=
s
.
rdb
.
TTL
(
s
.
ctx
,
waitKey
)
.
Result
()
require
.
NoError
(
s
.
T
(),
err
,
"TTL waitKey"
)
s
.
AssertTTLWithin
(
ttl
,
1
*
time
.
Second
,
s
lotTTL
)
s
.
AssertTTLWithin
(
ttl
,
1
*
time
.
Second
,
testS
lotTTL
)
require
.
NoError
(
s
.
T
(),
s
.
cache
.
DecrementWaitCount
(
s
.
ctx
,
userID
),
"DecrementWaitCount"
)
...
...
@@ -212,6 +218,48 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
require
.
GreaterOrEqual
(
s
.
T
(),
val
,
0
,
"expected non-negative wait count"
)
}
func
(
s
*
ConcurrencyCacheSuite
)
TestAccountWaitQueue_IncrementAndDecrement
()
{
accountID
:=
int64
(
30
)
waitKey
:=
fmt
.
Sprintf
(
"%s%d"
,
accountWaitKeyPrefix
,
accountID
)
ok
,
err
:=
s
.
cache
.
IncrementAccountWaitCount
(
s
.
ctx
,
accountID
,
2
)
require
.
NoError
(
s
.
T
(),
err
,
"IncrementAccountWaitCount 1"
)
require
.
True
(
s
.
T
(),
ok
)
ok
,
err
=
s
.
cache
.
IncrementAccountWaitCount
(
s
.
ctx
,
accountID
,
2
)
require
.
NoError
(
s
.
T
(),
err
,
"IncrementAccountWaitCount 2"
)
require
.
True
(
s
.
T
(),
ok
)
ok
,
err
=
s
.
cache
.
IncrementAccountWaitCount
(
s
.
ctx
,
accountID
,
2
)
require
.
NoError
(
s
.
T
(),
err
,
"IncrementAccountWaitCount 3"
)
require
.
False
(
s
.
T
(),
ok
,
"expected account wait increment over max to fail"
)
ttl
,
err
:=
s
.
rdb
.
TTL
(
s
.
ctx
,
waitKey
)
.
Result
()
require
.
NoError
(
s
.
T
(),
err
,
"TTL account waitKey"
)
s
.
AssertTTLWithin
(
ttl
,
1
*
time
.
Second
,
testSlotTTL
)
require
.
NoError
(
s
.
T
(),
s
.
cache
.
DecrementAccountWaitCount
(
s
.
ctx
,
accountID
),
"DecrementAccountWaitCount"
)
val
,
err
:=
s
.
rdb
.
Get
(
s
.
ctx
,
waitKey
)
.
Int
()
if
!
errors
.
Is
(
err
,
redis
.
Nil
)
{
require
.
NoError
(
s
.
T
(),
err
,
"Get waitKey"
)
}
require
.
Equal
(
s
.
T
(),
1
,
val
,
"expected account wait count 1"
)
}
func
(
s
*
ConcurrencyCacheSuite
)
TestAccountWaitQueue_DecrementNoNegative
()
{
accountID
:=
int64
(
301
)
waitKey
:=
fmt
.
Sprintf
(
"%s%d"
,
accountWaitKeyPrefix
,
accountID
)
require
.
NoError
(
s
.
T
(),
s
.
cache
.
DecrementAccountWaitCount
(
s
.
ctx
,
accountID
),
"DecrementAccountWaitCount on non-existent key"
)
val
,
err
:=
s
.
rdb
.
Get
(
s
.
ctx
,
waitKey
)
.
Int
()
if
!
errors
.
Is
(
err
,
redis
.
Nil
)
{
require
.
NoError
(
s
.
T
(),
err
,
"Get waitKey"
)
}
require
.
GreaterOrEqual
(
s
.
T
(),
val
,
0
,
"expected non-negative account wait count after decrement on empty"
)
}
func
(
s
*
ConcurrencyCacheSuite
)
TestGetAccountConcurrency_Missing
()
{
// When no slots exist, GetAccountConcurrency should return 0
cur
,
err
:=
s
.
cache
.
GetAccountConcurrency
(
s
.
ctx
,
999
)
...
...
@@ -226,6 +274,139 @@ func (s *ConcurrencyCacheSuite) TestGetUserConcurrency_Missing() {
require
.
Equal
(
s
.
T
(),
0
,
cur
)
}
func
(
s
*
ConcurrencyCacheSuite
)
TestGetAccountsLoadBatch
()
{
s
.
T
()
.
Skip
(
"TODO: Fix this test - CurrentConcurrency returns 0 instead of expected value in CI"
)
// Setup: Create accounts with different load states
account1
:=
int64
(
100
)
account2
:=
int64
(
101
)
account3
:=
int64
(
102
)
// Account 1: 2/3 slots used, 1 waiting
ok
,
err
:=
s
.
cache
.
AcquireAccountSlot
(
s
.
ctx
,
account1
,
3
,
"req1"
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
True
(
s
.
T
(),
ok
)
ok
,
err
=
s
.
cache
.
AcquireAccountSlot
(
s
.
ctx
,
account1
,
3
,
"req2"
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
True
(
s
.
T
(),
ok
)
ok
,
err
=
s
.
cache
.
IncrementAccountWaitCount
(
s
.
ctx
,
account1
,
5
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
True
(
s
.
T
(),
ok
)
// Account 2: 1/2 slots used, 0 waiting
ok
,
err
=
s
.
cache
.
AcquireAccountSlot
(
s
.
ctx
,
account2
,
2
,
"req3"
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
True
(
s
.
T
(),
ok
)
// Account 3: 0/1 slots used, 0 waiting (idle)
// Query batch load
accounts
:=
[]
service
.
AccountWithConcurrency
{
{
ID
:
account1
,
MaxConcurrency
:
3
},
{
ID
:
account2
,
MaxConcurrency
:
2
},
{
ID
:
account3
,
MaxConcurrency
:
1
},
}
loadMap
,
err
:=
s
.
cache
.
GetAccountsLoadBatch
(
s
.
ctx
,
accounts
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
Len
(
s
.
T
(),
loadMap
,
3
)
// Verify account1: (2 + 1) / 3 = 100%
load1
:=
loadMap
[
account1
]
require
.
NotNil
(
s
.
T
(),
load1
)
require
.
Equal
(
s
.
T
(),
account1
,
load1
.
AccountID
)
require
.
Equal
(
s
.
T
(),
2
,
load1
.
CurrentConcurrency
)
require
.
Equal
(
s
.
T
(),
1
,
load1
.
WaitingCount
)
require
.
Equal
(
s
.
T
(),
100
,
load1
.
LoadRate
)
// Verify account2: (1 + 0) / 2 = 50%
load2
:=
loadMap
[
account2
]
require
.
NotNil
(
s
.
T
(),
load2
)
require
.
Equal
(
s
.
T
(),
account2
,
load2
.
AccountID
)
require
.
Equal
(
s
.
T
(),
1
,
load2
.
CurrentConcurrency
)
require
.
Equal
(
s
.
T
(),
0
,
load2
.
WaitingCount
)
require
.
Equal
(
s
.
T
(),
50
,
load2
.
LoadRate
)
// Verify account3: (0 + 0) / 1 = 0%
load3
:=
loadMap
[
account3
]
require
.
NotNil
(
s
.
T
(),
load3
)
require
.
Equal
(
s
.
T
(),
account3
,
load3
.
AccountID
)
require
.
Equal
(
s
.
T
(),
0
,
load3
.
CurrentConcurrency
)
require
.
Equal
(
s
.
T
(),
0
,
load3
.
WaitingCount
)
require
.
Equal
(
s
.
T
(),
0
,
load3
.
LoadRate
)
}
func
(
s
*
ConcurrencyCacheSuite
)
TestGetAccountsLoadBatch_Empty
()
{
// Test with empty account list
loadMap
,
err
:=
s
.
cache
.
GetAccountsLoadBatch
(
s
.
ctx
,
[]
service
.
AccountWithConcurrency
{})
require
.
NoError
(
s
.
T
(),
err
)
require
.
Empty
(
s
.
T
(),
loadMap
)
}
func
(
s
*
ConcurrencyCacheSuite
)
TestCleanupExpiredAccountSlots
()
{
accountID
:=
int64
(
200
)
slotKey
:=
fmt
.
Sprintf
(
"%s%d"
,
accountSlotKeyPrefix
,
accountID
)
// Acquire 3 slots
ok
,
err
:=
s
.
cache
.
AcquireAccountSlot
(
s
.
ctx
,
accountID
,
5
,
"req1"
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
True
(
s
.
T
(),
ok
)
ok
,
err
=
s
.
cache
.
AcquireAccountSlot
(
s
.
ctx
,
accountID
,
5
,
"req2"
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
True
(
s
.
T
(),
ok
)
ok
,
err
=
s
.
cache
.
AcquireAccountSlot
(
s
.
ctx
,
accountID
,
5
,
"req3"
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
True
(
s
.
T
(),
ok
)
// Verify 3 slots exist
cur
,
err
:=
s
.
cache
.
GetAccountConcurrency
(
s
.
ctx
,
accountID
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
Equal
(
s
.
T
(),
3
,
cur
)
// Manually set old timestamps for req1 and req2 (simulate expired slots)
now
:=
time
.
Now
()
.
Unix
()
expiredTime
:=
now
-
int64
(
testSlotTTL
.
Seconds
())
-
10
// 10 seconds past TTL
err
=
s
.
rdb
.
ZAdd
(
s
.
ctx
,
slotKey
,
redis
.
Z
{
Score
:
float64
(
expiredTime
),
Member
:
"req1"
})
.
Err
()
require
.
NoError
(
s
.
T
(),
err
)
err
=
s
.
rdb
.
ZAdd
(
s
.
ctx
,
slotKey
,
redis
.
Z
{
Score
:
float64
(
expiredTime
),
Member
:
"req2"
})
.
Err
()
require
.
NoError
(
s
.
T
(),
err
)
// Run cleanup
err
=
s
.
cache
.
CleanupExpiredAccountSlots
(
s
.
ctx
,
accountID
)
require
.
NoError
(
s
.
T
(),
err
)
// Verify only 1 slot remains (req3)
cur
,
err
=
s
.
cache
.
GetAccountConcurrency
(
s
.
ctx
,
accountID
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
Equal
(
s
.
T
(),
1
,
cur
)
// Verify req3 still exists
members
,
err
:=
s
.
rdb
.
ZRange
(
s
.
ctx
,
slotKey
,
0
,
-
1
)
.
Result
()
require
.
NoError
(
s
.
T
(),
err
)
require
.
Len
(
s
.
T
(),
members
,
1
)
require
.
Equal
(
s
.
T
(),
"req3"
,
members
[
0
])
}
func
(
s
*
ConcurrencyCacheSuite
)
TestCleanupExpiredAccountSlots_NoExpired
()
{
accountID
:=
int64
(
201
)
// Acquire 2 fresh slots
ok
,
err
:=
s
.
cache
.
AcquireAccountSlot
(
s
.
ctx
,
accountID
,
5
,
"req1"
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
True
(
s
.
T
(),
ok
)
ok
,
err
=
s
.
cache
.
AcquireAccountSlot
(
s
.
ctx
,
accountID
,
5
,
"req2"
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
True
(
s
.
T
(),
ok
)
// Run cleanup (should not remove anything)
err
=
s
.
cache
.
CleanupExpiredAccountSlots
(
s
.
ctx
,
accountID
)
require
.
NoError
(
s
.
T
(),
err
)
// Verify both slots still exist
cur
,
err
:=
s
.
cache
.
GetAccountConcurrency
(
s
.
ctx
,
accountID
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
Equal
(
s
.
T
(),
2
,
cur
)
}
func
TestConcurrencyCacheSuite
(
t
*
testing
.
T
)
{
suite
.
Run
(
t
,
new
(
ConcurrencyCacheSuite
))
}
backend/internal/repository/db_pool.go
0 → 100644
View file @
7331220e
package
repository
import
(
"database/sql"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
)
type
dbPoolSettings
struct
{
MaxOpenConns
int
MaxIdleConns
int
ConnMaxLifetime
time
.
Duration
ConnMaxIdleTime
time
.
Duration
}
func
buildDBPoolSettings
(
cfg
*
config
.
Config
)
dbPoolSettings
{
return
dbPoolSettings
{
MaxOpenConns
:
cfg
.
Database
.
MaxOpenConns
,
MaxIdleConns
:
cfg
.
Database
.
MaxIdleConns
,
ConnMaxLifetime
:
time
.
Duration
(
cfg
.
Database
.
ConnMaxLifetimeMinutes
)
*
time
.
Minute
,
ConnMaxIdleTime
:
time
.
Duration
(
cfg
.
Database
.
ConnMaxIdleTimeMinutes
)
*
time
.
Minute
,
}
}
func
applyDBPoolSettings
(
db
*
sql
.
DB
,
cfg
*
config
.
Config
)
{
settings
:=
buildDBPoolSettings
(
cfg
)
db
.
SetMaxOpenConns
(
settings
.
MaxOpenConns
)
db
.
SetMaxIdleConns
(
settings
.
MaxIdleConns
)
db
.
SetConnMaxLifetime
(
settings
.
ConnMaxLifetime
)
db
.
SetConnMaxIdleTime
(
settings
.
ConnMaxIdleTime
)
}
backend/internal/repository/db_pool_test.go
0 → 100644
View file @
7331220e
package
repository
import
(
"database/sql"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
_
"github.com/lib/pq"
)
func
TestBuildDBPoolSettings
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
Database
:
config
.
DatabaseConfig
{
MaxOpenConns
:
50
,
MaxIdleConns
:
10
,
ConnMaxLifetimeMinutes
:
30
,
ConnMaxIdleTimeMinutes
:
5
,
},
}
settings
:=
buildDBPoolSettings
(
cfg
)
require
.
Equal
(
t
,
50
,
settings
.
MaxOpenConns
)
require
.
Equal
(
t
,
10
,
settings
.
MaxIdleConns
)
require
.
Equal
(
t
,
30
*
time
.
Minute
,
settings
.
ConnMaxLifetime
)
require
.
Equal
(
t
,
5
*
time
.
Minute
,
settings
.
ConnMaxIdleTime
)
}
func
TestApplyDBPoolSettings
(
t
*
testing
.
T
)
{
cfg
:=
&
config
.
Config
{
Database
:
config
.
DatabaseConfig
{
MaxOpenConns
:
40
,
MaxIdleConns
:
8
,
ConnMaxLifetimeMinutes
:
15
,
ConnMaxIdleTimeMinutes
:
3
,
},
}
db
,
err
:=
sql
.
Open
(
"postgres"
,
"host=127.0.0.1 port=5432 user=postgres sslmode=disable"
)
require
.
NoError
(
t
,
err
)
t
.
Cleanup
(
func
()
{
_
=
db
.
Close
()
})
applyDBPoolSettings
(
db
,
cfg
)
stats
:=
db
.
Stats
()
require
.
Equal
(
t
,
40
,
stats
.
MaxOpenConnections
)
}
backend/internal/
infrastructure
/ent.go
→
backend/internal/
repository
/ent.go
View file @
7331220e
// Package infrastructure 提供应用程序的基础设施层组件。
// 包括数据库连接初始化、ORM 客户端管理、Redis 连接、数据库迁移等核心功能。
package
infrastructure
package
repository
import
(
"context"
...
...
@@ -51,6 +51,7 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) {
if
err
!=
nil
{
return
nil
,
nil
,
err
}
applyDBPoolSettings
(
drv
.
DB
(),
cfg
)
// 确保数据库 schema 已准备就绪。
// SQL 迁移文件是 schema 的权威来源(source of truth)。
...
...
backend/internal/repository/error_translate.go
View file @
7331220e
package
repository
import
(
"context"
"database/sql"
"errors"
"strings"
dbent
"github.com/Wei-Shaw/sub2api/ent"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/
infrastructure
/errors"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/
pkg
/errors"
"github.com/lib/pq"
)
// clientFromContext 从 context 中获取事务 client,如果不存在则返回默认 client。
//
// 这个辅助函数支持 repository 方法在事务上下文中工作:
// - 如果 context 中存在事务(通过 ent.NewTxContext 设置),返回事务的 client
// - 否则返回传入的默认 client
//
// 使用示例:
//
// func (r *someRepo) SomeMethod(ctx context.Context) error {
// client := clientFromContext(ctx, r.client)
// return client.SomeEntity.Create().Save(ctx)
// }
func
clientFromContext
(
ctx
context
.
Context
,
defaultClient
*
dbent
.
Client
)
*
dbent
.
Client
{
if
tx
:=
dbent
.
TxFromContext
(
ctx
);
tx
!=
nil
{
return
tx
.
Client
()
}
return
defaultClient
}
// translatePersistenceError 将数据库层错误翻译为业务层错误。
//
// 这是 Repository 层的核心错误处理函数,确保数据库细节不会泄露到业务层。
...
...
backend/internal/repository/gemini_oauth_client.go
View file @
7331220e
...
...
@@ -109,9 +109,8 @@ func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refresh
}
func
createGeminiReqClient
(
proxyURL
string
)
*
req
.
Client
{
client
:=
req
.
C
()
.
SetTimeout
(
60
*
time
.
Second
)
if
proxyURL
!=
""
{
client
.
SetProxyURL
(
proxyURL
)
}
return
client
return
getSharedReqClient
(
reqClientOptions
{
ProxyURL
:
proxyURL
,
Timeout
:
60
*
time
.
Second
,
})
}
backend/internal/repository/geminicli_codeassist_client.go
View file @
7331220e
...
...
@@ -76,11 +76,10 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken
}
func
createGeminiCliReqClient
(
proxyURL
string
)
*
req
.
Client
{
client
:=
req
.
C
()
.
SetTimeout
(
30
*
time
.
Second
)
if
proxyURL
!=
""
{
client
.
SetProxyURL
(
proxyURL
)
}
return
client
return
getSharedReqClient
(
reqClientOptions
{
ProxyURL
:
proxyURL
,
Timeout
:
30
*
time
.
Second
,
})
}
func
defaultLoadCodeAssistRequest
()
*
geminicli
.
LoadCodeAssistRequest
{
...
...
backend/internal/repository/github_release_service.go
View file @
7331220e
...
...
@@ -9,6 +9,7 @@ import (
"os"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service"
)
...
...
@@ -17,10 +18,14 @@ type githubReleaseClient struct {
}
func
NewGitHubReleaseClient
()
service
.
GitHubReleaseClient
{
sharedClient
,
err
:=
httpclient
.
GetClient
(
httpclient
.
Options
{
Timeout
:
30
*
time
.
Second
,
})
if
err
!=
nil
{
sharedClient
=
&
http
.
Client
{
Timeout
:
30
*
time
.
Second
}
}
return
&
githubReleaseClient
{
httpClient
:
&
http
.
Client
{
Timeout
:
30
*
time
.
Second
,
},
httpClient
:
sharedClient
,
}
}
...
...
@@ -58,8 +63,13 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
return
err
}
client
:=
&
http
.
Client
{
Timeout
:
10
*
time
.
Minute
}
resp
,
err
:=
client
.
Do
(
req
)
downloadClient
,
err
:=
httpclient
.
GetClient
(
httpclient
.
Options
{
Timeout
:
10
*
time
.
Minute
,
})
if
err
!=
nil
{
downloadClient
=
&
http
.
Client
{
Timeout
:
10
*
time
.
Minute
}
}
resp
,
err
:=
downloadClient
.
Do
(
req
)
if
err
!=
nil
{
return
err
}
...
...
backend/internal/repository/group_repo.go
View file @
7331220e
...
...
@@ -42,7 +42,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetSubscriptionType
(
groupIn
.
SubscriptionType
)
.
SetNillableDailyLimitUsd
(
groupIn
.
DailyLimitUSD
)
.
SetNillableWeeklyLimitUsd
(
groupIn
.
WeeklyLimitUSD
)
.
SetNillableMonthlyLimitUsd
(
groupIn
.
MonthlyLimitUSD
)
SetNillableMonthlyLimitUsd
(
groupIn
.
MonthlyLimitUSD
)
.
SetDefaultValidityDays
(
groupIn
.
DefaultValidityDays
)
created
,
err
:=
builder
.
Save
(
ctx
)
if
err
==
nil
{
...
...
@@ -79,6 +80,7 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetNillableDailyLimitUsd
(
groupIn
.
DailyLimitUSD
)
.
SetNillableWeeklyLimitUsd
(
groupIn
.
WeeklyLimitUSD
)
.
SetNillableMonthlyLimitUsd
(
groupIn
.
MonthlyLimitUSD
)
.
SetDefaultValidityDays
(
groupIn
.
DefaultValidityDays
)
.
Save
(
ctx
)
if
err
!=
nil
{
return
translatePersistenceError
(
err
,
service
.
ErrGroupNotFound
,
service
.
ErrGroupExists
)
...
...
@@ -89,7 +91,7 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
func
(
r
*
groupRepository
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
_
,
err
:=
r
.
client
.
Group
.
Delete
()
.
Where
(
group
.
IDEQ
(
id
))
.
Exec
(
ctx
)
return
err
return
translatePersistenceError
(
err
,
service
.
ErrGroupNotFound
,
nil
)
}
func
(
r
*
groupRepository
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
service
.
Group
,
*
pagination
.
PaginationResult
,
error
)
{
...
...
@@ -239,8 +241,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
// err 为 dbent.ErrTxStarted 时,复用当前 client 参与同一事务。
// Lock the group row to avoid concurrent writes while we cascade.
// 这里使用 exec.QueryContext 手动扫描,确保同一事务内加锁并能区分
“
未找到
”
与其他错误。
rows
,
err
:=
exec
.
QueryContext
(
ctx
,
"SELECT id FROM groups WHERE id = $1 FOR UPDATE"
,
id
)
// 这里使用 exec.QueryContext 手动扫描,确保同一事务内加锁并能区分
"
未找到
"
与其他错误。
rows
,
err
:=
exec
.
QueryContext
(
ctx
,
"SELECT id FROM groups WHERE id = $1
AND deleted_at IS NULL
FOR UPDATE"
,
id
)
if
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -263,7 +265,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
var
affectedUserIDs
[]
int64
if
groupSvc
.
IsSubscriptionType
()
{
rows
,
err
:=
exec
.
QueryContext
(
ctx
,
"SELECT user_id FROM user_subscriptions WHERE group_id = $1"
,
id
)
// 只查询未软删除的订阅,避免通知已取消订阅的用户
rows
,
err
:=
exec
.
QueryContext
(
ctx
,
"SELECT user_id FROM user_subscriptions WHERE group_id = $1 AND deleted_at IS NULL"
,
id
)
if
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -282,7 +285,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
return
nil
,
err
}
if
_
,
err
:=
exec
.
ExecContext
(
ctx
,
"DELETE FROM user_subscriptions WHERE group_id = $1"
,
id
);
err
!=
nil
{
// 软删除订阅:设置 deleted_at 而非硬删除
if
_
,
err
:=
exec
.
ExecContext
(
ctx
,
"UPDATE user_subscriptions SET deleted_at = NOW() WHERE group_id = $1 AND deleted_at IS NULL"
,
id
);
err
!=
nil
{
return
nil
,
err
}
}
...
...
@@ -297,18 +301,11 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
return
nil
,
err
}
// 3. Remove the group id from user
s.
allowed_groups
array (legacy representation)
.
//
Phase 1 compatibility: also delete from
user
_
allowed_groups
join table when present.
// 3. Remove the group id from user
_
allowed_groups
join table
.
//
Legacy
user
s.
allowed_groups
列已弃用,不再同步。
if
_
,
err
:=
exec
.
ExecContext
(
ctx
,
"DELETE FROM user_allowed_groups WHERE group_id = $1"
,
id
);
err
!=
nil
{
return
nil
,
err
}
if
_
,
err
:=
exec
.
ExecContext
(
ctx
,
"UPDATE users SET allowed_groups = array_remove(allowed_groups, $1) WHERE $1 = ANY(allowed_groups)"
,
id
,
);
err
!=
nil
{
return
nil
,
err
}
// 4. Delete account_groups join rows.
if
_
,
err
:=
exec
.
ExecContext
(
ctx
,
"DELETE FROM account_groups WHERE group_id = $1"
,
id
);
err
!=
nil
{
...
...
backend/internal/repository/group_repo_integration_test.go
View file @
7331220e
...
...
@@ -478,3 +478,58 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
count
,
_
:=
s
.
repo
.
GetAccountCount
(
s
.
ctx
,
g
.
ID
)
s
.
Require
()
.
Zero
(
count
)
}
// --- 软删除过滤测试 ---
func
(
s
*
GroupRepoSuite
)
TestDelete_SoftDelete_NotVisibleInList
()
{
group
:=
&
service
.
Group
{
Name
:
"to-soft-delete"
,
Platform
:
service
.
PlatformAnthropic
,
RateMultiplier
:
1.0
,
IsExclusive
:
false
,
Status
:
service
.
StatusActive
,
SubscriptionType
:
service
.
SubscriptionTypeStandard
,
}
s
.
Require
()
.
NoError
(
s
.
repo
.
Create
(
s
.
ctx
,
group
))
// 获取删除前的列表数量
listBefore
,
_
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
100
})
s
.
Require
()
.
NoError
(
err
)
beforeCount
:=
len
(
listBefore
)
// 软删除
err
=
s
.
repo
.
Delete
(
s
.
ctx
,
group
.
ID
)
s
.
Require
()
.
NoError
(
err
,
"Delete (soft delete)"
)
// 验证列表中不再包含软删除的 group
listAfter
,
_
,
err
:=
s
.
repo
.
List
(
s
.
ctx
,
pagination
.
PaginationParams
{
Page
:
1
,
PageSize
:
100
})
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Len
(
listAfter
,
beforeCount
-
1
,
"soft deleted group should not appear in list"
)
// 验证 GetByID 也无法找到
_
,
err
=
s
.
repo
.
GetByID
(
s
.
ctx
,
group
.
ID
)
s
.
Require
()
.
Error
(
err
)
s
.
Require
()
.
ErrorIs
(
err
,
service
.
ErrGroupNotFound
)
}
func
(
s
*
GroupRepoSuite
)
TestDelete_SoftDeletedGroup_lockForUpdate
()
{
group
:=
&
service
.
Group
{
Name
:
"lock-soft-delete"
,
Platform
:
service
.
PlatformAnthropic
,
RateMultiplier
:
1.0
,
IsExclusive
:
false
,
Status
:
service
.
StatusActive
,
SubscriptionType
:
service
.
SubscriptionTypeStandard
,
}
s
.
Require
()
.
NoError
(
s
.
repo
.
Create
(
s
.
ctx
,
group
))
// 软删除
err
:=
s
.
repo
.
Delete
(
s
.
ctx
,
group
.
ID
)
s
.
Require
()
.
NoError
(
err
)
// 验证软删除的 group 在 GetByID 时返回 ErrGroupNotFound
// 这证明 lockForUpdate 的 deleted_at IS NULL 过滤正在工作
_
,
err
=
s
.
repo
.
GetByID
(
s
.
ctx
,
group
.
ID
)
s
.
Require
()
.
Error
(
err
,
"should fail to get soft-deleted group"
)
s
.
Require
()
.
ErrorIs
(
err
,
service
.
ErrGroupNotFound
)
}
backend/internal/repository/http_upstream.go
View file @
7331220e
package
repository
import
(
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// httpUpstreamService is a generic HTTP upstream service that can be used for
// making requests to any HTTP API (Claude, OpenAI, etc.) with optional proxy support.
// 默认配置常量
// 这些值在配置文件未指定时作为回退默认值使用
const
(
// directProxyKey: 无代理时的缓存键标识
directProxyKey
=
"direct"
// defaultMaxIdleConns: 默认最大空闲连接总数
// HTTP/2 场景下,单连接可多路复用,240 足以支撑高并发
defaultMaxIdleConns
=
240
// defaultMaxIdleConnsPerHost: 默认每主机最大空闲连接数
defaultMaxIdleConnsPerHost
=
120
// defaultMaxConnsPerHost: 默认每主机最大连接数(含活跃连接)
// 达到上限后新请求会等待,而非无限创建连接
defaultMaxConnsPerHost
=
240
// defaultIdleConnTimeout: 默认空闲连接超时时间(5分钟)
// 超时后连接会被关闭,释放系统资源
defaultIdleConnTimeout
=
300
*
time
.
Second
// defaultResponseHeaderTimeout: 默认等待响应头超时时间(5分钟)
// LLM 请求可能排队较久,需要较长超时
defaultResponseHeaderTimeout
=
300
*
time
.
Second
// defaultMaxUpstreamClients: 默认最大客户端缓存数量
// 超出后会淘汰最久未使用的客户端
defaultMaxUpstreamClients
=
5000
// defaultClientIdleTTLSeconds: 默认客户端空闲回收阈值(15分钟)
defaultClientIdleTTLSeconds
=
900
)
var
errUpstreamClientLimitReached
=
errors
.
New
(
"upstream client cache limit reached"
)
// poolSettings 连接池配置参数
// 封装 Transport 所需的各项连接池参数
type
poolSettings
struct
{
maxIdleConns
int
// 最大空闲连接总数
maxIdleConnsPerHost
int
// 每主机最大空闲连接数
maxConnsPerHost
int
// 每主机最大连接数(含活跃)
idleConnTimeout
time
.
Duration
// 空闲连接超时时间
responseHeaderTimeout
time
.
Duration
// 等待响应头超时时间
}
// upstreamClientEntry 上游客户端缓存条目
// 记录客户端实例及其元数据,用于连接池管理和淘汰策略
type
upstreamClientEntry
struct
{
client
*
http
.
Client
// HTTP 客户端实例
proxyKey
string
// 代理标识(用于检测代理变更)
poolKey
string
// 连接池配置标识(用于检测配置变更)
lastUsed
int64
// 最后使用时间戳(纳秒),用于 LRU 淘汰
inFlight
int64
// 当前进行中的请求数,>0 时不可淘汰
}
// httpUpstreamService 通用 HTTP 上游服务
// 用于向任意 HTTP API(Claude、OpenAI 等)发送请求,支持可选代理
//
// 架构设计:
// - 根据隔离策略(proxy/account/account_proxy)缓存客户端实例
// - 每个客户端拥有独立的 Transport 连接池
// - 支持 LRU + 空闲时间双重淘汰策略
//
// 性能优化:
// 1. 根据隔离策略缓存客户端实例,避免频繁创建 http.Client
// 2. 复用 Transport 连接池,减少 TCP 握手和 TLS 协商开销
// 3. 支持账号级隔离与空闲回收,降低连接层关联风险
// 4. 达到最大连接数后等待可用连接,而非无限创建
// 5. 仅回收空闲客户端,避免中断活跃请求
// 6. HTTP/2 多路复用,连接上限不等于并发请求上限
// 7. 代理变更时清空旧连接池,避免复用错误代理
// 8. 账号并发数与连接池上限对应(账号隔离策略下)
type
httpUpstreamService
struct
{
defaultClient
*
http
.
Client
cfg
*
config
.
Config
cfg
*
config
.
Config
// 全局配置
mu
sync
.
RWMutex
// 保护 clients map 的读写锁
clients
map
[
string
]
*
upstreamClientEntry
// 客户端缓存池,key 由隔离策略决定
}
// NewHTTPUpstream creates a new generic HTTP upstream service
// NewHTTPUpstream 创建通用 HTTP 上游服务
// 使用配置中的连接池参数构建 Transport
//
// 参数:
// - cfg: 全局配置,包含连接池参数和隔离策略
//
// 返回:
// - service.HTTPUpstream 接口实现
func
NewHTTPUpstream
(
cfg
*
config
.
Config
)
service
.
HTTPUpstream
{
re
sponseHeaderTimeout
:=
time
.
Duration
(
cfg
.
Gateway
.
ResponseHeaderTimeout
)
*
time
.
Second
if
responseHeaderTimeout
==
0
{
responseHeaderTimeout
=
300
*
time
.
Second
re
turn
&
httpUpstreamService
{
cfg
:
cfg
,
clients
:
make
(
map
[
string
]
*
upstreamClientEntry
),
}
}
transport
:=
&
http
.
Transport
{
MaxIdleConns
:
100
,
MaxIdleConnsPerHost
:
10
,
IdleConnTimeout
:
90
*
time
.
Second
,
ResponseHeaderTimeout
:
responseHeaderTimeout
,
// Do 执行 HTTP 请求
// 根据隔离策略获取或创建客户端,并跟踪请求生命周期
//
// 参数:
// - req: HTTP 请求对象
// - proxyURL: 代理地址,空字符串表示直连
// - accountID: 账户 ID,用于账户级隔离
// - accountConcurrency: 账户并发限制,用于动态调整连接池大小
//
// 返回:
// - *http.Response: HTTP 响应(Body 已包装,关闭时自动更新计数)
// - error: 请求错误
//
// 注意:
// - 调用方必须关闭 resp.Body,否则会导致 inFlight 计数泄漏
// - inFlight > 0 的客户端不会被淘汰,确保活跃请求不被中断
func
(
s
*
httpUpstreamService
)
Do
(
req
*
http
.
Request
,
proxyURL
string
,
accountID
int64
,
accountConcurrency
int
)
(
*
http
.
Response
,
error
)
{
// 获取或创建对应的客户端,并标记请求占用
entry
,
err
:=
s
.
acquireClient
(
proxyURL
,
accountID
,
accountConcurrency
)
if
err
!=
nil
{
return
nil
,
err
}
return
&
httpUpstreamService
{
defaultClient
:
&
http
.
Client
{
Transport
:
transport
},
cfg
:
cfg
,
// 执行请求
resp
,
err
:=
entry
.
client
.
Do
(
req
)
if
err
!=
nil
{
// 请求失败,立即减少计数
atomic
.
AddInt64
(
&
entry
.
inFlight
,
-
1
)
atomic
.
StoreInt64
(
&
entry
.
lastUsed
,
time
.
Now
()
.
UnixNano
())
return
nil
,
err
}
// 包装响应体,在关闭时自动减少计数并更新时间戳
// 这确保了流式响应(如 SSE)在完全读取前不会被淘汰
resp
.
Body
=
wrapTrackedBody
(
resp
.
Body
,
func
()
{
atomic
.
AddInt64
(
&
entry
.
inFlight
,
-
1
)
atomic
.
StoreInt64
(
&
entry
.
lastUsed
,
time
.
Now
()
.
UnixNano
())
})
return
resp
,
nil
}
func
(
s
*
httpUpstreamService
)
Do
(
req
*
http
.
Request
,
proxyURL
string
)
(
*
http
.
Response
,
error
)
{
if
proxyURL
==
""
{
return
s
.
defaultClient
.
Do
(
req
)
// acquireClient 获取或创建客户端,并标记为进行中请求
// 用于请求路径,避免在获取后被淘汰
func
(
s
*
httpUpstreamService
)
acquireClient
(
proxyURL
string
,
accountID
int64
,
accountConcurrency
int
)
(
*
upstreamClientEntry
,
error
)
{
return
s
.
getClientEntry
(
proxyURL
,
accountID
,
accountConcurrency
,
true
,
true
)
}
// getOrCreateClient 获取或创建客户端
// 根据隔离策略和参数决定缓存键,处理代理变更和配置变更
//
// 参数:
// - proxyURL: 代理地址
// - accountID: 账户 ID
// - accountConcurrency: 账户并发限制
//
// 返回:
// - *upstreamClientEntry: 客户端缓存条目
//
// 隔离策略说明:
// - proxy: 按代理地址隔离,同一代理共享客户端
// - account: 按账户隔离,同一账户共享客户端(代理变更时重建)
// - account_proxy: 按账户+代理组合隔离,最细粒度
func
(
s
*
httpUpstreamService
)
getOrCreateClient
(
proxyURL
string
,
accountID
int64
,
accountConcurrency
int
)
*
upstreamClientEntry
{
entry
,
_
:=
s
.
getClientEntry
(
proxyURL
,
accountID
,
accountConcurrency
,
false
,
false
)
return
entry
}
// getClientEntry 获取或创建客户端条目
// markInFlight=true 时会标记进行中请求,用于请求路径防止被淘汰
// enforceLimit=true 时会限制客户端数量,超限且无法淘汰时返回错误
func
(
s
*
httpUpstreamService
)
getClientEntry
(
proxyURL
string
,
accountID
int64
,
accountConcurrency
int
,
markInFlight
bool
,
enforceLimit
bool
)
(
*
upstreamClientEntry
,
error
)
{
// 获取隔离模式
isolation
:=
s
.
getIsolationMode
()
// 标准化代理 URL 并解析
proxyKey
,
parsedProxy
:=
normalizeProxyURL
(
proxyURL
)
// 构建缓存键(根据隔离策略不同)
cacheKey
:=
buildCacheKey
(
isolation
,
proxyKey
,
accountID
)
// 构建连接池配置键(用于检测配置变更)
poolKey
:=
s
.
buildPoolKey
(
isolation
,
accountConcurrency
)
now
:=
time
.
Now
()
nowUnix
:=
now
.
UnixNano
()
// 读锁快速路径:命中缓存直接返回,减少锁竞争
s
.
mu
.
RLock
()
if
entry
,
ok
:=
s
.
clients
[
cacheKey
];
ok
&&
s
.
shouldReuseEntry
(
entry
,
isolation
,
proxyKey
,
poolKey
)
{
atomic
.
StoreInt64
(
&
entry
.
lastUsed
,
nowUnix
)
if
markInFlight
{
atomic
.
AddInt64
(
&
entry
.
inFlight
,
1
)
}
s
.
mu
.
RUnlock
()
return
entry
,
nil
}
s
.
mu
.
RUnlock
()
// 写锁慢路径:创建或重建客户端
s
.
mu
.
Lock
()
if
entry
,
ok
:=
s
.
clients
[
cacheKey
];
ok
{
if
s
.
shouldReuseEntry
(
entry
,
isolation
,
proxyKey
,
poolKey
)
{
atomic
.
StoreInt64
(
&
entry
.
lastUsed
,
nowUnix
)
if
markInFlight
{
atomic
.
AddInt64
(
&
entry
.
inFlight
,
1
)
}
s
.
mu
.
Unlock
()
return
entry
,
nil
}
s
.
removeClientLocked
(
cacheKey
,
entry
)
}
// 超出缓存上限时尝试淘汰,无法淘汰则拒绝新建
if
enforceLimit
&&
s
.
maxUpstreamClients
()
>
0
{
s
.
evictIdleLocked
(
now
)
if
len
(
s
.
clients
)
>=
s
.
maxUpstreamClients
()
{
if
!
s
.
evictOldestIdleLocked
()
{
s
.
mu
.
Unlock
()
return
nil
,
errUpstreamClientLimitReached
}
}
}
// 缓存未命中或需要重建,创建新客户端
settings
:=
s
.
resolvePoolSettings
(
isolation
,
accountConcurrency
)
client
:=
&
http
.
Client
{
Transport
:
buildUpstreamTransport
(
settings
,
parsedProxy
)}
entry
:=
&
upstreamClientEntry
{
client
:
client
,
proxyKey
:
proxyKey
,
poolKey
:
poolKey
,
}
atomic
.
StoreInt64
(
&
entry
.
lastUsed
,
nowUnix
)
if
markInFlight
{
atomic
.
StoreInt64
(
&
entry
.
inFlight
,
1
)
}
s
.
clients
[
cacheKey
]
=
entry
// 执行淘汰策略:先淘汰空闲超时的,再淘汰超出数量限制的
s
.
evictIdleLocked
(
now
)
s
.
evictOverLimitLocked
()
s
.
mu
.
Unlock
()
return
entry
,
nil
}
// shouldReuseEntry 判断缓存条目是否可复用
// 若代理或连接池配置发生变化,则需要重建客户端
func
(
s
*
httpUpstreamService
)
shouldReuseEntry
(
entry
*
upstreamClientEntry
,
isolation
,
proxyKey
,
poolKey
string
)
bool
{
if
entry
==
nil
{
return
false
}
if
isolation
==
config
.
ConnectionPoolIsolationAccount
&&
entry
.
proxyKey
!=
proxyKey
{
return
false
}
if
entry
.
poolKey
!=
poolKey
{
return
false
}
return
true
}
// removeClientLocked 移除客户端(需持有锁)
// 从缓存中删除并关闭空闲连接
//
// 参数:
// - key: 缓存键
// - entry: 客户端条目
func
(
s
*
httpUpstreamService
)
removeClientLocked
(
key
string
,
entry
*
upstreamClientEntry
)
{
delete
(
s
.
clients
,
key
)
if
entry
!=
nil
&&
entry
.
client
!=
nil
{
// 关闭空闲连接,释放系统资源
// 注意:这不会中断活跃连接
entry
.
client
.
CloseIdleConnections
()
}
}
// evictIdleLocked 淘汰空闲超时的客户端(需持有锁)
// 遍历所有客户端,移除超过 TTL 且无活跃请求的条目
//
// 参数:
// - now: 当前时间
func
(
s
*
httpUpstreamService
)
evictIdleLocked
(
now
time
.
Time
)
{
ttl
:=
s
.
clientIdleTTL
()
if
ttl
<=
0
{
return
}
// 计算淘汰截止时间
cutoff
:=
now
.
Add
(
-
ttl
)
.
UnixNano
()
for
key
,
entry
:=
range
s
.
clients
{
// 跳过有活跃请求的客户端
if
atomic
.
LoadInt64
(
&
entry
.
inFlight
)
!=
0
{
continue
}
// 淘汰超时的空闲客户端
if
atomic
.
LoadInt64
(
&
entry
.
lastUsed
)
<=
cutoff
{
s
.
removeClientLocked
(
key
,
entry
)
}
}
}
// evictOldestIdleLocked 淘汰最久未使用且无活跃请求的客户端(需持有锁)
func
(
s
*
httpUpstreamService
)
evictOldestIdleLocked
()
bool
{
var
(
oldestKey
string
oldestEntry
*
upstreamClientEntry
oldestTime
int64
)
// 查找最久未使用且无活跃请求的客户端
for
key
,
entry
:=
range
s
.
clients
{
// 跳过有活跃请求的客户端
if
atomic
.
LoadInt64
(
&
entry
.
inFlight
)
!=
0
{
continue
}
lastUsed
:=
atomic
.
LoadInt64
(
&
entry
.
lastUsed
)
if
oldestEntry
==
nil
||
lastUsed
<
oldestTime
{
oldestKey
=
key
oldestEntry
=
entry
oldestTime
=
lastUsed
}
}
// 所有客户端都有活跃请求,无法淘汰
if
oldestEntry
==
nil
{
return
false
}
s
.
removeClientLocked
(
oldestKey
,
oldestEntry
)
return
true
}
// evictOverLimitLocked 淘汰超出数量限制的客户端(需持有锁)
// 使用 LRU 策略,优先淘汰最久未使用且无活跃请求的客户端
func
(
s
*
httpUpstreamService
)
evictOverLimitLocked
()
bool
{
maxClients
:=
s
.
maxUpstreamClients
()
if
maxClients
<=
0
{
return
false
}
evicted
:=
false
// 循环淘汰直到满足数量限制
for
len
(
s
.
clients
)
>
maxClients
{
if
!
s
.
evictOldestIdleLocked
()
{
return
evicted
}
evicted
=
true
}
return
evicted
}
// getIsolationMode 获取连接池隔离模式
// 从配置中读取,无效值回退到 account_proxy 模式
//
// 返回:
// - string: 隔离模式(proxy/account/account_proxy)
func
(
s
*
httpUpstreamService
)
getIsolationMode
()
string
{
if
s
.
cfg
==
nil
{
return
config
.
ConnectionPoolIsolationAccountProxy
}
mode
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
s
.
cfg
.
Gateway
.
ConnectionPoolIsolation
))
if
mode
==
""
{
return
config
.
ConnectionPoolIsolationAccountProxy
}
switch
mode
{
case
config
.
ConnectionPoolIsolationProxy
,
config
.
ConnectionPoolIsolationAccount
,
config
.
ConnectionPoolIsolationAccountProxy
:
return
mode
default
:
return
config
.
ConnectionPoolIsolationAccountProxy
}
}
// maxUpstreamClients 获取最大客户端缓存数量
// 从配置中读取,无效值使用默认值
func
(
s
*
httpUpstreamService
)
maxUpstreamClients
()
int
{
if
s
.
cfg
==
nil
{
return
defaultMaxUpstreamClients
}
if
s
.
cfg
.
Gateway
.
MaxUpstreamClients
>
0
{
return
s
.
cfg
.
Gateway
.
MaxUpstreamClients
}
return
defaultMaxUpstreamClients
}
// clientIdleTTL 获取客户端空闲回收阈值
// 从配置中读取,无效值使用默认值
func
(
s
*
httpUpstreamService
)
clientIdleTTL
()
time
.
Duration
{
if
s
.
cfg
==
nil
{
return
time
.
Duration
(
defaultClientIdleTTLSeconds
)
*
time
.
Second
}
if
s
.
cfg
.
Gateway
.
ClientIdleTTLSeconds
>
0
{
return
time
.
Duration
(
s
.
cfg
.
Gateway
.
ClientIdleTTLSeconds
)
*
time
.
Second
}
return
time
.
Duration
(
defaultClientIdleTTLSeconds
)
*
time
.
Second
}
// resolvePoolSettings 解析连接池配置
// 根据隔离策略和账户并发数动态调整连接池参数
//
// 参数:
// - isolation: 隔离模式
// - accountConcurrency: 账户并发限制
//
// 返回:
// - poolSettings: 连接池配置
//
// 说明:
// - 账户隔离模式下,连接池大小与账户并发数对应
// - 这确保了单账户不会占用过多连接资源
func
(
s
*
httpUpstreamService
)
resolvePoolSettings
(
isolation
string
,
accountConcurrency
int
)
poolSettings
{
settings
:=
defaultPoolSettings
(
s
.
cfg
)
// 账户隔离模式下,根据账户并发数调整连接池大小
if
(
isolation
==
config
.
ConnectionPoolIsolationAccount
||
isolation
==
config
.
ConnectionPoolIsolationAccountProxy
)
&&
accountConcurrency
>
0
{
settings
.
maxIdleConns
=
accountConcurrency
settings
.
maxIdleConnsPerHost
=
accountConcurrency
settings
.
maxConnsPerHost
=
accountConcurrency
}
client
:=
s
.
createProxyClient
(
proxyURL
)
return
client
.
Do
(
req
)
return
settings
}
func
(
s
*
httpUpstreamService
)
createProxyClient
(
proxyURL
string
)
*
http
.
Client
{
parsedURL
,
err
:=
url
.
Parse
(
proxyURL
)
// buildPoolKey 构建连接池配置键
// 用于检测配置变更,配置变更时需要重建客户端
//
// 参数:
// - isolation: 隔离模式
// - accountConcurrency: 账户并发限制
//
// 返回:
// - string: 配置键
func
(
s
*
httpUpstreamService
)
buildPoolKey
(
isolation
string
,
accountConcurrency
int
)
string
{
if
isolation
==
config
.
ConnectionPoolIsolationAccount
||
isolation
==
config
.
ConnectionPoolIsolationAccountProxy
{
if
accountConcurrency
>
0
{
return
fmt
.
Sprintf
(
"account:%d"
,
accountConcurrency
)
}
}
return
"default"
}
// buildCacheKey 构建客户端缓存键
// 根据隔离策略决定缓存键的组成
//
// 参数:
// - isolation: 隔离模式
// - proxyKey: 代理标识
// - accountID: 账户 ID
//
// 返回:
// - string: 缓存键
//
// 缓存键格式:
// - proxy 模式: "proxy:{proxyKey}"
// - account 模式: "account:{accountID}"
// - account_proxy 模式: "account:{accountID}|proxy:{proxyKey}"
func
buildCacheKey
(
isolation
,
proxyKey
string
,
accountID
int64
)
string
{
switch
isolation
{
case
config
.
ConnectionPoolIsolationAccount
:
return
fmt
.
Sprintf
(
"account:%d"
,
accountID
)
case
config
.
ConnectionPoolIsolationAccountProxy
:
return
fmt
.
Sprintf
(
"account:%d|proxy:%s"
,
accountID
,
proxyKey
)
default
:
return
fmt
.
Sprintf
(
"proxy:%s"
,
proxyKey
)
}
}
// normalizeProxyURL 标准化代理 URL
// 处理空值和解析错误,返回标准化的键和解析后的 URL
//
// 参数:
// - raw: 原始代理 URL 字符串
//
// 返回:
// - string: 标准化的代理键(空或解析失败返回 "direct")
// - *url.URL: 解析后的 URL(空或解析失败返回 nil)
func
normalizeProxyURL
(
raw
string
)
(
string
,
*
url
.
URL
)
{
proxyURL
:=
strings
.
TrimSpace
(
raw
)
if
proxyURL
==
""
{
return
directProxyKey
,
nil
}
parsed
,
err
:=
url
.
Parse
(
proxyURL
)
if
err
!=
nil
{
return
s
.
defaultClient
return
directProxyKey
,
nil
}
parsed
.
Scheme
=
strings
.
ToLower
(
parsed
.
Scheme
)
parsed
.
Host
=
strings
.
ToLower
(
parsed
.
Host
)
parsed
.
Path
=
""
parsed
.
RawPath
=
""
parsed
.
RawQuery
=
""
parsed
.
Fragment
=
""
parsed
.
ForceQuery
=
false
if
hostname
:=
parsed
.
Hostname
();
hostname
!=
""
{
port
:=
parsed
.
Port
()
if
(
parsed
.
Scheme
==
"http"
&&
port
==
"80"
)
||
(
parsed
.
Scheme
==
"https"
&&
port
==
"443"
)
{
port
=
""
}
hostname
=
strings
.
ToLower
(
hostname
)
if
port
!=
""
{
parsed
.
Host
=
net
.
JoinHostPort
(
hostname
,
port
)
}
else
{
parsed
.
Host
=
hostname
}
}
return
parsed
.
String
(),
parsed
}
// defaultPoolSettings 获取默认连接池配置
// 从全局配置中读取,无效值使用常量默认值
//
// 参数:
// - cfg: 全局配置
//
// 返回:
// - poolSettings: 连接池配置
func
defaultPoolSettings
(
cfg
*
config
.
Config
)
poolSettings
{
maxIdleConns
:=
defaultMaxIdleConns
maxIdleConnsPerHost
:=
defaultMaxIdleConnsPerHost
maxConnsPerHost
:=
defaultMaxConnsPerHost
idleConnTimeout
:=
defaultIdleConnTimeout
responseHeaderTimeout
:=
defaultResponseHeaderTimeout
if
cfg
!=
nil
{
if
cfg
.
Gateway
.
MaxIdleConns
>
0
{
maxIdleConns
=
cfg
.
Gateway
.
MaxIdleConns
}
if
cfg
.
Gateway
.
MaxIdleConnsPerHost
>
0
{
maxIdleConnsPerHost
=
cfg
.
Gateway
.
MaxIdleConnsPerHost
}
if
cfg
.
Gateway
.
MaxConnsPerHost
>=
0
{
maxConnsPerHost
=
cfg
.
Gateway
.
MaxConnsPerHost
}
if
cfg
.
Gateway
.
IdleConnTimeoutSeconds
>
0
{
idleConnTimeout
=
time
.
Duration
(
cfg
.
Gateway
.
IdleConnTimeoutSeconds
)
*
time
.
Second
}
if
cfg
.
Gateway
.
ResponseHeaderTimeout
>
0
{
responseHeaderTimeout
=
time
.
Duration
(
cfg
.
Gateway
.
ResponseHeaderTimeout
)
*
time
.
Second
}
}
responseHeaderTimeout
:=
time
.
Duration
(
s
.
cfg
.
Gateway
.
ResponseHeaderTimeout
)
*
time
.
Second
if
responseHeaderTimeout
==
0
{
responseHeaderTimeout
=
300
*
time
.
Second
return
poolSettings
{
maxIdleConns
:
maxIdleConns
,
maxIdleConnsPerHost
:
maxIdleConnsPerHost
,
maxConnsPerHost
:
maxConnsPerHost
,
idleConnTimeout
:
idleConnTimeout
,
responseHeaderTimeout
:
responseHeaderTimeout
,
}
}
// buildUpstreamTransport 构建上游请求的 Transport
// 使用配置文件中的连接池参数,支持生产环境调优
//
// 参数:
// - settings: 连接池配置
// - proxyURL: 代理 URL(nil 表示直连)
//
// 返回:
// - *http.Transport: 配置好的 Transport 实例
//
// Transport 参数说明:
// - MaxIdleConns: 所有主机的最大空闲连接总数
// - MaxIdleConnsPerHost: 每主机最大空闲连接数(影响连接复用率)
// - MaxConnsPerHost: 每主机最大连接数(达到后新请求等待)
// - IdleConnTimeout: 空闲连接超时(超时后关闭)
// - ResponseHeaderTimeout: 等待响应头超时(不影响流式传输)
func
buildUpstreamTransport
(
settings
poolSettings
,
proxyURL
*
url
.
URL
)
*
http
.
Transport
{
transport
:=
&
http
.
Transport
{
Proxy
:
http
.
ProxyURL
(
parsedURL
)
,
MaxIdleConns
:
100
,
Max
Idle
ConnsPerHost
:
10
,
IdleConnTimeout
:
90
*
time
.
Second
,
ResponseHeaderTimeout
:
responseHeaderTimeout
,
MaxIdleConns
:
settings
.
maxIdleConns
,
MaxIdleConns
PerHost
:
settings
.
maxIdleConnsPerHost
,
MaxConnsPerHost
:
settings
.
maxConnsPerHost
,
IdleConnTimeout
:
settings
.
idleConnTimeout
,
ResponseHeaderTimeout
:
settings
.
responseHeaderTimeout
,
}
if
proxyURL
!=
nil
{
transport
.
Proxy
=
http
.
ProxyURL
(
proxyURL
)
}
return
transport
}
return
&
http
.
Client
{
Transport
:
transport
}
// trackedBody 带跟踪功能的响应体包装器
// 在 Close 时执行回调,用于更新请求计数
type
trackedBody
struct
{
io
.
ReadCloser
// 原始响应体
once
sync
.
Once
onClose
func
()
// 关闭时的回调函数
}
// Close 关闭响应体并执行回调
// 使用 sync.Once 确保回调只执行一次
func
(
b
*
trackedBody
)
Close
()
error
{
err
:=
b
.
ReadCloser
.
Close
()
if
b
.
onClose
!=
nil
{
b
.
once
.
Do
(
b
.
onClose
)
}
return
err
}
// wrapTrackedBody 包装响应体以跟踪关闭事件
// 用于在响应体关闭时更新 inFlight 计数
//
// 参数:
// - body: 原始响应体
// - onClose: 关闭时的回调函数
//
// 返回:
// - io.ReadCloser: 包装后的响应体
func
wrapTrackedBody
(
body
io
.
ReadCloser
,
onClose
func
())
io
.
ReadCloser
{
if
body
==
nil
{
return
body
}
return
&
trackedBody
{
ReadCloser
:
body
,
onClose
:
onClose
}
}
backend/internal/repository/http_upstream_benchmark_test.go
0 → 100644
View file @
7331220e
package
repository
import
(
"net/http"
"net/url"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
)
// httpClientSink 用于防止编译器优化掉基准测试中的赋值操作
// 这是 Go 基准测试的常见模式,确保测试结果准确
var
httpClientSink
*
http
.
Client
// BenchmarkHTTPUpstreamProxyClient 对比重复创建与复用代理客户端的开销
//
// 测试目的:
// - 验证连接池复用相比每次新建的性能提升
// - 量化内存分配差异
//
// 预期结果:
// - "复用" 子测试应显著快于 "新建"
// - "复用" 子测试应零内存分配
func
BenchmarkHTTPUpstreamProxyClient
(
b
*
testing
.
B
)
{
// 创建测试配置
cfg
:=
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
ResponseHeaderTimeout
:
300
},
}
upstream
:=
NewHTTPUpstream
(
cfg
)
svc
,
ok
:=
upstream
.
(
*
httpUpstreamService
)
if
!
ok
{
b
.
Fatalf
(
"类型断言失败,无法获取 httpUpstreamService"
)
}
proxyURL
:=
"http://127.0.0.1:8080"
b
.
ReportAllocs
()
// 报告内存分配统计
// 子测试:每次新建客户端
// 模拟未优化前的行为,每次请求都创建新的 http.Client
b
.
Run
(
"新建"
,
func
(
b
*
testing
.
B
)
{
parsedProxy
,
err
:=
url
.
Parse
(
proxyURL
)
if
err
!=
nil
{
b
.
Fatalf
(
"解析代理地址失败: %v"
,
err
)
}
settings
:=
defaultPoolSettings
(
cfg
)
for
i
:=
0
;
i
<
b
.
N
;
i
++
{
// 每次迭代都创建新客户端,包含 Transport 分配
httpClientSink
=
&
http
.
Client
{
Transport
:
buildUpstreamTransport
(
settings
,
parsedProxy
),
}
}
})
// 子测试:复用已缓存的客户端
// 模拟优化后的行为,从缓存获取客户端
b
.
Run
(
"复用"
,
func
(
b
*
testing
.
B
)
{
// 预热:确保客户端已缓存
entry
:=
svc
.
getOrCreateClient
(
proxyURL
,
1
,
1
)
client
:=
entry
.
client
b
.
ResetTimer
()
// 重置计时器,排除预热时间
for
i
:=
0
;
i
<
b
.
N
;
i
++
{
// 直接使用缓存的客户端,无内存分配
httpClientSink
=
client
}
})
}
backend/internal/repository/http_upstream_test.go
View file @
7331220e
...
...
@@ -4,6 +4,7 @@ import (
"io"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
...
...
@@ -12,45 +13,86 @@ import (
"github.com/stretchr/testify/suite"
)
// HTTPUpstreamSuite HTTP 上游服务测试套件
// 使用 testify/suite 组织测试,支持 SetupTest 初始化
type
HTTPUpstreamSuite
struct
{
suite
.
Suite
cfg
*
config
.
Config
cfg
*
config
.
Config
// 测试用配置
}
// SetupTest 每个测试用例执行前的初始化
// 创建空配置,各测试用例可按需覆盖
func
(
s
*
HTTPUpstreamSuite
)
SetupTest
()
{
s
.
cfg
=
&
config
.
Config
{}
}
func
(
s
*
HTTPUpstreamSuite
)
TestDefaultResponseHeaderTimeout
()
{
// newService 创建测试用的 httpUpstreamService 实例
// 返回具体类型以便访问内部状态进行断言
func
(
s
*
HTTPUpstreamSuite
)
newService
()
*
httpUpstreamService
{
up
:=
NewHTTPUpstream
(
s
.
cfg
)
svc
,
ok
:=
up
.
(
*
httpUpstreamService
)
require
.
True
(
s
.
T
(),
ok
,
"expected *httpUpstreamService"
)
transport
,
ok
:=
svc
.
defaultClient
.
Transport
.
(
*
http
.
Transport
)
return
svc
}
// TestDefaultResponseHeaderTimeout 测试默认响应头超时配置
// 验证未配置时使用 300 秒默认值
func
(
s
*
HTTPUpstreamSuite
)
TestDefaultResponseHeaderTimeout
()
{
svc
:=
s
.
newService
()
entry
:=
svc
.
getOrCreateClient
(
""
,
0
,
0
)
transport
,
ok
:=
entry
.
client
.
Transport
.
(
*
http
.
Transport
)
require
.
True
(
s
.
T
(),
ok
,
"expected *http.Transport"
)
require
.
Equal
(
s
.
T
(),
300
*
time
.
Second
,
transport
.
ResponseHeaderTimeout
,
"ResponseHeaderTimeout mismatch"
)
}
// TestCustomResponseHeaderTimeout 测试自定义响应头超时配置
// 验证配置值能正确应用到 Transport
func
(
s
*
HTTPUpstreamSuite
)
TestCustomResponseHeaderTimeout
()
{
s
.
cfg
.
Gateway
=
config
.
GatewayConfig
{
ResponseHeaderTimeout
:
7
}
up
:=
NewHTTPUpstream
(
s
.
cfg
)
svc
,
ok
:=
up
.
(
*
httpUpstreamService
)
require
.
True
(
s
.
T
(),
ok
,
"expected *httpUpstreamService"
)
transport
,
ok
:=
svc
.
defaultClient
.
Transport
.
(
*
http
.
Transport
)
svc
:=
s
.
newService
()
entry
:=
svc
.
getOrCreateClient
(
""
,
0
,
0
)
transport
,
ok
:=
entry
.
client
.
Transport
.
(
*
http
.
Transport
)
require
.
True
(
s
.
T
(),
ok
,
"expected *http.Transport"
)
require
.
Equal
(
s
.
T
(),
7
*
time
.
Second
,
transport
.
ResponseHeaderTimeout
,
"ResponseHeaderTimeout mismatch"
)
}
func
(
s
*
HTTPUpstreamSuite
)
TestCreateProxyClient_InvalidURLFallsBackToDefault
()
{
s
.
cfg
.
Gateway
=
config
.
GatewayConfig
{
ResponseHeaderTimeout
:
5
}
up
:=
NewHTTPUpstream
(
s
.
cfg
)
svc
,
ok
:=
up
.
(
*
httpUpstreamService
)
require
.
True
(
s
.
T
(),
ok
,
"expected *httpUpstreamService"
)
// TestGetOrCreateClient_InvalidURLFallsBackToDirect 测试无效代理 URL 回退
// 验证解析失败时回退到直连模式
func
(
s
*
HTTPUpstreamSuite
)
TestGetOrCreateClient_InvalidURLFallsBackToDirect
()
{
svc
:=
s
.
newService
()
entry
:=
svc
.
getOrCreateClient
(
"://bad-proxy-url"
,
1
,
1
)
require
.
Equal
(
s
.
T
(),
directProxyKey
,
entry
.
proxyKey
,
"expected direct proxy fallback"
)
}
// TestNormalizeProxyURL_Canonicalizes 测试代理 URL 规范化
// 验证等价地址能够映射到同一缓存键
func
(
s
*
HTTPUpstreamSuite
)
TestNormalizeProxyURL_Canonicalizes
()
{
key1
,
_
:=
normalizeProxyURL
(
"http://proxy.local:8080"
)
key2
,
_
:=
normalizeProxyURL
(
"http://proxy.local:8080/"
)
require
.
Equal
(
s
.
T
(),
key1
,
key2
,
"expected normalized proxy keys to match"
)
}
// TestAcquireClient_OverLimitReturnsError 测试连接池缓存上限保护
// 验证超限且无可淘汰条目时返回错误
func
(
s
*
HTTPUpstreamSuite
)
TestAcquireClient_OverLimitReturnsError
()
{
s
.
cfg
.
Gateway
=
config
.
GatewayConfig
{
ConnectionPoolIsolation
:
config
.
ConnectionPoolIsolationAccountProxy
,
MaxUpstreamClients
:
1
,
}
svc
:=
s
.
newService
()
entry1
,
err
:=
svc
.
acquireClient
(
"http://proxy-a:8080"
,
1
,
1
)
require
.
NoError
(
s
.
T
(),
err
,
"expected first acquire to succeed"
)
require
.
NotNil
(
s
.
T
(),
entry1
,
"expected entry"
)
got
:=
svc
.
createProxyClient
(
"://bad-proxy-url"
)
require
.
Equal
(
s
.
T
(),
svc
.
defaultClient
,
got
,
"expected defaultClient fallback"
)
entry2
,
err
:=
svc
.
acquireClient
(
"http://proxy-b:8080"
,
2
,
1
)
require
.
Error
(
s
.
T
(),
err
,
"expected error when cache limit reached"
)
require
.
Nil
(
s
.
T
(),
entry2
,
"expected nil entry when cache limit reached"
)
}
// TestDo_WithoutProxy_GoesDirect 测试无代理时直连
// 验证空代理 URL 时请求直接发送到目标服务器
func
(
s
*
HTTPUpstreamSuite
)
TestDo_WithoutProxy_GoesDirect
()
{
// 创建模拟上游服务器
upstream
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
_
,
_
=
io
.
WriteString
(
w
,
"direct"
)
}))
...
...
@@ -60,17 +102,21 @@ func (s *HTTPUpstreamSuite) TestDo_WithoutProxy_GoesDirect() {
req
,
err
:=
http
.
NewRequest
(
http
.
MethodGet
,
upstream
.
URL
+
"/x"
,
nil
)
require
.
NoError
(
s
.
T
(),
err
,
"NewRequest"
)
resp
,
err
:=
up
.
Do
(
req
,
""
)
resp
,
err
:=
up
.
Do
(
req
,
""
,
1
,
1
)
require
.
NoError
(
s
.
T
(),
err
,
"Do"
)
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
b
,
_
:=
io
.
ReadAll
(
resp
.
Body
)
require
.
Equal
(
s
.
T
(),
"direct"
,
string
(
b
),
"unexpected body"
)
}
// TestDo_WithHTTPProxy_UsesProxy 测试 HTTP 代理功能
// 验证请求通过代理服务器转发,使用绝对 URI 格式
func
(
s
*
HTTPUpstreamSuite
)
TestDo_WithHTTPProxy_UsesProxy
()
{
// 用于接收代理请求的通道
seen
:=
make
(
chan
string
,
1
)
// 创建模拟代理服务器
proxySrv
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
seen
<-
r
.
RequestURI
seen
<-
r
.
RequestURI
// 记录请求 URI
_
,
_
=
io
.
WriteString
(
w
,
"proxied"
)
}))
s
.
T
()
.
Cleanup
(
proxySrv
.
Close
)
...
...
@@ -78,14 +124,16 @@ func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() {
s
.
cfg
.
Gateway
=
config
.
GatewayConfig
{
ResponseHeaderTimeout
:
1
}
up
:=
NewHTTPUpstream
(
s
.
cfg
)
// 发送请求到外部地址,应通过代理
req
,
err
:=
http
.
NewRequest
(
http
.
MethodGet
,
"http://example.com/test"
,
nil
)
require
.
NoError
(
s
.
T
(),
err
,
"NewRequest"
)
resp
,
err
:=
up
.
Do
(
req
,
proxySrv
.
URL
)
resp
,
err
:=
up
.
Do
(
req
,
proxySrv
.
URL
,
1
,
1
)
require
.
NoError
(
s
.
T
(),
err
,
"Do"
)
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
b
,
_
:=
io
.
ReadAll
(
resp
.
Body
)
require
.
Equal
(
s
.
T
(),
"proxied"
,
string
(
b
),
"unexpected body"
)
// 验证代理收到的是绝对 URI 格式(HTTP 代理规范要求)
select
{
case
uri
:=
<-
seen
:
require
.
Equal
(
s
.
T
(),
"http://example.com/test"
,
uri
,
"expected absolute-form request URI"
)
...
...
@@ -94,6 +142,8 @@ func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() {
}
}
// TestDo_EmptyProxy_UsesDirect 测试空代理字符串
// 验证空字符串代理等同于直连
func
(
s
*
HTTPUpstreamSuite
)
TestDo_EmptyProxy_UsesDirect
()
{
upstream
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
_
,
_
=
io
.
WriteString
(
w
,
"direct-empty"
)
...
...
@@ -103,13 +153,134 @@ func (s *HTTPUpstreamSuite) TestDo_EmptyProxy_UsesDirect() {
up
:=
NewHTTPUpstream
(
s
.
cfg
)
req
,
err
:=
http
.
NewRequest
(
http
.
MethodGet
,
upstream
.
URL
+
"/y"
,
nil
)
require
.
NoError
(
s
.
T
(),
err
,
"NewRequest"
)
resp
,
err
:=
up
.
Do
(
req
,
""
)
resp
,
err
:=
up
.
Do
(
req
,
""
,
1
,
1
)
require
.
NoError
(
s
.
T
(),
err
,
"Do with empty proxy"
)
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
b
,
_
:=
io
.
ReadAll
(
resp
.
Body
)
require
.
Equal
(
s
.
T
(),
"direct-empty"
,
string
(
b
))
}
// TestAccountIsolation_DifferentAccounts 测试账户隔离模式
// 验证不同账户使用独立的连接池
func
(
s
*
HTTPUpstreamSuite
)
TestAccountIsolation_DifferentAccounts
()
{
s
.
cfg
.
Gateway
=
config
.
GatewayConfig
{
ConnectionPoolIsolation
:
config
.
ConnectionPoolIsolationAccount
}
svc
:=
s
.
newService
()
// 同一代理,不同账户
entry1
:=
svc
.
getOrCreateClient
(
"http://proxy.local:8080"
,
1
,
3
)
entry2
:=
svc
.
getOrCreateClient
(
"http://proxy.local:8080"
,
2
,
3
)
require
.
NotSame
(
s
.
T
(),
entry1
,
entry2
,
"不同账号不应共享连接池"
)
require
.
Equal
(
s
.
T
(),
2
,
len
(
svc
.
clients
),
"账号隔离应缓存两个客户端"
)
}
// TestAccountProxyIsolation_DifferentProxy 测试账户+代理组合隔离模式
// 验证同一账户使用不同代理时创建独立连接池
func
(
s
*
HTTPUpstreamSuite
)
TestAccountProxyIsolation_DifferentProxy
()
{
s
.
cfg
.
Gateway
=
config
.
GatewayConfig
{
ConnectionPoolIsolation
:
config
.
ConnectionPoolIsolationAccountProxy
}
svc
:=
s
.
newService
()
// 同一账户,不同代理
entry1
:=
svc
.
getOrCreateClient
(
"http://proxy-a:8080"
,
1
,
3
)
entry2
:=
svc
.
getOrCreateClient
(
"http://proxy-b:8080"
,
1
,
3
)
require
.
NotSame
(
s
.
T
(),
entry1
,
entry2
,
"账号+代理隔离应区分不同代理"
)
require
.
Equal
(
s
.
T
(),
2
,
len
(
svc
.
clients
),
"账号+代理隔离应缓存两个客户端"
)
}
// TestAccountModeProxyChangeClearsPool 测试账户模式下代理变更
// 验证账户切换代理时清理旧连接池,避免复用错误代理
func
(
s
*
HTTPUpstreamSuite
)
TestAccountModeProxyChangeClearsPool
()
{
s
.
cfg
.
Gateway
=
config
.
GatewayConfig
{
ConnectionPoolIsolation
:
config
.
ConnectionPoolIsolationAccount
}
svc
:=
s
.
newService
()
// 同一账户,先后使用不同代理
entry1
:=
svc
.
getOrCreateClient
(
"http://proxy-a:8080"
,
1
,
3
)
entry2
:=
svc
.
getOrCreateClient
(
"http://proxy-b:8080"
,
1
,
3
)
require
.
NotSame
(
s
.
T
(),
entry1
,
entry2
,
"账号切换代理应创建新连接池"
)
require
.
Equal
(
s
.
T
(),
1
,
len
(
svc
.
clients
),
"账号模式下应仅保留一个连接池"
)
require
.
False
(
s
.
T
(),
hasEntry
(
svc
,
entry1
),
"旧连接池应被清理"
)
}
// TestAccountConcurrencyOverridesPoolSettings 测试账户并发数覆盖连接池配置
// 验证账户隔离模式下,连接池大小与账户并发数对应
func
(
s
*
HTTPUpstreamSuite
)
TestAccountConcurrencyOverridesPoolSettings
()
{
s
.
cfg
.
Gateway
=
config
.
GatewayConfig
{
ConnectionPoolIsolation
:
config
.
ConnectionPoolIsolationAccount
}
svc
:=
s
.
newService
()
// 账户并发数为 12
entry
:=
svc
.
getOrCreateClient
(
""
,
1
,
12
)
transport
,
ok
:=
entry
.
client
.
Transport
.
(
*
http
.
Transport
)
require
.
True
(
s
.
T
(),
ok
,
"expected *http.Transport"
)
// 连接池参数应与并发数一致
require
.
Equal
(
s
.
T
(),
12
,
transport
.
MaxConnsPerHost
,
"MaxConnsPerHost mismatch"
)
require
.
Equal
(
s
.
T
(),
12
,
transport
.
MaxIdleConns
,
"MaxIdleConns mismatch"
)
require
.
Equal
(
s
.
T
(),
12
,
transport
.
MaxIdleConnsPerHost
,
"MaxIdleConnsPerHost mismatch"
)
}
// TestAccountConcurrencyFallbackToDefault 测试账户并发数为 0 时回退到默认配置
// 验证未指定并发数时使用全局配置值
func
(
s
*
HTTPUpstreamSuite
)
TestAccountConcurrencyFallbackToDefault
()
{
s
.
cfg
.
Gateway
=
config
.
GatewayConfig
{
ConnectionPoolIsolation
:
config
.
ConnectionPoolIsolationAccount
,
MaxIdleConns
:
77
,
MaxIdleConnsPerHost
:
55
,
MaxConnsPerHost
:
66
,
}
svc
:=
s
.
newService
()
// 账户并发数为 0,应使用全局配置
entry
:=
svc
.
getOrCreateClient
(
""
,
1
,
0
)
transport
,
ok
:=
entry
.
client
.
Transport
.
(
*
http
.
Transport
)
require
.
True
(
s
.
T
(),
ok
,
"expected *http.Transport"
)
require
.
Equal
(
s
.
T
(),
66
,
transport
.
MaxConnsPerHost
,
"MaxConnsPerHost fallback mismatch"
)
require
.
Equal
(
s
.
T
(),
77
,
transport
.
MaxIdleConns
,
"MaxIdleConns fallback mismatch"
)
require
.
Equal
(
s
.
T
(),
55
,
transport
.
MaxIdleConnsPerHost
,
"MaxIdleConnsPerHost fallback mismatch"
)
}
// TestEvictOverLimitRemovesOldestIdle 测试超出数量限制时的 LRU 淘汰
// 验证优先淘汰最久未使用的空闲客户端
func
(
s
*
HTTPUpstreamSuite
)
TestEvictOverLimitRemovesOldestIdle
()
{
s
.
cfg
.
Gateway
=
config
.
GatewayConfig
{
ConnectionPoolIsolation
:
config
.
ConnectionPoolIsolationAccountProxy
,
MaxUpstreamClients
:
2
,
// 最多缓存 2 个客户端
}
svc
:=
s
.
newService
()
// 创建两个客户端,设置不同的最后使用时间
entry1
:=
svc
.
getOrCreateClient
(
"http://proxy-a:8080"
,
1
,
1
)
entry2
:=
svc
.
getOrCreateClient
(
"http://proxy-b:8080"
,
2
,
1
)
atomic
.
StoreInt64
(
&
entry1
.
lastUsed
,
time
.
Now
()
.
Add
(
-
2
*
time
.
Hour
)
.
UnixNano
())
// 最久
atomic
.
StoreInt64
(
&
entry2
.
lastUsed
,
time
.
Now
()
.
Add
(
-
time
.
Hour
)
.
UnixNano
())
// 创建第三个客户端,触发淘汰
_
=
svc
.
getOrCreateClient
(
"http://proxy-c:8080"
,
3
,
1
)
require
.
LessOrEqual
(
s
.
T
(),
len
(
svc
.
clients
),
2
,
"应保持在缓存上限内"
)
require
.
False
(
s
.
T
(),
hasEntry
(
svc
,
entry1
),
"最久未使用的连接池应被清理"
)
}
// TestIdleTTLDoesNotEvictActive 测试活跃请求保护
// 验证有进行中请求的客户端不会被空闲超时淘汰
func
(
s
*
HTTPUpstreamSuite
)
TestIdleTTLDoesNotEvictActive
()
{
s
.
cfg
.
Gateway
=
config
.
GatewayConfig
{
ConnectionPoolIsolation
:
config
.
ConnectionPoolIsolationAccount
,
ClientIdleTTLSeconds
:
1
,
// 1 秒空闲超时
}
svc
:=
s
.
newService
()
entry1
:=
svc
.
getOrCreateClient
(
""
,
1
,
1
)
// 设置为很久之前使用,但有活跃请求
atomic
.
StoreInt64
(
&
entry1
.
lastUsed
,
time
.
Now
()
.
Add
(
-
2
*
time
.
Minute
)
.
UnixNano
())
atomic
.
StoreInt64
(
&
entry1
.
inFlight
,
1
)
// 模拟有活跃请求
// 创建新客户端,触发淘汰检查
_
=
svc
.
getOrCreateClient
(
""
,
2
,
1
)
require
.
True
(
s
.
T
(),
hasEntry
(
svc
,
entry1
),
"有活跃请求时不应回收"
)
}
// TestHTTPUpstreamSuite 运行测试套件
func
TestHTTPUpstreamSuite
(
t
*
testing
.
T
)
{
suite
.
Run
(
t
,
new
(
HTTPUpstreamSuite
))
}
// hasEntry 检查客户端是否存在于缓存中
// 辅助函数,用于验证淘汰逻辑
func
hasEntry
(
svc
*
httpUpstreamService
,
target
*
upstreamClientEntry
)
bool
{
for
_
,
entry
:=
range
svc
.
clients
{
if
entry
==
target
{
return
true
}
}
return
false
}
backend/internal/repository/integration_harness_test.go
View file @
7331220e
...
...
@@ -17,7 +17,6 @@ import (
dbent
"github.com/Wei-Shaw/sub2api/ent"
_
"github.com/Wei-Shaw/sub2api/ent/runtime"
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
...
...
@@ -97,7 +96,7 @@ func TestMain(m *testing.M) {
log
.
Printf
(
"failed to open sql db: %v"
,
err
)
os
.
Exit
(
1
)
}
if
err
:=
infrastructure
.
ApplyMigrations
(
ctx
,
integrationDB
);
err
!=
nil
{
if
err
:=
ApplyMigrations
(
ctx
,
integrationDB
);
err
!=
nil
{
log
.
Printf
(
"failed to apply db migrations: %v"
,
err
)
os
.
Exit
(
1
)
}
...
...
@@ -330,7 +329,8 @@ func (h prefixHook) prefixCmd(cmd redisclient.Cmder) {
switch
strings
.
ToLower
(
cmd
.
Name
())
{
case
"get"
,
"set"
,
"setnx"
,
"setex"
,
"psetex"
,
"incr"
,
"decr"
,
"incrby"
,
"expire"
,
"pexpire"
,
"ttl"
,
"pttl"
,
"hgetall"
,
"hget"
,
"hset"
,
"hdel"
,
"hincrbyfloat"
,
"exists"
:
"hgetall"
,
"hget"
,
"hset"
,
"hdel"
,
"hincrbyfloat"
,
"exists"
,
"zadd"
,
"zcard"
,
"zrange"
,
"zrangebyscore"
,
"zrem"
,
"zremrangebyscore"
,
"zrevrange"
,
"zrevrangebyscore"
,
"zscore"
:
prefixOne
(
1
)
case
"del"
,
"unlink"
:
for
i
:=
1
;
i
<
len
(
args
);
i
++
{
...
...
backend/internal/
infrastructure
/migrations_runner.go
→
backend/internal/
repository
/migrations_runner.go
View file @
7331220e
package
infrastructure
package
repository
import
(
"context"
...
...
@@ -127,7 +127,15 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
if
existing
!=
checksum
{
// 校验和不匹配意味着迁移文件在应用后被修改,这是危险的。
// 正确的做法是创建新的迁移文件来进行变更。
return
fmt
.
Errorf
(
"migration %s checksum mismatch (db=%s file=%s)"
,
name
,
existing
,
checksum
)
return
fmt
.
Errorf
(
"migration %s checksum mismatch (db=%s file=%s)
\n
"
+
"This means the migration file was modified after being applied to the database.
\n
"
+
"Solutions:
\n
"
+
" 1. Revert to original: git log --oneline -- migrations/%s && git checkout <commit> -- migrations/%s
\n
"
+
" 2. For new changes, create a new migration file instead of modifying existing ones
\n
"
+
"Note: Modifying applied migrations breaks the immutability principle and can cause inconsistencies across environments"
,
name
,
existing
,
checksum
,
name
,
name
,
)
}
continue
// 迁移已应用且校验和匹配,跳过
}
...
...
backend/internal/repository/migrations_schema_integration_test.go
View file @
7331220e
...
...
@@ -7,7 +7,6 @@ import (
"database/sql"
"testing"
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
"github.com/stretchr/testify/require"
)
...
...
@@ -15,7 +14,7 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
tx
:=
testTx
(
t
)
// Re-apply migrations to verify idempotency (no errors, no duplicate rows).
require
.
NoError
(
t
,
infrastructure
.
ApplyMigrations
(
context
.
Background
(),
integrationDB
))
require
.
NoError
(
t
,
ApplyMigrations
(
context
.
Background
(),
integrationDB
))
// schema_migrations should have at least the current migration set.
var
applied
int
...
...
@@ -53,6 +52,20 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
var
uagRegclass
sql
.
NullString
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
context
.
Background
(),
"SELECT to_regclass('public.user_allowed_groups')"
)
.
Scan
(
&
uagRegclass
))
require
.
True
(
t
,
uagRegclass
.
Valid
,
"expected user_allowed_groups table to exist"
)
// user_subscriptions: deleted_at for soft delete support (migration 012)
requireColumn
(
t
,
tx
,
"user_subscriptions"
,
"deleted_at"
,
"timestamp with time zone"
,
0
,
true
)
// orphan_allowed_groups_audit table should exist (migration 013)
var
orphanAuditRegclass
sql
.
NullString
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
context
.
Background
(),
"SELECT to_regclass('public.orphan_allowed_groups_audit')"
)
.
Scan
(
&
orphanAuditRegclass
))
require
.
True
(
t
,
orphanAuditRegclass
.
Valid
,
"expected orphan_allowed_groups_audit table to exist"
)
// account_groups: created_at should be timestamptz
requireColumn
(
t
,
tx
,
"account_groups"
,
"created_at"
,
"timestamp with time zone"
,
0
,
false
)
// user_allowed_groups: created_at should be timestamptz
requireColumn
(
t
,
tx
,
"user_allowed_groups"
,
"created_at"
,
"timestamp with time zone"
,
0
,
false
)
}
func
requireColumn
(
t
*
testing
.
T
,
tx
*
sql
.
Tx
,
table
,
column
,
dataType
string
,
maxLen
int
,
nullable
bool
)
{
...
...
backend/internal/repository/openai_oauth_service.go
View file @
7331220e
...
...
@@ -82,12 +82,8 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
}
func
createOpenAIReqClient
(
proxyURL
string
)
*
req
.
Client
{
client
:=
req
.
C
()
.
SetTimeout
(
60
*
time
.
Second
)
if
proxyURL
!=
""
{
client
.
SetProxyURL
(
proxyURL
)
}
return
client
return
getSharedReqClient
(
reqClientOptions
{
ProxyURL
:
proxyURL
,
Timeout
:
60
*
time
.
Second
,
})
}
backend/internal/repository/pricing_service.go
View file @
7331220e
...
...
@@ -8,6 +8,7 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service"
)
...
...
@@ -16,10 +17,14 @@ type pricingRemoteClient struct {
}
func
NewPricingRemoteClient
()
service
.
PricingRemoteClient
{
sharedClient
,
err
:=
httpclient
.
GetClient
(
httpclient
.
Options
{
Timeout
:
30
*
time
.
Second
,
})
if
err
!=
nil
{
sharedClient
=
&
http
.
Client
{
Timeout
:
30
*
time
.
Second
}
}
return
&
pricingRemoteClient
{
httpClient
:
&
http
.
Client
{
Timeout
:
30
*
time
.
Second
,
},
httpClient
:
sharedClient
,
}
}
...
...
backend/internal/repository/proxy_probe_service.go
View file @
7331220e
...
...
@@ -2,18 +2,14 @@ package repository
import
(
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service"
"golang.org/x/net/proxy"
)
func
NewProxyExitInfoProber
()
service
.
ProxyExitInfoProber
{
...
...
@@ -27,14 +23,14 @@ type proxyProbeService struct {
}
func
(
s
*
proxyProbeService
)
ProbeProxy
(
ctx
context
.
Context
,
proxyURL
string
)
(
*
service
.
ProxyExitInfo
,
int64
,
error
)
{
transport
,
err
:=
createProxyTransport
(
proxyURL
)
client
,
err
:=
httpclient
.
GetClient
(
httpclient
.
Options
{
ProxyURL
:
proxyURL
,
Timeout
:
15
*
time
.
Second
,
InsecureSkipVerify
:
true
,
ProxyStrict
:
true
,
})
if
err
!=
nil
{
return
nil
,
0
,
fmt
.
Errorf
(
"failed to create proxy transport: %w"
,
err
)
}
client
:=
&
http
.
Client
{
Transport
:
transport
,
Timeout
:
15
*
time
.
Second
,
return
nil
,
0
,
fmt
.
Errorf
(
"failed to create proxy client: %w"
,
err
)
}
startTime
:=
time
.
Now
()
...
...
@@ -78,31 +74,3 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
Country
:
ipInfo
.
Country
,
},
latencyMs
,
nil
}
func
createProxyTransport
(
proxyURL
string
)
(
*
http
.
Transport
,
error
)
{
parsedURL
,
err
:=
url
.
Parse
(
proxyURL
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"invalid proxy URL: %w"
,
err
)
}
transport
:=
&
http
.
Transport
{
TLSClientConfig
:
&
tls
.
Config
{
InsecureSkipVerify
:
true
},
}
switch
parsedURL
.
Scheme
{
case
"http"
,
"https"
:
transport
.
Proxy
=
http
.
ProxyURL
(
parsedURL
)
case
"socks5"
:
dialer
,
err
:=
proxy
.
FromURL
(
parsedURL
,
proxy
.
Direct
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"failed to create socks5 dialer: %w"
,
err
)
}
transport
.
DialContext
=
func
(
ctx
context
.
Context
,
network
,
addr
string
)
(
net
.
Conn
,
error
)
{
return
dialer
.
Dial
(
network
,
addr
)
}
default
:
return
nil
,
fmt
.
Errorf
(
"unsupported proxy protocol: %s"
,
parsedURL
.
Scheme
)
}
return
transport
,
nil
}
Prev
1
2
3
4
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment