Commit a14dfb76 authored by yangjianbo's avatar yangjianbo
Browse files

Merge branch 'dev-release'

parents f3605ddc 2588fa6a
...@@ -65,8 +65,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -65,8 +65,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService) authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator) userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService) subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, configConfig)
redeemCache := repository.NewRedeemCache(redisClient) redeemCache := repository.NewRedeemCache(redisClient)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator) redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
secretEncryptor, err := repository.NewAESEncryptor(configConfig) secretEncryptor, err := repository.NewAESEncryptor(configConfig)
......
...@@ -54,6 +54,7 @@ type Config struct { ...@@ -54,6 +54,7 @@ type Config struct {
Pricing PricingConfig `mapstructure:"pricing"` Pricing PricingConfig `mapstructure:"pricing"`
Gateway GatewayConfig `mapstructure:"gateway"` Gateway GatewayConfig `mapstructure:"gateway"`
APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"` APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"`
SubscriptionCache SubscriptionCacheConfig `mapstructure:"subscription_cache"`
Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"` Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"`
DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"` DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"`
UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"` UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"`
...@@ -147,6 +148,7 @@ type ServerConfig struct { ...@@ -147,6 +148,7 @@ type ServerConfig struct {
Host string `mapstructure:"host"` Host string `mapstructure:"host"`
Port int `mapstructure:"port"` Port int `mapstructure:"port"`
Mode string `mapstructure:"mode"` // debug/release Mode string `mapstructure:"mode"` // debug/release
FrontendURL string `mapstructure:"frontend_url"` // 前端基础 URL,用于生成邮件中的外部链接
ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒) ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒) IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
TrustedProxies []string `mapstructure:"trusted_proxies"` // 可信代理列表(CIDR/IP) TrustedProxies []string `mapstructure:"trusted_proxies"` // 可信代理列表(CIDR/IP)
...@@ -226,6 +228,9 @@ type GatewayConfig struct { ...@@ -226,6 +228,9 @@ type GatewayConfig struct {
MaxBodySize int64 `mapstructure:"max_body_size"` MaxBodySize int64 `mapstructure:"max_body_size"`
// ConnectionPoolIsolation: 上游连接池隔离策略(proxy/account/account_proxy) // ConnectionPoolIsolation: 上游连接池隔离策略(proxy/account/account_proxy)
ConnectionPoolIsolation string `mapstructure:"connection_pool_isolation"` ConnectionPoolIsolation string `mapstructure:"connection_pool_isolation"`
// ForceCodexCLI: 强制将 OpenAI `/v1/responses` 请求按 Codex CLI 处理。
// 用于网关未透传/改写 User-Agent 时的兼容兜底(默认关闭,避免影响其他客户端)。
ForceCodexCLI bool `mapstructure:"force_codex_cli"`
// HTTP 上游连接池配置(性能优化:支持高并发场景调优) // HTTP 上游连接池配置(性能优化:支持高并发场景调优)
// MaxIdleConns: 所有主机的最大空闲连接总数 // MaxIdleConns: 所有主机的最大空闲连接总数
...@@ -525,6 +530,13 @@ type APIKeyAuthCacheConfig struct { ...@@ -525,6 +530,13 @@ type APIKeyAuthCacheConfig struct {
Singleflight bool `mapstructure:"singleflight"` Singleflight bool `mapstructure:"singleflight"`
} }
// SubscriptionCacheConfig 订阅认证 L1 缓存配置
type SubscriptionCacheConfig struct {
L1Size int `mapstructure:"l1_size"`
L1TTLSeconds int `mapstructure:"l1_ttl_seconds"`
JitterPercent int `mapstructure:"jitter_percent"`
}
// DashboardCacheConfig 仪表盘统计缓存配置 // DashboardCacheConfig 仪表盘统计缓存配置
type DashboardCacheConfig struct { type DashboardCacheConfig struct {
// Enabled: 是否启用仪表盘缓存 // Enabled: 是否启用仪表盘缓存
...@@ -630,6 +642,7 @@ func Load() (*Config, error) { ...@@ -630,6 +642,7 @@ func Load() (*Config, error) {
if cfg.Server.Mode == "" { if cfg.Server.Mode == "" {
cfg.Server.Mode = "debug" cfg.Server.Mode = "debug"
} }
cfg.Server.FrontendURL = strings.TrimSpace(cfg.Server.FrontendURL)
cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret) cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret)
cfg.LinuxDo.ClientID = strings.TrimSpace(cfg.LinuxDo.ClientID) cfg.LinuxDo.ClientID = strings.TrimSpace(cfg.LinuxDo.ClientID)
cfg.LinuxDo.ClientSecret = strings.TrimSpace(cfg.LinuxDo.ClientSecret) cfg.LinuxDo.ClientSecret = strings.TrimSpace(cfg.LinuxDo.ClientSecret)
...@@ -702,7 +715,8 @@ func setDefaults() { ...@@ -702,7 +715,8 @@ func setDefaults() {
// Server // Server
viper.SetDefault("server.host", "0.0.0.0") viper.SetDefault("server.host", "0.0.0.0")
viper.SetDefault("server.port", 8080) viper.SetDefault("server.port", 8080)
viper.SetDefault("server.mode", "debug") viper.SetDefault("server.mode", "release")
viper.SetDefault("server.frontend_url", "")
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头 viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时 viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
viper.SetDefault("server.trusted_proxies", []string{}) viper.SetDefault("server.trusted_proxies", []string{})
...@@ -737,7 +751,7 @@ func setDefaults() { ...@@ -737,7 +751,7 @@ func setDefaults() {
viper.SetDefault("security.url_allowlist.crs_hosts", []string{}) viper.SetDefault("security.url_allowlist.crs_hosts", []string{})
viper.SetDefault("security.url_allowlist.allow_private_hosts", true) viper.SetDefault("security.url_allowlist.allow_private_hosts", true)
viper.SetDefault("security.url_allowlist.allow_insecure_http", true) viper.SetDefault("security.url_allowlist.allow_insecure_http", true)
viper.SetDefault("security.response_headers.enabled", false) viper.SetDefault("security.response_headers.enabled", true)
viper.SetDefault("security.response_headers.additional_allowed", []string{}) viper.SetDefault("security.response_headers.additional_allowed", []string{})
viper.SetDefault("security.response_headers.force_remove", []string{}) viper.SetDefault("security.response_headers.force_remove", []string{})
viper.SetDefault("security.csp.enabled", true) viper.SetDefault("security.csp.enabled", true)
...@@ -775,9 +789,9 @@ func setDefaults() { ...@@ -775,9 +789,9 @@ func setDefaults() {
viper.SetDefault("database.user", "postgres") viper.SetDefault("database.user", "postgres")
viper.SetDefault("database.password", "postgres") viper.SetDefault("database.password", "postgres")
viper.SetDefault("database.dbname", "sub2api") viper.SetDefault("database.dbname", "sub2api")
viper.SetDefault("database.sslmode", "disable") viper.SetDefault("database.sslmode", "prefer")
viper.SetDefault("database.max_open_conns", 50) viper.SetDefault("database.max_open_conns", 256)
viper.SetDefault("database.max_idle_conns", 10) viper.SetDefault("database.max_idle_conns", 128)
viper.SetDefault("database.conn_max_lifetime_minutes", 30) viper.SetDefault("database.conn_max_lifetime_minutes", 30)
viper.SetDefault("database.conn_max_idle_time_minutes", 5) viper.SetDefault("database.conn_max_idle_time_minutes", 5)
...@@ -789,8 +803,8 @@ func setDefaults() { ...@@ -789,8 +803,8 @@ func setDefaults() {
viper.SetDefault("redis.dial_timeout_seconds", 5) viper.SetDefault("redis.dial_timeout_seconds", 5)
viper.SetDefault("redis.read_timeout_seconds", 3) viper.SetDefault("redis.read_timeout_seconds", 3)
viper.SetDefault("redis.write_timeout_seconds", 3) viper.SetDefault("redis.write_timeout_seconds", 3)
viper.SetDefault("redis.pool_size", 128) viper.SetDefault("redis.pool_size", 1024)
viper.SetDefault("redis.min_idle_conns", 10) viper.SetDefault("redis.min_idle_conns", 128)
viper.SetDefault("redis.enable_tls", false) viper.SetDefault("redis.enable_tls", false)
// Ops (vNext) // Ops (vNext)
...@@ -849,6 +863,11 @@ func setDefaults() { ...@@ -849,6 +863,11 @@ func setDefaults() {
viper.SetDefault("api_key_auth_cache.jitter_percent", 10) viper.SetDefault("api_key_auth_cache.jitter_percent", 10)
viper.SetDefault("api_key_auth_cache.singleflight", true) viper.SetDefault("api_key_auth_cache.singleflight", true)
// Subscription auth L1 cache
viper.SetDefault("subscription_cache.l1_size", 16384)
viper.SetDefault("subscription_cache.l1_ttl_seconds", 10)
viper.SetDefault("subscription_cache.jitter_percent", 10)
// Dashboard cache // Dashboard cache
viper.SetDefault("dashboard_cache.enabled", true) viper.SetDefault("dashboard_cache.enabled", true)
viper.SetDefault("dashboard_cache.key_prefix", "sub2api:") viper.SetDefault("dashboard_cache.key_prefix", "sub2api:")
...@@ -882,13 +901,14 @@ func setDefaults() { ...@@ -882,13 +901,14 @@ func setDefaults() {
viper.SetDefault("gateway.failover_on_400", false) viper.SetDefault("gateway.failover_on_400", false)
viper.SetDefault("gateway.max_account_switches", 10) viper.SetDefault("gateway.max_account_switches", 10)
viper.SetDefault("gateway.max_account_switches_gemini", 3) viper.SetDefault("gateway.max_account_switches_gemini", 3)
viper.SetDefault("gateway.force_codex_cli", false)
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1) viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024)) viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy) viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
// HTTP 上游连接池配置(针对 5000+ 并发用户优化) // HTTP 上游连接池配置(针对 5000+ 并发用户优化)
viper.SetDefault("gateway.max_idle_conns", 240) // 最大空闲连接总数(HTTP/2 场景默认 viper.SetDefault("gateway.max_idle_conns", 2560) // 最大空闲连接总数(高并发场景可调大
viper.SetDefault("gateway.max_idle_conns_per_host", 120) // 每主机最大空闲连接(HTTP/2 场景默认) viper.SetDefault("gateway.max_idle_conns_per_host", 120) // 每主机最大空闲连接(HTTP/2 场景默认)
viper.SetDefault("gateway.max_conns_per_host", 240) // 每主机最大连接数(含活跃,HTTP/2 场景默认 viper.SetDefault("gateway.max_conns_per_host", 1024) // 每主机最大连接数(含活跃;流式/HTTP1.1 场景可调大,如 2400+
viper.SetDefault("gateway.idle_conn_timeout_seconds", 90) // 空闲连接超时(秒) viper.SetDefault("gateway.idle_conn_timeout_seconds", 90) // 空闲连接超时(秒)
viper.SetDefault("gateway.max_upstream_clients", 5000) viper.SetDefault("gateway.max_upstream_clients", 5000)
viper.SetDefault("gateway.client_idle_ttl_seconds", 900) viper.SetDefault("gateway.client_idle_ttl_seconds", 900)
...@@ -933,6 +953,22 @@ func setDefaults() { ...@@ -933,6 +953,22 @@ func setDefaults() {
} }
func (c *Config) Validate() error { func (c *Config) Validate() error {
if strings.TrimSpace(c.Server.FrontendURL) != "" {
if err := ValidateAbsoluteHTTPURL(c.Server.FrontendURL); err != nil {
return fmt.Errorf("server.frontend_url invalid: %w", err)
}
u, err := url.Parse(strings.TrimSpace(c.Server.FrontendURL))
if err != nil {
return fmt.Errorf("server.frontend_url invalid: %w", err)
}
if u.RawQuery != "" || u.ForceQuery {
return fmt.Errorf("server.frontend_url invalid: must not include query")
}
if u.User != nil {
return fmt.Errorf("server.frontend_url invalid: must not include userinfo")
}
warnIfInsecureURL("server.frontend_url", c.Server.FrontendURL)
}
if c.JWT.ExpireHour <= 0 { if c.JWT.ExpireHour <= 0 {
return fmt.Errorf("jwt.expire_hour must be positive") return fmt.Errorf("jwt.expire_hour must be positive")
} }
......
...@@ -87,8 +87,34 @@ func TestLoadDefaultSecurityToggles(t *testing.T) { ...@@ -87,8 +87,34 @@ func TestLoadDefaultSecurityToggles(t *testing.T) {
if !cfg.Security.URLAllowlist.AllowPrivateHosts { if !cfg.Security.URLAllowlist.AllowPrivateHosts {
t.Fatalf("URLAllowlist.AllowPrivateHosts = false, want true") t.Fatalf("URLAllowlist.AllowPrivateHosts = false, want true")
} }
if cfg.Security.ResponseHeaders.Enabled { if !cfg.Security.ResponseHeaders.Enabled {
t.Fatalf("ResponseHeaders.Enabled = true, want false") t.Fatalf("ResponseHeaders.Enabled = false, want true")
}
}
func TestLoadDefaultServerMode(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.Server.Mode != "release" {
t.Fatalf("Server.Mode = %q, want %q", cfg.Server.Mode, "release")
}
}
func TestLoadDefaultDatabaseSSLMode(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.Database.SSLMode != "prefer" {
t.Fatalf("Database.SSLMode = %q, want %q", cfg.Database.SSLMode, "prefer")
} }
} }
...@@ -424,6 +450,40 @@ func TestValidateAbsoluteHTTPURL(t *testing.T) { ...@@ -424,6 +450,40 @@ func TestValidateAbsoluteHTTPURL(t *testing.T) {
} }
} }
func TestValidateServerFrontendURL(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.Server.FrontendURL = "https://example.com"
if err := cfg.Validate(); err != nil {
t.Fatalf("Validate() frontend_url valid error: %v", err)
}
cfg.Server.FrontendURL = "https://example.com/path"
if err := cfg.Validate(); err != nil {
t.Fatalf("Validate() frontend_url with path valid error: %v", err)
}
cfg.Server.FrontendURL = "https://example.com?utm=1"
if err := cfg.Validate(); err == nil {
t.Fatalf("Validate() should reject server.frontend_url with query")
}
cfg.Server.FrontendURL = "https://user:pass@example.com"
if err := cfg.Validate(); err == nil {
t.Fatalf("Validate() should reject server.frontend_url with userinfo")
}
cfg.Server.FrontendURL = "/relative"
if err := cfg.Validate(); err == nil {
t.Fatalf("Validate() should reject relative server.frontend_url")
}
}
func TestValidateFrontendRedirectURL(t *testing.T) { func TestValidateFrontendRedirectURL(t *testing.T) {
if err := ValidateFrontendRedirectURL("/auth/callback"); err != nil { if err := ValidateFrontendRedirectURL("/auth/callback"); err != nil {
t.Fatalf("ValidateFrontendRedirectURL relative error: %v", err) t.Fatalf("ValidateFrontendRedirectURL relative error: %v", err)
......
...@@ -3,6 +3,7 @@ package admin ...@@ -3,6 +3,7 @@ package admin
import ( import (
"errors" "errors"
"fmt"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
...@@ -789,57 +790,40 @@ func (h *AccountHandler) BatchUpdateCredentials(c *gin.Context) { ...@@ -789,57 +790,40 @@ func (h *AccountHandler) BatchUpdateCredentials(c *gin.Context) {
} }
ctx := c.Request.Context() ctx := c.Request.Context()
success := 0
failed := 0
results := []gin.H{}
// 阶段一:预验证所有账号存在,收集 credentials
type accountUpdate struct {
ID int64
Credentials map[string]any
}
updates := make([]accountUpdate, 0, len(req.AccountIDs))
for _, accountID := range req.AccountIDs { for _, accountID := range req.AccountIDs {
// Get account
account, err := h.adminService.GetAccount(ctx, accountID) account, err := h.adminService.GetAccount(ctx, accountID)
if err != nil { if err != nil {
failed++ response.Error(c, 404, fmt.Sprintf("Account %d not found", accountID))
results = append(results, gin.H{ return
"account_id": accountID,
"success": false,
"error": "Account not found",
})
continue
} }
// Update credentials field
if account.Credentials == nil { if account.Credentials == nil {
account.Credentials = make(map[string]any) account.Credentials = make(map[string]any)
} }
account.Credentials[req.Field] = req.Value account.Credentials[req.Field] = req.Value
updates = append(updates, accountUpdate{ID: accountID, Credentials: account.Credentials})
}
// Update account // 阶段二:依次更新,任何失败立即返回(避免部分成功部分失败)
for _, u := range updates {
updateInput := &service.UpdateAccountInput{ updateInput := &service.UpdateAccountInput{
Credentials: account.Credentials, Credentials: u.Credentials,
} }
if _, err := h.adminService.UpdateAccount(ctx, u.ID, updateInput); err != nil {
_, err = h.adminService.UpdateAccount(ctx, accountID, updateInput) response.Error(c, 500, fmt.Sprintf("Failed to update account %d: %v", u.ID, err))
if err != nil { return
failed++
results = append(results, gin.H{
"account_id": accountID,
"success": false,
"error": err.Error(),
})
continue
} }
success++
results = append(results, gin.H{
"account_id": accountID,
"success": true,
})
} }
response.Success(c, gin.H{ response.Success(c, gin.H{
"success": success, "success": len(updates),
"failed": failed, "failed": 0,
"results": results,
}) })
} }
......
//go:build unit
package admin
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// failingAdminService 嵌入 stubAdminService,可配置 UpdateAccount 在指定 ID 时失败。
type failingAdminService struct {
*stubAdminService
failOnAccountID int64
updateCallCount atomic.Int64
}
func (f *failingAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) {
f.updateCallCount.Add(1)
if id == f.failOnAccountID {
return nil, errors.New("database error")
}
return f.stubAdminService.UpdateAccount(ctx, id, input)
}
func setupAccountHandlerWithService(adminSvc service.AdminService) (*gin.Engine, *AccountHandler) {
gin.SetMode(gin.TestMode)
router := gin.New()
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
router.POST("/api/v1/admin/accounts/batch-update-credentials", handler.BatchUpdateCredentials)
return router, handler
}
func TestBatchUpdateCredentials_AllSuccess(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
AccountIDs: []int64{1, 2, 3},
Field: "account_uuid",
Value: "test-uuid",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code, "全部成功时应返回 200")
require.Equal(t, int64(3), svc.updateCallCount.Load(), "应调用 3 次 UpdateAccount")
}
func TestBatchUpdateCredentials_FailFast(t *testing.T) {
// 让第 2 个账号(ID=2)更新时失败
svc := &failingAdminService{
stubAdminService: newStubAdminService(),
failOnAccountID: 2,
}
router, _ := setupAccountHandlerWithService(svc)
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
AccountIDs: []int64{1, 2, 3},
Field: "org_uuid",
Value: "test-org",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusInternalServerError, w.Code, "ID=2 失败时应返回 500")
// 验证 fail-fast:ID=1 更新成功,ID=2 失败,ID=3 不应被调用
require.Equal(t, int64(2), svc.updateCallCount.Load(),
"fail-fast: 应只调用 2 次 UpdateAccount(ID=1 成功、ID=2 失败后停止)")
}
func TestBatchUpdateCredentials_FirstAccountNotFound(t *testing.T) {
// GetAccount 在 stubAdminService 中总是成功的,需要创建一个 GetAccount 会失败的 stub
svc := &getAccountFailingService{
stubAdminService: newStubAdminService(),
failOnAccountID: 1,
}
router, _ := setupAccountHandlerWithService(svc)
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
AccountIDs: []int64{1, 2, 3},
Field: "account_uuid",
Value: "test",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusNotFound, w.Code, "第一阶段验证失败应返回 404")
}
// getAccountFailingService 模拟 GetAccount 在特定 ID 时返回 not found。
type getAccountFailingService struct {
*stubAdminService
failOnAccountID int64
}
func (f *getAccountFailingService) GetAccount(ctx context.Context, id int64) (*service.Account, error) {
if id == f.failOnAccountID {
return nil, errors.New("not found")
}
return f.stubAdminService.GetAccount(ctx, id)
}
func TestBatchUpdateCredentials_InterceptWarmupRequests_NonBool(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
// intercept_warmup_requests 传入非 bool 类型(string),应返回 400
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1},
"field": "intercept_warmup_requests",
"value": "not-a-bool",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusBadRequest, w.Code,
"intercept_warmup_requests 传入非 bool 值应返回 400")
}
func TestBatchUpdateCredentials_InterceptWarmupRequests_ValidBool(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1},
"field": "intercept_warmup_requests",
"value": true,
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code,
"intercept_warmup_requests 传入合法 bool 值应返回 200")
}
func TestBatchUpdateCredentials_AccountUUID_NonString(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
// account_uuid 传入非 string 类型(number),应返回 400
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1},
"field": "account_uuid",
"value": 12345,
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusBadRequest, w.Code,
"account_uuid 传入非 string 值应返回 400")
}
func TestBatchUpdateCredentials_AccountUUID_NullValue(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
// account_uuid 传入 null(设置为空),应正常通过
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1},
"field": "account_uuid",
"value": nil,
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code,
"account_uuid 传入 null 应返回 200")
}
...@@ -379,7 +379,7 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) { ...@@ -379,7 +379,7 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
return return
} }
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs) stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs, time.Time{}, time.Time{})
if err != nil { if err != nil {
response.Error(c, 500, "Failed to get user usage stats") response.Error(c, 500, "Failed to get user usage stats")
return return
...@@ -407,7 +407,7 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) { ...@@ -407,7 +407,7 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) {
return return
} }
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs) stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs, time.Time{}, time.Time{})
if err != nil { if err != nil {
response.Error(c, 500, "Failed to get API key usage stats") response.Error(c, 500, "Failed to get API key usage stats")
return return
......
//go:build unit
package admin
import (
"testing"
"github.com/stretchr/testify/require"
)
// truncateSearchByRune 模拟 user_handler.go 中的 search 截断逻辑
func truncateSearchByRune(search string, maxRunes int) string {
if runes := []rune(search); len(runes) > maxRunes {
return string(runes[:maxRunes])
}
return search
}
func TestTruncateSearchByRune(t *testing.T) {
tests := []struct {
name string
input string
maxRunes int
wantLen int // 期望的 rune 长度
}{
{
name: "纯中文超长",
input: string(make([]rune, 150)),
maxRunes: 100,
wantLen: 100,
},
{
name: "纯 ASCII 超长",
input: string(make([]byte, 150)),
maxRunes: 100,
wantLen: 100,
},
{
name: "空字符串",
input: "",
maxRunes: 100,
wantLen: 0,
},
{
name: "恰好 100 个字符",
input: string(make([]rune, 100)),
maxRunes: 100,
wantLen: 100,
},
{
name: "不足 100 字符不截断",
input: "hello世界",
maxRunes: 100,
wantLen: 7,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := truncateSearchByRune(tc.input, tc.maxRunes)
require.Equal(t, tc.wantLen, len([]rune(result)))
})
}
}
func TestTruncateSearchByRune_PreservesMultibyte(t *testing.T) {
// 101 个中文字符,截断到 100 个后应该仍然是有效 UTF-8
input := ""
for i := 0; i < 101; i++ {
input += "中"
}
result := truncateSearchByRune(input, 100)
require.Equal(t, 100, len([]rune(result)))
// 验证截断结果是有效的 UTF-8(每个中文字符 3 字节)
require.Equal(t, 300, len(result))
}
func TestTruncateSearchByRune_MixedASCIIAndMultibyte(t *testing.T) {
// 50 个 ASCII + 51 个中文 = 101 个 rune
input := ""
for i := 0; i < 50; i++ {
input += "a"
}
for i := 0; i < 51; i++ {
input += "中"
}
result := truncateSearchByRune(input, 100)
runes := []rune(result)
require.Equal(t, 100, len(runes))
// 前 50 个应该是 'a',后 50 个应该是 '中'
require.Equal(t, 'a', runes[0])
require.Equal(t, 'a', runes[49])
require.Equal(t, '中', runes[50])
require.Equal(t, '中', runes[99])
}
...@@ -70,8 +70,8 @@ func (h *UserHandler) List(c *gin.Context) { ...@@ -70,8 +70,8 @@ func (h *UserHandler) List(c *gin.Context) {
search := c.Query("search") search := c.Query("search")
// 标准化和验证 search 参数 // 标准化和验证 search 参数
search = strings.TrimSpace(search) search = strings.TrimSpace(search)
if len(search) > 100 { if runes := []rune(search); len(runes) > 100 {
search = search[:100] search = string(runes[:100])
} }
filters := service.UserListFilters{ filters := service.UserListFilters{
......
...@@ -2,6 +2,7 @@ package handler ...@@ -2,6 +2,7 @@ package handler
import ( import (
"log/slog" "log/slog"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
...@@ -448,17 +449,12 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) { ...@@ -448,17 +449,12 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) {
return return
} }
// Build frontend base URL from request frontendBaseURL := strings.TrimSpace(h.cfg.Server.FrontendURL)
scheme := "https" if frontendBaseURL == "" {
if c.Request.TLS == nil { slog.Error("server.frontend_url not configured; cannot build password reset link")
// Check X-Forwarded-Proto header (common in reverse proxy setups) response.InternalError(c, "Password reset is not configured")
if proto := c.GetHeader("X-Forwarded-Proto"); proto != "" { return
scheme = proto
} else {
scheme = "http"
}
} }
frontendBaseURL := scheme + "://" + c.Request.Host
// Request password reset (async) // Request password reset (async)
// Note: This returns success even if email doesn't exist (to prevent enumeration) // Note: This returns success even if email doesn't exist (to prevent enumeration)
......
...@@ -236,7 +236,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -236,7 +236,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制 selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制
if err != nil { if err != nil {
if len(failedAccountIDs) == 0 { if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) log.Printf("[Gateway] SelectAccount failed: %v", err)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return return
} }
if lastFailoverErr != nil { if lastFailoverErr != nil {
...@@ -284,12 +285,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -284,12 +285,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if err == nil && canWait { if err == nil && canWait {
accountWaitCounted = true accountWaitCounted = true
} }
// Ensure the wait counter is decremented if we exit before acquiring the slot. releaseWait := func() {
defer func() {
if accountWaitCounted { if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
} }
}()
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c, c,
...@@ -301,14 +302,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -301,14 +302,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
) )
if err != nil { if err != nil {
log.Printf("Account concurrency acquire failed: %v", err) log.Printf("Account concurrency acquire failed: %v", err)
releaseWait()
h.handleConcurrencyError(c, err, "account", streamStarted) h.handleConcurrencyError(c, err, "account", streamStarted)
return return
} }
// Slot acquired: no longer waiting in queue. // Slot acquired: no longer waiting in queue.
if accountWaitCounted { releaseWait()
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil { if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err) log.Printf("Bind sticky session failed: %v", err)
} }
...@@ -398,7 +397,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -398,7 +397,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID) selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID)
if err != nil { if err != nil {
if len(failedAccountIDs) == 0 { if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) log.Printf("[Gateway] SelectAccount failed: %v", err)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return return
} }
if lastFailoverErr != nil { if lastFailoverErr != nil {
...@@ -446,11 +446,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -446,11 +446,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if err == nil && canWait { if err == nil && canWait {
accountWaitCounted = true accountWaitCounted = true
} }
defer func() { releaseWait := func() {
if accountWaitCounted { if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
} }
}()
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c, c,
...@@ -462,13 +463,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -462,13 +463,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
) )
if err != nil { if err != nil {
log.Printf("Account concurrency acquire failed: %v", err) log.Printf("Account concurrency acquire failed: %v", err)
releaseWait()
h.handleConcurrencyError(c, err, "account", streamStarted) h.handleConcurrencyError(c, err, "account", streamStarted)
return return
} }
if accountWaitCounted { // Slot acquired: no longer waiting in queue.
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) releaseWait()
accountWaitCounted = false
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil { if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err) log.Printf("Bind sticky session failed: %v", err)
} }
...@@ -967,7 +967,8 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { ...@@ -967,7 +967,8 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 选择支持该模型的账号 // 选择支持该模型的账号
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model) account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model)
if err != nil { if err != nil {
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error()) log.Printf("[Gateway] SelectAccountForModel failed: %v", err)
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable")
return return
} }
setOpsSelectedAccount(c, account.ID) setOpsSelectedAccount(c, account.ID)
...@@ -1238,7 +1239,8 @@ func billingErrorDetails(err error) (status int, code, message string) { ...@@ -1238,7 +1239,8 @@ func billingErrorDetails(err error) (status int, code, message string) {
} }
msg := pkgerrors.Message(err) msg := pkgerrors.Message(err)
if msg == "" { if msg == "" {
msg = err.Error() log.Printf("[Gateway] billing error details: %v", err)
msg = "Billing error"
} }
return http.StatusForbidden, "billing_error", msg return http.StatusForbidden, "billing_error", msg
} }
...@@ -28,6 +28,7 @@ type OpenAIGatewayHandler struct { ...@@ -28,6 +28,7 @@ type OpenAIGatewayHandler struct {
errorPassthroughService *service.ErrorPassthroughService errorPassthroughService *service.ErrorPassthroughService
concurrencyHelper *ConcurrencyHelper concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int maxAccountSwitches int
cfg *config.Config
} }
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler // NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
...@@ -54,6 +55,7 @@ func NewOpenAIGatewayHandler( ...@@ -54,6 +55,7 @@ func NewOpenAIGatewayHandler(
errorPassthroughService: errorPassthroughService, errorPassthroughService: errorPassthroughService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
maxAccountSwitches: maxAccountSwitches, maxAccountSwitches: maxAccountSwitches,
cfg: cfg,
} }
} }
...@@ -109,7 +111,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -109,7 +111,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
} }
userAgent := c.GetHeader("User-Agent") userAgent := c.GetHeader("User-Agent")
if !openai.IsCodexCLIRequest(userAgent) { isCodexCLI := openai.IsCodexCLIRequest(userAgent) || (h.cfg != nil && h.cfg.Gateway.ForceCodexCLI)
if !isCodexCLI {
existingInstructions, _ := reqBody["instructions"].(string) existingInstructions, _ := reqBody["instructions"].(string)
if strings.TrimSpace(existingInstructions) == "" { if strings.TrimSpace(existingInstructions) == "" {
if instructions := strings.TrimSpace(service.GetOpenCodeInstructions()); instructions != "" { if instructions := strings.TrimSpace(service.GetOpenCodeInstructions()); instructions != "" {
...@@ -218,7 +221,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -218,7 +221,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
if err != nil { if err != nil {
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err) log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
if len(failedAccountIDs) == 0 { if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) log.Printf("[OpenAI Gateway] SelectAccount failed: %v", err)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return return
} }
if lastFailoverErr != nil { if lastFailoverErr != nil {
...@@ -251,11 +255,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -251,11 +255,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
if err == nil && canWait { if err == nil && canWait {
accountWaitCounted = true accountWaitCounted = true
} }
defer func() { releaseWait := func() {
if accountWaitCounted { if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
} }
}()
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c, c,
...@@ -267,13 +272,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -267,13 +272,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
) )
if err != nil { if err != nil {
log.Printf("Account concurrency acquire failed: %v", err) log.Printf("Account concurrency acquire failed: %v", err)
releaseWait()
h.handleConcurrencyError(c, err, "account", streamStarted) h.handleConcurrencyError(c, err, "account", streamStarted)
return return
} }
if accountWaitCounted { // Slot acquired: no longer waiting in queue.
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) releaseWait()
accountWaitCounted = false
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil { if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err) log.Printf("Bind sticky session failed: %v", err)
} }
......
...@@ -392,7 +392,7 @@ func (h *UsageHandler) DashboardAPIKeysUsage(c *gin.Context) { ...@@ -392,7 +392,7 @@ func (h *UsageHandler) DashboardAPIKeysUsage(c *gin.Context) {
return return
} }
stats, err := h.usageService.GetBatchAPIKeyUsageStats(c.Request.Context(), validAPIKeyIDs) stats, err := h.usageService.GetBatchAPIKeyUsageStats(c.Request.Context(), validAPIKeyIDs, time.Time{}, time.Time{})
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
......
package antigravity package antigravity
import ( import (
"crypto/rand"
"encoding/json" "encoding/json"
"fmt" "fmt"
"log" "log"
...@@ -341,12 +342,16 @@ func buildGroundingText(grounding *GeminiGroundingMetadata) string { ...@@ -341,12 +342,16 @@ func buildGroundingText(grounding *GeminiGroundingMetadata) string {
return builder.String() return builder.String()
} }
// generateRandomID 生成随机 ID // generateRandomID 生成密码学安全的随机 ID
func generateRandomID() string { func generateRandomID() string {
const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
result := make([]byte, 12) result := make([]byte, 12)
for i := range result { randBytes := make([]byte, 12)
result[i] = chars[i%len(chars)] if _, err := rand.Read(randBytes); err != nil {
panic("crypto/rand unavailable: " + err.Error())
}
for i, b := range randBytes {
result[i] = chars[int(b)%len(chars)]
} }
return string(result) return string(result)
} }
//go:build unit
package antigravity
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestGenerateRandomID_Uniqueness(t *testing.T) {
seen := make(map[string]struct{}, 100)
for i := 0; i < 100; i++ {
id := generateRandomID()
require.Len(t, id, 12, "ID 长度应为 12")
_, dup := seen[id]
require.False(t, dup, "第 %d 次调用生成了重复 ID: %s", i, id)
seen[id] = struct{}{}
}
}
func TestGenerateRandomID_Charset(t *testing.T) {
const validChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
validSet := make(map[byte]struct{}, len(validChars))
for i := 0; i < len(validChars); i++ {
validSet[validChars[i]] = struct{}{}
}
for i := 0; i < 50; i++ {
id := generateRandomID()
for j := 0; j < len(id); j++ {
_, ok := validSet[id[j]]
require.True(t, ok, "ID 包含非法字符: %c (ID=%s)", id[j], id)
}
}
}
...@@ -54,29 +54,34 @@ func normalizeIP(ip string) string { ...@@ -54,29 +54,34 @@ func normalizeIP(ip string) string {
return ip return ip
} }
// isPrivateIP 检查 IP 是否为私有地址。 // privateNets 预编译私有 IP CIDR 块,避免每次调用 isPrivateIP 时重复解析
func isPrivateIP(ipStr string) bool { var privateNets []*net.IPNet
ip := net.ParseIP(ipStr)
if ip == nil {
return false
}
// 私有 IP 范围 func init() {
privateBlocks := []string{ for _, cidr := range []string{
"10.0.0.0/8", "10.0.0.0/8",
"172.16.0.0/12", "172.16.0.0/12",
"192.168.0.0/16", "192.168.0.0/16",
"127.0.0.0/8", "127.0.0.0/8",
"::1/128", "::1/128",
"fc00::/7", "fc00::/7",
} {
_, block, err := net.ParseCIDR(cidr)
if err != nil {
panic("invalid CIDR: " + cidr)
}
privateNets = append(privateNets, block)
} }
}
for _, block := range privateBlocks { // isPrivateIP 检查 IP 是否为私有地址。
_, cidr, err := net.ParseCIDR(block) func isPrivateIP(ipStr string) bool {
if err != nil { ip := net.ParseIP(ipStr)
continue if ip == nil {
return false
} }
if cidr.Contains(ip) { for _, block := range privateNets {
if block.Contains(ip) {
return true return true
} }
} }
......
//go:build unit
package ip
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestIsPrivateIP(t *testing.T) {
tests := []struct {
name string
ip string
expected bool
}{
// 私有 IPv4
{"10.x 私有地址", "10.0.0.1", true},
{"10.x 私有地址段末", "10.255.255.255", true},
{"172.16.x 私有地址", "172.16.0.1", true},
{"172.31.x 私有地址", "172.31.255.255", true},
{"192.168.x 私有地址", "192.168.1.1", true},
{"127.0.0.1 本地回环", "127.0.0.1", true},
{"127.x 回环段", "127.255.255.255", true},
// 公网 IPv4
{"8.8.8.8 公网 DNS", "8.8.8.8", false},
{"1.1.1.1 公网", "1.1.1.1", false},
{"172.15.255.255 非私有", "172.15.255.255", false},
{"172.32.0.0 非私有", "172.32.0.0", false},
{"11.0.0.1 公网", "11.0.0.1", false},
// IPv6
{"::1 IPv6 回环", "::1", true},
{"fc00:: IPv6 私有", "fc00::1", true},
{"fd00:: IPv6 私有", "fd00::1", true},
{"2001:db8::1 IPv6 公网", "2001:db8::1", false},
// 无效输入
{"空字符串", "", false},
{"非法字符串", "not-an-ip", false},
{"不完整 IP", "192.168", false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := isPrivateIP(tc.ip)
require.Equal(t, tc.expected, got, "isPrivateIP(%q)", tc.ip)
})
}
}
...@@ -50,6 +50,7 @@ type OAuthSession struct { ...@@ -50,6 +50,7 @@ type OAuthSession struct {
type SessionStore struct { type SessionStore struct {
mu sync.RWMutex mu sync.RWMutex
sessions map[string]*OAuthSession sessions map[string]*OAuthSession
stopOnce sync.Once
stopCh chan struct{} stopCh chan struct{}
} }
...@@ -65,7 +66,9 @@ func NewSessionStore() *SessionStore { ...@@ -65,7 +66,9 @@ func NewSessionStore() *SessionStore {
// Stop stops the cleanup goroutine // Stop stops the cleanup goroutine
func (s *SessionStore) Stop() { func (s *SessionStore) Stop() {
s.stopOnce.Do(func() {
close(s.stopCh) close(s.stopCh)
})
} }
// Set stores a session // Set stores a session
......
package oauth
import (
"sync"
"testing"
"time"
)
func TestSessionStore_Stop_Idempotent(t *testing.T) {
store := NewSessionStore()
store.Stop()
store.Stop()
select {
case <-store.stopCh:
// ok
case <-time.After(time.Second):
t.Fatal("stopCh 未关闭")
}
}
func TestSessionStore_Stop_Concurrent(t *testing.T) {
store := NewSessionStore()
var wg sync.WaitGroup
for range 50 {
wg.Add(1)
go func() {
defer wg.Done()
store.Stop()
}()
}
wg.Wait()
select {
case <-store.stopCh:
// ok
case <-time.After(time.Second):
t.Fatal("stopCh 未关闭")
}
}
...@@ -47,6 +47,7 @@ type OAuthSession struct { ...@@ -47,6 +47,7 @@ type OAuthSession struct {
type SessionStore struct { type SessionStore struct {
mu sync.RWMutex mu sync.RWMutex
sessions map[string]*OAuthSession sessions map[string]*OAuthSession
stopOnce sync.Once
stopCh chan struct{} stopCh chan struct{}
} }
...@@ -92,7 +93,9 @@ func (s *SessionStore) Delete(sessionID string) { ...@@ -92,7 +93,9 @@ func (s *SessionStore) Delete(sessionID string) {
// Stop stops the cleanup goroutine // Stop stops the cleanup goroutine
func (s *SessionStore) Stop() { func (s *SessionStore) Stop() {
s.stopOnce.Do(func() {
close(s.stopCh) close(s.stopCh)
})
} }
// cleanup removes expired sessions periodically // cleanup removes expired sessions periodically
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment