Commit 46dda583 authored by shaw's avatar shaw
Browse files

Merge PR #146: feat: 提升流式网关稳定性与安全策略强化

parents 27ed042c f8e7255c
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
) )
...@@ -19,7 +20,13 @@ type PricingServiceSuite struct { ...@@ -19,7 +20,13 @@ type PricingServiceSuite struct {
func (s *PricingServiceSuite) SetupTest() { func (s *PricingServiceSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
client, ok := NewPricingRemoteClient().(*pricingRemoteClient) client, ok := NewPricingRemoteClient(&config.Config{
Security: config.SecurityConfig{
URLAllowlist: config.URLAllowlistConfig{
AllowPrivateHosts: true,
},
},
}).(*pricingRemoteClient)
require.True(s.T(), ok, "type assertion failed") require.True(s.T(), ok, "type assertion failed")
s.client = client s.client = client
} }
......
...@@ -5,28 +5,48 @@ import ( ...@@ -5,28 +5,48 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"log"
"net/http" "net/http"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
) )
func NewProxyExitInfoProber() service.ProxyExitInfoProber { func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
return &proxyProbeService{ipInfoURL: defaultIPInfoURL} insecure := false
allowPrivate := false
if cfg != nil {
insecure = cfg.Security.ProxyProbe.InsecureSkipVerify
allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts
}
if insecure {
log.Printf("[ProxyProbe] Warning: TLS verification is disabled for proxy probing.")
}
return &proxyProbeService{
ipInfoURL: defaultIPInfoURL,
insecureSkipVerify: insecure,
allowPrivateHosts: allowPrivate,
}
} }
const defaultIPInfoURL = "https://ipinfo.io/json" const defaultIPInfoURL = "https://ipinfo.io/json"
type proxyProbeService struct { type proxyProbeService struct {
ipInfoURL string ipInfoURL string
insecureSkipVerify bool
allowPrivateHosts bool
} }
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) { func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
client, err := httpclient.GetClient(httpclient.Options{ client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL, ProxyURL: proxyURL,
Timeout: 15 * time.Second, Timeout: 15 * time.Second,
InsecureSkipVerify: true, InsecureSkipVerify: s.insecureSkipVerify,
ProxyStrict: true,
ValidateResolvedIP: true,
AllowPrivateHosts: s.allowPrivateHosts,
}) })
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("failed to create proxy client: %w", err) return nil, 0, fmt.Errorf("failed to create proxy client: %w", err)
......
...@@ -20,7 +20,10 @@ type ProxyProbeServiceSuite struct { ...@@ -20,7 +20,10 @@ type ProxyProbeServiceSuite struct {
func (s *ProxyProbeServiceSuite) SetupTest() { func (s *ProxyProbeServiceSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
s.prober = &proxyProbeService{ipInfoURL: "http://ipinfo.test/json"} s.prober = &proxyProbeService{
ipInfoURL: "http://ipinfo.test/json",
allowPrivateHosts: true,
}
} }
func (s *ProxyProbeServiceSuite) TearDownTest() { func (s *ProxyProbeServiceSuite) TearDownTest() {
......
...@@ -23,6 +23,7 @@ type turnstileVerifier struct { ...@@ -23,6 +23,7 @@ type turnstileVerifier struct {
func NewTurnstileVerifier() service.TurnstileVerifier { func NewTurnstileVerifier() service.TurnstileVerifier {
sharedClient, err := httpclient.GetClient(httpclient.Options{ sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 10 * time.Second, Timeout: 10 * time.Second,
ValidateResolvedIP: true,
}) })
if err != nil { if err != nil {
sharedClient = &http.Client{Timeout: 10 * time.Second} sharedClient = &http.Client{Timeout: 10 * time.Second}
......
...@@ -294,13 +294,13 @@ func TestAPIContracts(t *testing.T) { ...@@ -294,13 +294,13 @@ func TestAPIContracts(t *testing.T) {
"smtp_host": "smtp.example.com", "smtp_host": "smtp.example.com",
"smtp_port": 587, "smtp_port": 587,
"smtp_username": "user", "smtp_username": "user",
"smtp_password": "secret", "smtp_password_configured": true,
"smtp_from_email": "no-reply@example.com", "smtp_from_email": "no-reply@example.com",
"smtp_from_name": "Sub2API", "smtp_from_name": "Sub2API",
"smtp_use_tls": true, "smtp_use_tls": true,
"turnstile_enabled": true, "turnstile_enabled": true,
"turnstile_site_key": "site-key", "turnstile_site_key": "site-key",
"turnstile_secret_key": "secret-key", "turnstile_secret_key_configured": true,
"site_name": "Sub2API", "site_name": "Sub2API",
"site_logo": "", "site_logo": "",
"site_subtitle": "Subtitle", "site_subtitle": "Subtitle",
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
package server package server
import ( import (
"log"
"net/http" "net/http"
"time" "time"
...@@ -36,6 +37,15 @@ func ProvideRouter( ...@@ -36,6 +37,15 @@ func ProvideRouter(
r := gin.New() r := gin.New()
r.Use(middleware2.Recovery()) r.Use(middleware2.Recovery())
if len(cfg.Server.TrustedProxies) > 0 {
if err := r.SetTrustedProxies(cfg.Server.TrustedProxies); err != nil {
log.Printf("Failed to set trusted proxies: %v", err)
}
} else {
if err := r.SetTrustedProxies(nil); err != nil {
log.Printf("Failed to disable trusted proxies: %v", err)
}
}
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg) return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg)
} }
......
...@@ -19,6 +19,13 @@ func NewAPIKeyAuthMiddleware(apiKeyService *service.APIKeyService, subscriptionS ...@@ -19,6 +19,13 @@ func NewAPIKeyAuthMiddleware(apiKeyService *service.APIKeyService, subscriptionS
// apiKeyAuthWithSubscription API Key认证中间件(支持订阅验证) // apiKeyAuthWithSubscription API Key认证中间件(支持订阅验证)
func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc { func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
queryKey := strings.TrimSpace(c.Query("key"))
queryApiKey := strings.TrimSpace(c.Query("api_key"))
if queryKey != "" || queryApiKey != "" {
AbortWithError(c, 400, "api_key_in_query_deprecated", "API key in query parameter is deprecated. Please use Authorization header instead.")
return
}
// 尝试从Authorization header中提取API key (Bearer scheme) // 尝试从Authorization header中提取API key (Bearer scheme)
authHeader := c.GetHeader("Authorization") authHeader := c.GetHeader("Authorization")
var apiKeyString string var apiKeyString string
...@@ -41,19 +48,9 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti ...@@ -41,19 +48,9 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
apiKeyString = c.GetHeader("x-goog-api-key") apiKeyString = c.GetHeader("x-goog-api-key")
} }
// 如果header中没有,尝试从query参数中提取(Google API key风格)
if apiKeyString == "" {
apiKeyString = c.Query("key")
}
// 兼容常见别名
if apiKeyString == "" {
apiKeyString = c.Query("api_key")
}
// 如果所有header都没有API key // 如果所有header都没有API key
if apiKeyString == "" { if apiKeyString == "" {
AbortWithError(c, 401, "API_KEY_REQUIRED", "API key is required in Authorization header (Bearer scheme), x-api-key header, x-goog-api-key header, or key/api_key query parameter") AbortWithError(c, 401, "API_KEY_REQUIRED", "API key is required in Authorization header (Bearer scheme), x-api-key header, or x-goog-api-key header")
return return
} }
......
...@@ -22,6 +22,10 @@ func APIKeyAuthGoogle(apiKeyService *service.APIKeyService, cfg *config.Config) ...@@ -22,6 +22,10 @@ func APIKeyAuthGoogle(apiKeyService *service.APIKeyService, cfg *config.Config)
// It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations. // It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations.
func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc { func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
if v := strings.TrimSpace(c.Query("api_key")); v != "" {
abortWithGoogleError(c, 400, "Query parameter api_key is deprecated. Use Authorization header or key instead.")
return
}
apiKeyString := extractAPIKeyFromRequest(c) apiKeyString := extractAPIKeyFromRequest(c)
if apiKeyString == "" { if apiKeyString == "" {
abortWithGoogleError(c, 401, "API key is required") abortWithGoogleError(c, 401, "API key is required")
...@@ -116,15 +120,18 @@ func extractAPIKeyFromRequest(c *gin.Context) string { ...@@ -116,15 +120,18 @@ func extractAPIKeyFromRequest(c *gin.Context) string {
if v := strings.TrimSpace(c.GetHeader("x-goog-api-key")); v != "" { if v := strings.TrimSpace(c.GetHeader("x-goog-api-key")); v != "" {
return v return v
} }
if allowGoogleQueryKey(c.Request.URL.Path) {
if v := strings.TrimSpace(c.Query("key")); v != "" { if v := strings.TrimSpace(c.Query("key")); v != "" {
return v return v
} }
if v := strings.TrimSpace(c.Query("api_key")); v != "" {
return v
} }
return "" return ""
} }
func allowGoogleQueryKey(path string) bool {
return strings.HasPrefix(path, "/v1beta") || strings.HasPrefix(path, "/antigravity/v1beta")
}
func abortWithGoogleError(c *gin.Context, status int, message string) { func abortWithGoogleError(c *gin.Context, status int, message string) {
c.JSON(status, gin.H{ c.JSON(status, gin.H{
"error": gin.H{ "error": gin.H{
......
...@@ -109,6 +109,58 @@ func TestApiKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) { ...@@ -109,6 +109,58 @@ func TestApiKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) {
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status) require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
} }
func TestApiKeyAuthWithSubscriptionGoogle_QueryApiKeyRejected(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return nil, errors.New("should not be called")
},
})
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test?api_key=legacy", nil)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
var resp googleErrorResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, http.StatusBadRequest, resp.Error.Code)
require.Equal(t, "Query parameter api_key is deprecated. Use Authorization header or key instead.", resp.Error.Message)
require.Equal(t, "INVALID_ARGUMENT", resp.Error.Status)
}
func TestApiKeyAuthWithSubscriptionGoogle_QueryKeyAllowedOnV1Beta(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return &service.APIKey{
ID: 1,
Key: key,
Status: service.StatusActive,
User: &service.User{
ID: 123,
Status: service.StatusActive,
},
}, nil
},
})
cfg := &config.Config{RunMode: config.RunModeSimple}
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test?key=valid", nil)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
}
func TestApiKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) { func TestApiKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
......
package middleware package middleware
import ( import (
"log"
"net/http"
"strings"
"sync"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
var corsWarningOnce sync.Once
// CORS 跨域中间件 // CORS 跨域中间件
func CORS() gin.HandlerFunc { func CORS(cfg config.CORSConfig) gin.HandlerFunc {
allowedOrigins := normalizeOrigins(cfg.AllowedOrigins)
allowAll := false
for _, origin := range allowedOrigins {
if origin == "*" {
allowAll = true
break
}
}
wildcardWithSpecific := allowAll && len(allowedOrigins) > 1
if wildcardWithSpecific {
allowedOrigins = []string{"*"}
}
allowCredentials := cfg.AllowCredentials
corsWarningOnce.Do(func() {
if len(allowedOrigins) == 0 {
log.Println("Warning: CORS allowed_origins not configured; cross-origin requests will be rejected.")
}
if wildcardWithSpecific {
log.Println("Warning: CORS allowed_origins includes '*'; wildcard will take precedence over explicit origins.")
}
if allowAll && allowCredentials {
log.Println("Warning: CORS allowed_origins set to '*', disabling allow_credentials.")
}
})
if allowAll && allowCredentials {
allowCredentials = false
}
allowedSet := make(map[string]struct{}, len(allowedOrigins))
for _, origin := range allowedOrigins {
if origin == "" || origin == "*" {
continue
}
allowedSet[origin] = struct{}{}
}
return func(c *gin.Context) { return func(c *gin.Context) {
// 设置允许跨域的响应头 origin := strings.TrimSpace(c.GetHeader("Origin"))
originAllowed := allowAll
if origin != "" && !allowAll {
_, originAllowed = allowedSet[origin]
}
if originAllowed {
if allowAll {
c.Writer.Header().Set("Access-Control-Allow-Origin", "*") c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
} else if origin != "" {
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
c.Writer.Header().Add("Vary", "Origin")
}
if allowCredentials {
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true") c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
}
}
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key") c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key")
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH") c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")
// 处理预检请求 // 处理预检请求
if c.Request.Method == "OPTIONS" { if c.Request.Method == http.MethodOptions {
c.AbortWithStatus(204) if originAllowed {
c.AbortWithStatus(http.StatusNoContent)
} else {
c.AbortWithStatus(http.StatusForbidden)
}
return return
} }
c.Next() c.Next()
} }
} }
func normalizeOrigins(values []string) []string {
if len(values) == 0 {
return nil
}
normalized := make([]string, 0, len(values))
for _, value := range values {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
continue
}
normalized = append(normalized, trimmed)
}
return normalized
}
package middleware
import (
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
)
// SecurityHeaders sets baseline security headers for all responses.
func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
policy := strings.TrimSpace(cfg.Policy)
if policy == "" {
policy = config.DefaultCSPPolicy
}
return func(c *gin.Context) {
c.Header("X-Content-Type-Options", "nosniff")
c.Header("X-Frame-Options", "DENY")
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
if cfg.Enabled {
c.Header("Content-Security-Policy", policy)
}
c.Next()
}
}
...@@ -24,7 +24,8 @@ func SetupRouter( ...@@ -24,7 +24,8 @@ func SetupRouter(
) *gin.Engine { ) *gin.Engine {
// 应用中间件 // 应用中间件
r.Use(middleware2.Logger()) r.Use(middleware2.Logger())
r.Use(middleware2.CORS()) r.Use(middleware2.CORS(cfg.CORS))
r.Use(middleware2.SecurityHeaders(cfg.Security.CSP))
// Serve embedded frontend if available // Serve embedded frontend if available
if web.HasEmbeddedFrontend() { if web.HasEmbeddedFrontend() {
......
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
"crypto/rand" "crypto/rand"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"log" "log"
...@@ -14,9 +15,11 @@ import ( ...@@ -14,9 +15,11 @@ import (
"regexp" "regexp"
"strings" "strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
) )
...@@ -45,6 +48,7 @@ type AccountTestService struct { ...@@ -45,6 +48,7 @@ type AccountTestService struct {
geminiTokenProvider *GeminiTokenProvider geminiTokenProvider *GeminiTokenProvider
antigravityGatewayService *AntigravityGatewayService antigravityGatewayService *AntigravityGatewayService
httpUpstream HTTPUpstream httpUpstream HTTPUpstream
cfg *config.Config
} }
// NewAccountTestService creates a new AccountTestService // NewAccountTestService creates a new AccountTestService
...@@ -53,15 +57,32 @@ func NewAccountTestService( ...@@ -53,15 +57,32 @@ func NewAccountTestService(
geminiTokenProvider *GeminiTokenProvider, geminiTokenProvider *GeminiTokenProvider,
antigravityGatewayService *AntigravityGatewayService, antigravityGatewayService *AntigravityGatewayService,
httpUpstream HTTPUpstream, httpUpstream HTTPUpstream,
cfg *config.Config,
) *AccountTestService { ) *AccountTestService {
return &AccountTestService{ return &AccountTestService{
accountRepo: accountRepo, accountRepo: accountRepo,
geminiTokenProvider: geminiTokenProvider, geminiTokenProvider: geminiTokenProvider,
antigravityGatewayService: antigravityGatewayService, antigravityGatewayService: antigravityGatewayService,
httpUpstream: httpUpstream, httpUpstream: httpUpstream,
cfg: cfg,
} }
} }
func (s *AccountTestService) validateUpstreamBaseURL(raw string) (string, error) {
if s.cfg == nil {
return "", errors.New("config is not available")
}
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
RequireAllowlist: true,
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
return "", err
}
return normalized, nil
}
// generateSessionString generates a Claude Code style session string // generateSessionString generates a Claude Code style session string
func generateSessionString() (string, error) { func generateSessionString() (string, error) {
bytes := make([]byte, 32) bytes := make([]byte, 32)
...@@ -183,11 +204,15 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account ...@@ -183,11 +204,15 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
return s.sendErrorAndEnd(c, "No API key available") return s.sendErrorAndEnd(c, "No API key available")
} }
apiURL = account.GetBaseURL() baseURL := account.GetBaseURL()
if apiURL == "" { if baseURL == "" {
apiURL = "https://api.anthropic.com" baseURL = "https://api.anthropic.com"
} }
apiURL = strings.TrimSuffix(apiURL, "/") + "/v1/messages" normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
}
apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/messages"
} else { } else {
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type)) return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
} }
...@@ -300,7 +325,11 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account ...@@ -300,7 +325,11 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
if baseURL == "" { if baseURL == "" {
baseURL = "https://api.openai.com" baseURL = "https://api.openai.com"
} }
apiURL = strings.TrimSuffix(baseURL, "/") + "/responses" normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
}
apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/responses"
} else { } else {
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type)) return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
} }
...@@ -480,10 +509,14 @@ func (s *AccountTestService) buildGeminiAPIKeyRequest(ctx context.Context, accou ...@@ -480,10 +509,14 @@ func (s *AccountTestService) buildGeminiAPIKeyRequest(ctx context.Context, accou
if baseURL == "" { if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL baseURL = geminicli.AIStudioBaseURL
} }
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
// Use streamGenerateContent for real-time feedback // Use streamGenerateContent for real-time feedback
fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse",
strings.TrimRight(baseURL, "/"), modelID) strings.TrimRight(normalizedBaseURL, "/"), modelID)
req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(payload)) req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(payload))
if err != nil { if err != nil {
...@@ -515,7 +548,11 @@ func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, accoun ...@@ -515,7 +548,11 @@ func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, accoun
if strings.TrimSpace(baseURL) == "" { if strings.TrimSpace(baseURL) == "" {
baseURL = geminicli.AIStudioBaseURL baseURL = geminicli.AIStudioBaseURL
} }
fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", strings.TrimRight(baseURL, "/"), modelID) normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", strings.TrimRight(normalizedBaseURL, "/"), modelID)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(payload)) req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(payload))
if err != nil { if err != nil {
...@@ -544,7 +581,11 @@ func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessT ...@@ -544,7 +581,11 @@ func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessT
} }
wrappedBytes, _ := json.Marshal(wrapped) wrappedBytes, _ := json.Marshal(wrapped)
fullURL := fmt.Sprintf("%s/v1internal:streamGenerateContent?alt=sse", geminicli.GeminiCliBaseURL) normalizedBaseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
if err != nil {
return nil, err
}
fullURL := fmt.Sprintf("%s/v1internal:streamGenerateContent?alt=sse", normalizedBaseURL)
req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(wrappedBytes)) req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(wrappedBytes))
if err != nil { if err != nil {
......
...@@ -11,6 +11,7 @@ import ( ...@@ -11,6 +11,7 @@ import (
"log" "log"
"net/http" "net/http"
"strings" "strings"
"sync/atomic"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
...@@ -916,20 +917,102 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context ...@@ -916,20 +917,102 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
return nil, errors.New("streaming not supported") return nil, errors.New("streaming not supported")
} }
reader := bufio.NewReader(resp.Body) // 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
usage := &ClaudeUsage{} usage := &ClaudeUsage{}
var firstTokenMs *int var firstTokenMs *int
type scanEvent struct {
line string
err error
}
// 独立 goroutine 读取上游,避免读取阻塞影响超时处理
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
streamInterval := time.Duration(0)
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
// 仅发送一次错误事件,避免多次写入导致协议混乱
errorEventSent := false
sendErrorEvent := func(reason string) {
if errorEventSent {
return
}
errorEventSent = true
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
flusher.Flush()
}
for { for {
line, err := reader.ReadString('\n') select {
if len(line) > 0 { case ev, ok := <-events:
if !ok {
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
if ev.err != nil {
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
sendErrorEvent("response_too_large")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
}
sendErrorEvent("stream_read_error")
return nil, ev.err
}
line := ev.line
trimmed := strings.TrimRight(line, "\r\n") trimmed := strings.TrimRight(line, "\r\n")
if strings.HasPrefix(trimmed, "data:") { if strings.HasPrefix(trimmed, "data:") {
payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:")) payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
if payload == "" || payload == "[DONE]" { if payload == "" || payload == "[DONE]" {
_, _ = io.WriteString(c.Writer, line) if _, err := fmt.Fprintf(c.Writer, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush() flusher.Flush()
} else { continue
}
// 解包 v1internal 响应 // 解包 v1internal 响应
inner, parseErr := s.unwrapV1InternalResponse([]byte(payload)) inner, parseErr := s.unwrapV1InternalResponse([]byte(payload))
if parseErr == nil && inner != nil { if parseErr == nil && inner != nil {
...@@ -949,24 +1032,30 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context ...@@ -949,24 +1032,30 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
firstTokenMs = &ms firstTokenMs = &ms
} }
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", payload) if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", payload); err != nil {
flusher.Flush() sendErrorEvent("write_failed")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err
} }
} else {
_, _ = io.WriteString(c.Writer, line)
flusher.Flush() flusher.Flush()
continue
} }
if _, err := fmt.Fprintf(c.Writer, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err
} }
flusher.Flush()
if errors.Is(err, io.EOF) { case <-intervalCh:
break lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
} }
if err != nil { log.Printf("Stream data interval timeout (antigravity)")
return nil, err sendErrorEvent("stream_timeout")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
} }
} }
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
} }
func (s *AntigravityGatewayService) handleGeminiNonStreamingResponse(c *gin.Context, resp *http.Response) (*ClaudeUsage, error) { func (s *AntigravityGatewayService) handleGeminiNonStreamingResponse(c *gin.Context, resp *http.Response) (*ClaudeUsage, error) {
...@@ -1105,7 +1194,13 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context ...@@ -1105,7 +1194,13 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
processor := antigravity.NewStreamingProcessor(originalModel) processor := antigravity.NewStreamingProcessor(originalModel)
var firstTokenMs *int var firstTokenMs *int
reader := bufio.NewReader(resp.Body) // 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
// 辅助函数:转换 antigravity.ClaudeUsage 到 service.ClaudeUsage // 辅助函数:转换 antigravity.ClaudeUsage 到 service.ClaudeUsage
convertUsage := func(agUsage *antigravity.ClaudeUsage) *ClaudeUsage { convertUsage := func(agUsage *antigravity.ClaudeUsage) *ClaudeUsage {
...@@ -1120,13 +1215,85 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context ...@@ -1120,13 +1215,85 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
} }
} }
type scanEvent struct {
line string
err error
}
// 独立 goroutine 读取上游,避免读取阻塞影响超时处理
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
streamInterval := time.Duration(0)
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
// 仅发送一次错误事件,避免多次写入导致协议混乱
errorEventSent := false
sendErrorEvent := func(reason string) {
if errorEventSent {
return
}
errorEventSent = true
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
flusher.Flush()
}
for { for {
line, err := reader.ReadString('\n') select {
if err != nil && !errors.Is(err, io.EOF) { case ev, ok := <-events:
return nil, fmt.Errorf("stream read error: %w", err) if !ok {
// 发送结束事件
finalEvents, agUsage := processor.Finish()
if len(finalEvents) > 0 {
_, _ = c.Writer.Write(finalEvents)
flusher.Flush()
}
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, nil
}
if ev.err != nil {
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
sendErrorEvent("response_too_large")
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, ev.err
}
sendErrorEvent("stream_read_error")
return nil, fmt.Errorf("stream read error: %w", ev.err)
} }
if len(line) > 0 { line := ev.line
// 处理 SSE 行,转换为 Claude 格式 // 处理 SSE 行,转换为 Claude 格式
claudeEvents := processor.ProcessLine(strings.TrimRight(line, "\r\n")) claudeEvents := processor.ProcessLine(strings.TrimRight(line, "\r\n"))
...@@ -1141,23 +1308,21 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context ...@@ -1141,23 +1308,21 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
if len(finalEvents) > 0 { if len(finalEvents) > 0 {
_, _ = c.Writer.Write(finalEvents) _, _ = c.Writer.Write(finalEvents)
} }
sendErrorEvent("write_failed")
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, writeErr return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, writeErr
} }
flusher.Flush() flusher.Flush()
} }
}
if errors.Is(err, io.EOF) { case <-intervalCh:
break lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
} }
log.Printf("Stream data interval timeout (antigravity)")
sendErrorEvent("stream_timeout")
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
} }
// 发送结束事件
finalEvents, agUsage := processor.Finish()
if len(finalEvents) > 0 {
_, _ = c.Writer.Write(finalEvents)
flusher.Flush()
} }
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, nil
} }
...@@ -221,9 +221,33 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S ...@@ -221,9 +221,33 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
// VerifyTurnstile 验证Turnstile token // VerifyTurnstile 验证Turnstile token
func (s *AuthService) VerifyTurnstile(ctx context.Context, token string, remoteIP string) error { func (s *AuthService) VerifyTurnstile(ctx context.Context, token string, remoteIP string) error {
required := s.cfg != nil && s.cfg.Server.Mode == "release" && s.cfg.Turnstile.Required
if required {
if s.settingService == nil {
log.Println("[Auth] Turnstile required but settings service is not configured")
return ErrTurnstileNotConfigured
}
enabled := s.settingService.IsTurnstileEnabled(ctx)
secretConfigured := s.settingService.GetTurnstileSecretKey(ctx) != ""
if !enabled || !secretConfigured {
log.Printf("[Auth] Turnstile required but not configured (enabled=%v, secret_configured=%v)", enabled, secretConfigured)
return ErrTurnstileNotConfigured
}
}
if s.turnstileService == nil { if s.turnstileService == nil {
if required {
log.Println("[Auth] Turnstile required but service not configured")
return ErrTurnstileNotConfigured
}
return nil // 服务未配置则跳过验证 return nil // 服务未配置则跳过验证
} }
if !required && s.settingService != nil && s.settingService.IsTurnstileEnabled(ctx) && s.settingService.GetTurnstileSecretKey(ctx) == "" {
log.Println("[Auth] Turnstile enabled but secret key not configured")
}
return s.turnstileService.VerifyToken(ctx, token, remoteIP) return s.turnstileService.VerifyToken(ctx, token, remoteIP)
} }
......
...@@ -17,6 +17,7 @@ import ( ...@@ -17,6 +17,7 @@ import (
// 注:ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义 // 注:ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义
var ( var (
ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired") ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
ErrBillingServiceUnavailable = infraerrors.ServiceUnavailable("BILLING_SERVICE_ERROR", "Billing service temporarily unavailable. Please retry later.")
) )
// subscriptionCacheData 订阅缓存数据结构(内部使用) // subscriptionCacheData 订阅缓存数据结构(内部使用)
...@@ -76,6 +77,7 @@ type BillingCacheService struct { ...@@ -76,6 +77,7 @@ type BillingCacheService struct {
userRepo UserRepository userRepo UserRepository
subRepo UserSubscriptionRepository subRepo UserSubscriptionRepository
cfg *config.Config cfg *config.Config
circuitBreaker *billingCircuitBreaker
cacheWriteChan chan cacheWriteTask cacheWriteChan chan cacheWriteTask
cacheWriteWg sync.WaitGroup cacheWriteWg sync.WaitGroup
...@@ -95,6 +97,7 @@ func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo ...@@ -95,6 +97,7 @@ func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo
subRepo: subRepo, subRepo: subRepo,
cfg: cfg, cfg: cfg,
} }
svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker)
svc.startCacheWriteWorkers() svc.startCacheWriteWorkers()
return svc return svc
} }
...@@ -450,6 +453,9 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user ...@@ -450,6 +453,9 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
if s.cfg.RunMode == config.RunModeSimple { if s.cfg.RunMode == config.RunModeSimple {
return nil return nil
} }
if s.circuitBreaker != nil && !s.circuitBreaker.Allow() {
return ErrBillingServiceUnavailable
}
// 判断计费模式 // 判断计费模式
isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil
...@@ -465,9 +471,14 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user ...@@ -465,9 +471,14 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
func (s *BillingCacheService) checkBalanceEligibility(ctx context.Context, userID int64) error { func (s *BillingCacheService) checkBalanceEligibility(ctx context.Context, userID int64) error {
balance, err := s.GetUserBalance(ctx, userID) balance, err := s.GetUserBalance(ctx, userID)
if err != nil { if err != nil {
// 缓存/数据库错误,允许通过(降级处理) if s.circuitBreaker != nil {
log.Printf("Warning: get user balance failed, allowing request: %v", err) s.circuitBreaker.OnFailure(err)
return nil }
log.Printf("ALERT: billing balance check failed for user %d: %v", userID, err)
return ErrBillingServiceUnavailable.WithCause(err)
}
if s.circuitBreaker != nil {
s.circuitBreaker.OnSuccess()
} }
if balance <= 0 { if balance <= 0 {
...@@ -482,9 +493,14 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context, ...@@ -482,9 +493,14 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
// 获取订阅缓存数据 // 获取订阅缓存数据
subData, err := s.GetSubscriptionStatus(ctx, userID, group.ID) subData, err := s.GetSubscriptionStatus(ctx, userID, group.ID)
if err != nil { if err != nil {
// 缓存/数据库错误,降级使用传入的subscription进行检查 if s.circuitBreaker != nil {
log.Printf("Warning: get subscription cache failed, using fallback: %v", err) s.circuitBreaker.OnFailure(err)
return s.checkSubscriptionLimitsFallback(subscription, group) }
log.Printf("ALERT: billing subscription check failed for user %d group %d: %v", userID, group.ID, err)
return ErrBillingServiceUnavailable.WithCause(err)
}
if s.circuitBreaker != nil {
s.circuitBreaker.OnSuccess()
} }
// 检查订阅状态 // 检查订阅状态
...@@ -513,27 +529,133 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context, ...@@ -513,27 +529,133 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
return nil return nil
} }
// checkSubscriptionLimitsFallback 降级检查订阅限额 type billingCircuitBreakerState int
func (s *BillingCacheService) checkSubscriptionLimitsFallback(subscription *UserSubscription, group *Group) error {
if subscription == nil { const (
return ErrSubscriptionInvalid billingCircuitClosed billingCircuitBreakerState = iota
billingCircuitOpen
billingCircuitHalfOpen
)
type billingCircuitBreaker struct {
mu sync.Mutex
state billingCircuitBreakerState
failures int
openedAt time.Time
failureThreshold int
resetTimeout time.Duration
halfOpenRequests int
halfOpenRemaining int
}
func newBillingCircuitBreaker(cfg config.CircuitBreakerConfig) *billingCircuitBreaker {
if !cfg.Enabled {
return nil
}
resetTimeout := time.Duration(cfg.ResetTimeoutSeconds) * time.Second
if resetTimeout <= 0 {
resetTimeout = 30 * time.Second
}
halfOpen := cfg.HalfOpenRequests
if halfOpen <= 0 {
halfOpen = 1
} }
threshold := cfg.FailureThreshold
if threshold <= 0 {
threshold = 5
}
return &billingCircuitBreaker{
state: billingCircuitClosed,
failureThreshold: threshold,
resetTimeout: resetTimeout,
halfOpenRequests: halfOpen,
}
}
if !subscription.IsActive() { func (b *billingCircuitBreaker) Allow() bool {
return ErrSubscriptionInvalid b.mu.Lock()
defer b.mu.Unlock()
switch b.state {
case billingCircuitClosed:
return true
case billingCircuitOpen:
if time.Since(b.openedAt) < b.resetTimeout {
return false
}
b.state = billingCircuitHalfOpen
b.halfOpenRemaining = b.halfOpenRequests
log.Printf("ALERT: billing circuit breaker entering half-open state")
fallthrough
case billingCircuitHalfOpen:
if b.halfOpenRemaining <= 0 {
return false
} }
b.halfOpenRemaining--
return true
default:
return false
}
}
if !subscription.CheckDailyLimit(group, 0) { func (b *billingCircuitBreaker) OnFailure(err error) {
return ErrDailyLimitExceeded if b == nil {
return
} }
b.mu.Lock()
defer b.mu.Unlock()
if !subscription.CheckWeeklyLimit(group, 0) { switch b.state {
return ErrWeeklyLimitExceeded case billingCircuitOpen:
return
case billingCircuitHalfOpen:
b.state = billingCircuitOpen
b.openedAt = time.Now()
b.halfOpenRemaining = 0
log.Printf("ALERT: billing circuit breaker opened after half-open failure: %v", err)
return
default:
b.failures++
if b.failures >= b.failureThreshold {
b.state = billingCircuitOpen
b.openedAt = time.Now()
b.halfOpenRemaining = 0
log.Printf("ALERT: billing circuit breaker opened after %d failures: %v", b.failures, err)
}
} }
}
if !subscription.CheckMonthlyLimit(group, 0) { func (b *billingCircuitBreaker) OnSuccess() {
return ErrMonthlyLimitExceeded if b == nil {
return
} }
b.mu.Lock()
defer b.mu.Unlock()
return nil previousState := b.state
previousFailures := b.failures
b.state = billingCircuitClosed
b.failures = 0
b.halfOpenRemaining = 0
// 只有状态真正发生变化时才记录日志
if previousState != billingCircuitClosed {
log.Printf("ALERT: billing circuit breaker closed (was %s)", circuitStateString(previousState))
} else if previousFailures > 0 {
log.Printf("INFO: billing circuit breaker failures reset from %d", previousFailures)
}
}
func circuitStateString(state billingCircuitBreakerState) string {
switch state {
case billingCircuitClosed:
return "closed"
case billingCircuitOpen:
return "open"
case billingCircuitHalfOpen:
return "half-open"
default:
return "unknown"
}
} }
...@@ -8,12 +8,13 @@ import ( ...@@ -8,12 +8,13 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/url"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
) )
type CRSSyncService struct { type CRSSyncService struct {
...@@ -22,6 +23,7 @@ type CRSSyncService struct { ...@@ -22,6 +23,7 @@ type CRSSyncService struct {
oauthService *OAuthService oauthService *OAuthService
openaiOAuthService *OpenAIOAuthService openaiOAuthService *OpenAIOAuthService
geminiOAuthService *GeminiOAuthService geminiOAuthService *GeminiOAuthService
cfg *config.Config
} }
func NewCRSSyncService( func NewCRSSyncService(
...@@ -30,6 +32,7 @@ func NewCRSSyncService( ...@@ -30,6 +32,7 @@ func NewCRSSyncService(
oauthService *OAuthService, oauthService *OAuthService,
openaiOAuthService *OpenAIOAuthService, openaiOAuthService *OpenAIOAuthService,
geminiOAuthService *GeminiOAuthService, geminiOAuthService *GeminiOAuthService,
cfg *config.Config,
) *CRSSyncService { ) *CRSSyncService {
return &CRSSyncService{ return &CRSSyncService{
accountRepo: accountRepo, accountRepo: accountRepo,
...@@ -37,6 +40,7 @@ func NewCRSSyncService( ...@@ -37,6 +40,7 @@ func NewCRSSyncService(
oauthService: oauthService, oauthService: oauthService,
openaiOAuthService: openaiOAuthService, openaiOAuthService: openaiOAuthService,
geminiOAuthService: geminiOAuthService, geminiOAuthService: geminiOAuthService,
cfg: cfg,
} }
} }
...@@ -187,7 +191,10 @@ type crsGeminiAPIKeyAccount struct { ...@@ -187,7 +191,10 @@ type crsGeminiAPIKeyAccount struct {
} }
func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) { func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) {
baseURL, err := normalizeBaseURL(input.BaseURL) if s.cfg == nil {
return nil, errors.New("config is not available")
}
baseURL, err := normalizeBaseURL(input.BaseURL, s.cfg.Security.URLAllowlist.CRSHosts, s.cfg.Security.URLAllowlist.AllowPrivateHosts)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -197,6 +204,8 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -197,6 +204,8 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
client, err := httpclient.GetClient(httpclient.Options{ client, err := httpclient.GetClient(httpclient.Options{
Timeout: 20 * time.Second, Timeout: 20 * time.Second,
ValidateResolvedIP: true,
AllowPrivateHosts: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
}) })
if err != nil { if err != nil {
client = &http.Client{Timeout: 20 * time.Second} client = &http.Client{Timeout: 20 * time.Second}
...@@ -1055,17 +1064,18 @@ func mapCRSStatus(isActive bool, status string) string { ...@@ -1055,17 +1064,18 @@ func mapCRSStatus(isActive bool, status string) string {
return "active" return "active"
} }
func normalizeBaseURL(raw string) (string, error) { func normalizeBaseURL(raw string, allowlist []string, allowPrivate bool) (string, error) {
trimmed := strings.TrimSpace(raw) // 当 allowlist 为空时,不强制要求白名单(只进行基本的 URL 和 SSRF 验证)
if trimmed == "" { requireAllowlist := len(allowlist) > 0
return "", errors.New("base_url is required") normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
} AllowedHosts: allowlist,
u, err := url.Parse(trimmed) RequireAllowlist: requireAllowlist,
if err != nil || u.Scheme == "" || u.Host == "" { AllowPrivate: allowPrivate,
return "", fmt.Errorf("invalid base_url: %s", trimmed) })
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
} }
u.Path = strings.TrimRight(u.Path, "/") return normalized, nil
return strings.TrimRight(u.String(), "/"), nil
} }
// cleanBaseURL removes trailing suffix from base_url in credentials // cleanBaseURL removes trailing suffix from base_url in credentials
......
...@@ -15,11 +15,14 @@ import ( ...@@ -15,11 +15,14 @@ import (
"regexp" "regexp"
"sort" "sort"
"strings" "strings"
"sync/atomic"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
...@@ -30,6 +33,7 @@ const ( ...@@ -30,6 +33,7 @@ const (
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true" claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true" claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
stickySessionTTL = time.Hour // 粘性会话TTL stickySessionTTL = time.Hour // 粘性会话TTL
defaultMaxLineSize = 10 * 1024 * 1024
claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude." claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude."
) )
...@@ -1225,7 +1229,13 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex ...@@ -1225,7 +1229,13 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
targetURL := claudeAPIURL targetURL := claudeAPIURL
if account.Type == AccountTypeAPIKey { if account.Type == AccountTypeAPIKey {
baseURL := account.GetBaseURL() baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages" if baseURL != "" {
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
targetURL = validatedURL + "/v1/messages"
}
} }
// OAuth账号:应用统一指纹 // OAuth账号:应用统一指纹
...@@ -1594,12 +1604,87 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http ...@@ -1594,12 +1604,87 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
var firstTokenMs *int var firstTokenMs *int
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
// 设置更大的buffer以处理长行 // 设置更大的buffer以处理长行
scanner.Buffer(make([]byte, 64*1024), 1024*1024) maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
type scanEvent struct {
line string
err error
}
// 独立 goroutine 读取上游,避免读取阻塞导致超时/keepalive无法处理
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
streamInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
// 仅监控上游数据间隔超时,避免下游写入阻塞导致误判
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
errorEventSent := false
sendErrorEvent := func(reason string) {
if errorEventSent {
return
}
errorEventSent = true
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
flusher.Flush()
}
needModelReplace := originalModel != mappedModel needModelReplace := originalModel != mappedModel
for scanner.Scan() { for {
line := scanner.Text() select {
case ev, ok := <-events:
if !ok {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
if ev.err != nil {
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
sendErrorEvent("response_too_large")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
}
sendErrorEvent("stream_read_error")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
}
line := ev.line
if line == "event: error" { if line == "event: error" {
return nil, errors.New("have error in stream") return nil, errors.New("have error in stream")
} }
...@@ -1615,6 +1700,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http ...@@ -1615,6 +1700,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
// 转发行 // 转发行
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
} }
flusher.Flush() flusher.Flush()
...@@ -1628,17 +1714,23 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http ...@@ -1628,17 +1714,23 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
} else { } else {
// 非 data 行直接转发 // 非 data 行直接转发
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
} }
flusher.Flush() flusher.Flush()
} }
}
if err := scanner.Err(); err != nil { case <-intervalCh:
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err) lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
sendErrorEvent("stream_timeout")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
}
} }
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
} }
// replaceModelInSSELine 替换SSE数据行中的model字段 // replaceModelInSSELine 替换SSE数据行中的model字段
...@@ -1743,12 +1835,7 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h ...@@ -1743,12 +1835,7 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
body = s.replaceModelInResponseBody(body, mappedModel, originalModel) body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
} }
// 透传响应头 responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
// 写入响应 // 写入响应
c.Data(resp.StatusCode, "application/json", body) c.Data(resp.StatusCode, "application/json", body)
...@@ -2020,7 +2107,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con ...@@ -2020,7 +2107,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
targetURL := claudeAPICountTokensURL targetURL := claudeAPICountTokensURL
if account.Type == AccountTypeAPIKey { if account.Type == AccountTypeAPIKey {
baseURL := account.GetBaseURL() baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages/count_tokens" if baseURL != "" {
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
targetURL = validatedURL + "/v1/messages/count_tokens"
}
} }
// OAuth 账号:应用统一指纹和重写 userID // OAuth 账号:应用统一指纹和重写 userID
...@@ -2100,6 +2193,18 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m ...@@ -2100,6 +2193,18 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m
}) })
} }
func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
RequireAllowlist: true,
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
}
return normalized, nil
}
// GetAvailableModels returns the list of models available for a group // GetAvailableModels returns the list of models available for a group
// It aggregates model_mapping keys from all schedulable accounts in the group // It aggregates model_mapping keys from all schedulable accounts in the group
func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string { func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string {
......
...@@ -18,9 +18,12 @@ import ( ...@@ -18,9 +18,12 @@ import (
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
...@@ -41,6 +44,7 @@ type GeminiMessagesCompatService struct { ...@@ -41,6 +44,7 @@ type GeminiMessagesCompatService struct {
rateLimitService *RateLimitService rateLimitService *RateLimitService
httpUpstream HTTPUpstream httpUpstream HTTPUpstream
antigravityGatewayService *AntigravityGatewayService antigravityGatewayService *AntigravityGatewayService
cfg *config.Config
} }
func NewGeminiMessagesCompatService( func NewGeminiMessagesCompatService(
...@@ -51,6 +55,7 @@ func NewGeminiMessagesCompatService( ...@@ -51,6 +55,7 @@ func NewGeminiMessagesCompatService(
rateLimitService *RateLimitService, rateLimitService *RateLimitService,
httpUpstream HTTPUpstream, httpUpstream HTTPUpstream,
antigravityGatewayService *AntigravityGatewayService, antigravityGatewayService *AntigravityGatewayService,
cfg *config.Config,
) *GeminiMessagesCompatService { ) *GeminiMessagesCompatService {
return &GeminiMessagesCompatService{ return &GeminiMessagesCompatService{
accountRepo: accountRepo, accountRepo: accountRepo,
...@@ -60,6 +65,7 @@ func NewGeminiMessagesCompatService( ...@@ -60,6 +65,7 @@ func NewGeminiMessagesCompatService(
rateLimitService: rateLimitService, rateLimitService: rateLimitService,
httpUpstream: httpUpstream, httpUpstream: httpUpstream,
antigravityGatewayService: antigravityGatewayService, antigravityGatewayService: antigravityGatewayService,
cfg: cfg,
} }
} }
...@@ -230,6 +236,18 @@ func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *Antigravit ...@@ -230,6 +236,18 @@ func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *Antigravit
return s.antigravityGatewayService return s.antigravityGatewayService
} }
func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) {
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
RequireAllowlist: true,
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
}
return normalized, nil
}
// HasAntigravityAccounts 检查是否有可用的 antigravity 账户 // HasAntigravityAccounts 检查是否有可用的 antigravity 账户
func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context, groupID *int64) (bool, error) { func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context, groupID *int64) (bool, error) {
var accounts []Account var accounts []Account
...@@ -381,16 +399,20 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex ...@@ -381,16 +399,20 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
return nil, "", errors.New("gemini api_key not configured") return nil, "", errors.New("gemini api_key not configured")
} }
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/") baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" { if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL baseURL = geminicli.AIStudioBaseURL
} }
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
}
action := "generateContent" action := "generateContent"
if req.Stream { if req.Stream {
action = "streamGenerateContent" action = "streamGenerateContent"
} }
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(baseURL, "/"), mappedModel, action) fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action)
if req.Stream { if req.Stream {
fullURL += "?alt=sse" fullURL += "?alt=sse"
} }
...@@ -427,7 +449,11 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex ...@@ -427,7 +449,11 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token) // 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
if projectID != "" { if projectID != "" {
// Mode 1: Code Assist API // Mode 1: Code Assist API
fullURL := fmt.Sprintf("%s/v1internal:%s", geminicli.GeminiCliBaseURL, action) baseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1internal:%s", strings.TrimRight(baseURL, "/"), action)
if useUpstreamStream { if useUpstreamStream {
fullURL += "?alt=sse" fullURL += "?alt=sse"
} }
...@@ -453,12 +479,16 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex ...@@ -453,12 +479,16 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
return upstreamReq, "x-request-id", nil return upstreamReq, "x-request-id", nil
} else { } else {
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token) // Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/") baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" { if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL baseURL = geminicli.AIStudioBaseURL
} }
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", baseURL, mappedModel, action) fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action)
if useUpstreamStream { if useUpstreamStream {
fullURL += "?alt=sse" fullURL += "?alt=sse"
} }
...@@ -650,12 +680,16 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. ...@@ -650,12 +680,16 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return nil, "", errors.New("gemini api_key not configured") return nil, "", errors.New("gemini api_key not configured")
} }
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/") baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" { if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL baseURL = geminicli.AIStudioBaseURL
} }
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(baseURL, "/"), mappedModel, upstreamAction) fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, upstreamAction)
if useUpstreamStream { if useUpstreamStream {
fullURL += "?alt=sse" fullURL += "?alt=sse"
} }
...@@ -687,7 +721,11 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. ...@@ -687,7 +721,11 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token) // 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
if projectID != "" && !forceAIStudio { if projectID != "" && !forceAIStudio {
// Mode 1: Code Assist API // Mode 1: Code Assist API
fullURL := fmt.Sprintf("%s/v1internal:%s", geminicli.GeminiCliBaseURL, upstreamAction) baseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1internal:%s", strings.TrimRight(baseURL, "/"), upstreamAction)
if useUpstreamStream { if useUpstreamStream {
fullURL += "?alt=sse" fullURL += "?alt=sse"
} }
...@@ -713,12 +751,16 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. ...@@ -713,12 +751,16 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return upstreamReq, "x-request-id", nil return upstreamReq, "x-request-id", nil
} else { } else {
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token) // Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/") baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" { if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL baseURL = geminicli.AIStudioBaseURL
} }
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", baseURL, mappedModel, upstreamAction) fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, upstreamAction)
if useUpstreamStream { if useUpstreamStream {
fullURL += "?alt=sse" fullURL += "?alt=sse"
} }
...@@ -1652,6 +1694,8 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co ...@@ -1652,6 +1694,8 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
_ = json.Unmarshal(respBody, &parsed) _ = json.Unmarshal(respBody, &parsed)
} }
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
contentType := resp.Header.Get("Content-Type") contentType := resp.Header.Get("Content-Type")
if contentType == "" { if contentType == "" {
contentType = "application/json" contentType = "application/json"
...@@ -1773,11 +1817,15 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac ...@@ -1773,11 +1817,15 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
return nil, errors.New("invalid path") return nil, errors.New("invalid path")
} }
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/") baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" { if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL baseURL = geminicli.AIStudioBaseURL
} }
fullURL := strings.TrimRight(baseURL, "/") + path normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
fullURL := strings.TrimRight(normalizedBaseURL, "/") + path
var proxyURL string var proxyURL string
if account.ProxyID != nil && account.Proxy != nil { if account.ProxyID != nil && account.Proxy != nil {
...@@ -1816,9 +1864,14 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac ...@@ -1816,9 +1864,14 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
defer func() { _ = resp.Body.Close() }() defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<20)) body, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
wwwAuthenticate := resp.Header.Get("Www-Authenticate")
filteredHeaders := responseheaders.FilterHeaders(resp.Header, s.cfg.Security.ResponseHeaders)
if wwwAuthenticate != "" {
filteredHeaders.Set("Www-Authenticate", wwwAuthenticate)
}
return &UpstreamHTTPResult{ return &UpstreamHTTPResult{
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
Headers: resp.Header.Clone(), Headers: filteredHeaders,
Body: body, Body: body,
}, nil }, nil
} }
......
...@@ -1002,6 +1002,7 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR ...@@ -1002,6 +1002,7 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR
client, err := httpclient.GetClient(httpclient.Options{ client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: strings.TrimSpace(proxyURL), ProxyURL: strings.TrimSpace(proxyURL),
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
ValidateResolvedIP: true,
}) })
if err != nil { if err != nil {
client = &http.Client{Timeout: 30 * time.Second} client = &http.Client{Timeout: 30 * time.Second}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment