Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
a14dfb76
"vscode:/vscode.git/clone" did not exist on "73089bbfdf620ad9cbb02dcc8dc925ba81efe2d7"
Commit
a14dfb76
authored
Feb 07, 2026
by
yangjianbo
Browse files
Merge branch 'dev-release'
parents
f3605ddc
2588fa6a
Changes
62
Hide whitespace changes
Inline
Side-by-side
backend/cmd/server/VERSION
View file @
a14dfb76
0.1.70
0.1.70
.2
backend/cmd/server/wire_gen.go
View file @
a14dfb76
...
...
@@ -65,8 +65,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
apiKeyAuthCacheInvalidator
:=
service
.
ProvideAPIKeyAuthCacheInvalidator
(
apiKeyService
)
promoService
:=
service
.
NewPromoService
(
promoCodeRepository
,
userRepository
,
billingCacheService
,
client
,
apiKeyAuthCacheInvalidator
)
authService
:=
service
.
NewAuthService
(
userRepository
,
redeemCodeRepository
,
refreshTokenCache
,
configConfig
,
settingService
,
emailService
,
turnstileService
,
emailQueueService
,
promoService
)
userService
:=
service
.
NewUserService
(
userRepository
,
apiKeyAuthCacheInvalidator
)
subscriptionService
:=
service
.
NewSubscriptionService
(
groupRepository
,
userSubscriptionRepository
,
billingCacheService
)
userService
:=
service
.
NewUserService
(
userRepository
,
apiKeyAuthCacheInvalidator
,
billingCache
)
subscriptionService
:=
service
.
NewSubscriptionService
(
groupRepository
,
userSubscriptionRepository
,
billingCacheService
,
configConfig
)
redeemCache
:=
repository
.
NewRedeemCache
(
redisClient
)
redeemService
:=
service
.
NewRedeemService
(
redeemCodeRepository
,
userRepository
,
subscriptionService
,
redeemCache
,
billingCacheService
,
client
,
apiKeyAuthCacheInvalidator
)
secretEncryptor
,
err
:=
repository
.
NewAESEncryptor
(
configConfig
)
...
...
backend/internal/config/config.go
View file @
a14dfb76
...
...
@@ -38,31 +38,32 @@ const (
)
type
Config
struct
{
Server
ServerConfig
`mapstructure:"server"`
CORS
CORSConfig
`mapstructure:"cors"`
Security
SecurityConfig
`mapstructure:"security"`
Billing
BillingConfig
`mapstructure:"billing"`
Turnstile
TurnstileConfig
`mapstructure:"turnstile"`
Database
DatabaseConfig
`mapstructure:"database"`
Redis
RedisConfig
`mapstructure:"redis"`
Ops
OpsConfig
`mapstructure:"ops"`
JWT
JWTConfig
`mapstructure:"jwt"`
Totp
TotpConfig
`mapstructure:"totp"`
LinuxDo
LinuxDoConnectConfig
`mapstructure:"linuxdo_connect"`
Default
DefaultConfig
`mapstructure:"default"`
RateLimit
RateLimitConfig
`mapstructure:"rate_limit"`
Pricing
PricingConfig
`mapstructure:"pricing"`
Gateway
GatewayConfig
`mapstructure:"gateway"`
APIKeyAuth
APIKeyAuthCacheConfig
`mapstructure:"api_key_auth_cache"`
Dashboard
DashboardCacheConfig
`mapstructure:"dashboard_cache"`
DashboardAgg
DashboardAggregationConfig
`mapstructure:"dashboard_aggregation"`
UsageCleanup
UsageCleanupConfig
`mapstructure:"usage_cleanup"`
Concurrency
ConcurrencyConfig
`mapstructure:"concurrency"`
TokenRefresh
TokenRefreshConfig
`mapstructure:"token_refresh"`
RunMode
string
`mapstructure:"run_mode" yaml:"run_mode"`
Timezone
string
`mapstructure:"timezone"`
// e.g. "Asia/Shanghai", "UTC"
Gemini
GeminiConfig
`mapstructure:"gemini"`
Update
UpdateConfig
`mapstructure:"update"`
Server
ServerConfig
`mapstructure:"server"`
CORS
CORSConfig
`mapstructure:"cors"`
Security
SecurityConfig
`mapstructure:"security"`
Billing
BillingConfig
`mapstructure:"billing"`
Turnstile
TurnstileConfig
`mapstructure:"turnstile"`
Database
DatabaseConfig
`mapstructure:"database"`
Redis
RedisConfig
`mapstructure:"redis"`
Ops
OpsConfig
`mapstructure:"ops"`
JWT
JWTConfig
`mapstructure:"jwt"`
Totp
TotpConfig
`mapstructure:"totp"`
LinuxDo
LinuxDoConnectConfig
`mapstructure:"linuxdo_connect"`
Default
DefaultConfig
`mapstructure:"default"`
RateLimit
RateLimitConfig
`mapstructure:"rate_limit"`
Pricing
PricingConfig
`mapstructure:"pricing"`
Gateway
GatewayConfig
`mapstructure:"gateway"`
APIKeyAuth
APIKeyAuthCacheConfig
`mapstructure:"api_key_auth_cache"`
SubscriptionCache
SubscriptionCacheConfig
`mapstructure:"subscription_cache"`
Dashboard
DashboardCacheConfig
`mapstructure:"dashboard_cache"`
DashboardAgg
DashboardAggregationConfig
`mapstructure:"dashboard_aggregation"`
UsageCleanup
UsageCleanupConfig
`mapstructure:"usage_cleanup"`
Concurrency
ConcurrencyConfig
`mapstructure:"concurrency"`
TokenRefresh
TokenRefreshConfig
`mapstructure:"token_refresh"`
RunMode
string
`mapstructure:"run_mode" yaml:"run_mode"`
Timezone
string
`mapstructure:"timezone"`
// e.g. "Asia/Shanghai", "UTC"
Gemini
GeminiConfig
`mapstructure:"gemini"`
Update
UpdateConfig
`mapstructure:"update"`
}
type
GeminiConfig
struct
{
...
...
@@ -147,6 +148,7 @@ type ServerConfig struct {
Host
string
`mapstructure:"host"`
Port
int
`mapstructure:"port"`
Mode
string
`mapstructure:"mode"`
// debug/release
FrontendURL
string
`mapstructure:"frontend_url"`
// 前端基础 URL,用于生成邮件中的外部链接
ReadHeaderTimeout
int
`mapstructure:"read_header_timeout"`
// 读取请求头超时(秒)
IdleTimeout
int
`mapstructure:"idle_timeout"`
// 空闲连接超时(秒)
TrustedProxies
[]
string
`mapstructure:"trusted_proxies"`
// 可信代理列表(CIDR/IP)
...
...
@@ -226,6 +228,9 @@ type GatewayConfig struct {
MaxBodySize
int64
`mapstructure:"max_body_size"`
// ConnectionPoolIsolation: 上游连接池隔离策略(proxy/account/account_proxy)
ConnectionPoolIsolation
string
`mapstructure:"connection_pool_isolation"`
// ForceCodexCLI: 强制将 OpenAI `/v1/responses` 请求按 Codex CLI 处理。
// 用于网关未透传/改写 User-Agent 时的兼容兜底(默认关闭,避免影响其他客户端)。
ForceCodexCLI
bool
`mapstructure:"force_codex_cli"`
// HTTP 上游连接池配置(性能优化:支持高并发场景调优)
// MaxIdleConns: 所有主机的最大空闲连接总数
...
...
@@ -525,6 +530,13 @@ type APIKeyAuthCacheConfig struct {
Singleflight
bool
`mapstructure:"singleflight"`
}
// SubscriptionCacheConfig 订阅认证 L1 缓存配置
type
SubscriptionCacheConfig
struct
{
L1Size
int
`mapstructure:"l1_size"`
L1TTLSeconds
int
`mapstructure:"l1_ttl_seconds"`
JitterPercent
int
`mapstructure:"jitter_percent"`
}
// DashboardCacheConfig 仪表盘统计缓存配置
type
DashboardCacheConfig
struct
{
// Enabled: 是否启用仪表盘缓存
...
...
@@ -630,6 +642,7 @@ func Load() (*Config, error) {
if
cfg
.
Server
.
Mode
==
""
{
cfg
.
Server
.
Mode
=
"debug"
}
cfg
.
Server
.
FrontendURL
=
strings
.
TrimSpace
(
cfg
.
Server
.
FrontendURL
)
cfg
.
JWT
.
Secret
=
strings
.
TrimSpace
(
cfg
.
JWT
.
Secret
)
cfg
.
LinuxDo
.
ClientID
=
strings
.
TrimSpace
(
cfg
.
LinuxDo
.
ClientID
)
cfg
.
LinuxDo
.
ClientSecret
=
strings
.
TrimSpace
(
cfg
.
LinuxDo
.
ClientSecret
)
...
...
@@ -702,7 +715,8 @@ func setDefaults() {
// Server
viper
.
SetDefault
(
"server.host"
,
"0.0.0.0"
)
viper
.
SetDefault
(
"server.port"
,
8080
)
viper
.
SetDefault
(
"server.mode"
,
"debug"
)
viper
.
SetDefault
(
"server.mode"
,
"release"
)
viper
.
SetDefault
(
"server.frontend_url"
,
""
)
viper
.
SetDefault
(
"server.read_header_timeout"
,
30
)
// 30秒读取请求头
viper
.
SetDefault
(
"server.idle_timeout"
,
120
)
// 120秒空闲超时
viper
.
SetDefault
(
"server.trusted_proxies"
,
[]
string
{})
...
...
@@ -737,7 +751,7 @@ func setDefaults() {
viper
.
SetDefault
(
"security.url_allowlist.crs_hosts"
,
[]
string
{})
viper
.
SetDefault
(
"security.url_allowlist.allow_private_hosts"
,
true
)
viper
.
SetDefault
(
"security.url_allowlist.allow_insecure_http"
,
true
)
viper
.
SetDefault
(
"security.response_headers.enabled"
,
fals
e
)
viper
.
SetDefault
(
"security.response_headers.enabled"
,
tru
e
)
viper
.
SetDefault
(
"security.response_headers.additional_allowed"
,
[]
string
{})
viper
.
SetDefault
(
"security.response_headers.force_remove"
,
[]
string
{})
viper
.
SetDefault
(
"security.csp.enabled"
,
true
)
...
...
@@ -775,9 +789,9 @@ func setDefaults() {
viper
.
SetDefault
(
"database.user"
,
"postgres"
)
viper
.
SetDefault
(
"database.password"
,
"postgres"
)
viper
.
SetDefault
(
"database.dbname"
,
"sub2api"
)
viper
.
SetDefault
(
"database.sslmode"
,
"
disable
"
)
viper
.
SetDefault
(
"database.max_open_conns"
,
50
)
viper
.
SetDefault
(
"database.max_idle_conns"
,
1
0
)
viper
.
SetDefault
(
"database.sslmode"
,
"
prefer
"
)
viper
.
SetDefault
(
"database.max_open_conns"
,
256
)
viper
.
SetDefault
(
"database.max_idle_conns"
,
1
28
)
viper
.
SetDefault
(
"database.conn_max_lifetime_minutes"
,
30
)
viper
.
SetDefault
(
"database.conn_max_idle_time_minutes"
,
5
)
...
...
@@ -789,8 +803,8 @@ func setDefaults() {
viper
.
SetDefault
(
"redis.dial_timeout_seconds"
,
5
)
viper
.
SetDefault
(
"redis.read_timeout_seconds"
,
3
)
viper
.
SetDefault
(
"redis.write_timeout_seconds"
,
3
)
viper
.
SetDefault
(
"redis.pool_size"
,
1
28
)
viper
.
SetDefault
(
"redis.min_idle_conns"
,
1
0
)
viper
.
SetDefault
(
"redis.pool_size"
,
1
024
)
viper
.
SetDefault
(
"redis.min_idle_conns"
,
1
28
)
viper
.
SetDefault
(
"redis.enable_tls"
,
false
)
// Ops (vNext)
...
...
@@ -849,6 +863,11 @@ func setDefaults() {
viper
.
SetDefault
(
"api_key_auth_cache.jitter_percent"
,
10
)
viper
.
SetDefault
(
"api_key_auth_cache.singleflight"
,
true
)
// Subscription auth L1 cache
viper
.
SetDefault
(
"subscription_cache.l1_size"
,
16384
)
viper
.
SetDefault
(
"subscription_cache.l1_ttl_seconds"
,
10
)
viper
.
SetDefault
(
"subscription_cache.jitter_percent"
,
10
)
// Dashboard cache
viper
.
SetDefault
(
"dashboard_cache.enabled"
,
true
)
viper
.
SetDefault
(
"dashboard_cache.key_prefix"
,
"sub2api:"
)
...
...
@@ -882,13 +901,14 @@ func setDefaults() {
viper
.
SetDefault
(
"gateway.failover_on_400"
,
false
)
viper
.
SetDefault
(
"gateway.max_account_switches"
,
10
)
viper
.
SetDefault
(
"gateway.max_account_switches_gemini"
,
3
)
viper
.
SetDefault
(
"gateway.force_codex_cli"
,
false
)
viper
.
SetDefault
(
"gateway.antigravity_fallback_cooldown_minutes"
,
1
)
viper
.
SetDefault
(
"gateway.max_body_size"
,
int64
(
100
*
1024
*
1024
))
viper
.
SetDefault
(
"gateway.connection_pool_isolation"
,
ConnectionPoolIsolationAccountProxy
)
// HTTP 上游连接池配置(针对 5000+ 并发用户优化)
viper
.
SetDefault
(
"gateway.max_idle_conns"
,
2
4
0
)
// 最大空闲连接总数(
HTTP/2 场景默认
)
viper
.
SetDefault
(
"gateway.max_idle_conns"
,
2
56
0
)
// 最大空闲连接总数(
高并发场景可调大
)
viper
.
SetDefault
(
"gateway.max_idle_conns_per_host"
,
120
)
// 每主机最大空闲连接(HTTP/2 场景默认)
viper
.
SetDefault
(
"gateway.max_conns_per_host"
,
240
)
// 每主机最大连接数(含活跃
,HTTP/2 场景默认
)
viper
.
SetDefault
(
"gateway.max_conns_per_host"
,
1024
)
// 每主机最大连接数(含活跃
;流式/HTTP1.1 场景可调大,如 2400+
)
viper
.
SetDefault
(
"gateway.idle_conn_timeout_seconds"
,
90
)
// 空闲连接超时(秒)
viper
.
SetDefault
(
"gateway.max_upstream_clients"
,
5000
)
viper
.
SetDefault
(
"gateway.client_idle_ttl_seconds"
,
900
)
...
...
@@ -933,6 +953,22 @@ func setDefaults() {
}
func
(
c
*
Config
)
Validate
()
error
{
if
strings
.
TrimSpace
(
c
.
Server
.
FrontendURL
)
!=
""
{
if
err
:=
ValidateAbsoluteHTTPURL
(
c
.
Server
.
FrontendURL
);
err
!=
nil
{
return
fmt
.
Errorf
(
"server.frontend_url invalid: %w"
,
err
)
}
u
,
err
:=
url
.
Parse
(
strings
.
TrimSpace
(
c
.
Server
.
FrontendURL
))
if
err
!=
nil
{
return
fmt
.
Errorf
(
"server.frontend_url invalid: %w"
,
err
)
}
if
u
.
RawQuery
!=
""
||
u
.
ForceQuery
{
return
fmt
.
Errorf
(
"server.frontend_url invalid: must not include query"
)
}
if
u
.
User
!=
nil
{
return
fmt
.
Errorf
(
"server.frontend_url invalid: must not include userinfo"
)
}
warnIfInsecureURL
(
"server.frontend_url"
,
c
.
Server
.
FrontendURL
)
}
if
c
.
JWT
.
ExpireHour
<=
0
{
return
fmt
.
Errorf
(
"jwt.expire_hour must be positive"
)
}
...
...
backend/internal/config/config_test.go
View file @
a14dfb76
...
...
@@ -87,8 +87,34 @@ func TestLoadDefaultSecurityToggles(t *testing.T) {
if
!
cfg
.
Security
.
URLAllowlist
.
AllowPrivateHosts
{
t
.
Fatalf
(
"URLAllowlist.AllowPrivateHosts = false, want true"
)
}
if
cfg
.
Security
.
ResponseHeaders
.
Enabled
{
t
.
Fatalf
(
"ResponseHeaders.Enabled = true, want false"
)
if
!
cfg
.
Security
.
ResponseHeaders
.
Enabled
{
t
.
Fatalf
(
"ResponseHeaders.Enabled = false, want true"
)
}
}
func
TestLoadDefaultServerMode
(
t
*
testing
.
T
)
{
viper
.
Reset
()
cfg
,
err
:=
Load
()
if
err
!=
nil
{
t
.
Fatalf
(
"Load() error: %v"
,
err
)
}
if
cfg
.
Server
.
Mode
!=
"release"
{
t
.
Fatalf
(
"Server.Mode = %q, want %q"
,
cfg
.
Server
.
Mode
,
"release"
)
}
}
func
TestLoadDefaultDatabaseSSLMode
(
t
*
testing
.
T
)
{
viper
.
Reset
()
cfg
,
err
:=
Load
()
if
err
!=
nil
{
t
.
Fatalf
(
"Load() error: %v"
,
err
)
}
if
cfg
.
Database
.
SSLMode
!=
"prefer"
{
t
.
Fatalf
(
"Database.SSLMode = %q, want %q"
,
cfg
.
Database
.
SSLMode
,
"prefer"
)
}
}
...
...
@@ -424,6 +450,40 @@ func TestValidateAbsoluteHTTPURL(t *testing.T) {
}
}
func
TestValidateServerFrontendURL
(
t
*
testing
.
T
)
{
viper
.
Reset
()
cfg
,
err
:=
Load
()
if
err
!=
nil
{
t
.
Fatalf
(
"Load() error: %v"
,
err
)
}
cfg
.
Server
.
FrontendURL
=
"https://example.com"
if
err
:=
cfg
.
Validate
();
err
!=
nil
{
t
.
Fatalf
(
"Validate() frontend_url valid error: %v"
,
err
)
}
cfg
.
Server
.
FrontendURL
=
"https://example.com/path"
if
err
:=
cfg
.
Validate
();
err
!=
nil
{
t
.
Fatalf
(
"Validate() frontend_url with path valid error: %v"
,
err
)
}
cfg
.
Server
.
FrontendURL
=
"https://example.com?utm=1"
if
err
:=
cfg
.
Validate
();
err
==
nil
{
t
.
Fatalf
(
"Validate() should reject server.frontend_url with query"
)
}
cfg
.
Server
.
FrontendURL
=
"https://user:pass@example.com"
if
err
:=
cfg
.
Validate
();
err
==
nil
{
t
.
Fatalf
(
"Validate() should reject server.frontend_url with userinfo"
)
}
cfg
.
Server
.
FrontendURL
=
"/relative"
if
err
:=
cfg
.
Validate
();
err
==
nil
{
t
.
Fatalf
(
"Validate() should reject relative server.frontend_url"
)
}
}
func
TestValidateFrontendRedirectURL
(
t
*
testing
.
T
)
{
if
err
:=
ValidateFrontendRedirectURL
(
"/auth/callback"
);
err
!=
nil
{
t
.
Fatalf
(
"ValidateFrontendRedirectURL relative error: %v"
,
err
)
...
...
backend/internal/handler/admin/account_handler.go
View file @
a14dfb76
...
...
@@ -3,6 +3,7 @@ package admin
import
(
"errors"
"fmt"
"strconv"
"strings"
"sync"
...
...
@@ -789,57 +790,40 @@ func (h *AccountHandler) BatchUpdateCredentials(c *gin.Context) {
}
ctx
:=
c
.
Request
.
Context
()
success
:=
0
failed
:=
0
results
:=
[]
gin
.
H
{}
// 阶段一:预验证所有账号存在,收集 credentials
type
accountUpdate
struct
{
ID
int64
Credentials
map
[
string
]
any
}
updates
:=
make
([]
accountUpdate
,
0
,
len
(
req
.
AccountIDs
))
for
_
,
accountID
:=
range
req
.
AccountIDs
{
// Get account
account
,
err
:=
h
.
adminService
.
GetAccount
(
ctx
,
accountID
)
if
err
!=
nil
{
failed
++
results
=
append
(
results
,
gin
.
H
{
"account_id"
:
accountID
,
"success"
:
false
,
"error"
:
"Account not found"
,
})
continue
response
.
Error
(
c
,
404
,
fmt
.
Sprintf
(
"Account %d not found"
,
accountID
))
return
}
// Update credentials field
if
account
.
Credentials
==
nil
{
account
.
Credentials
=
make
(
map
[
string
]
any
)
}
account
.
Credentials
[
req
.
Field
]
=
req
.
Value
updates
=
append
(
updates
,
accountUpdate
{
ID
:
accountID
,
Credentials
:
account
.
Credentials
})
}
// Update account
// 阶段二:依次更新,任何失败立即返回(避免部分成功部分失败)
for
_
,
u
:=
range
updates
{
updateInput
:=
&
service
.
UpdateAccountInput
{
Credentials
:
account
.
Credentials
,
Credentials
:
u
.
Credentials
,
}
_
,
err
=
h
.
adminService
.
UpdateAccount
(
ctx
,
accountID
,
updateInput
)
if
err
!=
nil
{
failed
++
results
=
append
(
results
,
gin
.
H
{
"account_id"
:
accountID
,
"success"
:
false
,
"error"
:
err
.
Error
(),
})
continue
if
_
,
err
:=
h
.
adminService
.
UpdateAccount
(
ctx
,
u
.
ID
,
updateInput
);
err
!=
nil
{
response
.
Error
(
c
,
500
,
fmt
.
Sprintf
(
"Failed to update account %d: %v"
,
u
.
ID
,
err
))
return
}
success
++
results
=
append
(
results
,
gin
.
H
{
"account_id"
:
accountID
,
"success"
:
true
,
})
}
response
.
Success
(
c
,
gin
.
H
{
"success"
:
success
,
"failed"
:
failed
,
"results"
:
results
,
"success"
:
len
(
updates
),
"failed"
:
0
,
})
}
...
...
backend/internal/handler/admin/batch_update_credentials_test.go
0 → 100644
View file @
a14dfb76
//go:build unit
package
admin
import
(
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// failingAdminService 嵌入 stubAdminService,可配置 UpdateAccount 在指定 ID 时失败。
type
failingAdminService
struct
{
*
stubAdminService
failOnAccountID
int64
updateCallCount
atomic
.
Int64
}
func
(
f
*
failingAdminService
)
UpdateAccount
(
ctx
context
.
Context
,
id
int64
,
input
*
service
.
UpdateAccountInput
)
(
*
service
.
Account
,
error
)
{
f
.
updateCallCount
.
Add
(
1
)
if
id
==
f
.
failOnAccountID
{
return
nil
,
errors
.
New
(
"database error"
)
}
return
f
.
stubAdminService
.
UpdateAccount
(
ctx
,
id
,
input
)
}
func
setupAccountHandlerWithService
(
adminSvc
service
.
AdminService
)
(
*
gin
.
Engine
,
*
AccountHandler
)
{
gin
.
SetMode
(
gin
.
TestMode
)
router
:=
gin
.
New
()
handler
:=
NewAccountHandler
(
adminSvc
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
router
.
POST
(
"/api/v1/admin/accounts/batch-update-credentials"
,
handler
.
BatchUpdateCredentials
)
return
router
,
handler
}
func
TestBatchUpdateCredentials_AllSuccess
(
t
*
testing
.
T
)
{
svc
:=
&
failingAdminService
{
stubAdminService
:
newStubAdminService
()}
router
,
_
:=
setupAccountHandlerWithService
(
svc
)
body
,
_
:=
json
.
Marshal
(
BatchUpdateCredentialsRequest
{
AccountIDs
:
[]
int64
{
1
,
2
,
3
},
Field
:
"account_uuid"
,
Value
:
"test-uuid"
,
})
w
:=
httptest
.
NewRecorder
()
req
,
_
:=
http
.
NewRequest
(
"POST"
,
"/api/v1/admin/accounts/batch-update-credentials"
,
bytes
.
NewReader
(
body
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
,
"全部成功时应返回 200"
)
require
.
Equal
(
t
,
int64
(
3
),
svc
.
updateCallCount
.
Load
(),
"应调用 3 次 UpdateAccount"
)
}
func
TestBatchUpdateCredentials_FailFast
(
t
*
testing
.
T
)
{
// 让第 2 个账号(ID=2)更新时失败
svc
:=
&
failingAdminService
{
stubAdminService
:
newStubAdminService
(),
failOnAccountID
:
2
,
}
router
,
_
:=
setupAccountHandlerWithService
(
svc
)
body
,
_
:=
json
.
Marshal
(
BatchUpdateCredentialsRequest
{
AccountIDs
:
[]
int64
{
1
,
2
,
3
},
Field
:
"org_uuid"
,
Value
:
"test-org"
,
})
w
:=
httptest
.
NewRecorder
()
req
,
_
:=
http
.
NewRequest
(
"POST"
,
"/api/v1/admin/accounts/batch-update-credentials"
,
bytes
.
NewReader
(
body
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusInternalServerError
,
w
.
Code
,
"ID=2 失败时应返回 500"
)
// 验证 fail-fast:ID=1 更新成功,ID=2 失败,ID=3 不应被调用
require
.
Equal
(
t
,
int64
(
2
),
svc
.
updateCallCount
.
Load
(),
"fail-fast: 应只调用 2 次 UpdateAccount(ID=1 成功、ID=2 失败后停止)"
)
}
func
TestBatchUpdateCredentials_FirstAccountNotFound
(
t
*
testing
.
T
)
{
// GetAccount 在 stubAdminService 中总是成功的,需要创建一个 GetAccount 会失败的 stub
svc
:=
&
getAccountFailingService
{
stubAdminService
:
newStubAdminService
(),
failOnAccountID
:
1
,
}
router
,
_
:=
setupAccountHandlerWithService
(
svc
)
body
,
_
:=
json
.
Marshal
(
BatchUpdateCredentialsRequest
{
AccountIDs
:
[]
int64
{
1
,
2
,
3
},
Field
:
"account_uuid"
,
Value
:
"test"
,
})
w
:=
httptest
.
NewRecorder
()
req
,
_
:=
http
.
NewRequest
(
"POST"
,
"/api/v1/admin/accounts/batch-update-credentials"
,
bytes
.
NewReader
(
body
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusNotFound
,
w
.
Code
,
"第一阶段验证失败应返回 404"
)
}
// getAccountFailingService 模拟 GetAccount 在特定 ID 时返回 not found。
type
getAccountFailingService
struct
{
*
stubAdminService
failOnAccountID
int64
}
func
(
f
*
getAccountFailingService
)
GetAccount
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
Account
,
error
)
{
if
id
==
f
.
failOnAccountID
{
return
nil
,
errors
.
New
(
"not found"
)
}
return
f
.
stubAdminService
.
GetAccount
(
ctx
,
id
)
}
func
TestBatchUpdateCredentials_InterceptWarmupRequests_NonBool
(
t
*
testing
.
T
)
{
svc
:=
&
failingAdminService
{
stubAdminService
:
newStubAdminService
()}
router
,
_
:=
setupAccountHandlerWithService
(
svc
)
// intercept_warmup_requests 传入非 bool 类型(string),应返回 400
body
,
_
:=
json
.
Marshal
(
map
[
string
]
any
{
"account_ids"
:
[]
int64
{
1
},
"field"
:
"intercept_warmup_requests"
,
"value"
:
"not-a-bool"
,
})
w
:=
httptest
.
NewRecorder
()
req
,
_
:=
http
.
NewRequest
(
"POST"
,
"/api/v1/admin/accounts/batch-update-credentials"
,
bytes
.
NewReader
(
body
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
w
.
Code
,
"intercept_warmup_requests 传入非 bool 值应返回 400"
)
}
func
TestBatchUpdateCredentials_InterceptWarmupRequests_ValidBool
(
t
*
testing
.
T
)
{
svc
:=
&
failingAdminService
{
stubAdminService
:
newStubAdminService
()}
router
,
_
:=
setupAccountHandlerWithService
(
svc
)
body
,
_
:=
json
.
Marshal
(
map
[
string
]
any
{
"account_ids"
:
[]
int64
{
1
},
"field"
:
"intercept_warmup_requests"
,
"value"
:
true
,
})
w
:=
httptest
.
NewRecorder
()
req
,
_
:=
http
.
NewRequest
(
"POST"
,
"/api/v1/admin/accounts/batch-update-credentials"
,
bytes
.
NewReader
(
body
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
,
"intercept_warmup_requests 传入合法 bool 值应返回 200"
)
}
func
TestBatchUpdateCredentials_AccountUUID_NonString
(
t
*
testing
.
T
)
{
svc
:=
&
failingAdminService
{
stubAdminService
:
newStubAdminService
()}
router
,
_
:=
setupAccountHandlerWithService
(
svc
)
// account_uuid 传入非 string 类型(number),应返回 400
body
,
_
:=
json
.
Marshal
(
map
[
string
]
any
{
"account_ids"
:
[]
int64
{
1
},
"field"
:
"account_uuid"
,
"value"
:
12345
,
})
w
:=
httptest
.
NewRecorder
()
req
,
_
:=
http
.
NewRequest
(
"POST"
,
"/api/v1/admin/accounts/batch-update-credentials"
,
bytes
.
NewReader
(
body
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
w
.
Code
,
"account_uuid 传入非 string 值应返回 400"
)
}
func
TestBatchUpdateCredentials_AccountUUID_NullValue
(
t
*
testing
.
T
)
{
svc
:=
&
failingAdminService
{
stubAdminService
:
newStubAdminService
()}
router
,
_
:=
setupAccountHandlerWithService
(
svc
)
// account_uuid 传入 null(设置为空),应正常通过
body
,
_
:=
json
.
Marshal
(
map
[
string
]
any
{
"account_ids"
:
[]
int64
{
1
},
"field"
:
"account_uuid"
,
"value"
:
nil
,
})
w
:=
httptest
.
NewRecorder
()
req
,
_
:=
http
.
NewRequest
(
"POST"
,
"/api/v1/admin/accounts/batch-update-credentials"
,
bytes
.
NewReader
(
body
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
,
"account_uuid 传入 null 应返回 200"
)
}
backend/internal/handler/admin/dashboard_handler.go
View file @
a14dfb76
...
...
@@ -379,7 +379,7 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
return
}
stats
,
err
:=
h
.
dashboardService
.
GetBatchUserUsageStats
(
c
.
Request
.
Context
(),
req
.
UserIDs
)
stats
,
err
:=
h
.
dashboardService
.
GetBatchUserUsageStats
(
c
.
Request
.
Context
(),
req
.
UserIDs
,
time
.
Time
{},
time
.
Time
{}
)
if
err
!=
nil
{
response
.
Error
(
c
,
500
,
"Failed to get user usage stats"
)
return
...
...
@@ -407,7 +407,7 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) {
return
}
stats
,
err
:=
h
.
dashboardService
.
GetBatchAPIKeyUsageStats
(
c
.
Request
.
Context
(),
req
.
APIKeyIDs
)
stats
,
err
:=
h
.
dashboardService
.
GetBatchAPIKeyUsageStats
(
c
.
Request
.
Context
(),
req
.
APIKeyIDs
,
time
.
Time
{},
time
.
Time
{}
)
if
err
!=
nil
{
response
.
Error
(
c
,
500
,
"Failed to get API key usage stats"
)
return
...
...
backend/internal/handler/admin/search_truncate_test.go
0 → 100644
View file @
a14dfb76
//go:build unit
package
admin
import
(
"testing"
"github.com/stretchr/testify/require"
)
// truncateSearchByRune 模拟 user_handler.go 中的 search 截断逻辑
func
truncateSearchByRune
(
search
string
,
maxRunes
int
)
string
{
if
runes
:=
[]
rune
(
search
);
len
(
runes
)
>
maxRunes
{
return
string
(
runes
[
:
maxRunes
])
}
return
search
}
func
TestTruncateSearchByRune
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
input
string
maxRunes
int
wantLen
int
// 期望的 rune 长度
}{
{
name
:
"纯中文超长"
,
input
:
string
(
make
([]
rune
,
150
)),
maxRunes
:
100
,
wantLen
:
100
,
},
{
name
:
"纯 ASCII 超长"
,
input
:
string
(
make
([]
byte
,
150
)),
maxRunes
:
100
,
wantLen
:
100
,
},
{
name
:
"空字符串"
,
input
:
""
,
maxRunes
:
100
,
wantLen
:
0
,
},
{
name
:
"恰好 100 个字符"
,
input
:
string
(
make
([]
rune
,
100
)),
maxRunes
:
100
,
wantLen
:
100
,
},
{
name
:
"不足 100 字符不截断"
,
input
:
"hello世界"
,
maxRunes
:
100
,
wantLen
:
7
,
},
}
for
_
,
tc
:=
range
tests
{
t
.
Run
(
tc
.
name
,
func
(
t
*
testing
.
T
)
{
result
:=
truncateSearchByRune
(
tc
.
input
,
tc
.
maxRunes
)
require
.
Equal
(
t
,
tc
.
wantLen
,
len
([]
rune
(
result
)))
})
}
}
func
TestTruncateSearchByRune_PreservesMultibyte
(
t
*
testing
.
T
)
{
// 101 个中文字符,截断到 100 个后应该仍然是有效 UTF-8
input
:=
""
for
i
:=
0
;
i
<
101
;
i
++
{
input
+=
"中"
}
result
:=
truncateSearchByRune
(
input
,
100
)
require
.
Equal
(
t
,
100
,
len
([]
rune
(
result
)))
// 验证截断结果是有效的 UTF-8(每个中文字符 3 字节)
require
.
Equal
(
t
,
300
,
len
(
result
))
}
func
TestTruncateSearchByRune_MixedASCIIAndMultibyte
(
t
*
testing
.
T
)
{
// 50 个 ASCII + 51 个中文 = 101 个 rune
input
:=
""
for
i
:=
0
;
i
<
50
;
i
++
{
input
+=
"a"
}
for
i
:=
0
;
i
<
51
;
i
++
{
input
+=
"中"
}
result
:=
truncateSearchByRune
(
input
,
100
)
runes
:=
[]
rune
(
result
)
require
.
Equal
(
t
,
100
,
len
(
runes
))
// 前 50 个应该是 'a',后 50 个应该是 '中'
require
.
Equal
(
t
,
'a'
,
runes
[
0
])
require
.
Equal
(
t
,
'a'
,
runes
[
49
])
require
.
Equal
(
t
,
'中'
,
runes
[
50
])
require
.
Equal
(
t
,
'中'
,
runes
[
99
])
}
backend/internal/handler/admin/user_handler.go
View file @
a14dfb76
...
...
@@ -70,8 +70,8 @@ func (h *UserHandler) List(c *gin.Context) {
search
:=
c
.
Query
(
"search"
)
// 标准化和验证 search 参数
search
=
strings
.
TrimSpace
(
search
)
if
len
(
search
)
>
100
{
search
=
s
earch
[
:
100
]
if
runes
:=
[]
rune
(
search
);
len
(
runes
)
>
100
{
search
=
s
tring
(
runes
[
:
100
]
)
}
filters
:=
service
.
UserListFilters
{
...
...
backend/internal/handler/auth_handler.go
View file @
a14dfb76
...
...
@@ -2,6 +2,7 @@ package handler
import
(
"log/slog"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
...
...
@@ -448,17 +449,12 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) {
return
}
// Build frontend base URL from request
scheme
:=
"https"
if
c
.
Request
.
TLS
==
nil
{
// Check X-Forwarded-Proto header (common in reverse proxy setups)
if
proto
:=
c
.
GetHeader
(
"X-Forwarded-Proto"
);
proto
!=
""
{
scheme
=
proto
}
else
{
scheme
=
"http"
}
frontendBaseURL
:=
strings
.
TrimSpace
(
h
.
cfg
.
Server
.
FrontendURL
)
if
frontendBaseURL
==
""
{
slog
.
Error
(
"server.frontend_url not configured; cannot build password reset link"
)
response
.
InternalError
(
c
,
"Password reset is not configured"
)
return
}
frontendBaseURL
:=
scheme
+
"://"
+
c
.
Request
.
Host
// Request password reset (async)
// Note: This returns success even if email doesn't exist (to prevent enumeration)
...
...
backend/internal/handler/gateway_handler.go
View file @
a14dfb76
...
...
@@ -236,7 +236,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
reqModel
,
failedAccountIDs
,
""
)
// Gemini 不使用会话限制
if
err
!=
nil
{
if
len
(
failedAccountIDs
)
==
0
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts: "
+
err
.
Error
(),
streamStarted
)
log
.
Printf
(
"[Gateway] SelectAccount failed: %v"
,
err
)
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"Service temporarily unavailable"
,
streamStarted
)
return
}
if
lastFailoverErr
!=
nil
{
...
...
@@ -284,12 +285,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if
err
==
nil
&&
canWait
{
accountWaitCounted
=
true
}
// Ensure the wait counter is decremented if we exit before acquiring the slot.
defer
func
()
{
releaseWait
:=
func
()
{
if
accountWaitCounted
{
h
.
concurrencyHelper
.
DecrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
)
accountWaitCounted
=
false
}
}
()
}
accountReleaseFunc
,
err
=
h
.
concurrencyHelper
.
AcquireAccountSlotWithWaitTimeout
(
c
,
...
...
@@ -301,14 +302,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
)
if
err
!=
nil
{
log
.
Printf
(
"Account concurrency acquire failed: %v"
,
err
)
releaseWait
()
h
.
handleConcurrencyError
(
c
,
err
,
"account"
,
streamStarted
)
return
}
// Slot acquired: no longer waiting in queue.
if
accountWaitCounted
{
h
.
concurrencyHelper
.
DecrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
)
accountWaitCounted
=
false
}
releaseWait
()
if
err
:=
h
.
gatewayService
.
BindStickySession
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
account
.
ID
);
err
!=
nil
{
log
.
Printf
(
"Bind sticky session failed: %v"
,
err
)
}
...
...
@@ -398,7 +397,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
currentAPIKey
.
GroupID
,
sessionKey
,
reqModel
,
failedAccountIDs
,
parsedReq
.
MetadataUserID
)
if
err
!=
nil
{
if
len
(
failedAccountIDs
)
==
0
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts: "
+
err
.
Error
(),
streamStarted
)
log
.
Printf
(
"[Gateway] SelectAccount failed: %v"
,
err
)
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"Service temporarily unavailable"
,
streamStarted
)
return
}
if
lastFailoverErr
!=
nil
{
...
...
@@ -446,11 +446,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if
err
==
nil
&&
canWait
{
accountWaitCounted
=
true
}
defer
func
()
{
releaseWait
:=
func
()
{
if
accountWaitCounted
{
h
.
concurrencyHelper
.
DecrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
)
accountWaitCounted
=
false
}
}
()
}
accountReleaseFunc
,
err
=
h
.
concurrencyHelper
.
AcquireAccountSlotWithWaitTimeout
(
c
,
...
...
@@ -462,13 +463,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
)
if
err
!=
nil
{
log
.
Printf
(
"Account concurrency acquire failed: %v"
,
err
)
releaseWait
()
h
.
handleConcurrencyError
(
c
,
err
,
"account"
,
streamStarted
)
return
}
if
accountWaitCounted
{
h
.
concurrencyHelper
.
DecrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
)
accountWaitCounted
=
false
}
// Slot acquired: no longer waiting in queue.
releaseWait
()
if
err
:=
h
.
gatewayService
.
BindStickySession
(
c
.
Request
.
Context
(),
currentAPIKey
.
GroupID
,
sessionKey
,
account
.
ID
);
err
!=
nil
{
log
.
Printf
(
"Bind sticky session failed: %v"
,
err
)
}
...
...
@@ -967,7 +967,8 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 选择支持该模型的账号
account
,
err
:=
h
.
gatewayService
.
SelectAccountForModel
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionHash
,
parsedReq
.
Model
)
if
err
!=
nil
{
h
.
errorResponse
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts: "
+
err
.
Error
())
log
.
Printf
(
"[Gateway] SelectAccountForModel failed: %v"
,
err
)
h
.
errorResponse
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"Service temporarily unavailable"
)
return
}
setOpsSelectedAccount
(
c
,
account
.
ID
)
...
...
@@ -1238,7 +1239,8 @@ func billingErrorDetails(err error) (status int, code, message string) {
}
msg
:=
pkgerrors
.
Message
(
err
)
if
msg
==
""
{
msg
=
err
.
Error
()
log
.
Printf
(
"[Gateway] billing error details: %v"
,
err
)
msg
=
"Billing error"
}
return
http
.
StatusForbidden
,
"billing_error"
,
msg
}
backend/internal/handler/openai_gateway_handler.go
View file @
a14dfb76
...
...
@@ -28,6 +28,7 @@ type OpenAIGatewayHandler struct {
errorPassthroughService
*
service
.
ErrorPassthroughService
concurrencyHelper
*
ConcurrencyHelper
maxAccountSwitches
int
cfg
*
config
.
Config
}
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
...
...
@@ -54,6 +55,7 @@ func NewOpenAIGatewayHandler(
errorPassthroughService
:
errorPassthroughService
,
concurrencyHelper
:
NewConcurrencyHelper
(
concurrencyService
,
SSEPingFormatComment
,
pingInterval
),
maxAccountSwitches
:
maxAccountSwitches
,
cfg
:
cfg
,
}
}
...
...
@@ -109,7 +111,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
}
userAgent
:=
c
.
GetHeader
(
"User-Agent"
)
if
!
openai
.
IsCodexCLIRequest
(
userAgent
)
{
isCodexCLI
:=
openai
.
IsCodexCLIRequest
(
userAgent
)
||
(
h
.
cfg
!=
nil
&&
h
.
cfg
.
Gateway
.
ForceCodexCLI
)
if
!
isCodexCLI
{
existingInstructions
,
_
:=
reqBody
[
"instructions"
]
.
(
string
)
if
strings
.
TrimSpace
(
existingInstructions
)
==
""
{
if
instructions
:=
strings
.
TrimSpace
(
service
.
GetOpenCodeInstructions
());
instructions
!=
""
{
...
...
@@ -218,7 +221,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
if
err
!=
nil
{
log
.
Printf
(
"[OpenAI Handler] SelectAccount failed: %v"
,
err
)
if
len
(
failedAccountIDs
)
==
0
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts: "
+
err
.
Error
(),
streamStarted
)
log
.
Printf
(
"[OpenAI Gateway] SelectAccount failed: %v"
,
err
)
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"Service temporarily unavailable"
,
streamStarted
)
return
}
if
lastFailoverErr
!=
nil
{
...
...
@@ -251,11 +255,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
if
err
==
nil
&&
canWait
{
accountWaitCounted
=
true
}
defer
func
()
{
releaseWait
:=
func
()
{
if
accountWaitCounted
{
h
.
concurrencyHelper
.
DecrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
)
accountWaitCounted
=
false
}
}
()
}
accountReleaseFunc
,
err
=
h
.
concurrencyHelper
.
AcquireAccountSlotWithWaitTimeout
(
c
,
...
...
@@ -267,13 +272,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
)
if
err
!=
nil
{
log
.
Printf
(
"Account concurrency acquire failed: %v"
,
err
)
releaseWait
()
h
.
handleConcurrencyError
(
c
,
err
,
"account"
,
streamStarted
)
return
}
if
accountWaitCounted
{
h
.
concurrencyHelper
.
DecrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
)
accountWaitCounted
=
false
}
// Slot acquired: no longer waiting in queue.
releaseWait
()
if
err
:=
h
.
gatewayService
.
BindStickySession
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionHash
,
account
.
ID
);
err
!=
nil
{
log
.
Printf
(
"Bind sticky session failed: %v"
,
err
)
}
...
...
backend/internal/handler/usage_handler.go
View file @
a14dfb76
...
...
@@ -392,7 +392,7 @@ func (h *UsageHandler) DashboardAPIKeysUsage(c *gin.Context) {
return
}
stats
,
err
:=
h
.
usageService
.
GetBatchAPIKeyUsageStats
(
c
.
Request
.
Context
(),
validAPIKeyIDs
)
stats
,
err
:=
h
.
usageService
.
GetBatchAPIKeyUsageStats
(
c
.
Request
.
Context
(),
validAPIKeyIDs
,
time
.
Time
{},
time
.
Time
{}
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
...
...
backend/internal/pkg/antigravity/response_transformer.go
View file @
a14dfb76
package
antigravity
import
(
"crypto/rand"
"encoding/json"
"fmt"
"log"
...
...
@@ -341,12 +342,16 @@ func buildGroundingText(grounding *GeminiGroundingMetadata) string {
return
builder
.
String
()
}
// generateRandomID 生成随机 ID
// generateRandomID 生成
密码学安全的
随机 ID
func
generateRandomID
()
string
{
const
chars
=
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
result
:=
make
([]
byte
,
12
)
for
i
:=
range
result
{
result
[
i
]
=
chars
[
i
%
len
(
chars
)]
randBytes
:=
make
([]
byte
,
12
)
if
_
,
err
:=
rand
.
Read
(
randBytes
);
err
!=
nil
{
panic
(
"crypto/rand unavailable: "
+
err
.
Error
())
}
for
i
,
b
:=
range
randBytes
{
result
[
i
]
=
chars
[
int
(
b
)
%
len
(
chars
)]
}
return
string
(
result
)
}
backend/internal/pkg/antigravity/response_transformer_test.go
0 → 100644
View file @
a14dfb76
//go:build unit
package
antigravity
import
(
"testing"
"github.com/stretchr/testify/require"
)
func
TestGenerateRandomID_Uniqueness
(
t
*
testing
.
T
)
{
seen
:=
make
(
map
[
string
]
struct
{},
100
)
for
i
:=
0
;
i
<
100
;
i
++
{
id
:=
generateRandomID
()
require
.
Len
(
t
,
id
,
12
,
"ID 长度应为 12"
)
_
,
dup
:=
seen
[
id
]
require
.
False
(
t
,
dup
,
"第 %d 次调用生成了重复 ID: %s"
,
i
,
id
)
seen
[
id
]
=
struct
{}{}
}
}
func
TestGenerateRandomID_Charset
(
t
*
testing
.
T
)
{
const
validChars
=
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
validSet
:=
make
(
map
[
byte
]
struct
{},
len
(
validChars
))
for
i
:=
0
;
i
<
len
(
validChars
);
i
++
{
validSet
[
validChars
[
i
]]
=
struct
{}{}
}
for
i
:=
0
;
i
<
50
;
i
++
{
id
:=
generateRandomID
()
for
j
:=
0
;
j
<
len
(
id
);
j
++
{
_
,
ok
:=
validSet
[
id
[
j
]]
require
.
True
(
t
,
ok
,
"ID 包含非法字符: %c (ID=%s)"
,
id
[
j
],
id
)
}
}
}
backend/internal/pkg/ip/ip.go
View file @
a14dfb76
...
...
@@ -54,29 +54,34 @@ func normalizeIP(ip string) string {
return
ip
}
// isPrivateIP 检查 IP 是否为私有地址。
func
isPrivateIP
(
ipStr
string
)
bool
{
ip
:=
net
.
ParseIP
(
ipStr
)
if
ip
==
nil
{
return
false
}
// privateNets 预编译私有 IP CIDR 块,避免每次调用 isPrivateIP 时重复解析
var
privateNets
[]
*
net
.
IPNet
// 私有 IP 范围
privateBlocks
:=
[]
string
{
func
init
()
{
for
_
,
cidr
:=
range
[]
string
{
"10.0.0.0/8"
,
"172.16.0.0/12"
,
"192.168.0.0/16"
,
"127.0.0.0/8"
,
"::1/128"
,
"fc00::/7"
,
}
for
_
,
block
:=
range
privateBlocks
{
_
,
cidr
,
err
:=
net
.
ParseCIDR
(
block
)
}
{
_
,
block
,
err
:=
net
.
ParseCIDR
(
cidr
)
if
err
!=
nil
{
continue
panic
(
"invalid CIDR: "
+
cidr
)
}
if
cidr
.
Contains
(
ip
)
{
privateNets
=
append
(
privateNets
,
block
)
}
}
// isPrivateIP 检查 IP 是否为私有地址。
func
isPrivateIP
(
ipStr
string
)
bool
{
ip
:=
net
.
ParseIP
(
ipStr
)
if
ip
==
nil
{
return
false
}
for
_
,
block
:=
range
privateNets
{
if
block
.
Contains
(
ip
)
{
return
true
}
}
...
...
backend/internal/pkg/ip/ip_test.go
0 → 100644
View file @
a14dfb76
//go:build unit
package
ip
import
(
"testing"
"github.com/stretchr/testify/require"
)
func
TestIsPrivateIP
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
ip
string
expected
bool
}{
// 私有 IPv4
{
"10.x 私有地址"
,
"10.0.0.1"
,
true
},
{
"10.x 私有地址段末"
,
"10.255.255.255"
,
true
},
{
"172.16.x 私有地址"
,
"172.16.0.1"
,
true
},
{
"172.31.x 私有地址"
,
"172.31.255.255"
,
true
},
{
"192.168.x 私有地址"
,
"192.168.1.1"
,
true
},
{
"127.0.0.1 本地回环"
,
"127.0.0.1"
,
true
},
{
"127.x 回环段"
,
"127.255.255.255"
,
true
},
// 公网 IPv4
{
"8.8.8.8 公网 DNS"
,
"8.8.8.8"
,
false
},
{
"1.1.1.1 公网"
,
"1.1.1.1"
,
false
},
{
"172.15.255.255 非私有"
,
"172.15.255.255"
,
false
},
{
"172.32.0.0 非私有"
,
"172.32.0.0"
,
false
},
{
"11.0.0.1 公网"
,
"11.0.0.1"
,
false
},
// IPv6
{
"::1 IPv6 回环"
,
"::1"
,
true
},
{
"fc00:: IPv6 私有"
,
"fc00::1"
,
true
},
{
"fd00:: IPv6 私有"
,
"fd00::1"
,
true
},
{
"2001:db8::1 IPv6 公网"
,
"2001:db8::1"
,
false
},
// 无效输入
{
"空字符串"
,
""
,
false
},
{
"非法字符串"
,
"not-an-ip"
,
false
},
{
"不完整 IP"
,
"192.168"
,
false
},
}
for
_
,
tc
:=
range
tests
{
t
.
Run
(
tc
.
name
,
func
(
t
*
testing
.
T
)
{
got
:=
isPrivateIP
(
tc
.
ip
)
require
.
Equal
(
t
,
tc
.
expected
,
got
,
"isPrivateIP(%q)"
,
tc
.
ip
)
})
}
}
backend/internal/pkg/oauth/oauth.go
View file @
a14dfb76
...
...
@@ -50,6 +50,7 @@ type OAuthSession struct {
type
SessionStore
struct
{
mu
sync
.
RWMutex
sessions
map
[
string
]
*
OAuthSession
stopOnce
sync
.
Once
stopCh
chan
struct
{}
}
...
...
@@ -65,7 +66,9 @@ func NewSessionStore() *SessionStore {
// Stop stops the cleanup goroutine
func
(
s
*
SessionStore
)
Stop
()
{
close
(
s
.
stopCh
)
s
.
stopOnce
.
Do
(
func
()
{
close
(
s
.
stopCh
)
})
}
// Set stores a session
...
...
backend/internal/pkg/oauth/oauth_test.go
0 → 100644
View file @
a14dfb76
package
oauth
import
(
"sync"
"testing"
"time"
)
func
TestSessionStore_Stop_Idempotent
(
t
*
testing
.
T
)
{
store
:=
NewSessionStore
()
store
.
Stop
()
store
.
Stop
()
select
{
case
<-
store
.
stopCh
:
// ok
case
<-
time
.
After
(
time
.
Second
)
:
t
.
Fatal
(
"stopCh 未关闭"
)
}
}
func
TestSessionStore_Stop_Concurrent
(
t
*
testing
.
T
)
{
store
:=
NewSessionStore
()
var
wg
sync
.
WaitGroup
for
range
50
{
wg
.
Add
(
1
)
go
func
()
{
defer
wg
.
Done
()
store
.
Stop
()
}()
}
wg
.
Wait
()
select
{
case
<-
store
.
stopCh
:
// ok
case
<-
time
.
After
(
time
.
Second
)
:
t
.
Fatal
(
"stopCh 未关闭"
)
}
}
backend/internal/pkg/openai/oauth.go
View file @
a14dfb76
...
...
@@ -47,6 +47,7 @@ type OAuthSession struct {
type
SessionStore
struct
{
mu
sync
.
RWMutex
sessions
map
[
string
]
*
OAuthSession
stopOnce
sync
.
Once
stopCh
chan
struct
{}
}
...
...
@@ -92,7 +93,9 @@ func (s *SessionStore) Delete(sessionID string) {
// Stop stops the cleanup goroutine
func
(
s
*
SessionStore
)
Stop
()
{
close
(
s
.
stopCh
)
s
.
stopOnce
.
Do
(
func
()
{
close
(
s
.
stopCh
)
})
}
// cleanup removes expired sessions periodically
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment