Unverified Commit 7abec188 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #712 from DaydreamCoding/feat/proxy-failfast-proxyurl

feat(proxy): 集中代理 URL 验证并实现全局 fail-fast
parents 445bfdf2 fdcbf7aa
......@@ -66,7 +66,6 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
ProxyURL: proxyURL,
Timeout: defaultProxyProbeTimeout,
InsecureSkipVerify: s.insecureSkipVerify,
ProxyStrict: true,
ValidateResolvedIP: s.validateResolvedIP,
AllowPrivateHosts: s.allowPrivateHosts,
})
......
......@@ -6,6 +6,8 @@ import (
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
"github.com/imroc/req/v3"
)
......@@ -33,11 +35,11 @@ var sharedReqClients sync.Map
// getSharedReqClient 获取共享的 req 客户端实例
// 性能优化:相同配置复用同一客户端,避免重复创建
func getSharedReqClient(opts reqClientOptions) *req.Client {
func getSharedReqClient(opts reqClientOptions) (*req.Client, error) {
key := buildReqClientKey(opts)
if cached, ok := sharedReqClients.Load(key); ok {
if c, ok := cached.(*req.Client); ok {
return c
return c, nil
}
}
......@@ -48,15 +50,19 @@ func getSharedReqClient(opts reqClientOptions) *req.Client {
if opts.Impersonate {
client = client.ImpersonateChrome()
}
if strings.TrimSpace(opts.ProxyURL) != "" {
client.SetProxyURL(strings.TrimSpace(opts.ProxyURL))
trimmed, _, err := proxyurl.Parse(opts.ProxyURL)
if err != nil {
return nil, err
}
if trimmed != "" {
client.SetProxyURL(trimmed)
}
actual, _ := sharedReqClients.LoadOrStore(key, client)
if c, ok := actual.(*req.Client); ok {
return c
return c, nil
}
return client
return client, nil
}
func buildReqClientKey(opts reqClientOptions) string {
......
......@@ -26,11 +26,13 @@ func TestGetSharedReqClient_ForceHTTP2SeparatesCache(t *testing.T) {
ProxyURL: "http://proxy.local:8080",
Timeout: time.Second,
}
clientDefault := getSharedReqClient(base)
clientDefault, err := getSharedReqClient(base)
require.NoError(t, err)
force := base
force.ForceHTTP2 = true
clientForce := getSharedReqClient(force)
clientForce, err := getSharedReqClient(force)
require.NoError(t, err)
require.NotSame(t, clientDefault, clientForce)
require.NotEqual(t, buildReqClientKey(base), buildReqClientKey(force))
......@@ -42,8 +44,10 @@ func TestGetSharedReqClient_ReuseCachedClient(t *testing.T) {
ProxyURL: "http://proxy.local:8080",
Timeout: 2 * time.Second,
}
first := getSharedReqClient(opts)
second := getSharedReqClient(opts)
first, err := getSharedReqClient(opts)
require.NoError(t, err)
second, err := getSharedReqClient(opts)
require.NoError(t, err)
require.Same(t, first, second)
}
......@@ -56,7 +60,8 @@ func TestGetSharedReqClient_IgnoresNonClientCache(t *testing.T) {
key := buildReqClientKey(opts)
sharedReqClients.Store(key, "invalid")
client := getSharedReqClient(opts)
client, err := getSharedReqClient(opts)
require.NoError(t, err)
require.NotNil(t, client)
loaded, ok := sharedReqClients.Load(key)
......@@ -71,20 +76,45 @@ func TestGetSharedReqClient_ImpersonateAndProxy(t *testing.T) {
Timeout: 4 * time.Second,
Impersonate: true,
}
client := getSharedReqClient(opts)
client, err := getSharedReqClient(opts)
require.NoError(t, err)
require.NotNil(t, client)
require.Equal(t, "http://proxy.local:8080|4s|true|false", buildReqClientKey(opts))
}
func TestGetSharedReqClient_InvalidProxyURL(t *testing.T) {
sharedReqClients = sync.Map{}
opts := reqClientOptions{
ProxyURL: "://missing-scheme",
Timeout: time.Second,
}
_, err := getSharedReqClient(opts)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid proxy URL")
}
func TestGetSharedReqClient_ProxyURLMissingHost(t *testing.T) {
sharedReqClients = sync.Map{}
opts := reqClientOptions{
ProxyURL: "http://",
Timeout: time.Second,
}
_, err := getSharedReqClient(opts)
require.Error(t, err)
require.Contains(t, err.Error(), "proxy URL missing host")
}
func TestCreateOpenAIReqClient_Timeout120Seconds(t *testing.T) {
sharedReqClients = sync.Map{}
client := createOpenAIReqClient("http://proxy.local:8080")
client, err := createOpenAIReqClient("http://proxy.local:8080")
require.NoError(t, err)
require.Equal(t, 120*time.Second, client.GetClient().Timeout)
}
func TestCreateGeminiReqClient_ForceHTTP2Disabled(t *testing.T) {
sharedReqClients = sync.Map{}
client := createGeminiReqClient("http://proxy.local:8080")
client, err := createGeminiReqClient("http://proxy.local:8080")
require.NoError(t, err)
require.Equal(t, "", forceHTTPVersion(t, client))
}
......@@ -34,7 +34,7 @@ func ProvideGitHubReleaseClient(cfg *config.Config) service.GitHubReleaseClient
// ProvidePricingRemoteClient 创建定价数据远程客户端
// 从配置中读取代理设置,支持国内服务器通过代理访问 GitHub 上的定价数据
func ProvidePricingRemoteClient(cfg *config.Config) service.PricingRemoteClient {
return NewPricingRemoteClient(cfg.Update.ProxyURL)
return NewPricingRemoteClient(cfg.Update.ProxyURL, cfg.Security.ProxyFallback.AllowDirectOnError)
}
// ProvideSessionLimitCache 创建会话限制缓存
......
......@@ -2028,7 +2028,6 @@ func (s *adminServiceImpl) CheckProxyQuality(ctx context.Context, id int64) (*Pr
ProxyURL: proxyURL,
Timeout: proxyQualityRequestTimeout,
ResponseHeaderTimeout: proxyQualityResponseHeaderTimeout,
ProxyStrict: true,
})
if err != nil {
result.Items = append(result.Items, ProxyQualityCheckItem{
......
......@@ -112,7 +112,10 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
}
}
client := antigravity.NewClient(proxyURL)
client, err := antigravity.NewClient(proxyURL)
if err != nil {
return nil, fmt.Errorf("create antigravity client failed: %w", err)
}
// 交换 token
tokenResp, err := client.ExchangeCode(ctx, input.Code, session.CodeVerifier)
......@@ -167,7 +170,10 @@ func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken
time.Sleep(backoff)
}
client := antigravity.NewClient(proxyURL)
client, err := antigravity.NewClient(proxyURL)
if err != nil {
return nil, fmt.Errorf("create antigravity client failed: %w", err)
}
tokenResp, err := client.RefreshToken(ctx, refreshToken)
if err == nil {
now := time.Now()
......@@ -209,7 +215,10 @@ func (s *AntigravityOAuthService) ValidateRefreshToken(ctx context.Context, refr
}
// 获取用户信息(email)
client := antigravity.NewClient(proxyURL)
client, err := antigravity.NewClient(proxyURL)
if err != nil {
return nil, fmt.Errorf("create antigravity client failed: %w", err)
}
userInfo, err := client.GetUserInfo(ctx, tokenInfo.AccessToken)
if err != nil {
fmt.Printf("[AntigravityOAuth] 警告: 获取用户信息失败: %v\n", err)
......@@ -309,7 +318,10 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac
time.Sleep(backoff)
}
client := antigravity.NewClient(proxyURL)
client, err := antigravity.NewClient(proxyURL)
if err != nil {
return "", fmt.Errorf("create antigravity client failed: %w", err)
}
loadResp, loadRaw, err := client.LoadCodeAssist(ctx, accessToken)
if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" {
......
......@@ -2,6 +2,7 @@ package service
import (
"context"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
......@@ -31,7 +32,10 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
accessToken := account.GetCredential("access_token")
projectID := account.GetCredential("project_id")
client := antigravity.NewClient(proxyURL)
client, err := antigravity.NewClient(proxyURL)
if err != nil {
return nil, fmt.Errorf("create antigravity client failed: %w", err)
}
// 调用 API 获取配额
modelsResp, modelsRaw, err := client.FetchAvailableModels(ctx, accessToken, projectID)
......
......@@ -221,7 +221,7 @@ func (s *CRSSyncService) fetchCRSExport(ctx context.Context, baseURL, username,
AllowPrivateHosts: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
client = &http.Client{Timeout: 20 * time.Second}
return nil, fmt.Errorf("create http client failed: %w", err)
}
adminToken, err := crsLogin(ctx, client, normalizedURL, username, password)
......
......@@ -1045,7 +1045,7 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR
ValidateResolvedIP: true,
})
if err != nil {
client = &http.Client{Timeout: 30 * time.Second}
return "", fmt.Errorf("create http client failed: %w", err)
}
resp, err := client.Do(req)
......
......@@ -7,7 +7,6 @@ import (
"io"
"log/slog"
"net/http"
"net/url"
"regexp"
"sort"
"strconv"
......@@ -15,6 +14,7 @@ import (
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
)
......@@ -273,7 +273,13 @@ func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessi
req.Header.Set("Referer", "https://sora.chatgpt.com/")
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
client := newOpenAIOAuthHTTPClient(proxyURL)
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL,
Timeout: 120 * time.Second,
})
if err != nil {
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_CLIENT_FAILED", "create http client failed: %v", err)
}
resp, err := client.Do(req)
if err != nil {
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_REQUEST_FAILED", "request failed: %v", err)
......@@ -530,19 +536,6 @@ func (s *OpenAIOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64
return proxy.URL(), nil
}
func newOpenAIOAuthHTTPClient(proxyURL string) *http.Client {
transport := &http.Transport{}
if strings.TrimSpace(proxyURL) != "" {
if parsed, err := url.Parse(proxyURL); err == nil && parsed.Host != "" {
transport.Proxy = http.ProxyURL(parsed)
}
}
return &http.Client{
Timeout: 120 * time.Second,
Transport: transport,
}
}
func normalizeOpenAIOAuthPlatform(platform string) string {
switch strings.ToLower(strings.TrimSpace(platform)) {
case PlatformSora:
......
......@@ -134,6 +134,12 @@ security:
# Allow skipping TLS verification for proxy probe (debug only)
# 允许代理探测时跳过 TLS 证书验证(仅用于调试)
insecure_skip_verify: false
proxy_fallback:
# Allow auxiliary services (update check, pricing data) to fallback to direct
# connection when proxy initialization fails. Does NOT affect AI gateway connections.
# 辅助服务(更新检查、定价数据拉取)代理初始化失败时是否允许回退直连。
# 不影响 AI 账号网关连接。默认 false:fail-fast 防止 IP 泄露。
allow_direct_on_error: false
# =============================================================================
# Gateway Configuration
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment