Unverified Commit 7abec188 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #712 from DaydreamCoding/feat/proxy-failfast-proxyurl

feat(proxy): 集中代理 URL 验证并实现全局 fail-fast
parents 445bfdf2 fdcbf7aa
...@@ -66,7 +66,6 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s ...@@ -66,7 +66,6 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
ProxyURL: proxyURL, ProxyURL: proxyURL,
Timeout: defaultProxyProbeTimeout, Timeout: defaultProxyProbeTimeout,
InsecureSkipVerify: s.insecureSkipVerify, InsecureSkipVerify: s.insecureSkipVerify,
ProxyStrict: true,
ValidateResolvedIP: s.validateResolvedIP, ValidateResolvedIP: s.validateResolvedIP,
AllowPrivateHosts: s.allowPrivateHosts, AllowPrivateHosts: s.allowPrivateHosts,
}) })
......
...@@ -6,6 +6,8 @@ import ( ...@@ -6,6 +6,8 @@ import (
"sync" "sync"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
"github.com/imroc/req/v3" "github.com/imroc/req/v3"
) )
...@@ -33,11 +35,11 @@ var sharedReqClients sync.Map ...@@ -33,11 +35,11 @@ var sharedReqClients sync.Map
// getSharedReqClient 获取共享的 req 客户端实例 // getSharedReqClient 获取共享的 req 客户端实例
// 性能优化:相同配置复用同一客户端,避免重复创建 // 性能优化:相同配置复用同一客户端,避免重复创建
func getSharedReqClient(opts reqClientOptions) *req.Client { func getSharedReqClient(opts reqClientOptions) (*req.Client, error) {
key := buildReqClientKey(opts) key := buildReqClientKey(opts)
if cached, ok := sharedReqClients.Load(key); ok { if cached, ok := sharedReqClients.Load(key); ok {
if c, ok := cached.(*req.Client); ok { if c, ok := cached.(*req.Client); ok {
return c return c, nil
} }
} }
...@@ -48,15 +50,19 @@ func getSharedReqClient(opts reqClientOptions) *req.Client { ...@@ -48,15 +50,19 @@ func getSharedReqClient(opts reqClientOptions) *req.Client {
if opts.Impersonate { if opts.Impersonate {
client = client.ImpersonateChrome() client = client.ImpersonateChrome()
} }
if strings.TrimSpace(opts.ProxyURL) != "" { trimmed, _, err := proxyurl.Parse(opts.ProxyURL)
client.SetProxyURL(strings.TrimSpace(opts.ProxyURL)) if err != nil {
return nil, err
}
if trimmed != "" {
client.SetProxyURL(trimmed)
} }
actual, _ := sharedReqClients.LoadOrStore(key, client) actual, _ := sharedReqClients.LoadOrStore(key, client)
if c, ok := actual.(*req.Client); ok { if c, ok := actual.(*req.Client); ok {
return c return c, nil
} }
return client return client, nil
} }
func buildReqClientKey(opts reqClientOptions) string { func buildReqClientKey(opts reqClientOptions) string {
......
...@@ -26,11 +26,13 @@ func TestGetSharedReqClient_ForceHTTP2SeparatesCache(t *testing.T) { ...@@ -26,11 +26,13 @@ func TestGetSharedReqClient_ForceHTTP2SeparatesCache(t *testing.T) {
ProxyURL: "http://proxy.local:8080", ProxyURL: "http://proxy.local:8080",
Timeout: time.Second, Timeout: time.Second,
} }
clientDefault := getSharedReqClient(base) clientDefault, err := getSharedReqClient(base)
require.NoError(t, err)
force := base force := base
force.ForceHTTP2 = true force.ForceHTTP2 = true
clientForce := getSharedReqClient(force) clientForce, err := getSharedReqClient(force)
require.NoError(t, err)
require.NotSame(t, clientDefault, clientForce) require.NotSame(t, clientDefault, clientForce)
require.NotEqual(t, buildReqClientKey(base), buildReqClientKey(force)) require.NotEqual(t, buildReqClientKey(base), buildReqClientKey(force))
...@@ -42,8 +44,10 @@ func TestGetSharedReqClient_ReuseCachedClient(t *testing.T) { ...@@ -42,8 +44,10 @@ func TestGetSharedReqClient_ReuseCachedClient(t *testing.T) {
ProxyURL: "http://proxy.local:8080", ProxyURL: "http://proxy.local:8080",
Timeout: 2 * time.Second, Timeout: 2 * time.Second,
} }
first := getSharedReqClient(opts) first, err := getSharedReqClient(opts)
second := getSharedReqClient(opts) require.NoError(t, err)
second, err := getSharedReqClient(opts)
require.NoError(t, err)
require.Same(t, first, second) require.Same(t, first, second)
} }
...@@ -56,7 +60,8 @@ func TestGetSharedReqClient_IgnoresNonClientCache(t *testing.T) { ...@@ -56,7 +60,8 @@ func TestGetSharedReqClient_IgnoresNonClientCache(t *testing.T) {
key := buildReqClientKey(opts) key := buildReqClientKey(opts)
sharedReqClients.Store(key, "invalid") sharedReqClients.Store(key, "invalid")
client := getSharedReqClient(opts) client, err := getSharedReqClient(opts)
require.NoError(t, err)
require.NotNil(t, client) require.NotNil(t, client)
loaded, ok := sharedReqClients.Load(key) loaded, ok := sharedReqClients.Load(key)
...@@ -71,20 +76,45 @@ func TestGetSharedReqClient_ImpersonateAndProxy(t *testing.T) { ...@@ -71,20 +76,45 @@ func TestGetSharedReqClient_ImpersonateAndProxy(t *testing.T) {
Timeout: 4 * time.Second, Timeout: 4 * time.Second,
Impersonate: true, Impersonate: true,
} }
client := getSharedReqClient(opts) client, err := getSharedReqClient(opts)
require.NoError(t, err)
require.NotNil(t, client) require.NotNil(t, client)
require.Equal(t, "http://proxy.local:8080|4s|true|false", buildReqClientKey(opts)) require.Equal(t, "http://proxy.local:8080|4s|true|false", buildReqClientKey(opts))
} }
func TestGetSharedReqClient_InvalidProxyURL(t *testing.T) {
sharedReqClients = sync.Map{}
opts := reqClientOptions{
ProxyURL: "://missing-scheme",
Timeout: time.Second,
}
_, err := getSharedReqClient(opts)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid proxy URL")
}
func TestGetSharedReqClient_ProxyURLMissingHost(t *testing.T) {
sharedReqClients = sync.Map{}
opts := reqClientOptions{
ProxyURL: "http://",
Timeout: time.Second,
}
_, err := getSharedReqClient(opts)
require.Error(t, err)
require.Contains(t, err.Error(), "proxy URL missing host")
}
func TestCreateOpenAIReqClient_Timeout120Seconds(t *testing.T) { func TestCreateOpenAIReqClient_Timeout120Seconds(t *testing.T) {
sharedReqClients = sync.Map{} sharedReqClients = sync.Map{}
client := createOpenAIReqClient("http://proxy.local:8080") client, err := createOpenAIReqClient("http://proxy.local:8080")
require.NoError(t, err)
require.Equal(t, 120*time.Second, client.GetClient().Timeout) require.Equal(t, 120*time.Second, client.GetClient().Timeout)
} }
func TestCreateGeminiReqClient_ForceHTTP2Disabled(t *testing.T) { func TestCreateGeminiReqClient_ForceHTTP2Disabled(t *testing.T) {
sharedReqClients = sync.Map{} sharedReqClients = sync.Map{}
client := createGeminiReqClient("http://proxy.local:8080") client, err := createGeminiReqClient("http://proxy.local:8080")
require.NoError(t, err)
require.Equal(t, "", forceHTTPVersion(t, client)) require.Equal(t, "", forceHTTPVersion(t, client))
} }
...@@ -34,7 +34,7 @@ func ProvideGitHubReleaseClient(cfg *config.Config) service.GitHubReleaseClient ...@@ -34,7 +34,7 @@ func ProvideGitHubReleaseClient(cfg *config.Config) service.GitHubReleaseClient
// ProvidePricingRemoteClient 创建定价数据远程客户端 // ProvidePricingRemoteClient 创建定价数据远程客户端
// 从配置中读取代理设置,支持国内服务器通过代理访问 GitHub 上的定价数据 // 从配置中读取代理设置,支持国内服务器通过代理访问 GitHub 上的定价数据
func ProvidePricingRemoteClient(cfg *config.Config) service.PricingRemoteClient { func ProvidePricingRemoteClient(cfg *config.Config) service.PricingRemoteClient {
return NewPricingRemoteClient(cfg.Update.ProxyURL) return NewPricingRemoteClient(cfg.Update.ProxyURL, cfg.Security.ProxyFallback.AllowDirectOnError)
} }
// ProvideSessionLimitCache 创建会话限制缓存 // ProvideSessionLimitCache 创建会话限制缓存
......
...@@ -2028,7 +2028,6 @@ func (s *adminServiceImpl) CheckProxyQuality(ctx context.Context, id int64) (*Pr ...@@ -2028,7 +2028,6 @@ func (s *adminServiceImpl) CheckProxyQuality(ctx context.Context, id int64) (*Pr
ProxyURL: proxyURL, ProxyURL: proxyURL,
Timeout: proxyQualityRequestTimeout, Timeout: proxyQualityRequestTimeout,
ResponseHeaderTimeout: proxyQualityResponseHeaderTimeout, ResponseHeaderTimeout: proxyQualityResponseHeaderTimeout,
ProxyStrict: true,
}) })
if err != nil { if err != nil {
result.Items = append(result.Items, ProxyQualityCheckItem{ result.Items = append(result.Items, ProxyQualityCheckItem{
......
...@@ -112,7 +112,10 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig ...@@ -112,7 +112,10 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
} }
} }
client := antigravity.NewClient(proxyURL) client, err := antigravity.NewClient(proxyURL)
if err != nil {
return nil, fmt.Errorf("create antigravity client failed: %w", err)
}
// 交换 token // 交换 token
tokenResp, err := client.ExchangeCode(ctx, input.Code, session.CodeVerifier) tokenResp, err := client.ExchangeCode(ctx, input.Code, session.CodeVerifier)
...@@ -167,7 +170,10 @@ func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken ...@@ -167,7 +170,10 @@ func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken
time.Sleep(backoff) time.Sleep(backoff)
} }
client := antigravity.NewClient(proxyURL) client, err := antigravity.NewClient(proxyURL)
if err != nil {
return nil, fmt.Errorf("create antigravity client failed: %w", err)
}
tokenResp, err := client.RefreshToken(ctx, refreshToken) tokenResp, err := client.RefreshToken(ctx, refreshToken)
if err == nil { if err == nil {
now := time.Now() now := time.Now()
...@@ -209,7 +215,10 @@ func (s *AntigravityOAuthService) ValidateRefreshToken(ctx context.Context, refr ...@@ -209,7 +215,10 @@ func (s *AntigravityOAuthService) ValidateRefreshToken(ctx context.Context, refr
} }
// 获取用户信息(email) // 获取用户信息(email)
client := antigravity.NewClient(proxyURL) client, err := antigravity.NewClient(proxyURL)
if err != nil {
return nil, fmt.Errorf("create antigravity client failed: %w", err)
}
userInfo, err := client.GetUserInfo(ctx, tokenInfo.AccessToken) userInfo, err := client.GetUserInfo(ctx, tokenInfo.AccessToken)
if err != nil { if err != nil {
fmt.Printf("[AntigravityOAuth] 警告: 获取用户信息失败: %v\n", err) fmt.Printf("[AntigravityOAuth] 警告: 获取用户信息失败: %v\n", err)
...@@ -309,7 +318,10 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac ...@@ -309,7 +318,10 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac
time.Sleep(backoff) time.Sleep(backoff)
} }
client := antigravity.NewClient(proxyURL) client, err := antigravity.NewClient(proxyURL)
if err != nil {
return "", fmt.Errorf("create antigravity client failed: %w", err)
}
loadResp, loadRaw, err := client.LoadCodeAssist(ctx, accessToken) loadResp, loadRaw, err := client.LoadCodeAssist(ctx, accessToken)
if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" { if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" {
......
...@@ -2,6 +2,7 @@ package service ...@@ -2,6 +2,7 @@ package service
import ( import (
"context" "context"
"fmt"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
...@@ -31,7 +32,10 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou ...@@ -31,7 +32,10 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
accessToken := account.GetCredential("access_token") accessToken := account.GetCredential("access_token")
projectID := account.GetCredential("project_id") projectID := account.GetCredential("project_id")
client := antigravity.NewClient(proxyURL) client, err := antigravity.NewClient(proxyURL)
if err != nil {
return nil, fmt.Errorf("create antigravity client failed: %w", err)
}
// 调用 API 获取配额 // 调用 API 获取配额
modelsResp, modelsRaw, err := client.FetchAvailableModels(ctx, accessToken, projectID) modelsResp, modelsRaw, err := client.FetchAvailableModels(ctx, accessToken, projectID)
......
...@@ -221,7 +221,7 @@ func (s *CRSSyncService) fetchCRSExport(ctx context.Context, baseURL, username, ...@@ -221,7 +221,7 @@ func (s *CRSSyncService) fetchCRSExport(ctx context.Context, baseURL, username,
AllowPrivateHosts: s.cfg.Security.URLAllowlist.AllowPrivateHosts, AllowPrivateHosts: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
}) })
if err != nil { if err != nil {
client = &http.Client{Timeout: 20 * time.Second} return nil, fmt.Errorf("create http client failed: %w", err)
} }
adminToken, err := crsLogin(ctx, client, normalizedURL, username, password) adminToken, err := crsLogin(ctx, client, normalizedURL, username, password)
......
...@@ -1045,7 +1045,7 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR ...@@ -1045,7 +1045,7 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR
ValidateResolvedIP: true, ValidateResolvedIP: true,
}) })
if err != nil { if err != nil {
client = &http.Client{Timeout: 30 * time.Second} return "", fmt.Errorf("create http client failed: %w", err)
} }
resp, err := client.Do(req) resp, err := client.Do(req)
......
...@@ -7,7 +7,6 @@ import ( ...@@ -7,7 +7,6 @@ import (
"io" "io"
"log/slog" "log/slog"
"net/http" "net/http"
"net/url"
"regexp" "regexp"
"sort" "sort"
"strconv" "strconv"
...@@ -15,6 +14,7 @@ import ( ...@@ -15,6 +14,7 @@ import (
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
) )
...@@ -273,7 +273,13 @@ func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessi ...@@ -273,7 +273,13 @@ func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessi
req.Header.Set("Referer", "https://sora.chatgpt.com/") req.Header.Set("Referer", "https://sora.chatgpt.com/")
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
client := newOpenAIOAuthHTTPClient(proxyURL) client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL,
Timeout: 120 * time.Second,
})
if err != nil {
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_CLIENT_FAILED", "create http client failed: %v", err)
}
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_REQUEST_FAILED", "request failed: %v", err) return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_REQUEST_FAILED", "request failed: %v", err)
...@@ -530,19 +536,6 @@ func (s *OpenAIOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64 ...@@ -530,19 +536,6 @@ func (s *OpenAIOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64
return proxy.URL(), nil return proxy.URL(), nil
} }
func newOpenAIOAuthHTTPClient(proxyURL string) *http.Client {
transport := &http.Transport{}
if strings.TrimSpace(proxyURL) != "" {
if parsed, err := url.Parse(proxyURL); err == nil && parsed.Host != "" {
transport.Proxy = http.ProxyURL(parsed)
}
}
return &http.Client{
Timeout: 120 * time.Second,
Transport: transport,
}
}
func normalizeOpenAIOAuthPlatform(platform string) string { func normalizeOpenAIOAuthPlatform(platform string) string {
switch strings.ToLower(strings.TrimSpace(platform)) { switch strings.ToLower(strings.TrimSpace(platform)) {
case PlatformSora: case PlatformSora:
......
...@@ -134,6 +134,12 @@ security: ...@@ -134,6 +134,12 @@ security:
# Allow skipping TLS verification for proxy probe (debug only) # Allow skipping TLS verification for proxy probe (debug only)
# 允许代理探测时跳过 TLS 证书验证(仅用于调试) # 允许代理探测时跳过 TLS 证书验证(仅用于调试)
insecure_skip_verify: false insecure_skip_verify: false
proxy_fallback:
# Allow auxiliary services (update check, pricing data) to fallback to direct
# connection when proxy initialization fails. Does NOT affect AI gateway connections.
# 辅助服务(更新检查、定价数据拉取)代理初始化失败时是否允许回退直连。
# 不影响 AI 账号网关连接。默认 false:fail-fast 防止 IP 泄露。
allow_direct_on_error: false
# ============================================================================= # =============================================================================
# Gateway Configuration # Gateway Configuration
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment