Commit 195e227c authored by song's avatar song
Browse files

merge: 合并 upstream/main 并保留本地图片计费功能

parents 6fa704d6 752882a0
...@@ -22,7 +22,13 @@ type HTTPUpstreamSuite struct { ...@@ -22,7 +22,13 @@ type HTTPUpstreamSuite struct {
// SetupTest 每个测试用例执行前的初始化 // SetupTest 每个测试用例执行前的初始化
// 创建空配置,各测试用例可按需覆盖 // 创建空配置,各测试用例可按需覆盖
func (s *HTTPUpstreamSuite) SetupTest() { func (s *HTTPUpstreamSuite) SetupTest() {
s.cfg = &config.Config{} s.cfg = &config.Config{
Security: config.SecurityConfig{
URLAllowlist: config.URLAllowlistConfig{
AllowPrivateHosts: true,
},
},
}
} }
// newService 创建测试用的 httpUpstreamService 实例 // newService 创建测试用的 httpUpstreamService 实例
......
...@@ -26,6 +26,7 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) { ...@@ -26,6 +26,7 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
requireColumn(t, tx, "users", "notes", "text", 0, false) requireColumn(t, tx, "users", "notes", "text", 0, false)
// accounts: schedulable and rate-limit fields // accounts: schedulable and rate-limit fields
requireColumn(t, tx, "accounts", "notes", "text", 0, true)
requireColumn(t, tx, "accounts", "schedulable", "boolean", 0, false) requireColumn(t, tx, "accounts", "schedulable", "boolean", 0, false)
requireColumn(t, tx, "accounts", "rate_limited_at", "timestamp with time zone", 0, true) requireColumn(t, tx, "accounts", "rate_limited_at", "timestamp with time zone", 0, true)
requireColumn(t, tx, "accounts", "rate_limit_reset_at", "timestamp with time zone", 0, true) requireColumn(t, tx, "accounts", "rate_limit_reset_at", "timestamp with time zone", 0, true)
......
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
) )
...@@ -16,9 +17,17 @@ type pricingRemoteClient struct { ...@@ -16,9 +17,17 @@ type pricingRemoteClient struct {
httpClient *http.Client httpClient *http.Client
} }
func NewPricingRemoteClient() service.PricingRemoteClient { func NewPricingRemoteClient(cfg *config.Config) service.PricingRemoteClient {
allowPrivate := false
validateResolvedIP := true
if cfg != nil {
allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts
validateResolvedIP = cfg.Security.URLAllowlist.Enabled
}
sharedClient, err := httpclient.GetClient(httpclient.Options{ sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
ValidateResolvedIP: validateResolvedIP,
AllowPrivateHosts: allowPrivate,
}) })
if err != nil { if err != nil {
sharedClient = &http.Client{Timeout: 30 * time.Second} sharedClient = &http.Client{Timeout: 30 * time.Second}
......
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
) )
...@@ -19,7 +20,13 @@ type PricingServiceSuite struct { ...@@ -19,7 +20,13 @@ type PricingServiceSuite struct {
func (s *PricingServiceSuite) SetupTest() { func (s *PricingServiceSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
client, ok := NewPricingRemoteClient().(*pricingRemoteClient) client, ok := NewPricingRemoteClient(&config.Config{
Security: config.SecurityConfig{
URLAllowlist: config.URLAllowlistConfig{
AllowPrivateHosts: true,
},
},
}).(*pricingRemoteClient)
require.True(s.T(), ok, "type assertion failed") require.True(s.T(), ok, "type assertion failed")
s.client = client s.client = client
} }
......
...@@ -5,28 +5,52 @@ import ( ...@@ -5,28 +5,52 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"log"
"net/http" "net/http"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
) )
func NewProxyExitInfoProber() service.ProxyExitInfoProber { func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
return &proxyProbeService{ipInfoURL: defaultIPInfoURL} insecure := false
allowPrivate := false
validateResolvedIP := true
if cfg != nil {
insecure = cfg.Security.ProxyProbe.InsecureSkipVerify
allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts
validateResolvedIP = cfg.Security.URLAllowlist.Enabled
}
if insecure {
log.Printf("[ProxyProbe] Warning: TLS verification is disabled for proxy probing.")
}
return &proxyProbeService{
ipInfoURL: defaultIPInfoURL,
insecureSkipVerify: insecure,
allowPrivateHosts: allowPrivate,
validateResolvedIP: validateResolvedIP,
}
} }
const defaultIPInfoURL = "https://ipinfo.io/json" const defaultIPInfoURL = "https://ipinfo.io/json"
type proxyProbeService struct { type proxyProbeService struct {
ipInfoURL string ipInfoURL string
insecureSkipVerify bool
allowPrivateHosts bool
validateResolvedIP bool
} }
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) { func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
client, err := httpclient.GetClient(httpclient.Options{ client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL, ProxyURL: proxyURL,
Timeout: 15 * time.Second, Timeout: 15 * time.Second,
InsecureSkipVerify: true, InsecureSkipVerify: s.insecureSkipVerify,
ProxyStrict: true,
ValidateResolvedIP: s.validateResolvedIP,
AllowPrivateHosts: s.allowPrivateHosts,
}) })
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("failed to create proxy client: %w", err) return nil, 0, fmt.Errorf("failed to create proxy client: %w", err)
......
...@@ -20,7 +20,10 @@ type ProxyProbeServiceSuite struct { ...@@ -20,7 +20,10 @@ type ProxyProbeServiceSuite struct {
func (s *ProxyProbeServiceSuite) SetupTest() { func (s *ProxyProbeServiceSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
s.prober = &proxyProbeService{ipInfoURL: "http://ipinfo.test/json"} s.prober = &proxyProbeService{
ipInfoURL: "http://ipinfo.test/json",
allowPrivateHosts: true,
}
} }
func (s *ProxyProbeServiceSuite) TearDownTest() { func (s *ProxyProbeServiceSuite) TearDownTest() {
......
...@@ -23,6 +23,7 @@ type turnstileVerifier struct { ...@@ -23,6 +23,7 @@ type turnstileVerifier struct {
func NewTurnstileVerifier() service.TurnstileVerifier { func NewTurnstileVerifier() service.TurnstileVerifier {
sharedClient, err := httpclient.GetClient(httpclient.Options{ sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 10 * time.Second, Timeout: 10 * time.Second,
ValidateResolvedIP: true,
}) })
if err != nil { if err != nil {
sharedClient = &http.Client{Timeout: 10 * time.Second} sharedClient = &http.Client{Timeout: 10 * time.Second}
......
...@@ -329,17 +329,20 @@ func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount flo ...@@ -329,17 +329,20 @@ func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount flo
return nil return nil
} }
// DeductBalance 扣除用户余额
// 透支策略:允许余额变为负数,确保当前请求能够完成
// 中间件会阻止余额 <= 0 的用户发起后续请求
func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount float64) error { func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount float64) error {
client := clientFromContext(ctx, r.client) client := clientFromContext(ctx, r.client)
n, err := client.User.Update(). n, err := client.User.Update().
Where(dbuser.IDEQ(id), dbuser.BalanceGTE(amount)). Where(dbuser.IDEQ(id)).
AddBalance(-amount). AddBalance(-amount).
Save(ctx) Save(ctx)
if err != nil { if err != nil {
return err return err
} }
if n == 0 { if n == 0 {
return service.ErrInsufficientBalance return service.ErrUserNotFound
} }
return nil return nil
} }
......
...@@ -290,9 +290,14 @@ func (s *UserRepoSuite) TestDeductBalance() { ...@@ -290,9 +290,14 @@ func (s *UserRepoSuite) TestDeductBalance() {
func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() { func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() {
user := s.mustCreateUser(&service.User{Email: "insuf@test.com", Balance: 5}) user := s.mustCreateUser(&service.User{Email: "insuf@test.com", Balance: 5})
// 透支策略:允许扣除超过余额的金额
err := s.repo.DeductBalance(s.ctx, user.ID, 999) err := s.repo.DeductBalance(s.ctx, user.ID, 999)
s.Require().Error(err, "expected error for insufficient balance") s.Require().NoError(err, "DeductBalance should allow overdraft")
s.Require().ErrorIs(err, service.ErrInsufficientBalance)
// 验证余额变为负数
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().InDelta(-994.0, got.Balance, 1e-6, "Balance should be negative after overdraft")
} }
func (s *UserRepoSuite) TestDeductBalance_ExactAmount() { func (s *UserRepoSuite) TestDeductBalance_ExactAmount() {
...@@ -306,6 +311,19 @@ func (s *UserRepoSuite) TestDeductBalance_ExactAmount() { ...@@ -306,6 +311,19 @@ func (s *UserRepoSuite) TestDeductBalance_ExactAmount() {
s.Require().InDelta(0.0, got.Balance, 1e-6) s.Require().InDelta(0.0, got.Balance, 1e-6)
} }
func (s *UserRepoSuite) TestDeductBalance_AllowsOverdraft() {
user := s.mustCreateUser(&service.User{Email: "overdraft@test.com", Balance: 5.0})
// 扣除超过余额的金额 - 应该成功
err := s.repo.DeductBalance(s.ctx, user.ID, 10.0)
s.Require().NoError(err, "DeductBalance should allow overdraft")
// 验证余额为负
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().InDelta(-5.0, got.Balance, 1e-6, "Balance should be -5.0 after overdraft")
}
// --- Concurrency --- // --- Concurrency ---
func (s *UserRepoSuite) TestUpdateConcurrency() { func (s *UserRepoSuite) TestUpdateConcurrency() {
...@@ -477,9 +495,12 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() { ...@@ -477,9 +495,12 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
s.Require().NoError(err, "GetByID after DeductBalance") s.Require().NoError(err, "GetByID after DeductBalance")
s.Require().InDelta(7.5, got4.Balance, 1e-6) s.Require().InDelta(7.5, got4.Balance, 1e-6)
// 透支策略:允许扣除超过余额的金额
err = s.repo.DeductBalance(s.ctx, user1.ID, 999) err = s.repo.DeductBalance(s.ctx, user1.ID, 999)
s.Require().Error(err, "DeductBalance expected error for insufficient balance") s.Require().NoError(err, "DeductBalance should allow overdraft")
s.Require().ErrorIs(err, service.ErrInsufficientBalance, "DeductBalance unexpected error") gotOverdraft, err := s.repo.GetByID(s.ctx, user1.ID)
s.Require().NoError(err, "GetByID after overdraft")
s.Require().Less(gotOverdraft.Balance, 0.0, "Balance should be negative after overdraft")
s.Require().NoError(s.repo.UpdateConcurrency(s.ctx, user1.ID, 3), "UpdateConcurrency") s.Require().NoError(s.repo.UpdateConcurrency(s.ctx, user1.ID, 3), "UpdateConcurrency")
got5, err := s.repo.GetByID(s.ctx, user1.ID) got5, err := s.repo.GetByID(s.ctx, user1.ID)
...@@ -511,6 +532,6 @@ func (s *UserRepoSuite) TestUpdateConcurrency_NotFound() { ...@@ -511,6 +532,6 @@ func (s *UserRepoSuite) TestUpdateConcurrency_NotFound() {
func (s *UserRepoSuite) TestDeductBalance_NotFound() { func (s *UserRepoSuite) TestDeductBalance_NotFound() {
err := s.repo.DeductBalance(s.ctx, 999999, 5) err := s.repo.DeductBalance(s.ctx, 999999, 5)
s.Require().Error(err, "expected error for non-existent user") s.Require().Error(err, "expected error for non-existent user")
// DeductBalance 在用户不存在时返回 ErrInsufficientBalance 因为 WHERE 条件不匹配 // DeductBalance 在用户不存在时返回 ErrUserNotFound
s.Require().ErrorIs(err, service.ErrInsufficientBalance) s.Require().ErrorIs(err, service.ErrUserNotFound)
} }
...@@ -296,13 +296,13 @@ func TestAPIContracts(t *testing.T) { ...@@ -296,13 +296,13 @@ func TestAPIContracts(t *testing.T) {
"smtp_host": "smtp.example.com", "smtp_host": "smtp.example.com",
"smtp_port": 587, "smtp_port": 587,
"smtp_username": "user", "smtp_username": "user",
"smtp_password": "secret", "smtp_password_configured": true,
"smtp_from_email": "no-reply@example.com", "smtp_from_email": "no-reply@example.com",
"smtp_from_name": "Sub2API", "smtp_from_name": "Sub2API",
"smtp_use_tls": true, "smtp_use_tls": true,
"turnstile_enabled": true, "turnstile_enabled": true,
"turnstile_site_key": "site-key", "turnstile_site_key": "site-key",
"turnstile_secret_key": "secret-key", "turnstile_secret_key_configured": true,
"site_name": "Sub2API", "site_name": "Sub2API",
"site_logo": "", "site_logo": "",
"site_subtitle": "Subtitle", "site_subtitle": "Subtitle",
...@@ -315,7 +315,9 @@ func TestAPIContracts(t *testing.T) { ...@@ -315,7 +315,9 @@ func TestAPIContracts(t *testing.T) {
"fallback_model_anthropic": "claude-3-5-sonnet-20241022", "fallback_model_anthropic": "claude-3-5-sonnet-20241022",
"fallback_model_antigravity": "gemini-2.5-pro", "fallback_model_antigravity": "gemini-2.5-pro",
"fallback_model_gemini": "gemini-2.5-pro", "fallback_model_gemini": "gemini-2.5-pro",
"fallback_model_openai": "gpt-4o" "fallback_model_openai": "gpt-4o",
"enable_identity_patch": true,
"identity_patch_prompt": ""
} }
}`, }`,
}, },
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
package server package server
import ( import (
"log"
"net/http" "net/http"
"time" "time"
...@@ -36,6 +37,15 @@ func ProvideRouter( ...@@ -36,6 +37,15 @@ func ProvideRouter(
r := gin.New() r := gin.New()
r.Use(middleware2.Recovery()) r.Use(middleware2.Recovery())
if len(cfg.Server.TrustedProxies) > 0 {
if err := r.SetTrustedProxies(cfg.Server.TrustedProxies); err != nil {
log.Printf("Failed to set trusted proxies: %v", err)
}
} else {
if err := r.SetTrustedProxies(nil); err != nil {
log.Printf("Failed to disable trusted proxies: %v", err)
}
}
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg) return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg)
} }
......
...@@ -19,6 +19,13 @@ func NewAPIKeyAuthMiddleware(apiKeyService *service.APIKeyService, subscriptionS ...@@ -19,6 +19,13 @@ func NewAPIKeyAuthMiddleware(apiKeyService *service.APIKeyService, subscriptionS
// apiKeyAuthWithSubscription API Key认证中间件(支持订阅验证) // apiKeyAuthWithSubscription API Key认证中间件(支持订阅验证)
func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc { func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
queryKey := strings.TrimSpace(c.Query("key"))
queryApiKey := strings.TrimSpace(c.Query("api_key"))
if queryKey != "" || queryApiKey != "" {
AbortWithError(c, 400, "api_key_in_query_deprecated", "API key in query parameter is deprecated. Please use Authorization header instead.")
return
}
// 尝试从Authorization header中提取API key (Bearer scheme) // 尝试从Authorization header中提取API key (Bearer scheme)
authHeader := c.GetHeader("Authorization") authHeader := c.GetHeader("Authorization")
var apiKeyString string var apiKeyString string
...@@ -41,19 +48,9 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti ...@@ -41,19 +48,9 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
apiKeyString = c.GetHeader("x-goog-api-key") apiKeyString = c.GetHeader("x-goog-api-key")
} }
// 如果header中没有,尝试从query参数中提取(Google API key风格)
if apiKeyString == "" {
apiKeyString = c.Query("key")
}
// 兼容常见别名
if apiKeyString == "" {
apiKeyString = c.Query("api_key")
}
// 如果所有header都没有API key // 如果所有header都没有API key
if apiKeyString == "" { if apiKeyString == "" {
AbortWithError(c, 401, "API_KEY_REQUIRED", "API key is required in Authorization header (Bearer scheme), x-api-key header, x-goog-api-key header, or key/api_key query parameter") AbortWithError(c, 401, "API_KEY_REQUIRED", "API key is required in Authorization header (Bearer scheme), x-api-key header, or x-goog-api-key header")
return return
} }
......
...@@ -22,6 +22,10 @@ func APIKeyAuthGoogle(apiKeyService *service.APIKeyService, cfg *config.Config) ...@@ -22,6 +22,10 @@ func APIKeyAuthGoogle(apiKeyService *service.APIKeyService, cfg *config.Config)
// It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations. // It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations.
func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc { func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
if v := strings.TrimSpace(c.Query("api_key")); v != "" {
abortWithGoogleError(c, 400, "Query parameter api_key is deprecated. Use Authorization header or key instead.")
return
}
apiKeyString := extractAPIKeyFromRequest(c) apiKeyString := extractAPIKeyFromRequest(c)
if apiKeyString == "" { if apiKeyString == "" {
abortWithGoogleError(c, 401, "API key is required") abortWithGoogleError(c, 401, "API key is required")
...@@ -116,15 +120,18 @@ func extractAPIKeyFromRequest(c *gin.Context) string { ...@@ -116,15 +120,18 @@ func extractAPIKeyFromRequest(c *gin.Context) string {
if v := strings.TrimSpace(c.GetHeader("x-goog-api-key")); v != "" { if v := strings.TrimSpace(c.GetHeader("x-goog-api-key")); v != "" {
return v return v
} }
if allowGoogleQueryKey(c.Request.URL.Path) {
if v := strings.TrimSpace(c.Query("key")); v != "" { if v := strings.TrimSpace(c.Query("key")); v != "" {
return v return v
} }
if v := strings.TrimSpace(c.Query("api_key")); v != "" {
return v
} }
return "" return ""
} }
func allowGoogleQueryKey(path string) bool {
return strings.HasPrefix(path, "/v1beta") || strings.HasPrefix(path, "/antigravity/v1beta")
}
func abortWithGoogleError(c *gin.Context, status int, message string) { func abortWithGoogleError(c *gin.Context, status int, message string) {
c.JSON(status, gin.H{ c.JSON(status, gin.H{
"error": gin.H{ "error": gin.H{
......
...@@ -109,6 +109,58 @@ func TestApiKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) { ...@@ -109,6 +109,58 @@ func TestApiKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) {
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status) require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
} }
func TestApiKeyAuthWithSubscriptionGoogle_QueryApiKeyRejected(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return nil, errors.New("should not be called")
},
})
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test?api_key=legacy", nil)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
var resp googleErrorResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, http.StatusBadRequest, resp.Error.Code)
require.Equal(t, "Query parameter api_key is deprecated. Use Authorization header or key instead.", resp.Error.Message)
require.Equal(t, "INVALID_ARGUMENT", resp.Error.Status)
}
func TestApiKeyAuthWithSubscriptionGoogle_QueryKeyAllowedOnV1Beta(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return &service.APIKey{
ID: 1,
Key: key,
Status: service.StatusActive,
User: &service.User{
ID: 123,
Status: service.StatusActive,
},
}, nil
},
})
cfg := &config.Config{RunMode: config.RunModeSimple}
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test?key=valid", nil)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
}
func TestApiKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) { func TestApiKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
......
package middleware package middleware
import ( import (
"log"
"net/http"
"strings"
"sync"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
var corsWarningOnce sync.Once
// CORS 跨域中间件 // CORS 跨域中间件
func CORS() gin.HandlerFunc { func CORS(cfg config.CORSConfig) gin.HandlerFunc {
allowedOrigins := normalizeOrigins(cfg.AllowedOrigins)
allowAll := false
for _, origin := range allowedOrigins {
if origin == "*" {
allowAll = true
break
}
}
wildcardWithSpecific := allowAll && len(allowedOrigins) > 1
if wildcardWithSpecific {
allowedOrigins = []string{"*"}
}
allowCredentials := cfg.AllowCredentials
corsWarningOnce.Do(func() {
if len(allowedOrigins) == 0 {
log.Println("Warning: CORS allowed_origins not configured; cross-origin requests will be rejected.")
}
if wildcardWithSpecific {
log.Println("Warning: CORS allowed_origins includes '*'; wildcard will take precedence over explicit origins.")
}
if allowAll && allowCredentials {
log.Println("Warning: CORS allowed_origins set to '*', disabling allow_credentials.")
}
})
if allowAll && allowCredentials {
allowCredentials = false
}
allowedSet := make(map[string]struct{}, len(allowedOrigins))
for _, origin := range allowedOrigins {
if origin == "" || origin == "*" {
continue
}
allowedSet[origin] = struct{}{}
}
return func(c *gin.Context) { return func(c *gin.Context) {
// 设置允许跨域的响应头 origin := strings.TrimSpace(c.GetHeader("Origin"))
originAllowed := allowAll
if origin != "" && !allowAll {
_, originAllowed = allowedSet[origin]
}
if originAllowed {
if allowAll {
c.Writer.Header().Set("Access-Control-Allow-Origin", "*") c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
} else if origin != "" {
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
c.Writer.Header().Add("Vary", "Origin")
}
if allowCredentials {
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true") c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
}
}
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key") c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key")
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH") c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")
// 处理预检请求 // 处理预检请求
if c.Request.Method == "OPTIONS" { if c.Request.Method == http.MethodOptions {
c.AbortWithStatus(204) if originAllowed {
c.AbortWithStatus(http.StatusNoContent)
} else {
c.AbortWithStatus(http.StatusForbidden)
}
return return
} }
c.Next() c.Next()
} }
} }
func normalizeOrigins(values []string) []string {
if len(values) == 0 {
return nil
}
normalized := make([]string, 0, len(values))
for _, value := range values {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
continue
}
normalized = append(normalized, trimmed)
}
return normalized
}
package middleware
import (
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
)
// SecurityHeaders sets baseline security headers for all responses.
func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
policy := strings.TrimSpace(cfg.Policy)
if policy == "" {
policy = config.DefaultCSPPolicy
}
return func(c *gin.Context) {
c.Header("X-Content-Type-Options", "nosniff")
c.Header("X-Frame-Options", "DENY")
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
if cfg.Enabled {
c.Header("Content-Security-Policy", policy)
}
c.Next()
}
}
...@@ -24,7 +24,8 @@ func SetupRouter( ...@@ -24,7 +24,8 @@ func SetupRouter(
) *gin.Engine { ) *gin.Engine {
// 应用中间件 // 应用中间件
r.Use(middleware2.Logger()) r.Use(middleware2.Logger())
r.Use(middleware2.CORS()) r.Use(middleware2.CORS(cfg.CORS))
r.Use(middleware2.SecurityHeaders(cfg.Security.CSP))
// Serve embedded frontend if available // Serve embedded frontend if available
if web.HasEmbeddedFrontend() { if web.HasEmbeddedFrontend() {
......
...@@ -11,6 +11,7 @@ import ( ...@@ -11,6 +11,7 @@ import (
type Account struct { type Account struct {
ID int64 ID int64
Name string Name string
Notes *string
Platform string Platform string
Type string Type string
Credentials map[string]any Credentials map[string]any
...@@ -262,6 +263,17 @@ func parseTempUnschedStrings(value any) []string { ...@@ -262,6 +263,17 @@ func parseTempUnschedStrings(value any) []string {
return out return out
} }
func normalizeAccountNotes(value *string) *string {
if value == nil {
return nil
}
trimmed := strings.TrimSpace(*value)
if trimmed == "" {
return nil
}
return &trimmed
}
func parseTempUnschedInt(value any) int { func parseTempUnschedInt(value any) int {
switch v := value.(type) { switch v := value.(type) {
case int: case int:
......
...@@ -72,6 +72,7 @@ type AccountBulkUpdate struct { ...@@ -72,6 +72,7 @@ type AccountBulkUpdate struct {
// CreateAccountRequest 创建账号请求 // CreateAccountRequest 创建账号请求
type CreateAccountRequest struct { type CreateAccountRequest struct {
Name string `json:"name"` Name string `json:"name"`
Notes *string `json:"notes"`
Platform string `json:"platform"` Platform string `json:"platform"`
Type string `json:"type"` Type string `json:"type"`
Credentials map[string]any `json:"credentials"` Credentials map[string]any `json:"credentials"`
...@@ -85,6 +86,7 @@ type CreateAccountRequest struct { ...@@ -85,6 +86,7 @@ type CreateAccountRequest struct {
// UpdateAccountRequest 更新账号请求 // UpdateAccountRequest 更新账号请求
type UpdateAccountRequest struct { type UpdateAccountRequest struct {
Name *string `json:"name"` Name *string `json:"name"`
Notes *string `json:"notes"`
Credentials *map[string]any `json:"credentials"` Credentials *map[string]any `json:"credentials"`
Extra *map[string]any `json:"extra"` Extra *map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"` ProxyID *int64 `json:"proxy_id"`
...@@ -123,6 +125,7 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) ( ...@@ -123,6 +125,7 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
// 创建账号 // 创建账号
account := &Account{ account := &Account{
Name: req.Name, Name: req.Name,
Notes: normalizeAccountNotes(req.Notes),
Platform: req.Platform, Platform: req.Platform,
Type: req.Type, Type: req.Type,
Credentials: req.Credentials, Credentials: req.Credentials,
...@@ -194,6 +197,9 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount ...@@ -194,6 +197,9 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount
if req.Name != nil { if req.Name != nil {
account.Name = *req.Name account.Name = *req.Name
} }
if req.Notes != nil {
account.Notes = normalizeAccountNotes(req.Notes)
}
if req.Credentials != nil { if req.Credentials != nil {
account.Credentials = *req.Credentials account.Credentials = *req.Credentials
......
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
"crypto/rand" "crypto/rand"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"log" "log"
...@@ -14,9 +15,11 @@ import ( ...@@ -14,9 +15,11 @@ import (
"regexp" "regexp"
"strings" "strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
) )
...@@ -45,6 +48,7 @@ type AccountTestService struct { ...@@ -45,6 +48,7 @@ type AccountTestService struct {
geminiTokenProvider *GeminiTokenProvider geminiTokenProvider *GeminiTokenProvider
antigravityGatewayService *AntigravityGatewayService antigravityGatewayService *AntigravityGatewayService
httpUpstream HTTPUpstream httpUpstream HTTPUpstream
cfg *config.Config
} }
// NewAccountTestService creates a new AccountTestService // NewAccountTestService creates a new AccountTestService
...@@ -53,15 +57,35 @@ func NewAccountTestService( ...@@ -53,15 +57,35 @@ func NewAccountTestService(
geminiTokenProvider *GeminiTokenProvider, geminiTokenProvider *GeminiTokenProvider,
antigravityGatewayService *AntigravityGatewayService, antigravityGatewayService *AntigravityGatewayService,
httpUpstream HTTPUpstream, httpUpstream HTTPUpstream,
cfg *config.Config,
) *AccountTestService { ) *AccountTestService {
return &AccountTestService{ return &AccountTestService{
accountRepo: accountRepo, accountRepo: accountRepo,
geminiTokenProvider: geminiTokenProvider, geminiTokenProvider: geminiTokenProvider,
antigravityGatewayService: antigravityGatewayService, antigravityGatewayService: antigravityGatewayService,
httpUpstream: httpUpstream, httpUpstream: httpUpstream,
cfg: cfg,
} }
} }
func (s *AccountTestService) validateUpstreamBaseURL(raw string) (string, error) {
if s.cfg == nil {
return "", errors.New("config is not available")
}
if !s.cfg.Security.URLAllowlist.Enabled {
return urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
}
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
RequireAllowlist: true,
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
return "", err
}
return normalized, nil
}
// generateSessionString generates a Claude Code style session string // generateSessionString generates a Claude Code style session string
func generateSessionString() (string, error) { func generateSessionString() (string, error) {
bytes := make([]byte, 32) bytes := make([]byte, 32)
...@@ -183,11 +207,15 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account ...@@ -183,11 +207,15 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
return s.sendErrorAndEnd(c, "No API key available") return s.sendErrorAndEnd(c, "No API key available")
} }
apiURL = account.GetBaseURL() baseURL := account.GetBaseURL()
if apiURL == "" { if baseURL == "" {
apiURL = "https://api.anthropic.com" baseURL = "https://api.anthropic.com"
} }
apiURL = strings.TrimSuffix(apiURL, "/") + "/v1/messages" normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
}
apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/messages"
} else { } else {
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type)) return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
} }
...@@ -300,7 +328,11 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account ...@@ -300,7 +328,11 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
if baseURL == "" { if baseURL == "" {
baseURL = "https://api.openai.com" baseURL = "https://api.openai.com"
} }
apiURL = strings.TrimSuffix(baseURL, "/") + "/responses" normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
}
apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/responses"
} else { } else {
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type)) return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
} }
...@@ -480,10 +512,14 @@ func (s *AccountTestService) buildGeminiAPIKeyRequest(ctx context.Context, accou ...@@ -480,10 +512,14 @@ func (s *AccountTestService) buildGeminiAPIKeyRequest(ctx context.Context, accou
if baseURL == "" { if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL baseURL = geminicli.AIStudioBaseURL
} }
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
// Use streamGenerateContent for real-time feedback // Use streamGenerateContent for real-time feedback
fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse",
strings.TrimRight(baseURL, "/"), modelID) strings.TrimRight(normalizedBaseURL, "/"), modelID)
req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(payload)) req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(payload))
if err != nil { if err != nil {
...@@ -515,7 +551,11 @@ func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, accoun ...@@ -515,7 +551,11 @@ func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, accoun
if strings.TrimSpace(baseURL) == "" { if strings.TrimSpace(baseURL) == "" {
baseURL = geminicli.AIStudioBaseURL baseURL = geminicli.AIStudioBaseURL
} }
fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", strings.TrimRight(baseURL, "/"), modelID) normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", strings.TrimRight(normalizedBaseURL, "/"), modelID)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(payload)) req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(payload))
if err != nil { if err != nil {
...@@ -544,7 +584,11 @@ func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessT ...@@ -544,7 +584,11 @@ func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessT
} }
wrappedBytes, _ := json.Marshal(wrapped) wrappedBytes, _ := json.Marshal(wrapped)
fullURL := fmt.Sprintf("%s/v1internal:streamGenerateContent?alt=sse", geminicli.GeminiCliBaseURL) normalizedBaseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
if err != nil {
return nil, err
}
fullURL := fmt.Sprintf("%s/v1internal:streamGenerateContent?alt=sse", normalizedBaseURL)
req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(wrappedBytes)) req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(wrappedBytes))
if err != nil { if err != nil {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment