Commit bd4bf008 authored by yangjianbo's avatar yangjianbo
Browse files

feat(安全): 强化安全策略与配置校验

- 增加 CORS/CSP/安全响应头与代理信任配置

- 引入 URL 白名单与私网开关,校验上游与价格源

- 改善 API Key 处理与网关错误返回

- 管理端设置隐藏敏感字段并优化前端提示

- 增加计费熔断与相关配置示例

测试: go test ./...
parent 3fd9bd4a
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
"crypto/rand" "crypto/rand"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"log" "log"
...@@ -15,9 +16,11 @@ import ( ...@@ -15,9 +16,11 @@ import (
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
) )
...@@ -49,6 +52,7 @@ type AccountTestService struct { ...@@ -49,6 +52,7 @@ type AccountTestService struct {
geminiTokenProvider *GeminiTokenProvider geminiTokenProvider *GeminiTokenProvider
antigravityGatewayService *AntigravityGatewayService antigravityGatewayService *AntigravityGatewayService
httpUpstream HTTPUpstream httpUpstream HTTPUpstream
cfg *config.Config
} }
// NewAccountTestService creates a new AccountTestService // NewAccountTestService creates a new AccountTestService
...@@ -59,6 +63,7 @@ func NewAccountTestService( ...@@ -59,6 +63,7 @@ func NewAccountTestService(
geminiTokenProvider *GeminiTokenProvider, geminiTokenProvider *GeminiTokenProvider,
antigravityGatewayService *AntigravityGatewayService, antigravityGatewayService *AntigravityGatewayService,
httpUpstream HTTPUpstream, httpUpstream HTTPUpstream,
cfg *config.Config,
) *AccountTestService { ) *AccountTestService {
return &AccountTestService{ return &AccountTestService{
accountRepo: accountRepo, accountRepo: accountRepo,
...@@ -67,9 +72,25 @@ func NewAccountTestService( ...@@ -67,9 +72,25 @@ func NewAccountTestService(
geminiTokenProvider: geminiTokenProvider, geminiTokenProvider: geminiTokenProvider,
antigravityGatewayService: antigravityGatewayService, antigravityGatewayService: antigravityGatewayService,
httpUpstream: httpUpstream, httpUpstream: httpUpstream,
cfg: cfg,
} }
} }
func (s *AccountTestService) validateUpstreamBaseURL(raw string) (string, error) {
if s.cfg == nil {
return "", errors.New("config is not available")
}
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
RequireAllowlist: true,
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
return "", err
}
return normalized, nil
}
// generateSessionString generates a Claude Code style session string // generateSessionString generates a Claude Code style session string
func generateSessionString() (string, error) { func generateSessionString() (string, error) {
bytes := make([]byte, 32) bytes := make([]byte, 32)
...@@ -207,11 +228,15 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account ...@@ -207,11 +228,15 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
return s.sendErrorAndEnd(c, "No API key available") return s.sendErrorAndEnd(c, "No API key available")
} }
apiURL = account.GetBaseURL() baseURL := account.GetBaseURL()
if apiURL == "" { if baseURL == "" {
apiURL = "https://api.anthropic.com" baseURL = "https://api.anthropic.com"
} }
apiURL = strings.TrimSuffix(apiURL, "/") + "/v1/messages" normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
}
apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/messages"
} else { } else {
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type)) return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
} }
...@@ -333,7 +358,11 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account ...@@ -333,7 +358,11 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
if baseURL == "" { if baseURL == "" {
baseURL = "https://api.openai.com" baseURL = "https://api.openai.com"
} }
apiURL = strings.TrimSuffix(baseURL, "/") + "/responses" normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
}
apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/responses"
} else { } else {
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type)) return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
} }
...@@ -513,10 +542,14 @@ func (s *AccountTestService) buildGeminiAPIKeyRequest(ctx context.Context, accou ...@@ -513,10 +542,14 @@ func (s *AccountTestService) buildGeminiAPIKeyRequest(ctx context.Context, accou
if baseURL == "" { if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL baseURL = geminicli.AIStudioBaseURL
} }
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
// Use streamGenerateContent for real-time feedback // Use streamGenerateContent for real-time feedback
fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse",
strings.TrimRight(baseURL, "/"), modelID) strings.TrimRight(normalizedBaseURL, "/"), modelID)
req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(payload)) req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(payload))
if err != nil { if err != nil {
...@@ -548,7 +581,11 @@ func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, accoun ...@@ -548,7 +581,11 @@ func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, accoun
if strings.TrimSpace(baseURL) == "" { if strings.TrimSpace(baseURL) == "" {
baseURL = geminicli.AIStudioBaseURL baseURL = geminicli.AIStudioBaseURL
} }
fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", strings.TrimRight(baseURL, "/"), modelID) normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", strings.TrimRight(normalizedBaseURL, "/"), modelID)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(payload)) req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(payload))
if err != nil { if err != nil {
...@@ -577,7 +614,11 @@ func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessT ...@@ -577,7 +614,11 @@ func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessT
} }
wrappedBytes, _ := json.Marshal(wrapped) wrappedBytes, _ := json.Marshal(wrapped)
fullURL := fmt.Sprintf("%s/v1internal:streamGenerateContent?alt=sse", geminicli.GeminiCliBaseURL) normalizedBaseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
if err != nil {
return nil, err
}
fullURL := fmt.Sprintf("%s/v1internal:streamGenerateContent?alt=sse", normalizedBaseURL)
req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(wrappedBytes)) req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(wrappedBytes))
if err != nil { if err != nil {
......
...@@ -221,9 +221,33 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S ...@@ -221,9 +221,33 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
// VerifyTurnstile 验证Turnstile token // VerifyTurnstile 验证Turnstile token
func (s *AuthService) VerifyTurnstile(ctx context.Context, token string, remoteIP string) error { func (s *AuthService) VerifyTurnstile(ctx context.Context, token string, remoteIP string) error {
required := s.cfg != nil && s.cfg.Server.Mode == "release" && s.cfg.Turnstile.Required
if required {
if s.settingService == nil {
log.Println("[Auth] Turnstile required but settings service is not configured")
return ErrTurnstileNotConfigured
}
enabled := s.settingService.IsTurnstileEnabled(ctx)
secretConfigured := s.settingService.GetTurnstileSecretKey(ctx) != ""
if !enabled || !secretConfigured {
log.Printf("[Auth] Turnstile required but not configured (enabled=%v, secret_configured=%v)", enabled, secretConfigured)
return ErrTurnstileNotConfigured
}
}
if s.turnstileService == nil { if s.turnstileService == nil {
if required {
log.Println("[Auth] Turnstile required but service not configured")
return ErrTurnstileNotConfigured
}
return nil // 服务未配置则跳过验证 return nil // 服务未配置则跳过验证
} }
if !required && s.settingService != nil && s.settingService.IsTurnstileEnabled(ctx) && s.settingService.GetTurnstileSecretKey(ctx) == "" {
log.Println("[Auth] Turnstile enabled but secret key not configured")
}
return s.turnstileService.VerifyToken(ctx, token, remoteIP) return s.turnstileService.VerifyToken(ctx, token, remoteIP)
} }
......
...@@ -17,6 +17,7 @@ import ( ...@@ -17,6 +17,7 @@ import (
// 注:ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义 // 注:ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义
var ( var (
ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired") ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
ErrBillingServiceUnavailable = infraerrors.ServiceUnavailable("BILLING_SERVICE_ERROR", "Billing service temporarily unavailable. Please retry later.")
) )
// subscriptionCacheData 订阅缓存数据结构(内部使用) // subscriptionCacheData 订阅缓存数据结构(内部使用)
...@@ -76,6 +77,7 @@ type BillingCacheService struct { ...@@ -76,6 +77,7 @@ type BillingCacheService struct {
userRepo UserRepository userRepo UserRepository
subRepo UserSubscriptionRepository subRepo UserSubscriptionRepository
cfg *config.Config cfg *config.Config
circuitBreaker *billingCircuitBreaker
cacheWriteChan chan cacheWriteTask cacheWriteChan chan cacheWriteTask
cacheWriteWg sync.WaitGroup cacheWriteWg sync.WaitGroup
...@@ -95,6 +97,7 @@ func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo ...@@ -95,6 +97,7 @@ func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo
subRepo: subRepo, subRepo: subRepo,
cfg: cfg, cfg: cfg,
} }
svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker)
svc.startCacheWriteWorkers() svc.startCacheWriteWorkers()
return svc return svc
} }
...@@ -450,6 +453,9 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user ...@@ -450,6 +453,9 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
if s.cfg.RunMode == config.RunModeSimple { if s.cfg.RunMode == config.RunModeSimple {
return nil return nil
} }
if s.circuitBreaker != nil && !s.circuitBreaker.Allow() {
return ErrBillingServiceUnavailable
}
// 判断计费模式 // 判断计费模式
isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil
...@@ -465,9 +471,14 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user ...@@ -465,9 +471,14 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
func (s *BillingCacheService) checkBalanceEligibility(ctx context.Context, userID int64) error { func (s *BillingCacheService) checkBalanceEligibility(ctx context.Context, userID int64) error {
balance, err := s.GetUserBalance(ctx, userID) balance, err := s.GetUserBalance(ctx, userID)
if err != nil { if err != nil {
// 缓存/数据库错误,允许通过(降级处理) if s.circuitBreaker != nil {
log.Printf("Warning: get user balance failed, allowing request: %v", err) s.circuitBreaker.OnFailure(err)
return nil }
log.Printf("ALERT: billing balance check failed for user %d: %v", userID, err)
return ErrBillingServiceUnavailable.WithCause(err)
}
if s.circuitBreaker != nil {
s.circuitBreaker.OnSuccess()
} }
if balance <= 0 { if balance <= 0 {
...@@ -482,9 +493,14 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context, ...@@ -482,9 +493,14 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
// 获取订阅缓存数据 // 获取订阅缓存数据
subData, err := s.GetSubscriptionStatus(ctx, userID, group.ID) subData, err := s.GetSubscriptionStatus(ctx, userID, group.ID)
if err != nil { if err != nil {
// 缓存/数据库错误,降级使用传入的subscription进行检查 if s.circuitBreaker != nil {
log.Printf("Warning: get subscription cache failed, using fallback: %v", err) s.circuitBreaker.OnFailure(err)
return s.checkSubscriptionLimitsFallback(subscription, group) }
log.Printf("ALERT: billing subscription check failed for user %d group %d: %v", userID, group.ID, err)
return ErrBillingServiceUnavailable.WithCause(err)
}
if s.circuitBreaker != nil {
s.circuitBreaker.OnSuccess()
} }
// 检查订阅状态 // 检查订阅状态
...@@ -513,6 +529,137 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context, ...@@ -513,6 +529,137 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
return nil return nil
} }
type billingCircuitBreakerState int
const (
billingCircuitClosed billingCircuitBreakerState = iota
billingCircuitOpen
billingCircuitHalfOpen
)
type billingCircuitBreaker struct {
mu sync.Mutex
state billingCircuitBreakerState
failures int
openedAt time.Time
failureThreshold int
resetTimeout time.Duration
halfOpenRequests int
halfOpenRemaining int
}
func newBillingCircuitBreaker(cfg config.CircuitBreakerConfig) *billingCircuitBreaker {
if !cfg.Enabled {
return nil
}
resetTimeout := time.Duration(cfg.ResetTimeoutSeconds) * time.Second
if resetTimeout <= 0 {
resetTimeout = 30 * time.Second
}
halfOpen := cfg.HalfOpenRequests
if halfOpen <= 0 {
halfOpen = 1
}
threshold := cfg.FailureThreshold
if threshold <= 0 {
threshold = 5
}
return &billingCircuitBreaker{
state: billingCircuitClosed,
failureThreshold: threshold,
resetTimeout: resetTimeout,
halfOpenRequests: halfOpen,
}
}
func (b *billingCircuitBreaker) Allow() bool {
b.mu.Lock()
defer b.mu.Unlock()
switch b.state {
case billingCircuitClosed:
return true
case billingCircuitOpen:
if time.Since(b.openedAt) < b.resetTimeout {
return false
}
b.state = billingCircuitHalfOpen
b.halfOpenRemaining = b.halfOpenRequests
log.Printf("ALERT: billing circuit breaker entering half-open state")
fallthrough
case billingCircuitHalfOpen:
if b.halfOpenRemaining <= 0 {
return false
}
b.halfOpenRemaining--
return true
default:
return false
}
}
func (b *billingCircuitBreaker) OnFailure(err error) {
if b == nil {
return
}
b.mu.Lock()
defer b.mu.Unlock()
switch b.state {
case billingCircuitOpen:
return
case billingCircuitHalfOpen:
b.state = billingCircuitOpen
b.openedAt = time.Now()
b.halfOpenRemaining = 0
log.Printf("ALERT: billing circuit breaker opened after half-open failure: %v", err)
return
default:
b.failures++
if b.failures >= b.failureThreshold {
b.state = billingCircuitOpen
b.openedAt = time.Now()
b.halfOpenRemaining = 0
log.Printf("ALERT: billing circuit breaker opened after %d failures: %v", b.failures, err)
}
}
}
func (b *billingCircuitBreaker) OnSuccess() {
if b == nil {
return
}
b.mu.Lock()
defer b.mu.Unlock()
previousState := b.state
previousFailures := b.failures
b.state = billingCircuitClosed
b.failures = 0
b.halfOpenRemaining = 0
// 只有状态真正发生变化时才记录日志
if previousState != billingCircuitClosed {
log.Printf("ALERT: billing circuit breaker closed (was %s)", circuitStateString(previousState))
} else if previousFailures > 0 {
log.Printf("INFO: billing circuit breaker failures reset from %d", previousFailures)
}
}
func circuitStateString(state billingCircuitBreakerState) string {
switch state {
case billingCircuitClosed:
return "closed"
case billingCircuitOpen:
return "open"
case billingCircuitHalfOpen:
return "half-open"
default:
return "unknown"
}
}
// checkSubscriptionLimitsFallback 降级检查订阅限额 // checkSubscriptionLimitsFallback 降级检查订阅限额
func (s *BillingCacheService) checkSubscriptionLimitsFallback(subscription *UserSubscription, group *Group) error { func (s *BillingCacheService) checkSubscriptionLimitsFallback(subscription *UserSubscription, group *Group) error {
if subscription == nil { if subscription == nil {
......
...@@ -8,12 +8,13 @@ import ( ...@@ -8,12 +8,13 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/url"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
) )
type CRSSyncService struct { type CRSSyncService struct {
...@@ -22,6 +23,7 @@ type CRSSyncService struct { ...@@ -22,6 +23,7 @@ type CRSSyncService struct {
oauthService *OAuthService oauthService *OAuthService
openaiOAuthService *OpenAIOAuthService openaiOAuthService *OpenAIOAuthService
geminiOAuthService *GeminiOAuthService geminiOAuthService *GeminiOAuthService
cfg *config.Config
} }
func NewCRSSyncService( func NewCRSSyncService(
...@@ -30,6 +32,7 @@ func NewCRSSyncService( ...@@ -30,6 +32,7 @@ func NewCRSSyncService(
oauthService *OAuthService, oauthService *OAuthService,
openaiOAuthService *OpenAIOAuthService, openaiOAuthService *OpenAIOAuthService,
geminiOAuthService *GeminiOAuthService, geminiOAuthService *GeminiOAuthService,
cfg *config.Config,
) *CRSSyncService { ) *CRSSyncService {
return &CRSSyncService{ return &CRSSyncService{
accountRepo: accountRepo, accountRepo: accountRepo,
...@@ -37,6 +40,7 @@ func NewCRSSyncService( ...@@ -37,6 +40,7 @@ func NewCRSSyncService(
oauthService: oauthService, oauthService: oauthService,
openaiOAuthService: openaiOAuthService, openaiOAuthService: openaiOAuthService,
geminiOAuthService: geminiOAuthService, geminiOAuthService: geminiOAuthService,
cfg: cfg,
} }
} }
...@@ -187,7 +191,10 @@ type crsGeminiAPIKeyAccount struct { ...@@ -187,7 +191,10 @@ type crsGeminiAPIKeyAccount struct {
} }
func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) { func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) {
baseURL, err := normalizeBaseURL(input.BaseURL) if s.cfg == nil {
return nil, errors.New("config is not available")
}
baseURL, err := normalizeBaseURL(input.BaseURL, s.cfg.Security.URLAllowlist.CRSHosts, s.cfg.Security.URLAllowlist.AllowPrivateHosts)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -1055,17 +1062,18 @@ func mapCRSStatus(isActive bool, status string) string { ...@@ -1055,17 +1062,18 @@ func mapCRSStatus(isActive bool, status string) string {
return "active" return "active"
} }
func normalizeBaseURL(raw string) (string, error) { func normalizeBaseURL(raw string, allowlist []string, allowPrivate bool) (string, error) {
trimmed := strings.TrimSpace(raw) // 当 allowlist 为空时,不强制要求白名单(只进行基本的 URL 和 SSRF 验证)
if trimmed == "" { requireAllowlist := len(allowlist) > 0
return "", errors.New("base_url is required") normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
} AllowedHosts: allowlist,
u, err := url.Parse(trimmed) RequireAllowlist: requireAllowlist,
if err != nil || u.Scheme == "" || u.Host == "" { AllowPrivate: allowPrivate,
return "", fmt.Errorf("invalid base_url: %s", trimmed) })
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
} }
u.Path = strings.TrimRight(u.Path, "/") return normalized, nil
return strings.TrimRight(u.String(), "/"), nil
} }
// cleanBaseURL removes trailing suffix from base_url in credentials // cleanBaseURL removes trailing suffix from base_url in credentials
......
...@@ -19,6 +19,8 @@ import ( ...@@ -19,6 +19,8 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
...@@ -724,7 +726,13 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex ...@@ -724,7 +726,13 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
targetURL := claudeAPIURL targetURL := claudeAPIURL
if account.Type == AccountTypeApiKey { if account.Type == AccountTypeApiKey {
baseURL := account.GetBaseURL() baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages" if baseURL != "" {
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
targetURL = validatedURL + "/v1/messages"
}
} }
// OAuth账号:应用统一指纹 // OAuth账号:应用统一指纹
...@@ -1107,12 +1115,7 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h ...@@ -1107,12 +1115,7 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
body = s.replaceModelInResponseBody(body, mappedModel, originalModel) body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
} }
// 透传响应头 responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
// 写入响应 // 写入响应
c.Data(resp.StatusCode, "application/json", body) c.Data(resp.StatusCode, "application/json", body)
...@@ -1352,7 +1355,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con ...@@ -1352,7 +1355,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
targetURL := claudeAPICountTokensURL targetURL := claudeAPICountTokensURL
if account.Type == AccountTypeApiKey { if account.Type == AccountTypeApiKey {
baseURL := account.GetBaseURL() baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages/count_tokens" if baseURL != "" {
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
targetURL = validatedURL + "/v1/messages/count_tokens"
}
} }
// OAuth 账号:应用统一指纹和重写 userID // OAuth 账号:应用统一指纹和重写 userID
...@@ -1424,3 +1433,15 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m ...@@ -1424,3 +1433,15 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m
}, },
}) })
} }
func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
RequireAllowlist: true,
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
}
return normalized, nil
}
...@@ -18,9 +18,12 @@ import ( ...@@ -18,9 +18,12 @@ import (
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
...@@ -41,6 +44,7 @@ type GeminiMessagesCompatService struct { ...@@ -41,6 +44,7 @@ type GeminiMessagesCompatService struct {
rateLimitService *RateLimitService rateLimitService *RateLimitService
httpUpstream HTTPUpstream httpUpstream HTTPUpstream
antigravityGatewayService *AntigravityGatewayService antigravityGatewayService *AntigravityGatewayService
cfg *config.Config
} }
func NewGeminiMessagesCompatService( func NewGeminiMessagesCompatService(
...@@ -51,6 +55,7 @@ func NewGeminiMessagesCompatService( ...@@ -51,6 +55,7 @@ func NewGeminiMessagesCompatService(
rateLimitService *RateLimitService, rateLimitService *RateLimitService,
httpUpstream HTTPUpstream, httpUpstream HTTPUpstream,
antigravityGatewayService *AntigravityGatewayService, antigravityGatewayService *AntigravityGatewayService,
cfg *config.Config,
) *GeminiMessagesCompatService { ) *GeminiMessagesCompatService {
return &GeminiMessagesCompatService{ return &GeminiMessagesCompatService{
accountRepo: accountRepo, accountRepo: accountRepo,
...@@ -60,6 +65,7 @@ func NewGeminiMessagesCompatService( ...@@ -60,6 +65,7 @@ func NewGeminiMessagesCompatService(
rateLimitService: rateLimitService, rateLimitService: rateLimitService,
httpUpstream: httpUpstream, httpUpstream: httpUpstream,
antigravityGatewayService: antigravityGatewayService, antigravityGatewayService: antigravityGatewayService,
cfg: cfg,
} }
} }
...@@ -209,6 +215,18 @@ func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *Antigravit ...@@ -209,6 +215,18 @@ func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *Antigravit
return s.antigravityGatewayService return s.antigravityGatewayService
} }
func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) {
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
RequireAllowlist: true,
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
}
return normalized, nil
}
// HasAntigravityAccounts 检查是否有可用的 antigravity 账户 // HasAntigravityAccounts 检查是否有可用的 antigravity 账户
func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context, groupID *int64) (bool, error) { func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context, groupID *int64) (bool, error) {
var accounts []Account var accounts []Account
...@@ -360,16 +378,20 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex ...@@ -360,16 +378,20 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
return nil, "", errors.New("gemini api_key not configured") return nil, "", errors.New("gemini api_key not configured")
} }
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/") baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" { if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL baseURL = geminicli.AIStudioBaseURL
} }
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
}
action := "generateContent" action := "generateContent"
if req.Stream { if req.Stream {
action = "streamGenerateContent" action = "streamGenerateContent"
} }
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(baseURL, "/"), mappedModel, action) fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action)
if req.Stream { if req.Stream {
fullURL += "?alt=sse" fullURL += "?alt=sse"
} }
...@@ -406,7 +428,11 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex ...@@ -406,7 +428,11 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token) // 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
if projectID != "" { if projectID != "" {
// Mode 1: Code Assist API // Mode 1: Code Assist API
fullURL := fmt.Sprintf("%s/v1internal:%s", geminicli.GeminiCliBaseURL, action) baseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1internal:%s", strings.TrimRight(baseURL, "/"), action)
if useUpstreamStream { if useUpstreamStream {
fullURL += "?alt=sse" fullURL += "?alt=sse"
} }
...@@ -432,12 +458,16 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex ...@@ -432,12 +458,16 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
return upstreamReq, "x-request-id", nil return upstreamReq, "x-request-id", nil
} else { } else {
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token) // Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/") baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" { if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL baseURL = geminicli.AIStudioBaseURL
} }
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", baseURL, mappedModel, action) fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action)
if useUpstreamStream { if useUpstreamStream {
fullURL += "?alt=sse" fullURL += "?alt=sse"
} }
...@@ -622,12 +652,16 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. ...@@ -622,12 +652,16 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return nil, "", errors.New("gemini api_key not configured") return nil, "", errors.New("gemini api_key not configured")
} }
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/") baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" { if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL baseURL = geminicli.AIStudioBaseURL
} }
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(baseURL, "/"), mappedModel, upstreamAction) fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, upstreamAction)
if useUpstreamStream { if useUpstreamStream {
fullURL += "?alt=sse" fullURL += "?alt=sse"
} }
...@@ -659,7 +693,11 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. ...@@ -659,7 +693,11 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token) // 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
if projectID != "" && !forceAIStudio { if projectID != "" && !forceAIStudio {
// Mode 1: Code Assist API // Mode 1: Code Assist API
fullURL := fmt.Sprintf("%s/v1internal:%s", geminicli.GeminiCliBaseURL, upstreamAction) baseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1internal:%s", strings.TrimRight(baseURL, "/"), upstreamAction)
if useUpstreamStream { if useUpstreamStream {
fullURL += "?alt=sse" fullURL += "?alt=sse"
} }
...@@ -685,12 +723,16 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. ...@@ -685,12 +723,16 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return upstreamReq, "x-request-id", nil return upstreamReq, "x-request-id", nil
} else { } else {
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token) // Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/") baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" { if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL baseURL = geminicli.AIStudioBaseURL
} }
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", baseURL, mappedModel, upstreamAction) fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, upstreamAction)
if useUpstreamStream { if useUpstreamStream {
fullURL += "?alt=sse" fullURL += "?alt=sse"
} }
...@@ -1608,6 +1650,8 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co ...@@ -1608,6 +1650,8 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
_ = json.Unmarshal(respBody, &parsed) _ = json.Unmarshal(respBody, &parsed)
} }
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
contentType := resp.Header.Get("Content-Type") contentType := resp.Header.Get("Content-Type")
if contentType == "" { if contentType == "" {
contentType = "application/json" contentType = "application/json"
...@@ -1720,11 +1764,15 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac ...@@ -1720,11 +1764,15 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
return nil, errors.New("invalid path") return nil, errors.New("invalid path")
} }
baseURL := strings.TrimRight(account.GetCredential("base_url"), "/") baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" { if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL baseURL = geminicli.AIStudioBaseURL
} }
fullURL := strings.TrimRight(baseURL, "/") + path normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
fullURL := strings.TrimRight(normalizedBaseURL, "/") + path
var proxyURL string var proxyURL string
if account.ProxyID != nil && account.Proxy != nil { if account.ProxyID != nil && account.Proxy != nil {
...@@ -1763,9 +1811,14 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac ...@@ -1763,9 +1811,14 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
defer func() { _ = resp.Body.Close() }() defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<20)) body, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
wwwAuthenticate := resp.Header.Get("Www-Authenticate")
filteredHeaders := responseheaders.FilterHeaders(resp.Header, s.cfg.Security.ResponseHeaders)
if wwwAuthenticate != "" {
filteredHeaders.Set("Www-Authenticate", wwwAuthenticate)
}
return &UpstreamHTTPResult{ return &UpstreamHTTPResult{
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
Headers: resp.Header.Clone(), Headers: filteredHeaders,
Body: body, Body: body,
}, nil }, nil
} }
......
...@@ -18,6 +18,8 @@ import ( ...@@ -18,6 +18,8 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
...@@ -370,10 +372,14 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. ...@@ -370,10 +372,14 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
case AccountTypeApiKey: case AccountTypeApiKey:
// API Key accounts use Platform API or custom base URL // API Key accounts use Platform API or custom base URL
baseURL := account.GetOpenAIBaseURL() baseURL := account.GetOpenAIBaseURL()
if baseURL != "" { if baseURL == "" {
targetURL = baseURL + "/responses"
} else {
targetURL = openaiPlatformAPIURL targetURL = openaiPlatformAPIURL
} else {
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
targetURL = validatedURL + "/responses"
} }
default: default:
targetURL = openaiPlatformAPIURL targetURL = openaiPlatformAPIURL
...@@ -645,18 +651,25 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r ...@@ -645,18 +651,25 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
body = s.replaceModelInResponseBody(body, mappedModel, originalModel) body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
} }
// Pass through headers responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
c.Data(resp.StatusCode, "application/json", body) c.Data(resp.StatusCode, "application/json", body)
return usage, nil return usage, nil
} }
func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, error) {
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
RequireAllowlist: true,
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
}
return normalized, nil
}
func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte { func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
var resp map[string]any var resp map[string]any
if err := json.Unmarshal(body, &resp); err != nil { if err := json.Unmarshal(body, &resp); err != nil {
......
...@@ -16,6 +16,7 @@ import ( ...@@ -16,6 +16,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
) )
var ( var (
...@@ -211,16 +212,35 @@ func (s *PricingService) syncWithRemote() error { ...@@ -211,16 +212,35 @@ func (s *PricingService) syncWithRemote() error {
// downloadPricingData 从远程下载价格数据 // downloadPricingData 从远程下载价格数据
func (s *PricingService) downloadPricingData() error { func (s *PricingService) downloadPricingData() error {
log.Printf("[Pricing] Downloading from %s", s.cfg.Pricing.RemoteURL) remoteURL, err := s.validatePricingURL(s.cfg.Pricing.RemoteURL)
if err != nil {
return err
}
log.Printf("[Pricing] Downloading from %s", remoteURL)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel() defer cancel()
body, err := s.remoteClient.FetchPricingJSON(ctx, s.cfg.Pricing.RemoteURL) var expectedHash string
if strings.TrimSpace(s.cfg.Pricing.HashURL) != "" {
expectedHash, err = s.fetchRemoteHash()
if err != nil {
return fmt.Errorf("fetch remote hash: %w", err)
}
}
body, err := s.remoteClient.FetchPricingJSON(ctx, remoteURL)
if err != nil { if err != nil {
return fmt.Errorf("download failed: %w", err) return fmt.Errorf("download failed: %w", err)
} }
if expectedHash != "" {
actualHash := sha256.Sum256(body)
if !strings.EqualFold(expectedHash, hex.EncodeToString(actualHash[:])) {
return fmt.Errorf("pricing hash mismatch")
}
}
// 解析JSON数据(使用灵活的解析方式) // 解析JSON数据(使用灵活的解析方式)
data, err := s.parsePricingData(body) data, err := s.parsePricingData(body)
if err != nil { if err != nil {
...@@ -373,10 +393,31 @@ func (s *PricingService) useFallbackPricing() error { ...@@ -373,10 +393,31 @@ func (s *PricingService) useFallbackPricing() error {
// fetchRemoteHash 从远程获取哈希值 // fetchRemoteHash 从远程获取哈希值
func (s *PricingService) fetchRemoteHash() (string, error) { func (s *PricingService) fetchRemoteHash() (string, error) {
hashURL, err := s.validatePricingURL(s.cfg.Pricing.HashURL)
if err != nil {
return "", err
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
return s.remoteClient.FetchHashText(ctx, s.cfg.Pricing.HashURL) hash, err := s.remoteClient.FetchHashText(ctx, hashURL)
if err != nil {
return "", err
}
return strings.TrimSpace(hash), nil
}
func (s *PricingService) validatePricingURL(raw string) (string, error) {
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.PricingHosts,
RequireAllowlist: true,
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
return "", fmt.Errorf("invalid pricing url: %w", err)
}
return normalized, nil
} }
// computeFileHash 计算文件哈希 // computeFileHash 计算文件哈希
......
...@@ -215,8 +215,10 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin ...@@ -215,8 +215,10 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
SmtpFrom: settings[SettingKeySmtpFrom], SmtpFrom: settings[SettingKeySmtpFrom],
SmtpFromName: settings[SettingKeySmtpFromName], SmtpFromName: settings[SettingKeySmtpFromName],
SmtpUseTLS: settings[SettingKeySmtpUseTLS] == "true", SmtpUseTLS: settings[SettingKeySmtpUseTLS] == "true",
SmtpPasswordConfigured: settings[SettingKeySmtpPassword] != "",
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true", TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey], TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
TurnstileSecretKeyConfigured: settings[SettingKeyTurnstileSecretKey] != "",
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"), SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
SiteLogo: settings[SettingKeySiteLogo], SiteLogo: settings[SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"), SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
...@@ -245,10 +247,6 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin ...@@ -245,10 +247,6 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result.DefaultBalance = s.cfg.Default.UserBalance result.DefaultBalance = s.cfg.Default.UserBalance
} }
// 敏感信息直接返回,方便测试连接时使用
result.SmtpPassword = settings[SettingKeySmtpPassword]
result.TurnstileSecretKey = settings[SettingKeyTurnstileSecretKey]
return result return result
} }
......
...@@ -8,6 +8,7 @@ type SystemSettings struct { ...@@ -8,6 +8,7 @@ type SystemSettings struct {
SmtpPort int SmtpPort int
SmtpUsername string SmtpUsername string
SmtpPassword string SmtpPassword string
SmtpPasswordConfigured bool
SmtpFrom string SmtpFrom string
SmtpFromName string SmtpFromName string
SmtpUseTLS bool SmtpUseTLS bool
...@@ -15,6 +16,7 @@ type SystemSettings struct { ...@@ -15,6 +16,7 @@ type SystemSettings struct {
TurnstileEnabled bool TurnstileEnabled bool
TurnstileSiteKey string TurnstileSiteKey string
TurnstileSecretKey string TurnstileSecretKey string
TurnstileSecretKeyConfigured bool
SiteName string SiteName string
SiteLogo string SiteLogo string
......
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"log" "log"
"os" "os"
"strconv" "strconv"
"strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/infrastructure" "github.com/Wei-Shaw/sub2api/internal/infrastructure"
...@@ -196,11 +197,17 @@ func Install(cfg *SetupConfig) error { ...@@ -196,11 +197,17 @@ func Install(cfg *SetupConfig) error {
// Generate JWT secret if not provided // Generate JWT secret if not provided
if cfg.JWT.Secret == "" { if cfg.JWT.Secret == "" {
if strings.EqualFold(cfg.Server.Mode, "release") {
return fmt.Errorf("jwt secret is required in release mode")
}
secret, err := generateSecret(32) secret, err := generateSecret(32)
if err != nil { if err != nil {
return fmt.Errorf("failed to generate jwt secret: %w", err) return fmt.Errorf("failed to generate jwt secret: %w", err)
} }
cfg.JWT.Secret = secret cfg.JWT.Secret = secret
log.Println("Warning: JWT secret auto-generated for non-release mode. Do not use in production.")
} else if strings.EqualFold(cfg.Server.Mode, "release") && len(cfg.JWT.Secret) < 32 {
return fmt.Errorf("jwt secret must be at least 32 characters in release mode")
} }
// Test connections // Test connections
...@@ -474,12 +481,17 @@ func AutoSetupFromEnv() error { ...@@ -474,12 +481,17 @@ func AutoSetupFromEnv() error {
// Generate JWT secret if not provided // Generate JWT secret if not provided
if cfg.JWT.Secret == "" { if cfg.JWT.Secret == "" {
if strings.EqualFold(cfg.Server.Mode, "release") {
return fmt.Errorf("jwt secret is required in release mode")
}
secret, err := generateSecret(32) secret, err := generateSecret(32)
if err != nil { if err != nil {
return fmt.Errorf("failed to generate jwt secret: %w", err) return fmt.Errorf("failed to generate jwt secret: %w", err)
} }
cfg.JWT.Secret = secret cfg.JWT.Secret = secret
log.Println("Generated JWT secret automatically") log.Println("Warning: JWT secret auto-generated for non-release mode. Do not use in production.")
} else if strings.EqualFold(cfg.Server.Mode, "release") && len(cfg.JWT.Secret) < 32 {
return fmt.Errorf("jwt secret must be at least 32 characters in release mode")
} }
// Generate admin password if not provided // Generate admin password if not provided
...@@ -489,8 +501,8 @@ func AutoSetupFromEnv() error { ...@@ -489,8 +501,8 @@ func AutoSetupFromEnv() error {
return fmt.Errorf("failed to generate admin password: %w", err) return fmt.Errorf("failed to generate admin password: %w", err)
} }
cfg.Admin.Password = password cfg.Admin.Password = password
log.Printf("Generated admin password: %s", cfg.Admin.Password) fmt.Printf("Generated admin password (one-time): %s\n", cfg.Admin.Password)
log.Println("IMPORTANT: Save this password! It will not be shown again.") fmt.Println("IMPORTANT: Save this password! It will not be shown again.")
} }
// Test database connection // Test database connection
......
package logredact
import (
"encoding/json"
"strings"
)
// maxRedactDepth 限制递归深度以防止栈溢出
const maxRedactDepth = 32
var defaultSensitiveKeys = map[string]struct{}{
"authorization_code": {},
"code": {},
"code_verifier": {},
"access_token": {},
"refresh_token": {},
"id_token": {},
"client_secret": {},
"password": {},
}
func RedactMap(input map[string]any, extraKeys ...string) map[string]any {
if input == nil {
return map[string]any{}
}
keys := buildKeySet(extraKeys)
redacted, ok := redactValueWithDepth(input, keys, 0).(map[string]any)
if !ok {
return map[string]any{}
}
return redacted
}
func RedactJSON(raw []byte, extraKeys ...string) string {
if len(raw) == 0 {
return ""
}
var value any
if err := json.Unmarshal(raw, &value); err != nil {
return "<non-json payload redacted>"
}
keys := buildKeySet(extraKeys)
redacted := redactValueWithDepth(value, keys, 0)
encoded, err := json.Marshal(redacted)
if err != nil {
return "<redacted>"
}
return string(encoded)
}
func buildKeySet(extraKeys []string) map[string]struct{} {
keys := make(map[string]struct{}, len(defaultSensitiveKeys)+len(extraKeys))
for k := range defaultSensitiveKeys {
keys[k] = struct{}{}
}
for _, key := range extraKeys {
normalized := normalizeKey(key)
if normalized == "" {
continue
}
keys[normalized] = struct{}{}
}
return keys
}
func redactValueWithDepth(value any, keys map[string]struct{}, depth int) any {
if depth > maxRedactDepth {
return "<depth limit exceeded>"
}
switch v := value.(type) {
case map[string]any:
out := make(map[string]any, len(v))
for k, val := range v {
if isSensitiveKey(k, keys) {
out[k] = "***"
continue
}
out[k] = redactValueWithDepth(val, keys, depth+1)
}
return out
case []any:
out := make([]any, len(v))
for i, item := range v {
out[i] = redactValueWithDepth(item, keys, depth+1)
}
return out
default:
return value
}
}
func isSensitiveKey(key string, keys map[string]struct{}) bool {
_, ok := keys[normalizeKey(key)]
return ok
}
func normalizeKey(key string) string {
return strings.ToLower(strings.TrimSpace(key))
}
package responseheaders
import (
"net/http"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
)
// defaultAllowed 定义允许透传的响应头白名单
// 注意:以下头部由 Go HTTP 包自动处理,不应手动设置:
// - content-length: 由 ResponseWriter 根据实际写入数据自动设置
// - transfer-encoding: 由 HTTP 库根据需要自动添加/移除
// - connection: 由 HTTP 库管理连接复用
var defaultAllowed = map[string]struct{}{
"content-type": {},
"content-encoding": {},
"content-language": {},
"cache-control": {},
"etag": {},
"last-modified": {},
"expires": {},
"vary": {},
"date": {},
"x-request-id": {},
"x-ratelimit-limit-requests": {},
"x-ratelimit-limit-tokens": {},
"x-ratelimit-remaining-requests": {},
"x-ratelimit-remaining-tokens": {},
"x-ratelimit-reset-requests": {},
"x-ratelimit-reset-tokens": {},
"retry-after": {},
"location": {},
}
// hopByHopHeaders 是跳过的 hop-by-hop 头部,这些头部由 HTTP 库自动处理
var hopByHopHeaders = map[string]struct{}{
"content-length": {},
"transfer-encoding": {},
"connection": {},
}
func FilterHeaders(src http.Header, cfg config.ResponseHeaderConfig) http.Header {
allowed := make(map[string]struct{}, len(defaultAllowed)+len(cfg.AdditionalAllowed))
for key := range defaultAllowed {
allowed[key] = struct{}{}
}
for _, key := range cfg.AdditionalAllowed {
normalized := strings.ToLower(strings.TrimSpace(key))
if normalized == "" {
continue
}
allowed[normalized] = struct{}{}
}
forceRemove := make(map[string]struct{}, len(cfg.ForceRemove))
for _, key := range cfg.ForceRemove {
normalized := strings.ToLower(strings.TrimSpace(key))
if normalized == "" {
continue
}
forceRemove[normalized] = struct{}{}
}
filtered := make(http.Header, len(src))
for key, values := range src {
lower := strings.ToLower(key)
if _, blocked := forceRemove[lower]; blocked {
continue
}
if _, ok := allowed[lower]; !ok {
continue
}
// 跳过 hop-by-hop 头部,这些由 HTTP 库自动处理
if _, isHopByHop := hopByHopHeaders[lower]; isHopByHop {
continue
}
for _, value := range values {
filtered.Add(key, value)
}
}
return filtered
}
func WriteFilteredHeaders(dst http.Header, src http.Header, cfg config.ResponseHeaderConfig) {
filtered := FilterHeaders(src, cfg)
for key, values := range filtered {
for _, value := range values {
dst.Add(key, value)
}
}
}
package urlvalidator
import (
"context"
"errors"
"fmt"
"net"
"net/url"
"strings"
"time"
)
type ValidationOptions struct {
AllowedHosts []string
RequireAllowlist bool
AllowPrivate bool
}
func ValidateHTTPSURL(raw string, opts ValidationOptions) (string, error) {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return "", errors.New("url is required")
}
parsed, err := url.Parse(trimmed)
if err != nil || parsed.Scheme == "" || parsed.Host == "" {
return "", fmt.Errorf("invalid url: %s", trimmed)
}
if !strings.EqualFold(parsed.Scheme, "https") {
return "", fmt.Errorf("invalid url scheme: %s", parsed.Scheme)
}
host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
if host == "" {
return "", errors.New("invalid host")
}
if !opts.AllowPrivate && isBlockedHost(host) {
return "", fmt.Errorf("host is not allowed: %s", host)
}
allowlist := normalizeAllowlist(opts.AllowedHosts)
if opts.RequireAllowlist && len(allowlist) == 0 {
return "", errors.New("allowlist is not configured")
}
if len(allowlist) > 0 && !isAllowedHost(host, allowlist) {
return "", fmt.Errorf("host is not allowed: %s", host)
}
parsed.Path = strings.TrimRight(parsed.Path, "/")
parsed.RawPath = ""
return strings.TrimRight(parsed.String(), "/"), nil
}
// ValidateResolvedIP 验证 DNS 解析后的 IP 地址是否安全
// 用于防止 DNS Rebinding 攻击:在实际 HTTP 请求时调用此函数验证解析后的 IP
func ValidateResolvedIP(host string) error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
ips, err := net.DefaultResolver.LookupIP(ctx, "ip", host)
if err != nil {
return fmt.Errorf("dns resolution failed: %w", err)
}
for _, ip := range ips {
if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() ||
ip.IsLinkLocalMulticast() || ip.IsUnspecified() {
return fmt.Errorf("resolved ip %s is not allowed", ip.String())
}
}
return nil
}
func normalizeAllowlist(values []string) []string {
if len(values) == 0 {
return nil
}
normalized := make([]string, 0, len(values))
for _, v := range values {
entry := strings.ToLower(strings.TrimSpace(v))
if entry == "" {
continue
}
if host, _, err := net.SplitHostPort(entry); err == nil {
entry = host
}
normalized = append(normalized, entry)
}
return normalized
}
func isAllowedHost(host string, allowlist []string) bool {
for _, entry := range allowlist {
if entry == "" {
continue
}
if strings.HasPrefix(entry, "*.") {
suffix := strings.TrimPrefix(entry, "*.")
if host == suffix || strings.HasSuffix(host, "."+suffix) {
return true
}
continue
}
if host == entry {
return true
}
}
return false
}
func isBlockedHost(host string) bool {
if host == "localhost" || strings.HasSuffix(host, ".localhost") {
return true
}
if ip := net.ParseIP(host); ip != nil {
if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsUnspecified() {
return true
}
}
return false
}
...@@ -12,6 +12,8 @@ server: ...@@ -12,6 +12,8 @@ server:
port: 8080 port: 8080
# Mode: "debug" for development, "release" for production # Mode: "debug" for development, "release" for production
mode: "release" mode: "release"
# Trusted proxies for X-Forwarded-For parsing (CIDR/IP). Empty disables trusted proxies.
trusted_proxies: []
# ============================================================================= # =============================================================================
# Run Mode Configuration # Run Mode Configuration
...@@ -21,6 +23,48 @@ server: ...@@ -21,6 +23,48 @@ server:
# - simple: Hides SaaS features and skips billing/balance checks # - simple: Hides SaaS features and skips billing/balance checks
run_mode: "standard" run_mode: "standard"
# =============================================================================
# CORS Configuration
# =============================================================================
cors:
# Allowed origins list. Leave empty to disable cross-origin requests.
allowed_origins: []
# Allow credentials (cookies/authorization headers). Cannot be used with "*".
allow_credentials: true
# =============================================================================
# Security Configuration
# =============================================================================
security:
url_allowlist:
# Allowed upstream hosts for API proxying
upstream_hosts:
- "api.openai.com"
- "api.anthropic.com"
- "generativelanguage.googleapis.com"
- "cloudcode-pa.googleapis.com"
- "*.openai.azure.com"
# Allowed hosts for pricing data download
pricing_hosts:
- "raw.githubusercontent.com"
# Allowed hosts for CRS sync (required when using CRS sync)
crs_hosts: []
# Allow localhost/private IPs for upstream/pricing/CRS (use only in trusted networks)
allow_private_hosts: false
response_headers:
# Extra allowed response headers from upstream
additional_allowed: []
# Force-remove response headers from upstream
force_remove: []
csp:
# Enable Content-Security-Policy header
enabled: true
# Default CSP policy (override if you host assets on other domains)
policy: "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' data:; connect-src 'self' https:; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
proxy_probe:
# Allow skipping TLS verification for proxy probe (debug only)
insecure_skip_verify: false
# ============================================================================= # =============================================================================
# 网关配置 # 网关配置
# ============================================================================= # =============================================================================
...@@ -77,7 +121,7 @@ jwt: ...@@ -77,7 +121,7 @@ jwt:
# IMPORTANT: Change this to a random string in production! # IMPORTANT: Change this to a random string in production!
# Generate with: openssl rand -hex 32 # Generate with: openssl rand -hex 32
secret: "change-this-to-a-secure-random-string" secret: "change-this-to-a-secure-random-string"
# Token expiration time in hours # Token expiration time in hours (max 24)
expire_hour: 24 expire_hour: 24
# ============================================================================= # =============================================================================
...@@ -122,6 +166,23 @@ pricing: ...@@ -122,6 +166,23 @@ pricing:
# Hash check interval in minutes # Hash check interval in minutes
hash_check_interval_minutes: 10 hash_check_interval_minutes: 10
# =============================================================================
# Billing Configuration
# =============================================================================
billing:
circuit_breaker:
enabled: true
failure_threshold: 5
reset_timeout_seconds: 30
half_open_requests: 3
# =============================================================================
# Turnstile Configuration
# =============================================================================
turnstile:
# Require Turnstile in release mode (when enabled, login/register will fail if not configured)
required: false
# ============================================================================= # =============================================================================
# Gemini OAuth (Required for Gemini accounts) # Gemini OAuth (Required for Gemini accounts)
# ============================================================================= # =============================================================================
......
...@@ -26,14 +26,37 @@ export interface SystemSettings { ...@@ -26,14 +26,37 @@ export interface SystemSettings {
smtp_host: string smtp_host: string
smtp_port: number smtp_port: number
smtp_username: string smtp_username: string
smtp_password: string smtp_password_configured: boolean
smtp_from_email: string smtp_from_email: string
smtp_from_name: string smtp_from_name: string
smtp_use_tls: boolean smtp_use_tls: boolean
// Cloudflare Turnstile settings // Cloudflare Turnstile settings
turnstile_enabled: boolean turnstile_enabled: boolean
turnstile_site_key: string turnstile_site_key: string
turnstile_secret_key: string turnstile_secret_key_configured: boolean
}
export interface UpdateSettingsRequest {
registration_enabled?: boolean
email_verify_enabled?: boolean
default_balance?: number
default_concurrency?: number
site_name?: string
site_logo?: string
site_subtitle?: string
api_base_url?: string
contact_info?: string
doc_url?: string
smtp_host?: string
smtp_port?: number
smtp_username?: string
smtp_password?: string
smtp_from_email?: string
smtp_from_name?: string
smtp_use_tls?: boolean
turnstile_enabled?: boolean
turnstile_site_key?: string
turnstile_secret_key?: string
} }
/** /**
...@@ -50,7 +73,7 @@ export async function getSettings(): Promise<SystemSettings> { ...@@ -50,7 +73,7 @@ export async function getSettings(): Promise<SystemSettings> {
* @param settings - Partial settings to update * @param settings - Partial settings to update
* @returns Updated settings * @returns Updated settings
*/ */
export async function updateSettings(settings: Partial<SystemSettings>): Promise<SystemSettings> { export async function updateSettings(settings: UpdateSettingsRequest): Promise<SystemSettings> {
const { data } = await apiClient.put<SystemSettings>('/admin/settings', settings) const { data } = await apiClient.put<SystemSettings>('/admin/settings', settings)
return data return data
} }
......
...@@ -62,8 +62,24 @@ apiClient.interceptors.response.use( ...@@ -62,8 +62,24 @@ apiClient.interceptors.response.use(
// 401: Unauthorized - clear token and redirect to login // 401: Unauthorized - clear token and redirect to login
if (status === 401) { if (status === 401) {
const hasToken = !!localStorage.getItem('auth_token')
const url = error.config?.url || ''
const isAuthEndpoint =
url.includes('/auth/login') || url.includes('/auth/register') || url.includes('/auth/refresh')
const headers = error.config?.headers as Record<string, unknown> | undefined
const authHeader = headers?.Authorization ?? headers?.authorization
const sentAuth =
typeof authHeader === 'string'
? authHeader.trim() !== ''
: Array.isArray(authHeader)
? authHeader.length > 0
: !!authHeader
localStorage.removeItem('auth_token') localStorage.removeItem('auth_token')
localStorage.removeItem('auth_user') localStorage.removeItem('auth_user')
if ((hasToken || sentAuth) && !isAuthEndpoint) {
sessionStorage.setItem('auth_expired', '1')
}
// Only redirect if not already on login page // Only redirect if not already on login page
if (!window.location.pathname.includes('/login')) { if (!window.location.pathname.includes('/login')) {
window.location.href = '/login' window.location.href = '/login'
......
...@@ -136,16 +136,16 @@ ...@@ -136,16 +136,16 @@
<ol <ol
class="list-inside list-decimal space-y-1 text-xs text-amber-700 dark:text-amber-300" class="list-inside list-decimal space-y-1 text-xs text-amber-700 dark:text-amber-300"
> >
<li v-html="t('admin.accounts.oauth.step1')"></li> <li>{{ t('admin.accounts.oauth.step1') }}</li>
<li v-html="t('admin.accounts.oauth.step2')"></li> <li>{{ t('admin.accounts.oauth.step2') }}</li>
<li v-html="t('admin.accounts.oauth.step3')"></li> <li>{{ t('admin.accounts.oauth.step3') }}</li>
<li v-html="t('admin.accounts.oauth.step4')"></li> <li>{{ t('admin.accounts.oauth.step4') }}</li>
<li v-html="t('admin.accounts.oauth.step5')"></li> <li>{{ t('admin.accounts.oauth.step5') }}</li>
<li v-html="t('admin.accounts.oauth.step6')"></li> <li>{{ t('admin.accounts.oauth.step6') }}</li>
</ol> </ol>
<p <p
class="mt-2 text-xs text-amber-600 dark:text-amber-400" class="mt-2 text-xs text-amber-600 dark:text-amber-400"
v-html="t('admin.accounts.oauth.sessionKeyFormat')" v-text="t('admin.accounts.oauth.sessionKeyFormat')"
></p> ></p>
</div> </div>
...@@ -390,7 +390,7 @@ ...@@ -390,7 +390,7 @@
> >
<p <p
class="text-xs text-amber-800 dark:text-amber-300" class="text-xs text-amber-800 dark:text-amber-300"
v-html="oauthImportantNotice" v-text="oauthImportantNotice"
></p> ></p>
</div> </div>
<!-- Proxy Warning (for non-OpenAI) --> <!-- Proxy Warning (for non-OpenAI) -->
...@@ -400,7 +400,7 @@ ...@@ -400,7 +400,7 @@
> >
<p <p
class="text-xs text-yellow-800 dark:text-yellow-300" class="text-xs text-yellow-800 dark:text-yellow-300"
v-html="t('admin.accounts.oauth.proxyWarning')" v-text="t('admin.accounts.oauth.proxyWarning')"
></p> ></p>
</div> </div>
</div> </div>
...@@ -423,7 +423,7 @@ ...@@ -423,7 +423,7 @@
</p> </p>
<p <p
class="mb-3 text-sm text-blue-700 dark:text-blue-300" class="mb-3 text-sm text-blue-700 dark:text-blue-300"
v-html="oauthAuthCodeDesc" v-text="oauthAuthCodeDesc"
></p> ></p>
<div> <div>
<label class="input-label"> <label class="input-label">
......
...@@ -85,7 +85,7 @@ ...@@ -85,7 +85,7 @@
</button> </button>
</div> </div>
<!-- Code Content --> <!-- Code Content -->
<pre class="p-4 text-sm font-mono text-gray-100 overflow-x-auto"><code v-html="file.highlighted"></code></pre> <pre class="p-4 text-sm font-mono text-gray-100 overflow-x-auto"><code v-text="file.content"></code></pre>
</div> </div>
</div> </div>
</div> </div>
...@@ -142,7 +142,6 @@ interface TabConfig { ...@@ -142,7 +142,6 @@ interface TabConfig {
interface FileConfig { interface FileConfig {
path: string path: string
content: string content: string
highlighted: string
hint?: string // Optional hint message for this file hint?: string // Optional hint message for this file
} }
...@@ -227,13 +226,6 @@ const platformNote = computed(() => { ...@@ -227,13 +226,6 @@ const platformNote = computed(() => {
}) })
// Syntax highlighting helpers // Syntax highlighting helpers
const keyword = (text: string) => `<span class="text-purple-400">${text}</span>`
const variable = (text: string) => `<span class="text-cyan-400">${text}</span>`
const string = (text: string) => `<span class="text-green-400">${text}</span>`
const operator = (text: string) => `<span class="text-yellow-400">${text}</span>`
const comment = (text: string) => `<span class="text-gray-500">${text}</span>`
const key = (text: string) => `<span class="text-blue-400">${text}</span>`
// Generate file configs based on platform and active tab // Generate file configs based on platform and active tab
const currentFiles = computed((): FileConfig[] => { const currentFiles = computed((): FileConfig[] => {
const baseUrl = props.baseUrl || window.location.origin const baseUrl = props.baseUrl || window.location.origin
...@@ -249,37 +241,29 @@ const currentFiles = computed((): FileConfig[] => { ...@@ -249,37 +241,29 @@ const currentFiles = computed((): FileConfig[] => {
function generateAnthropicFiles(baseUrl: string, apiKey: string): FileConfig[] { function generateAnthropicFiles(baseUrl: string, apiKey: string): FileConfig[] {
let path: string let path: string
let content: string let content: string
let highlighted: string
switch (activeTab.value) { switch (activeTab.value) {
case 'unix': case 'unix':
path = 'Terminal' path = 'Terminal'
content = `export ANTHROPIC_BASE_URL="${baseUrl}" content = `export ANTHROPIC_BASE_URL="${baseUrl}"
export ANTHROPIC_AUTH_TOKEN="${apiKey}"` export ANTHROPIC_AUTH_TOKEN="${apiKey}"`
highlighted = `${keyword('export')} ${variable('ANTHROPIC_BASE_URL')}${operator('=')}${string(`"${baseUrl}"`)}
${keyword('export')} ${variable('ANTHROPIC_AUTH_TOKEN')}${operator('=')}${string(`"${apiKey}"`)}`
break break
case 'cmd': case 'cmd':
path = 'Command Prompt' path = 'Command Prompt'
content = `set ANTHROPIC_BASE_URL=${baseUrl} content = `set ANTHROPIC_BASE_URL=${baseUrl}
set ANTHROPIC_AUTH_TOKEN=${apiKey}` set ANTHROPIC_AUTH_TOKEN=${apiKey}`
highlighted = `${keyword('set')} ${variable('ANTHROPIC_BASE_URL')}${operator('=')}${baseUrl}
${keyword('set')} ${variable('ANTHROPIC_AUTH_TOKEN')}${operator('=')}${apiKey}`
break break
case 'powershell': case 'powershell':
path = 'PowerShell' path = 'PowerShell'
content = `$env:ANTHROPIC_BASE_URL="${baseUrl}" content = `$env:ANTHROPIC_BASE_URL="${baseUrl}"
$env:ANTHROPIC_AUTH_TOKEN="${apiKey}"` $env:ANTHROPIC_AUTH_TOKEN="${apiKey}"`
highlighted = `${keyword('$env:')}${variable('ANTHROPIC_BASE_URL')}${operator('=')}${string(`"${baseUrl}"`)}
${keyword('$env:')}${variable('ANTHROPIC_AUTH_TOKEN')}${operator('=')}${string(`"${apiKey}"`)}`
break break
default: default:
path = 'Terminal' path = 'Terminal'
content = '' content = ''
highlighted = ''
} }
return [{ path, content, highlighted }] return [{ path, content }]
} }
function generateOpenAIFiles(baseUrl: string, apiKey: string): FileConfig[] { function generateOpenAIFiles(baseUrl: string, apiKey: string): FileConfig[] {
...@@ -301,40 +285,20 @@ base_url = "${baseUrl}" ...@@ -301,40 +285,20 @@ base_url = "${baseUrl}"
wire_api = "responses" wire_api = "responses"
requires_openai_auth = true` requires_openai_auth = true`
const configHighlighted = `${key('model_provider')} ${operator('=')} ${string('"sub2api"')}
${key('model')} ${operator('=')} ${string('"gpt-5.2-codex"')}
${key('model_reasoning_effort')} ${operator('=')} ${string('"high"')}
${key('network_access')} ${operator('=')} ${string('"enabled"')}
${key('disable_response_storage')} ${operator('=')} ${keyword('true')}
${key('windows_wsl_setup_acknowledged')} ${operator('=')} ${keyword('true')}
${key('model_verbosity')} ${operator('=')} ${string('"high"')}
${comment('[model_providers.sub2api]')}
${key('name')} ${operator('=')} ${string('"sub2api"')}
${key('base_url')} ${operator('=')} ${string(`"${baseUrl}"`)}
${key('wire_api')} ${operator('=')} ${string('"responses"')}
${key('requires_openai_auth')} ${operator('=')} ${keyword('true')}`
// auth.json content // auth.json content
const authContent = `{ const authContent = `{
"OPENAI_API_KEY": "${apiKey}" "OPENAI_API_KEY": "${apiKey}"
}` }`
const authHighlighted = `{
${key('"OPENAI_API_KEY"')}: ${string(`"${apiKey}"`)}
}`
return [ return [
{ {
path: `${configDir}/config.toml`, path: `${configDir}/config.toml`,
content: configContent, content: configContent,
highlighted: configHighlighted,
hint: t('keys.useKeyModal.openai.configTomlHint') hint: t('keys.useKeyModal.openai.configTomlHint')
}, },
{ {
path: `${configDir}/auth.json`, path: `${configDir}/auth.json`,
content: authContent, content: authContent
highlighted: authHighlighted
} }
] ]
} }
......
...@@ -63,6 +63,7 @@ ...@@ -63,6 +63,7 @@
<script setup lang="ts"> <script setup lang="ts">
import { ref, computed, onMounted } from 'vue' import { ref, computed, onMounted } from 'vue'
import { getPublicSettings } from '@/api/auth' import { getPublicSettings } from '@/api/auth'
import { sanitizeUrl } from '@/utils/url'
const siteName = ref('Sub2API') const siteName = ref('Sub2API')
const siteLogo = ref('') const siteLogo = ref('')
...@@ -74,7 +75,7 @@ onMounted(async () => { ...@@ -74,7 +75,7 @@ onMounted(async () => {
try { try {
const settings = await getPublicSettings() const settings = await getPublicSettings()
siteName.value = settings.site_name || 'Sub2API' siteName.value = settings.site_name || 'Sub2API'
siteLogo.value = settings.site_logo || '' siteLogo.value = sanitizeUrl(settings.site_logo || '', { allowRelative: true })
siteSubtitle.value = settings.site_subtitle || 'Subscription to API Conversion Platform' siteSubtitle.value = settings.site_subtitle || 'Subscription to API Conversion Platform'
} catch (error) { } catch (error) {
console.error('Failed to load public settings:', error) console.error('Failed to load public settings:', error)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment